[dnn] better specification of recompilation key

This commit is contained in:
Philippe Tillet
2019-08-02 17:42:48 -07:00
parent 3b92ddf7e6
commit d9945692a9
31 changed files with 418 additions and 428 deletions

View File

@@ -37,7 +37,7 @@ std::vector<torch::Tensor>
triton::driver::cu_buffer m(ctx, (CUdeviceptr)fw_m.storage().data(), false);
triton::driver::cu_buffer v(ctx, (CUdeviceptr)fw_v.storage().data(), false);
// create template
triton::dnn::batchnorm_forward batchnorm(C, 1, H, W, B, "fp32");
triton::dnn::batchnorm_forward batchnorm(C, 1, H, W, B, "float");
batchnorm.enqueue(&stream, {&y, &m, &v, &x, &g, &b});
stream.synchronize();
return {fw_y, fw_m, fw_v};
@@ -79,7 +79,7 @@ std::vector<torch::Tensor>
triton::driver::cu_buffer dg(ctx, (CUdeviceptr)fw_dg.storage().data(), false);
triton::driver::cu_buffer db(ctx, (CUdeviceptr)fw_db.storage().data(), false);
// create config
triton::dnn::batchnorm_backward batchnorm(C, 1, H, W, B, "fp32", eps);
triton::dnn::batchnorm_backward batchnorm(C, 1, H, W, B, "float", eps);
batchnorm.enqueue(&stream, {&dx, &dg, &db, &dy, &x, &g, &m, &v});
stream.synchronize();
return {fw_dx, fw_dg, fw_db};