[OPS][BLOCKSPARSE] Added .contiguous()
in blocksparse inputs when necessary (#420)
This commit is contained in:
@@ -76,6 +76,10 @@ def _sdd_kernel(
|
||||
tl.store(pc, c, mask=True)
|
||||
|
||||
def sdd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, widths, out = None):
|
||||
if a.stride(2) != 1 and a.stride(3) != 1:
|
||||
a = a.contiguous()
|
||||
if b.stride(2) != 1 and b.stride(3) != 1:
|
||||
b = b.contiguous()
|
||||
# (A * B)^T = B^T * A^T
|
||||
if trans_c:
|
||||
a, b = b, a
|
||||
@@ -190,6 +194,10 @@ def _dsd_kernel(
|
||||
tl.store(pc, c, mask = offs_cn[None, :] < DS0)
|
||||
|
||||
def dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out = None):
|
||||
if a.stride(2) != 1 and a.stride(3) != 1:
|
||||
a = a.contiguous()
|
||||
if b.stride(2) != 1 and b.stride(3) != 1:
|
||||
b = b.contiguous()
|
||||
# shapes / dtypes
|
||||
AS1 = block * spdims[2 if trans_a else 1]
|
||||
BS0 = b.size(0)
|
||||
@@ -378,6 +386,10 @@ def _dds_kernel(
|
||||
tl.store(ptrs_c, c, mask = offs_cm[:, None] < DS0)
|
||||
|
||||
def dds_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out = None):
|
||||
if a.stride(2) != 1 and a.stride(3) != 1:
|
||||
a = a.contiguous()
|
||||
if b.stride(2) != 1 and b.stride(3) != 1:
|
||||
b = b.contiguous()
|
||||
# shapes / dtypes
|
||||
AS0 = a.size(0)
|
||||
AS1 = a.size(1)
|
||||
|
Reference in New Issue
Block a user