Made sure it works for FP16

This commit is contained in:
Philippe Tillet
2019-07-30 20:02:16 -07:00
parent 080bf1af88
commit 5af7e5adac
21 changed files with 118 additions and 101 deletions

View File

@@ -50,8 +50,8 @@ public:
int32_t pad_d = 0, pad_h = 0, pad_w = 0;
bool has_bias = false;
// wrap buffers
triton::driver::cu_buffer a(ctx, (CUdeviceptr)tfa.flat<Eigen::half>().data(), false);
triton::driver::cu_buffer b(ctx, (CUdeviceptr)tfb.flat<Eigen::half>().data(), false);
triton::driver::cu_buffer a(ctx, tfa.tensor_data().size(), (CUdeviceptr)tfa.tensor_data().data(), false);
triton::driver::cu_buffer b(ctx, tfb.tensor_data().size(), (CUdeviceptr)tfb.tensor_data().data(), false);
triton::driver::buffer* bias = nullptr;
// template
triton::dnn::conv conv(B, C,
@@ -68,7 +68,7 @@ public:
Tensor* tfc = nullptr;
TensorShape out_shape({c_shapes[0], c_shapes[1], c_shapes[2], c_shapes[3]});
OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &tfc));
triton::driver::cu_buffer c(ctx, (CUdeviceptr)tfc->flat<float>().data(), false);
triton::driver::cu_buffer c(ctx, tfc->tensor_data().size(), (CUdeviceptr)tfc->tensor_data().data(), false);
// enqueue
conv.enqueue(stream, {&a, &b, &c, bias});
}