[dnn/shift]: added support for fp16

This commit is contained in:
Philippe Tillet
2019-07-13 21:05:34 -07:00
parent fe42cb7142
commit 3e7a3ed67a
11 changed files with 76 additions and 43 deletions

View File

@@ -45,11 +45,20 @@ torch::Tensor shift_common(
CUstream custream = (CUstream)at::cuda::getCurrentCUDAStream(device).stream();
triton::driver::cu_stream stream(custream, false);
triton::driver::context* ctx = stream.context();
// Data-type
std::string dtype;
at::ScalarType type = torcha.scalar_type();
switch(type){
case at::ScalarType::Double: dtype = "fp64"; break;
case at::ScalarType::Float: dtype = "fp32"; break;
case at::ScalarType::Half: dtype = "fp16"; break;
default: AT_ERROR("unknown data-type for shift-conv");
}
// Get configuration
bool has_bias = torchbias.storage().size() > 0;
triton::dnn::shift shift(B, C, D, H, W, T, R, S, F,
stride_h, stride_w,
shift_h, shift_w, "fp32", "fp32",
shift_h, shift_w, dtype, dtype,
ty, has_bias, layout);
// Bind memory
triton::driver::cu_buffer a(ctx, (CUdeviceptr)torcha.storage().data(), false);
@@ -61,7 +70,9 @@ torch::Tensor shift_common(
std::vector<long int> c_shapes;
for(auto x: _c_shapes)
c_shapes.push_back(x);
torch::Tensor torchc = torch::empty(c_shapes).cuda();
torch::Tensor torchc = torch::empty(c_shapes, type).cuda();
triton::driver::cu_buffer c(ctx, (CUdeviceptr)torchc.storage().data(), false);
// Enqueue
shift.enqueue(&stream, {&a, &b, &c});