[dnn/shift]: added support for fp16
This commit is contained in:
@@ -45,11 +45,20 @@ torch::Tensor shift_common(
|
||||
CUstream custream = (CUstream)at::cuda::getCurrentCUDAStream(device).stream();
|
||||
triton::driver::cu_stream stream(custream, false);
|
||||
triton::driver::context* ctx = stream.context();
|
||||
// Data-type
|
||||
std::string dtype;
|
||||
at::ScalarType type = torcha.scalar_type();
|
||||
switch(type){
|
||||
case at::ScalarType::Double: dtype = "fp64"; break;
|
||||
case at::ScalarType::Float: dtype = "fp32"; break;
|
||||
case at::ScalarType::Half: dtype = "fp16"; break;
|
||||
default: AT_ERROR("unknown data-type for shift-conv");
|
||||
}
|
||||
// Get configuration
|
||||
bool has_bias = torchbias.storage().size() > 0;
|
||||
triton::dnn::shift shift(B, C, D, H, W, T, R, S, F,
|
||||
stride_h, stride_w,
|
||||
shift_h, shift_w, "fp32", "fp32",
|
||||
shift_h, shift_w, dtype, dtype,
|
||||
ty, has_bias, layout);
|
||||
// Bind memory
|
||||
triton::driver::cu_buffer a(ctx, (CUdeviceptr)torcha.storage().data(), false);
|
||||
@@ -61,7 +70,9 @@ torch::Tensor shift_common(
|
||||
std::vector<long int> c_shapes;
|
||||
for(auto x: _c_shapes)
|
||||
c_shapes.push_back(x);
|
||||
torch::Tensor torchc = torch::empty(c_shapes).cuda();
|
||||
torch::Tensor torchc = torch::empty(c_shapes, type).cuda();
|
||||
|
||||
|
||||
triton::driver::cu_buffer c(ctx, (CUdeviceptr)torchc.storage().data(), false);
|
||||
// Enqueue
|
||||
shift.enqueue(&stream, {&a, &b, &c});
|
||||
|
Reference in New Issue
Block a user