[dnn/shift]: added support for fp16

This commit is contained in:
Philippe Tillet
2019-07-13 21:05:34 -07:00
parent fe42cb7142
commit 3e7a3ed67a
11 changed files with 76 additions and 43 deletions

View File

@@ -123,8 +123,6 @@ class ShiftConvFunction(torch.autograd.Function):
dw = torch.ops.triton.shift_conv_dw(dy.contiguous(), input, bias, width[0], width[1], stride[0], stride[1], shift_h, shift_w)
if ctx.needs_input_grad[2]:
dbias = torch.sum(dy, (1, 2, 3))
#print('dx', ctx.needs_input_grad[0], np.isnan(dx.cpu().numpy()).any())
#print('dw', ctx.needs_input_grad[1], np.isnan(dw.cpu().numpy()).any())
return dx, dw, dbias, None, None, None, None