diff --git a/python/triton/ops/blocksparse/matmul.py b/python/triton/ops/blocksparse/matmul.py index c1787eaa8..abafed584 100644 --- a/python/triton/ops/blocksparse/matmul.py +++ b/python/triton/ops/blocksparse/matmul.py @@ -512,8 +512,8 @@ class _matmul(torch.autograd.Function): b.stride(2 if trans_b else 3), c.stride(0), c.stride(1), - c.stride(2), - c.stride(3), + c.stride(3 if trans_c else 2), + c.stride(2 if trans_c else 3), BS3, AS1, 0,