From 41410012e84421ca910b881f7670ab772bb4f166 Mon Sep 17 00:00:00 2001 From: Nora Belrose <39116809+norabelrose@users.noreply.github.com> Date: Fri, 23 Apr 2021 13:28:44 -0700 Subject: [PATCH] [PYTHON] Fix crash in sparse matmul for tensors w/ less than 4 dims (#91) --- python/triton/ops/blocksparse/matmul.py | 38 ++++++++----------------- 1 file changed, 12 insertions(+), 26 deletions(-) diff --git a/python/triton/ops/blocksparse/matmul.py b/python/triton/ops/blocksparse/matmul.py index 99a4d0281..f8e7a0388 100644 --- a/python/triton/ops/blocksparse/matmul.py +++ b/python/triton/ops/blocksparse/matmul.py @@ -1,8 +1,6 @@ import triton import triton._C.libtriton as libtriton import torch -import os -import math @triton.jit @@ -257,7 +255,7 @@ class _matmul(torch.autograd.Function): ########################## @staticmethod - def make_sdd_lut(layout, block, dtype, device): + def make_sdd_lut(layout, block, device): start_width = 128 // block layout = layout.type(torch.int32) superblocks = libtriton.superblock(layout.data_ptr(), layout.shape[0], layout.shape[1], layout.shape[2], start_width) @@ -268,6 +266,7 @@ class _matmul(torch.autograd.Function): luts.append(torch.from_numpy(nnz).type(torch.int32).to(device)) widths.append(width) packs.append(size) + # create locks return luts, None, widths, packs @@ -278,18 +277,11 @@ class _matmul(torch.autograd.Function): a, b = b, a trans_a, trans_b = not trans_b, not trans_a AS0 = a.size(0) - AS1 = a.size(1) AS2 = a.size(3 if trans_a else 2) AS3 = a.size(2 if trans_a else 3) - BS0 = b.size(0) - BS1 = b.size(1) - BS2 = b.size(3 if trans_b else 2) - BS3 = b.size(2 if trans_b else 3) dtype = a.dtype device = a.device is_16_multiple = AS3 % 16 == 0 - is_32_multiple = AS3 % 32 == 0 - is_64_multiple = AS3 % 64 == 0 if not is_16_multiple: raise ValueError('Reduction size for SDD must be a multiple of 16') # create kernel @@ -437,9 +429,6 @@ class _matmul(torch.autograd.Function): AS0 = a.size(0) AS1 = a.size(1) AS2 = a.size(3 if trans_a else 2) - AS3 = a.size(2 if trans_a else 3) - BS0 = spdims[0] - BS1 = block * spdims[2 if trans_b else 1] BS2 = block * spdims[1 if trans_b else 2] dtype = a.dtype # kernel @@ -484,12 +473,9 @@ class _matmul(torch.autograd.Function): @staticmethod def _dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, num_locks, width, packs): # shapes / dtypes - AS0 = spdims[0] AS1 = block * spdims[2 if trans_a else 1] - AS2 = block * spdims[1 if trans_a else 2] BS0 = b.size(0) BS1 = b.size(1) - BS2 = b.size(3 if trans_b else 2) BS3 = b.size(2 if trans_b else 3) dtype = a.dtype # kernel @@ -593,7 +579,7 @@ class matmul: layout, block = self.layout, self.block step = 16 if self.mode == 'sdd': - c_lut, c_num_locks, c_width, c_packs = _matmul.make_sdd_lut(layout, block, dtype, device) + c_lut, c_num_locks, c_width, c_packs = _matmul.make_sdd_lut(layout, block, device) elif self.mode == 'dsd': c_lut, c_num_locks, c_width, c_packs = _matmul.make_dxx_lut(layout, block, step, not self.trans_a, device) elif self.mode == 'dds': @@ -602,7 +588,7 @@ class matmul: if self.mode == 'sdd': da_lut, da_num_locks, da_width, da_packs = _matmul.make_dxx_lut(layout, block, step, True, device) elif self.mode == 'dsd': - da_lut, da_num_locks, da_width, da_packs = _matmul.make_sdd_lut(layout, block, dtype, device) + da_lut, da_num_locks, da_width, da_packs = _matmul.make_sdd_lut(layout, block, device) elif self.mode == 'dds': da_lut, da_num_locks, da_width, da_packs = _matmul.make_dxx_lut(layout, block, step, not self.trans_b, device) # DB look-up table @@ -611,7 +597,7 @@ class matmul: elif self.mode == 'dsd': db_lut, db_num_locks, db_width, db_packs = _matmul.make_dxx_lut(layout, block, step, self.trans_a, device) elif self.mode == 'dds': - db_lut, db_num_locks, db_width, db_packs = _matmul.make_sdd_lut(layout, block, dtype, device) + db_lut, db_num_locks, db_width, db_packs = _matmul.make_sdd_lut(layout, block, device) self.lut_cache[key] = (c_lut, c_num_locks, c_width, c_packs,\ da_lut, da_num_locks, da_width, da_packs,\ db_lut, db_num_locks, db_width, db_packs) @@ -630,13 +616,13 @@ class matmul: self.block = block self.layout = layout - # pad shapes of a tensor to make it - # compatible with kernel calls + # Kernel assumes that all tensors are 4 dimensional @staticmethod - def _pad_shape(x, is_sparse): - max_dim = 3 if is_sparse else 4 - for i in range(max_dim - x.dim()): + def _pad_shape(x): + # Add extra batch dimensions if needed + for i in range(4 - x.ndim): x = x.unsqueeze(0) + return x def __call__(self, a, b): @@ -644,8 +630,8 @@ class matmul: da_lut, da_num_locks, da_width, da_packs,\ db_lut, db_num_locks, db_width, db_packs = self.make_lut(a.dtype, a.device) # pad shapes with ones - a = matmul._pad_shape(a, self.mode == 'dsd') - b = matmul._pad_shape(b, self.mode == 'dds') + a = matmul._pad_shape(a) + b = matmul._pad_shape(b) # execute c = _matmul.apply( a, b, self.trans_a, self.trans_b, False, self.mode, self.spdims, self.block, c_lut, c_num_locks, c_width, c_packs,