[PYTHON] Fix variable referenced before assignment error in blocksparse matmul backward (#90)

This commit is contained in:
Nora Belrose
2021-04-23 11:22:27 -07:00
committed by Philippe Tillet
parent d9112144b4
commit 1112e2526e

View File

@@ -560,7 +560,9 @@ class _matmul(torch.autograd.Function):
def backward(ctx, dc):
# saved for backward
a, b = ctx.saved_tensors
da, db = None, None
mode = ctx.mode
# gradients w.r.t. a
if ctx.needs_input_grad[0]:
mode_da = mode[1] + mode[0] + mode[2]