Made sure it works for FP16
This commit is contained in:
@@ -119,9 +119,9 @@ public:
|
||||
if (out_shapes.num_elements() == 0)
|
||||
return;
|
||||
// matrix multiplication parameters
|
||||
triton::driver::cu_buffer da(ctx, (CUdeviceptr)tf_a.flat<Eigen::half>().data(), false);
|
||||
triton::driver::cu_buffer db(ctx, (CUdeviceptr)tf_b.flat<Eigen::half>().data(), false);
|
||||
triton::driver::cu_buffer dc(ctx, (CUdeviceptr)tf_c->flat<Eigen::half>().data(), false);
|
||||
triton::driver::cu_buffer da(ctx, tf_a.tensor_data().size(), (CUdeviceptr)tf_a.tensor_data().data(), false);
|
||||
triton::driver::cu_buffer db(ctx, tf_b.tensor_data().size(), (CUdeviceptr)tf_b.tensor_data().data(), false);
|
||||
triton::driver::cu_buffer dc(ctx, tf_c->tensor_data().size(), (CUdeviceptr)tf_c->tensor_data().data(), false);
|
||||
shift.enqueue(stream, {&da, &db, &dc}, triton::dnn::PARTIAL_TUNING);
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user