Made sure it works for FP16

This commit is contained in:
Philippe Tillet
2019-07-30 20:02:16 -07:00
parent 080bf1af88
commit 5af7e5adac
21 changed files with 118 additions and 101 deletions

View File

@@ -51,12 +51,12 @@ public:
OP_REQUIRES_OK(context, context->allocate_output(1, fw_g.shape(), &fw_m));
OP_REQUIRES_OK(context, context->allocate_output(2, fw_g.shape(), &fw_v));
// triton handles
triton::driver::cu_buffer x(ctx, (CUdeviceptr)fw_x.flat<float>().data(), false);
triton::driver::cu_buffer g(ctx, (CUdeviceptr)fw_g.flat<float>().data(), false);
triton::driver::cu_buffer b(ctx, (CUdeviceptr)fw_b.flat<float>().data(), false);
triton::driver::cu_buffer y(ctx, (CUdeviceptr)fw_y->flat<float>().data(), false);
triton::driver::cu_buffer m(ctx, (CUdeviceptr)fw_m->flat<float>().data(), false);
triton::driver::cu_buffer v(ctx, (CUdeviceptr)fw_v->flat<float>().data(), false);
triton::driver::cu_buffer x(ctx, fw_x.tensor_data().size(), (CUdeviceptr)fw_x.tensor_data().data(), false);
triton::driver::cu_buffer g(ctx, fw_g.tensor_data().size(), (CUdeviceptr)fw_g.tensor_data().data(), false);
triton::driver::cu_buffer b(ctx, fw_b.tensor_data().size(), (CUdeviceptr)fw_b.tensor_data().data(), false);
triton::driver::cu_buffer y(ctx, fw_y->tensor_data().size(), (CUdeviceptr)fw_y->tensor_data().data(), false);
triton::driver::cu_buffer m(ctx, fw_m->tensor_data().size(), (CUdeviceptr)fw_m->tensor_data().data(), false);
triton::driver::cu_buffer v(ctx, fw_v->tensor_data().size(), (CUdeviceptr)fw_v->tensor_data().data(), false);
// create config
triton::dnn::batchnorm_forward batchnorm(C, 1, H, W, B, "fp32");
batchnorm.enqueue(stream, {&y, &m, &v, &x, &g, &b});
@@ -117,14 +117,14 @@ public:
OP_REQUIRES_OK(context, context->allocate_output(1, fw_g.shape(), &fw_dg));
OP_REQUIRES_OK(context, context->allocate_output(2, fw_g.shape(), &fw_db));
// triton handles
triton::driver::cu_buffer dy(ctx, (CUdeviceptr)fw_dy.flat<float>().data(), false);
triton::driver::cu_buffer x(ctx, (CUdeviceptr)fw_x.flat<float>().data(), false);
triton::driver::cu_buffer g(ctx, (CUdeviceptr)fw_g.flat<float>().data(), false);
triton::driver::cu_buffer m(ctx, (CUdeviceptr)fw_m.flat<float>().data(), false);
triton::driver::cu_buffer v(ctx, (CUdeviceptr)fw_v.flat<float>().data(), false);
triton::driver::cu_buffer dx(ctx, (CUdeviceptr)fw_dx->flat<float>().data(), false);
triton::driver::cu_buffer dg(ctx, (CUdeviceptr)fw_dg->flat<float>().data(), false);
triton::driver::cu_buffer db(ctx, (CUdeviceptr)fw_db->flat<float>().data(), false);
triton::driver::cu_buffer dy(ctx, fw_dy.tensor_data().size(), (CUdeviceptr)fw_dy.tensor_data().data(), false);
triton::driver::cu_buffer x(ctx, fw_x.tensor_data().size(), (CUdeviceptr)fw_x.tensor_data().data(), false);
triton::driver::cu_buffer g(ctx, fw_g.tensor_data().size(), (CUdeviceptr)fw_g.tensor_data().data(), false);
triton::driver::cu_buffer m(ctx, fw_m.tensor_data().size(), (CUdeviceptr)fw_m.tensor_data().data(), false);
triton::driver::cu_buffer v(ctx, fw_v.tensor_data().size(), (CUdeviceptr)fw_v.tensor_data().data(), false);
triton::driver::cu_buffer dx(ctx, fw_dx->tensor_data().size(), (CUdeviceptr)fw_dx->tensor_data().data(), false);
triton::driver::cu_buffer dg(ctx, fw_dg->tensor_data().size(), (CUdeviceptr)fw_dg->tensor_data().data(), false);
triton::driver::cu_buffer db(ctx, fw_db->tensor_data().size(), (CUdeviceptr)fw_db->tensor_data().data(), false);
// create config
triton::dnn::batchnorm_backward batchnorm(C, 1, H, W, B, "fp32");
batchnorm.enqueue(stream, {&dx, &dg, &db, &dy, &x, &g, &m, &v});