[dnn] better specification of recompilation key

This commit is contained in:
Philippe Tillet
2019-08-02 17:42:48 -07:00
parent 3b92ddf7e6
commit d9945692a9
31 changed files with 418 additions and 428 deletions

View File

@@ -58,7 +58,7 @@ public:
triton::driver::cu_buffer m(ctx, fw_m->tensor_data().size(), (CUdeviceptr)fw_m->tensor_data().data(), false);
triton::driver::cu_buffer v(ctx, fw_v->tensor_data().size(), (CUdeviceptr)fw_v->tensor_data().data(), false);
// create config
triton::dnn::batchnorm_forward batchnorm(C, 1, H, W, B, "fp32");
triton::dnn::batchnorm_forward batchnorm(C, 1, H, W, B, "float", triton::dnn::FULL_TUNING);
batchnorm.enqueue(stream, {&y, &m, &v, &x, &g, &b});
}
@@ -126,7 +126,7 @@ public:
triton::driver::cu_buffer dg(ctx, fw_dg->tensor_data().size(), (CUdeviceptr)fw_dg->tensor_data().data(), false);
triton::driver::cu_buffer db(ctx, fw_db->tensor_data().size(), (CUdeviceptr)fw_db->tensor_data().data(), false);
// create config
triton::dnn::batchnorm_backward batchnorm(C, 1, H, W, B, "fp32");
triton::dnn::batchnorm_backward batchnorm(C, 1, H, W, B, "float", triton::dnn::FULL_TUNING);
batchnorm.enqueue(stream, {&dx, &dg, &db, &dy, &x, &g, &m, &v});
}