[PYTHON] Fix crash in sparse matmul for tensors w/ less than 4 dims (#91)
This commit is contained in:
committed by
Philippe Tillet
parent
f6688372db
commit
41410012e8
@@ -1,8 +1,6 @@
|
||||
import triton
|
||||
import triton._C.libtriton as libtriton
|
||||
import torch
|
||||
import os
|
||||
import math
|
||||
|
||||
|
||||
@triton.jit
|
||||
@@ -257,7 +255,7 @@ class _matmul(torch.autograd.Function):
|
||||
##########################
|
||||
|
||||
@staticmethod
|
||||
def make_sdd_lut(layout, block, dtype, device):
|
||||
def make_sdd_lut(layout, block, device):
|
||||
start_width = 128 // block
|
||||
layout = layout.type(torch.int32)
|
||||
superblocks = libtriton.superblock(layout.data_ptr(), layout.shape[0], layout.shape[1], layout.shape[2], start_width)
|
||||
@@ -268,6 +266,7 @@ class _matmul(torch.autograd.Function):
|
||||
luts.append(torch.from_numpy(nnz).type(torch.int32).to(device))
|
||||
widths.append(width)
|
||||
packs.append(size)
|
||||
|
||||
# create locks
|
||||
return luts, None, widths, packs
|
||||
|
||||
@@ -278,18 +277,11 @@ class _matmul(torch.autograd.Function):
|
||||
a, b = b, a
|
||||
trans_a, trans_b = not trans_b, not trans_a
|
||||
AS0 = a.size(0)
|
||||
AS1 = a.size(1)
|
||||
AS2 = a.size(3 if trans_a else 2)
|
||||
AS3 = a.size(2 if trans_a else 3)
|
||||
BS0 = b.size(0)
|
||||
BS1 = b.size(1)
|
||||
BS2 = b.size(3 if trans_b else 2)
|
||||
BS3 = b.size(2 if trans_b else 3)
|
||||
dtype = a.dtype
|
||||
device = a.device
|
||||
is_16_multiple = AS3 % 16 == 0
|
||||
is_32_multiple = AS3 % 32 == 0
|
||||
is_64_multiple = AS3 % 64 == 0
|
||||
if not is_16_multiple:
|
||||
raise ValueError('Reduction size for SDD must be a multiple of 16')
|
||||
# create kernel
|
||||
@@ -437,9 +429,6 @@ class _matmul(torch.autograd.Function):
|
||||
AS0 = a.size(0)
|
||||
AS1 = a.size(1)
|
||||
AS2 = a.size(3 if trans_a else 2)
|
||||
AS3 = a.size(2 if trans_a else 3)
|
||||
BS0 = spdims[0]
|
||||
BS1 = block * spdims[2 if trans_b else 1]
|
||||
BS2 = block * spdims[1 if trans_b else 2]
|
||||
dtype = a.dtype
|
||||
# kernel
|
||||
@@ -484,12 +473,9 @@ class _matmul(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def _dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, num_locks, width, packs):
|
||||
# shapes / dtypes
|
||||
AS0 = spdims[0]
|
||||
AS1 = block * spdims[2 if trans_a else 1]
|
||||
AS2 = block * spdims[1 if trans_a else 2]
|
||||
BS0 = b.size(0)
|
||||
BS1 = b.size(1)
|
||||
BS2 = b.size(3 if trans_b else 2)
|
||||
BS3 = b.size(2 if trans_b else 3)
|
||||
dtype = a.dtype
|
||||
# kernel
|
||||
@@ -593,7 +579,7 @@ class matmul:
|
||||
layout, block = self.layout, self.block
|
||||
step = 16
|
||||
if self.mode == 'sdd':
|
||||
c_lut, c_num_locks, c_width, c_packs = _matmul.make_sdd_lut(layout, block, dtype, device)
|
||||
c_lut, c_num_locks, c_width, c_packs = _matmul.make_sdd_lut(layout, block, device)
|
||||
elif self.mode == 'dsd':
|
||||
c_lut, c_num_locks, c_width, c_packs = _matmul.make_dxx_lut(layout, block, step, not self.trans_a, device)
|
||||
elif self.mode == 'dds':
|
||||
@@ -602,7 +588,7 @@ class matmul:
|
||||
if self.mode == 'sdd':
|
||||
da_lut, da_num_locks, da_width, da_packs = _matmul.make_dxx_lut(layout, block, step, True, device)
|
||||
elif self.mode == 'dsd':
|
||||
da_lut, da_num_locks, da_width, da_packs = _matmul.make_sdd_lut(layout, block, dtype, device)
|
||||
da_lut, da_num_locks, da_width, da_packs = _matmul.make_sdd_lut(layout, block, device)
|
||||
elif self.mode == 'dds':
|
||||
da_lut, da_num_locks, da_width, da_packs = _matmul.make_dxx_lut(layout, block, step, not self.trans_b, device)
|
||||
# DB look-up table
|
||||
@@ -611,7 +597,7 @@ class matmul:
|
||||
elif self.mode == 'dsd':
|
||||
db_lut, db_num_locks, db_width, db_packs = _matmul.make_dxx_lut(layout, block, step, self.trans_a, device)
|
||||
elif self.mode == 'dds':
|
||||
db_lut, db_num_locks, db_width, db_packs = _matmul.make_sdd_lut(layout, block, dtype, device)
|
||||
db_lut, db_num_locks, db_width, db_packs = _matmul.make_sdd_lut(layout, block, device)
|
||||
self.lut_cache[key] = (c_lut, c_num_locks, c_width, c_packs,\
|
||||
da_lut, da_num_locks, da_width, da_packs,\
|
||||
db_lut, db_num_locks, db_width, db_packs)
|
||||
@@ -630,13 +616,13 @@ class matmul:
|
||||
self.block = block
|
||||
self.layout = layout
|
||||
|
||||
# pad shapes of a tensor to make it
|
||||
# compatible with kernel calls
|
||||
# Kernel assumes that all tensors are 4 dimensional
|
||||
@staticmethod
|
||||
def _pad_shape(x, is_sparse):
|
||||
max_dim = 3 if is_sparse else 4
|
||||
for i in range(max_dim - x.dim()):
|
||||
def _pad_shape(x):
|
||||
# Add extra batch dimensions if needed
|
||||
for i in range(4 - x.ndim):
|
||||
x = x.unsqueeze(0)
|
||||
|
||||
return x
|
||||
|
||||
def __call__(self, a, b):
|
||||
@@ -644,8 +630,8 @@ class matmul:
|
||||
da_lut, da_num_locks, da_width, da_packs,\
|
||||
db_lut, db_num_locks, db_width, db_packs = self.make_lut(a.dtype, a.device)
|
||||
# pad shapes with ones
|
||||
a = matmul._pad_shape(a, self.mode == 'dsd')
|
||||
b = matmul._pad_shape(b, self.mode == 'dds')
|
||||
a = matmul._pad_shape(a)
|
||||
b = matmul._pad_shape(b)
|
||||
# execute
|
||||
c = _matmul.apply(
|
||||
a, b, self.trans_a, self.trans_b, False, self.mode, self.spdims, self.block, c_lut, c_num_locks, c_width, c_packs,
|
||||
|
Reference in New Issue
Block a user