[dnn/shift]: added support for fp16
This commit is contained in:
@@ -45,11 +45,20 @@ torch::Tensor shift_common(
|
||||
CUstream custream = (CUstream)at::cuda::getCurrentCUDAStream(device).stream();
|
||||
triton::driver::cu_stream stream(custream, false);
|
||||
triton::driver::context* ctx = stream.context();
|
||||
// Data-type
|
||||
std::string dtype;
|
||||
at::ScalarType type = torcha.scalar_type();
|
||||
switch(type){
|
||||
case at::ScalarType::Double: dtype = "fp64"; break;
|
||||
case at::ScalarType::Float: dtype = "fp32"; break;
|
||||
case at::ScalarType::Half: dtype = "fp16"; break;
|
||||
default: AT_ERROR("unknown data-type for shift-conv");
|
||||
}
|
||||
// Get configuration
|
||||
bool has_bias = torchbias.storage().size() > 0;
|
||||
triton::dnn::shift shift(B, C, D, H, W, T, R, S, F,
|
||||
stride_h, stride_w,
|
||||
shift_h, shift_w, "fp32", "fp32",
|
||||
shift_h, shift_w, dtype, dtype,
|
||||
ty, has_bias, layout);
|
||||
// Bind memory
|
||||
triton::driver::cu_buffer a(ctx, (CUdeviceptr)torcha.storage().data(), false);
|
||||
@@ -61,7 +70,9 @@ torch::Tensor shift_common(
|
||||
std::vector<long int> c_shapes;
|
||||
for(auto x: _c_shapes)
|
||||
c_shapes.push_back(x);
|
||||
torch::Tensor torchc = torch::empty(c_shapes).cuda();
|
||||
torch::Tensor torchc = torch::empty(c_shapes, type).cuda();
|
||||
|
||||
|
||||
triton::driver::cu_buffer c(ctx, (CUdeviceptr)torchc.storage().data(), false);
|
||||
// Enqueue
|
||||
shift.enqueue(&stream, {&a, &b, &c});
|
||||
|
@@ -123,8 +123,6 @@ class ShiftConvFunction(torch.autograd.Function):
|
||||
dw = torch.ops.triton.shift_conv_dw(dy.contiguous(), input, bias, width[0], width[1], stride[0], stride[1], shift_h, shift_w)
|
||||
if ctx.needs_input_grad[2]:
|
||||
dbias = torch.sum(dy, (1, 2, 3))
|
||||
#print('dx', ctx.needs_input_grad[0], np.isnan(dx.cpu().numpy()).any())
|
||||
#print('dw', ctx.needs_input_grad[1], np.isnan(dw.cpu().numpy()).any())
|
||||
return dx, dw, dbias, None, None, None, None
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user