diff --git a/python/triton/ops/blocksparse/matmul.py b/python/triton/ops/blocksparse/matmul.py index 50af3c564..99a4d0281 100644 --- a/python/triton/ops/blocksparse/matmul.py +++ b/python/triton/ops/blocksparse/matmul.py @@ -560,7 +560,9 @@ class _matmul(torch.autograd.Function): def backward(ctx, dc): # saved for backward a, b = ctx.saved_tensors + da, db = None, None mode = ctx.mode + # gradients w.r.t. a if ctx.needs_input_grad[0]: mode_da = mode[1] + mode[0] + mode[2]