Made sure it works for FP16
This commit is contained in:
@@ -45,9 +45,9 @@ class DotOp : public OpKernel {
|
||||
if (out_shape.num_elements() == 0)
|
||||
return;
|
||||
// matrix multiplication parameters
|
||||
triton::driver::cu_buffer da(ctx, (CUdeviceptr)a.flat<Eigen::half>().data(), false);
|
||||
triton::driver::cu_buffer db(ctx, (CUdeviceptr)b.flat<Eigen::half>().data(), false);
|
||||
triton::driver::cu_buffer dc(ctx, (CUdeviceptr)c->flat<float>().data(), false);
|
||||
triton::driver::cu_buffer da(ctx, a.tensor_data().size(), (CUdeviceptr)a.tensor_data().data(), false);
|
||||
triton::driver::cu_buffer db(ctx, b.tensor_data().size(), (CUdeviceptr)b.tensor_data().data(), false);
|
||||
triton::driver::cu_buffer dc(ctx, c->tensor_data().size(), (CUdeviceptr)c->tensor_data().data(), false);
|
||||
// template
|
||||
triton::dnn::dot dot(M, N, K, false, false, "fp16", "fp16", 8, 8);
|
||||
dot.enqueue(stream, {&da, &db, &dc});
|
||||
|
Reference in New Issue
Block a user