[dnn] better specification of recompilation key
This commit is contained in:
@@ -58,7 +58,7 @@ public:
|
||||
triton::driver::cu_buffer m(ctx, fw_m->tensor_data().size(), (CUdeviceptr)fw_m->tensor_data().data(), false);
|
||||
triton::driver::cu_buffer v(ctx, fw_v->tensor_data().size(), (CUdeviceptr)fw_v->tensor_data().data(), false);
|
||||
// create config
|
||||
triton::dnn::batchnorm_forward batchnorm(C, 1, H, W, B, "fp32");
|
||||
triton::dnn::batchnorm_forward batchnorm(C, 1, H, W, B, "float", triton::dnn::FULL_TUNING);
|
||||
batchnorm.enqueue(stream, {&y, &m, &v, &x, &g, &b});
|
||||
}
|
||||
|
||||
@@ -126,7 +126,7 @@ public:
|
||||
triton::driver::cu_buffer dg(ctx, fw_dg->tensor_data().size(), (CUdeviceptr)fw_dg->tensor_data().data(), false);
|
||||
triton::driver::cu_buffer db(ctx, fw_db->tensor_data().size(), (CUdeviceptr)fw_db->tensor_data().data(), false);
|
||||
// create config
|
||||
triton::dnn::batchnorm_backward batchnorm(C, 1, H, W, B, "fp32");
|
||||
triton::dnn::batchnorm_backward batchnorm(C, 1, H, W, B, "float", triton::dnn::FULL_TUNING);
|
||||
batchnorm.enqueue(stream, {&dx, &dg, &db, &dy, &x, &g, &m, &v});
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user