[dnn/shift]: added support for fp16
This commit is contained in:
@@ -123,8 +123,6 @@ class ShiftConvFunction(torch.autograd.Function):
|
||||
dw = torch.ops.triton.shift_conv_dw(dy.contiguous(), input, bias, width[0], width[1], stride[0], stride[1], shift_h, shift_w)
|
||||
if ctx.needs_input_grad[2]:
|
||||
dbias = torch.sum(dy, (1, 2, 3))
|
||||
#print('dx', ctx.needs_input_grad[0], np.isnan(dx.cpu().numpy()).any())
|
||||
#print('dw', ctx.needs_input_grad[1], np.isnan(dw.cpu().numpy()).any())
|
||||
return dx, dw, dbias, None, None, None, None
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user