[OPS][BLOCKSPARSE] Added .contiguous()
in blocksparse inputs when necessary (#420)
This commit is contained in:
@@ -76,6 +76,10 @@ def _sdd_kernel(
|
|||||||
tl.store(pc, c, mask=True)
|
tl.store(pc, c, mask=True)
|
||||||
|
|
||||||
def sdd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, widths, out = None):
|
def sdd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, widths, out = None):
|
||||||
|
if a.stride(2) != 1 and a.stride(3) != 1:
|
||||||
|
a = a.contiguous()
|
||||||
|
if b.stride(2) != 1 and b.stride(3) != 1:
|
||||||
|
b = b.contiguous()
|
||||||
# (A * B)^T = B^T * A^T
|
# (A * B)^T = B^T * A^T
|
||||||
if trans_c:
|
if trans_c:
|
||||||
a, b = b, a
|
a, b = b, a
|
||||||
@@ -190,6 +194,10 @@ def _dsd_kernel(
|
|||||||
tl.store(pc, c, mask = offs_cn[None, :] < DS0)
|
tl.store(pc, c, mask = offs_cn[None, :] < DS0)
|
||||||
|
|
||||||
def dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out = None):
|
def dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out = None):
|
||||||
|
if a.stride(2) != 1 and a.stride(3) != 1:
|
||||||
|
a = a.contiguous()
|
||||||
|
if b.stride(2) != 1 and b.stride(3) != 1:
|
||||||
|
b = b.contiguous()
|
||||||
# shapes / dtypes
|
# shapes / dtypes
|
||||||
AS1 = block * spdims[2 if trans_a else 1]
|
AS1 = block * spdims[2 if trans_a else 1]
|
||||||
BS0 = b.size(0)
|
BS0 = b.size(0)
|
||||||
@@ -378,6 +386,10 @@ def _dds_kernel(
|
|||||||
tl.store(ptrs_c, c, mask = offs_cm[:, None] < DS0)
|
tl.store(ptrs_c, c, mask = offs_cm[:, None] < DS0)
|
||||||
|
|
||||||
def dds_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out = None):
|
def dds_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out = None):
|
||||||
|
if a.stride(2) != 1 and a.stride(3) != 1:
|
||||||
|
a = a.contiguous()
|
||||||
|
if b.stride(2) != 1 and b.stride(3) != 1:
|
||||||
|
b = b.contiguous()
|
||||||
# shapes / dtypes
|
# shapes / dtypes
|
||||||
AS0 = a.size(0)
|
AS0 = a.size(0)
|
||||||
AS1 = a.size(1)
|
AS1 = a.size(1)
|
||||||
|
Reference in New Issue
Block a user