more cleaning of conv
This commit is contained in:
@@ -25,6 +25,7 @@ int main() {
|
||||
int32_t M = B*RD*RH*RW;
|
||||
int32_t N = NF;
|
||||
int32_t K = NC*T*R*S;
|
||||
// convolution configuration
|
||||
std::vector<float> hc(B*RH*RW*NF);
|
||||
std::vector<float> rc(B*RH*RW*NF);
|
||||
std::vector<float> ha(B*NC*H*W);
|
||||
@@ -57,8 +58,9 @@ int main() {
|
||||
int32_t stride_o_k = RD*stride_o_m;
|
||||
int32_t stride_o_n = NF*stride_o_k;
|
||||
// look-up table
|
||||
triton::dnn::conv configuration(B, NC, H, W, R, S, NF, 1, 1, 0, 0);
|
||||
std::vector<int> h_delta, h_masks;
|
||||
triton::dnn::conv::init_cst(stride_i_d, stride_i_h, stride_i_w, stride_i_c, pad_d, pad_h, pad_w, T, R, S, h_delta, h_masks);
|
||||
configuration.build_lut(h_delta, h_masks);
|
||||
// benchmark a given convolution kernel
|
||||
auto benchmark = [&](triton::driver::kernel* kernel,
|
||||
triton::jit::launch_information info) {
|
||||
|
Reference in New Issue
Block a user