[PYTHON] Fix crash in sparse matmul for tensors w/ less than 4 dims (#91)

This commit is contained in:
Nora Belrose
2021-04-23 13:28:44 -07:00
committed by Philippe Tillet
parent f6688372db
commit 41410012e8

View File

@@ -1,8 +1,6 @@
import triton
import triton._C.libtriton as libtriton
import torch
import os
import math
@triton.jit
@@ -257,7 +255,7 @@ class _matmul(torch.autograd.Function):
##########################
@staticmethod
def make_sdd_lut(layout, block, dtype, device):
def make_sdd_lut(layout, block, device):
start_width = 128 // block
layout = layout.type(torch.int32)
superblocks = libtriton.superblock(layout.data_ptr(), layout.shape[0], layout.shape[1], layout.shape[2], start_width)
@@ -268,6 +266,7 @@ class _matmul(torch.autograd.Function):
luts.append(torch.from_numpy(nnz).type(torch.int32).to(device))
widths.append(width)
packs.append(size)
# create locks
return luts, None, widths, packs
@@ -278,18 +277,11 @@ class _matmul(torch.autograd.Function):
a, b = b, a
trans_a, trans_b = not trans_b, not trans_a
AS0 = a.size(0)
AS1 = a.size(1)
AS2 = a.size(3 if trans_a else 2)
AS3 = a.size(2 if trans_a else 3)
BS0 = b.size(0)
BS1 = b.size(1)
BS2 = b.size(3 if trans_b else 2)
BS3 = b.size(2 if trans_b else 3)
dtype = a.dtype
device = a.device
is_16_multiple = AS3 % 16 == 0
is_32_multiple = AS3 % 32 == 0
is_64_multiple = AS3 % 64 == 0
if not is_16_multiple:
raise ValueError('Reduction size for SDD must be a multiple of 16')
# create kernel
@@ -437,9 +429,6 @@ class _matmul(torch.autograd.Function):
AS0 = a.size(0)
AS1 = a.size(1)
AS2 = a.size(3 if trans_a else 2)
AS3 = a.size(2 if trans_a else 3)
BS0 = spdims[0]
BS1 = block * spdims[2 if trans_b else 1]
BS2 = block * spdims[1 if trans_b else 2]
dtype = a.dtype
# kernel
@@ -484,12 +473,9 @@ class _matmul(torch.autograd.Function):
@staticmethod
def _dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, num_locks, width, packs):
# shapes / dtypes
AS0 = spdims[0]
AS1 = block * spdims[2 if trans_a else 1]
AS2 = block * spdims[1 if trans_a else 2]
BS0 = b.size(0)
BS1 = b.size(1)
BS2 = b.size(3 if trans_b else 2)
BS3 = b.size(2 if trans_b else 3)
dtype = a.dtype
# kernel
@@ -593,7 +579,7 @@ class matmul:
layout, block = self.layout, self.block
step = 16
if self.mode == 'sdd':
c_lut, c_num_locks, c_width, c_packs = _matmul.make_sdd_lut(layout, block, dtype, device)
c_lut, c_num_locks, c_width, c_packs = _matmul.make_sdd_lut(layout, block, device)
elif self.mode == 'dsd':
c_lut, c_num_locks, c_width, c_packs = _matmul.make_dxx_lut(layout, block, step, not self.trans_a, device)
elif self.mode == 'dds':
@@ -602,7 +588,7 @@ class matmul:
if self.mode == 'sdd':
da_lut, da_num_locks, da_width, da_packs = _matmul.make_dxx_lut(layout, block, step, True, device)
elif self.mode == 'dsd':
da_lut, da_num_locks, da_width, da_packs = _matmul.make_sdd_lut(layout, block, dtype, device)
da_lut, da_num_locks, da_width, da_packs = _matmul.make_sdd_lut(layout, block, device)
elif self.mode == 'dds':
da_lut, da_num_locks, da_width, da_packs = _matmul.make_dxx_lut(layout, block, step, not self.trans_b, device)
# DB look-up table
@@ -611,7 +597,7 @@ class matmul:
elif self.mode == 'dsd':
db_lut, db_num_locks, db_width, db_packs = _matmul.make_dxx_lut(layout, block, step, self.trans_a, device)
elif self.mode == 'dds':
db_lut, db_num_locks, db_width, db_packs = _matmul.make_sdd_lut(layout, block, dtype, device)
db_lut, db_num_locks, db_width, db_packs = _matmul.make_sdd_lut(layout, block, device)
self.lut_cache[key] = (c_lut, c_num_locks, c_width, c_packs,\
da_lut, da_num_locks, da_width, da_packs,\
db_lut, db_num_locks, db_width, db_packs)
@@ -630,13 +616,13 @@ class matmul:
self.block = block
self.layout = layout
# pad shapes of a tensor to make it
# compatible with kernel calls
# Kernel assumes that all tensors are 4 dimensional
@staticmethod
def _pad_shape(x, is_sparse):
max_dim = 3 if is_sparse else 4
for i in range(max_dim - x.dim()):
def _pad_shape(x):
# Add extra batch dimensions if needed
for i in range(4 - x.ndim):
x = x.unsqueeze(0)
return x
def __call__(self, a, b):
@@ -644,8 +630,8 @@ class matmul:
da_lut, da_num_locks, da_width, da_packs,\
db_lut, db_num_locks, db_width, db_packs = self.make_lut(a.dtype, a.device)
# pad shapes with ones
a = matmul._pad_shape(a, self.mode == 'dsd')
b = matmul._pad_shape(b, self.mode == 'dds')
a = matmul._pad_shape(a)
b = matmul._pad_shape(b)
# execute
c = _matmul.apply(
a, b, self.trans_a, self.trans_b, False, self.mode, self.spdims, self.block, c_lut, c_num_locks, c_width, c_packs,