more cleaning of conv

This commit is contained in:
Philippe Tillet
2019-05-06 19:30:22 -04:00
parent fd91368f98
commit 615569287e
2 changed files with 120 additions and 57 deletions

View File

@@ -25,6 +25,7 @@ int main() {
int32_t M = B*RD*RH*RW;
int32_t N = NF;
int32_t K = NC*T*R*S;
// convolution configuration
std::vector<float> hc(B*RH*RW*NF);
std::vector<float> rc(B*RH*RW*NF);
std::vector<float> ha(B*NC*H*W);
@@ -57,8 +58,9 @@ int main() {
int32_t stride_o_k = RD*stride_o_m;
int32_t stride_o_n = NF*stride_o_k;
// look-up table
triton::dnn::conv configuration(B, NC, H, W, R, S, NF, 1, 1, 0, 0);
std::vector<int> h_delta, h_masks;
triton::dnn::conv::init_cst(stride_i_d, stride_i_h, stride_i_w, stride_i_c, pad_d, pad_h, pad_w, T, R, S, h_delta, h_masks);
configuration.build_lut(h_delta, h_masks);
// benchmark a given convolution kernel
auto benchmark = [&](triton::driver::kernel* kernel,
triton::jit::launch_information info) {