[general] creation of dnn module for gemm/conv triton routines

This commit is contained in:
Philippe Tillet
2019-05-06 17:47:06 -04:00
parent f80441017c
commit fd91368f98
8 changed files with 430 additions and 297 deletions

View File

@@ -266,24 +266,23 @@ torch::Tensor conv_forward(
kernel->setArg(6, B);
kernel->setArg(7, H);
kernel->setArg(8, W);
kernel->setArg(9, B);
kernel->setArg(10, NF);
kernel->setArg(11, P);
kernel->setArg(12, Q);
kernel->setArg(13, Ci);
kernel->setArg(14, R);
kernel->setArg(15, S);
kernel->setArg(16, stride_i_n);
kernel->setArg(17, stride_i_c);
kernel->setArg(18, stride_i_h);
kernel->setArg(19, stride_i_w);
kernel->setArg(20, stride_o_n);
kernel->setArg(21, stride_o_k);
kernel->setArg(22, stride_o_p);
kernel->setArg(23, stride_o_q);
kernel->setArg(24, pad_h);
kernel->setArg(25, pad_w);
kernel->setArg(26, bound);
kernel->setArg(9, NF);
kernel->setArg(10, P);
kernel->setArg(11, Q);
kernel->setArg(12, Ci);
kernel->setArg(13, R);
kernel->setArg(14, S);
kernel->setArg(15, stride_i_n);
kernel->setArg(16, stride_i_c);
kernel->setArg(17, stride_i_h);
kernel->setArg(18, stride_i_w);
kernel->setArg(19, stride_o_n);
kernel->setArg(20, stride_o_k);
kernel->setArg(21, stride_o_p);
kernel->setArg(22, stride_o_q);
kernel->setArg(23, pad_h);
kernel->setArg(24, pad_w);
kernel->setArg(25, bound);
// // dry run
stream->enqueue(kernel, grid, {nthreads, 1, 1});
return output;