[dnn] better specification of recompilation key
This commit is contained in:
@@ -37,7 +37,7 @@ std::vector<torch::Tensor>
|
||||
triton::driver::cu_buffer m(ctx, (CUdeviceptr)fw_m.storage().data(), false);
|
||||
triton::driver::cu_buffer v(ctx, (CUdeviceptr)fw_v.storage().data(), false);
|
||||
// create template
|
||||
triton::dnn::batchnorm_forward batchnorm(C, 1, H, W, B, "fp32");
|
||||
triton::dnn::batchnorm_forward batchnorm(C, 1, H, W, B, "float");
|
||||
batchnorm.enqueue(&stream, {&y, &m, &v, &x, &g, &b});
|
||||
stream.synchronize();
|
||||
return {fw_y, fw_m, fw_v};
|
||||
@@ -79,7 +79,7 @@ std::vector<torch::Tensor>
|
||||
triton::driver::cu_buffer dg(ctx, (CUdeviceptr)fw_dg.storage().data(), false);
|
||||
triton::driver::cu_buffer db(ctx, (CUdeviceptr)fw_db.storage().data(), false);
|
||||
// create config
|
||||
triton::dnn::batchnorm_backward batchnorm(C, 1, H, W, B, "fp32", eps);
|
||||
triton::dnn::batchnorm_backward batchnorm(C, 1, H, W, B, "float", eps);
|
||||
batchnorm.enqueue(&stream, {&dx, &dg, &db, &dy, &x, &g, &m, &v});
|
||||
stream.synchronize();
|
||||
return {fw_dx, fw_dg, fw_db};
|
||||
|
Reference in New Issue
Block a user