[OPS][BLOCKSPARSE] Added .contiguous() in blocksparse inputs when necessary (#420)

This commit is contained in:
Philippe Tillet
2022-01-06 12:56:22 -05:00
committed by GitHub
parent 0ab9d67bad
commit 001fb757fe

View File

@@ -76,6 +76,10 @@ def _sdd_kernel(
tl.store(pc, c, mask=True) tl.store(pc, c, mask=True)
def sdd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, widths, out = None): def sdd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, widths, out = None):
if a.stride(2) != 1 and a.stride(3) != 1:
a = a.contiguous()
if b.stride(2) != 1 and b.stride(3) != 1:
b = b.contiguous()
# (A * B)^T = B^T * A^T # (A * B)^T = B^T * A^T
if trans_c: if trans_c:
a, b = b, a a, b = b, a
@@ -190,6 +194,10 @@ def _dsd_kernel(
tl.store(pc, c, mask = offs_cn[None, :] < DS0) tl.store(pc, c, mask = offs_cn[None, :] < DS0)
def dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out = None): def dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out = None):
if a.stride(2) != 1 and a.stride(3) != 1:
a = a.contiguous()
if b.stride(2) != 1 and b.stride(3) != 1:
b = b.contiguous()
# shapes / dtypes # shapes / dtypes
AS1 = block * spdims[2 if trans_a else 1] AS1 = block * spdims[2 if trans_a else 1]
BS0 = b.size(0) BS0 = b.size(0)
@@ -378,6 +386,10 @@ def _dds_kernel(
tl.store(ptrs_c, c, mask = offs_cm[:, None] < DS0) tl.store(ptrs_c, c, mask = offs_cm[:, None] < DS0)
def dds_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out = None): def dds_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out = None):
if a.stride(2) != 1 and a.stride(3) != 1:
a = a.contiguous()
if b.stride(2) != 1 and b.stride(3) != 1:
b = b.contiguous()
# shapes / dtypes # shapes / dtypes
AS0 = a.size(0) AS0 = a.size(0)
AS1 = a.size(1) AS1 = a.size(1)