From 001fb757fe0fea4f82e721378ed2f9dea35513ba Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Thu, 6 Jan 2022 12:56:22 -0500 Subject: [PATCH] [OPS][BLOCKSPARSE] Added `.contiguous()` in blocksparse inputs when necessary (#420) --- python/triton/ops/blocksparse/matmul.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/python/triton/ops/blocksparse/matmul.py b/python/triton/ops/blocksparse/matmul.py index 9c3317fe0..ce15c9af4 100644 --- a/python/triton/ops/blocksparse/matmul.py +++ b/python/triton/ops/blocksparse/matmul.py @@ -76,6 +76,10 @@ def _sdd_kernel( tl.store(pc, c, mask=True) def sdd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, widths, out = None): + if a.stride(2) != 1 and a.stride(3) != 1: + a = a.contiguous() + if b.stride(2) != 1 and b.stride(3) != 1: + b = b.contiguous() # (A * B)^T = B^T * A^T if trans_c: a, b = b, a @@ -190,6 +194,10 @@ def _dsd_kernel( tl.store(pc, c, mask = offs_cn[None, :] < DS0) def dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out = None): + if a.stride(2) != 1 and a.stride(3) != 1: + a = a.contiguous() + if b.stride(2) != 1 and b.stride(3) != 1: + b = b.contiguous() # shapes / dtypes AS1 = block * spdims[2 if trans_a else 1] BS0 = b.size(0) @@ -378,6 +386,10 @@ def _dds_kernel( tl.store(ptrs_c, c, mask = offs_cm[:, None] < DS0) def dds_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out = None): + if a.stride(2) != 1 and a.stride(3) != 1: + a = a.contiguous() + if b.stride(2) != 1 and b.stride(3) != 1: + b = b.contiguous() # shapes / dtypes AS0 = a.size(0) AS1 = a.size(1)