Made sure it works for FP16
This commit is contained in:
@@ -51,12 +51,12 @@ public:
|
||||
OP_REQUIRES_OK(context, context->allocate_output(1, fw_g.shape(), &fw_m));
|
||||
OP_REQUIRES_OK(context, context->allocate_output(2, fw_g.shape(), &fw_v));
|
||||
// triton handles
|
||||
triton::driver::cu_buffer x(ctx, (CUdeviceptr)fw_x.flat<float>().data(), false);
|
||||
triton::driver::cu_buffer g(ctx, (CUdeviceptr)fw_g.flat<float>().data(), false);
|
||||
triton::driver::cu_buffer b(ctx, (CUdeviceptr)fw_b.flat<float>().data(), false);
|
||||
triton::driver::cu_buffer y(ctx, (CUdeviceptr)fw_y->flat<float>().data(), false);
|
||||
triton::driver::cu_buffer m(ctx, (CUdeviceptr)fw_m->flat<float>().data(), false);
|
||||
triton::driver::cu_buffer v(ctx, (CUdeviceptr)fw_v->flat<float>().data(), false);
|
||||
triton::driver::cu_buffer x(ctx, fw_x.tensor_data().size(), (CUdeviceptr)fw_x.tensor_data().data(), false);
|
||||
triton::driver::cu_buffer g(ctx, fw_g.tensor_data().size(), (CUdeviceptr)fw_g.tensor_data().data(), false);
|
||||
triton::driver::cu_buffer b(ctx, fw_b.tensor_data().size(), (CUdeviceptr)fw_b.tensor_data().data(), false);
|
||||
triton::driver::cu_buffer y(ctx, fw_y->tensor_data().size(), (CUdeviceptr)fw_y->tensor_data().data(), false);
|
||||
triton::driver::cu_buffer m(ctx, fw_m->tensor_data().size(), (CUdeviceptr)fw_m->tensor_data().data(), false);
|
||||
triton::driver::cu_buffer v(ctx, fw_v->tensor_data().size(), (CUdeviceptr)fw_v->tensor_data().data(), false);
|
||||
// create config
|
||||
triton::dnn::batchnorm_forward batchnorm(C, 1, H, W, B, "fp32");
|
||||
batchnorm.enqueue(stream, {&y, &m, &v, &x, &g, &b});
|
||||
@@ -117,14 +117,14 @@ public:
|
||||
OP_REQUIRES_OK(context, context->allocate_output(1, fw_g.shape(), &fw_dg));
|
||||
OP_REQUIRES_OK(context, context->allocate_output(2, fw_g.shape(), &fw_db));
|
||||
// triton handles
|
||||
triton::driver::cu_buffer dy(ctx, (CUdeviceptr)fw_dy.flat<float>().data(), false);
|
||||
triton::driver::cu_buffer x(ctx, (CUdeviceptr)fw_x.flat<float>().data(), false);
|
||||
triton::driver::cu_buffer g(ctx, (CUdeviceptr)fw_g.flat<float>().data(), false);
|
||||
triton::driver::cu_buffer m(ctx, (CUdeviceptr)fw_m.flat<float>().data(), false);
|
||||
triton::driver::cu_buffer v(ctx, (CUdeviceptr)fw_v.flat<float>().data(), false);
|
||||
triton::driver::cu_buffer dx(ctx, (CUdeviceptr)fw_dx->flat<float>().data(), false);
|
||||
triton::driver::cu_buffer dg(ctx, (CUdeviceptr)fw_dg->flat<float>().data(), false);
|
||||
triton::driver::cu_buffer db(ctx, (CUdeviceptr)fw_db->flat<float>().data(), false);
|
||||
triton::driver::cu_buffer dy(ctx, fw_dy.tensor_data().size(), (CUdeviceptr)fw_dy.tensor_data().data(), false);
|
||||
triton::driver::cu_buffer x(ctx, fw_x.tensor_data().size(), (CUdeviceptr)fw_x.tensor_data().data(), false);
|
||||
triton::driver::cu_buffer g(ctx, fw_g.tensor_data().size(), (CUdeviceptr)fw_g.tensor_data().data(), false);
|
||||
triton::driver::cu_buffer m(ctx, fw_m.tensor_data().size(), (CUdeviceptr)fw_m.tensor_data().data(), false);
|
||||
triton::driver::cu_buffer v(ctx, fw_v.tensor_data().size(), (CUdeviceptr)fw_v.tensor_data().data(), false);
|
||||
triton::driver::cu_buffer dx(ctx, fw_dx->tensor_data().size(), (CUdeviceptr)fw_dx->tensor_data().data(), false);
|
||||
triton::driver::cu_buffer dg(ctx, fw_dg->tensor_data().size(), (CUdeviceptr)fw_dg->tensor_data().data(), false);
|
||||
triton::driver::cu_buffer db(ctx, fw_db->tensor_data().size(), (CUdeviceptr)fw_db->tensor_data().data(), false);
|
||||
// create config
|
||||
triton::dnn::batchnorm_backward batchnorm(C, 1, H, W, B, "fp32");
|
||||
batchnorm.enqueue(stream, {&dx, &dg, &db, &dy, &x, &g, &m, &v});
|
||||
|
Reference in New Issue
Block a user