Made sure it works for FP16

This commit is contained in:
Philippe Tillet
2019-07-30 20:02:16 -07:00
parent 080bf1af88
commit 5af7e5adac
21 changed files with 118 additions and 101 deletions

View File

@@ -119,9 +119,9 @@ public:
if (out_shapes.num_elements() == 0)
return;
// matrix multiplication parameters
triton::driver::cu_buffer da(ctx, (CUdeviceptr)tf_a.flat<Eigen::half>().data(), false);
triton::driver::cu_buffer db(ctx, (CUdeviceptr)tf_b.flat<Eigen::half>().data(), false);
triton::driver::cu_buffer dc(ctx, (CUdeviceptr)tf_c->flat<Eigen::half>().data(), false);
triton::driver::cu_buffer da(ctx, tf_a.tensor_data().size(), (CUdeviceptr)tf_a.tensor_data().data(), false);
triton::driver::cu_buffer db(ctx, tf_b.tensor_data().size(), (CUdeviceptr)tf_b.tensor_data().data(), false);
triton::driver::cu_buffer dc(ctx, tf_c->tensor_data().size(), (CUdeviceptr)tf_c->tensor_data().data(), false);
shift.enqueue(stream, {&da, &db, &dc}, triton::dnn::PARTIAL_TUNING);
}