Made sure it works for FP16
This commit is contained in:
@@ -50,8 +50,8 @@ public:
|
||||
int32_t pad_d = 0, pad_h = 0, pad_w = 0;
|
||||
bool has_bias = false;
|
||||
// wrap buffers
|
||||
triton::driver::cu_buffer a(ctx, (CUdeviceptr)tfa.flat<Eigen::half>().data(), false);
|
||||
triton::driver::cu_buffer b(ctx, (CUdeviceptr)tfb.flat<Eigen::half>().data(), false);
|
||||
triton::driver::cu_buffer a(ctx, tfa.tensor_data().size(), (CUdeviceptr)tfa.tensor_data().data(), false);
|
||||
triton::driver::cu_buffer b(ctx, tfb.tensor_data().size(), (CUdeviceptr)tfb.tensor_data().data(), false);
|
||||
triton::driver::buffer* bias = nullptr;
|
||||
// template
|
||||
triton::dnn::conv conv(B, C,
|
||||
@@ -68,7 +68,7 @@ public:
|
||||
Tensor* tfc = nullptr;
|
||||
TensorShape out_shape({c_shapes[0], c_shapes[1], c_shapes[2], c_shapes[3]});
|
||||
OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &tfc));
|
||||
triton::driver::cu_buffer c(ctx, (CUdeviceptr)tfc->flat<float>().data(), false);
|
||||
triton::driver::cu_buffer c(ctx, tfc->tensor_data().size(), (CUdeviceptr)tfc->tensor_data().data(), false);
|
||||
// enqueue
|
||||
conv.enqueue(stream, {&a, &b, &c, bias});
|
||||
}
|
||||
|
Reference in New Issue
Block a user