[general] creation of dnn module for gemm/conv triton routines
This commit is contained in:
@@ -266,24 +266,23 @@ torch::Tensor conv_forward(
|
||||
kernel->setArg(6, B);
|
||||
kernel->setArg(7, H);
|
||||
kernel->setArg(8, W);
|
||||
kernel->setArg(9, B);
|
||||
kernel->setArg(10, NF);
|
||||
kernel->setArg(11, P);
|
||||
kernel->setArg(12, Q);
|
||||
kernel->setArg(13, Ci);
|
||||
kernel->setArg(14, R);
|
||||
kernel->setArg(15, S);
|
||||
kernel->setArg(16, stride_i_n);
|
||||
kernel->setArg(17, stride_i_c);
|
||||
kernel->setArg(18, stride_i_h);
|
||||
kernel->setArg(19, stride_i_w);
|
||||
kernel->setArg(20, stride_o_n);
|
||||
kernel->setArg(21, stride_o_k);
|
||||
kernel->setArg(22, stride_o_p);
|
||||
kernel->setArg(23, stride_o_q);
|
||||
kernel->setArg(24, pad_h);
|
||||
kernel->setArg(25, pad_w);
|
||||
kernel->setArg(26, bound);
|
||||
kernel->setArg(9, NF);
|
||||
kernel->setArg(10, P);
|
||||
kernel->setArg(11, Q);
|
||||
kernel->setArg(12, Ci);
|
||||
kernel->setArg(13, R);
|
||||
kernel->setArg(14, S);
|
||||
kernel->setArg(15, stride_i_n);
|
||||
kernel->setArg(16, stride_i_c);
|
||||
kernel->setArg(17, stride_i_h);
|
||||
kernel->setArg(18, stride_i_w);
|
||||
kernel->setArg(19, stride_o_n);
|
||||
kernel->setArg(20, stride_o_k);
|
||||
kernel->setArg(21, stride_o_p);
|
||||
kernel->setArg(22, stride_o_q);
|
||||
kernel->setArg(23, pad_h);
|
||||
kernel->setArg(24, pad_w);
|
||||
kernel->setArg(25, bound);
|
||||
// // dry run
|
||||
stream->enqueue(kernel, grid, {nthreads, 1, 1});
|
||||
return output;
|
||||
|
Reference in New Issue
Block a user