[dnn/conv] Added bias and forward stride

This commit is contained in:
Philippe Tillet
2019-05-20 12:20:29 -04:00
parent f33a1f3fe3
commit e8f23bcade
8 changed files with 303 additions and 199 deletions

View File

@@ -10,13 +10,15 @@ int main() {
// initialize default compute device
auto context = triton::driver::backend::contexts::get_default();
triton::jit jit(context);
triton::dnn::conv::type ty = triton::dnn::conv::BPROP;
triton::dnn::conv::type ty = triton::dnn::conv::FPROP;
// initialization
int32_t B = 4, NF = 32;
int32_t D = 1, H = 56, W = 56;
int32_t NC = 32, T = 1, R = 3, S = 3;
int32_t pad_d = 0, pad_h = 0, pad_w = 0;
triton::dnn::conv configuration(B, NC, D, H, W, T, R, S, NF, 1, 1, 1, pad_d, pad_h, pad_w, ty);
int32_t stride_d = 1, stride_h = 1, stride_w = 1;
int32_t upsample_d = 1, upsample_h = 1, upsample_w = 1;
triton::dnn::conv configuration(B, NC, D, H, W, T, R, S, NF, stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, upsample_d, upsample_h, upsample_w, ty);
// convolution configuration
std::vector<float> hc(configuration.c_size());
std::vector<float> rc(configuration.c_size());
@@ -47,7 +49,7 @@ int main() {
std::array<size_t, 3> grid = configuration.get_grid(TM, TN);
configuration.init(stream, (triton::driver::cu_module*)kernel->module());
stream->synchronize();
configuration.set_arg(kernel, da, db, dc);
configuration.set_arg(kernel, da, db, dc, nullptr);
stream->enqueue(kernel, grid, {nthreads, 1, 1});
stream->synchronize();
double ts = bench([&](){stream->enqueue(kernel, grid, {nthreads, 1, 1});},