[PYTHON] Fix crash in sparse matmul for tensors w/ less than 4 dims (#91)
This commit is contained in:
committed by
Philippe Tillet
parent
f6688372db
commit
41410012e8
@@ -1,8 +1,6 @@
|
|||||||
import triton
|
import triton
|
||||||
import triton._C.libtriton as libtriton
|
import triton._C.libtriton as libtriton
|
||||||
import torch
|
import torch
|
||||||
import os
|
|
||||||
import math
|
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
@@ -257,7 +255,7 @@ class _matmul(torch.autograd.Function):
|
|||||||
##########################
|
##########################
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def make_sdd_lut(layout, block, dtype, device):
|
def make_sdd_lut(layout, block, device):
|
||||||
start_width = 128 // block
|
start_width = 128 // block
|
||||||
layout = layout.type(torch.int32)
|
layout = layout.type(torch.int32)
|
||||||
superblocks = libtriton.superblock(layout.data_ptr(), layout.shape[0], layout.shape[1], layout.shape[2], start_width)
|
superblocks = libtriton.superblock(layout.data_ptr(), layout.shape[0], layout.shape[1], layout.shape[2], start_width)
|
||||||
@@ -268,6 +266,7 @@ class _matmul(torch.autograd.Function):
|
|||||||
luts.append(torch.from_numpy(nnz).type(torch.int32).to(device))
|
luts.append(torch.from_numpy(nnz).type(torch.int32).to(device))
|
||||||
widths.append(width)
|
widths.append(width)
|
||||||
packs.append(size)
|
packs.append(size)
|
||||||
|
|
||||||
# create locks
|
# create locks
|
||||||
return luts, None, widths, packs
|
return luts, None, widths, packs
|
||||||
|
|
||||||
@@ -278,18 +277,11 @@ class _matmul(torch.autograd.Function):
|
|||||||
a, b = b, a
|
a, b = b, a
|
||||||
trans_a, trans_b = not trans_b, not trans_a
|
trans_a, trans_b = not trans_b, not trans_a
|
||||||
AS0 = a.size(0)
|
AS0 = a.size(0)
|
||||||
AS1 = a.size(1)
|
|
||||||
AS2 = a.size(3 if trans_a else 2)
|
AS2 = a.size(3 if trans_a else 2)
|
||||||
AS3 = a.size(2 if trans_a else 3)
|
AS3 = a.size(2 if trans_a else 3)
|
||||||
BS0 = b.size(0)
|
|
||||||
BS1 = b.size(1)
|
|
||||||
BS2 = b.size(3 if trans_b else 2)
|
|
||||||
BS3 = b.size(2 if trans_b else 3)
|
|
||||||
dtype = a.dtype
|
dtype = a.dtype
|
||||||
device = a.device
|
device = a.device
|
||||||
is_16_multiple = AS3 % 16 == 0
|
is_16_multiple = AS3 % 16 == 0
|
||||||
is_32_multiple = AS3 % 32 == 0
|
|
||||||
is_64_multiple = AS3 % 64 == 0
|
|
||||||
if not is_16_multiple:
|
if not is_16_multiple:
|
||||||
raise ValueError('Reduction size for SDD must be a multiple of 16')
|
raise ValueError('Reduction size for SDD must be a multiple of 16')
|
||||||
# create kernel
|
# create kernel
|
||||||
@@ -437,9 +429,6 @@ class _matmul(torch.autograd.Function):
|
|||||||
AS0 = a.size(0)
|
AS0 = a.size(0)
|
||||||
AS1 = a.size(1)
|
AS1 = a.size(1)
|
||||||
AS2 = a.size(3 if trans_a else 2)
|
AS2 = a.size(3 if trans_a else 2)
|
||||||
AS3 = a.size(2 if trans_a else 3)
|
|
||||||
BS0 = spdims[0]
|
|
||||||
BS1 = block * spdims[2 if trans_b else 1]
|
|
||||||
BS2 = block * spdims[1 if trans_b else 2]
|
BS2 = block * spdims[1 if trans_b else 2]
|
||||||
dtype = a.dtype
|
dtype = a.dtype
|
||||||
# kernel
|
# kernel
|
||||||
@@ -484,12 +473,9 @@ class _matmul(torch.autograd.Function):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def _dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, num_locks, width, packs):
|
def _dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, num_locks, width, packs):
|
||||||
# shapes / dtypes
|
# shapes / dtypes
|
||||||
AS0 = spdims[0]
|
|
||||||
AS1 = block * spdims[2 if trans_a else 1]
|
AS1 = block * spdims[2 if trans_a else 1]
|
||||||
AS2 = block * spdims[1 if trans_a else 2]
|
|
||||||
BS0 = b.size(0)
|
BS0 = b.size(0)
|
||||||
BS1 = b.size(1)
|
BS1 = b.size(1)
|
||||||
BS2 = b.size(3 if trans_b else 2)
|
|
||||||
BS3 = b.size(2 if trans_b else 3)
|
BS3 = b.size(2 if trans_b else 3)
|
||||||
dtype = a.dtype
|
dtype = a.dtype
|
||||||
# kernel
|
# kernel
|
||||||
@@ -593,7 +579,7 @@ class matmul:
|
|||||||
layout, block = self.layout, self.block
|
layout, block = self.layout, self.block
|
||||||
step = 16
|
step = 16
|
||||||
if self.mode == 'sdd':
|
if self.mode == 'sdd':
|
||||||
c_lut, c_num_locks, c_width, c_packs = _matmul.make_sdd_lut(layout, block, dtype, device)
|
c_lut, c_num_locks, c_width, c_packs = _matmul.make_sdd_lut(layout, block, device)
|
||||||
elif self.mode == 'dsd':
|
elif self.mode == 'dsd':
|
||||||
c_lut, c_num_locks, c_width, c_packs = _matmul.make_dxx_lut(layout, block, step, not self.trans_a, device)
|
c_lut, c_num_locks, c_width, c_packs = _matmul.make_dxx_lut(layout, block, step, not self.trans_a, device)
|
||||||
elif self.mode == 'dds':
|
elif self.mode == 'dds':
|
||||||
@@ -602,7 +588,7 @@ class matmul:
|
|||||||
if self.mode == 'sdd':
|
if self.mode == 'sdd':
|
||||||
da_lut, da_num_locks, da_width, da_packs = _matmul.make_dxx_lut(layout, block, step, True, device)
|
da_lut, da_num_locks, da_width, da_packs = _matmul.make_dxx_lut(layout, block, step, True, device)
|
||||||
elif self.mode == 'dsd':
|
elif self.mode == 'dsd':
|
||||||
da_lut, da_num_locks, da_width, da_packs = _matmul.make_sdd_lut(layout, block, dtype, device)
|
da_lut, da_num_locks, da_width, da_packs = _matmul.make_sdd_lut(layout, block, device)
|
||||||
elif self.mode == 'dds':
|
elif self.mode == 'dds':
|
||||||
da_lut, da_num_locks, da_width, da_packs = _matmul.make_dxx_lut(layout, block, step, not self.trans_b, device)
|
da_lut, da_num_locks, da_width, da_packs = _matmul.make_dxx_lut(layout, block, step, not self.trans_b, device)
|
||||||
# DB look-up table
|
# DB look-up table
|
||||||
@@ -611,7 +597,7 @@ class matmul:
|
|||||||
elif self.mode == 'dsd':
|
elif self.mode == 'dsd':
|
||||||
db_lut, db_num_locks, db_width, db_packs = _matmul.make_dxx_lut(layout, block, step, self.trans_a, device)
|
db_lut, db_num_locks, db_width, db_packs = _matmul.make_dxx_lut(layout, block, step, self.trans_a, device)
|
||||||
elif self.mode == 'dds':
|
elif self.mode == 'dds':
|
||||||
db_lut, db_num_locks, db_width, db_packs = _matmul.make_sdd_lut(layout, block, dtype, device)
|
db_lut, db_num_locks, db_width, db_packs = _matmul.make_sdd_lut(layout, block, device)
|
||||||
self.lut_cache[key] = (c_lut, c_num_locks, c_width, c_packs,\
|
self.lut_cache[key] = (c_lut, c_num_locks, c_width, c_packs,\
|
||||||
da_lut, da_num_locks, da_width, da_packs,\
|
da_lut, da_num_locks, da_width, da_packs,\
|
||||||
db_lut, db_num_locks, db_width, db_packs)
|
db_lut, db_num_locks, db_width, db_packs)
|
||||||
@@ -630,13 +616,13 @@ class matmul:
|
|||||||
self.block = block
|
self.block = block
|
||||||
self.layout = layout
|
self.layout = layout
|
||||||
|
|
||||||
# pad shapes of a tensor to make it
|
# Kernel assumes that all tensors are 4 dimensional
|
||||||
# compatible with kernel calls
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _pad_shape(x, is_sparse):
|
def _pad_shape(x):
|
||||||
max_dim = 3 if is_sparse else 4
|
# Add extra batch dimensions if needed
|
||||||
for i in range(max_dim - x.dim()):
|
for i in range(4 - x.ndim):
|
||||||
x = x.unsqueeze(0)
|
x = x.unsqueeze(0)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def __call__(self, a, b):
|
def __call__(self, a, b):
|
||||||
@@ -644,8 +630,8 @@ class matmul:
|
|||||||
da_lut, da_num_locks, da_width, da_packs,\
|
da_lut, da_num_locks, da_width, da_packs,\
|
||||||
db_lut, db_num_locks, db_width, db_packs = self.make_lut(a.dtype, a.device)
|
db_lut, db_num_locks, db_width, db_packs = self.make_lut(a.dtype, a.device)
|
||||||
# pad shapes with ones
|
# pad shapes with ones
|
||||||
a = matmul._pad_shape(a, self.mode == 'dsd')
|
a = matmul._pad_shape(a)
|
||||||
b = matmul._pad_shape(b, self.mode == 'dds')
|
b = matmul._pad_shape(b)
|
||||||
# execute
|
# execute
|
||||||
c = _matmul.apply(
|
c = _matmul.apply(
|
||||||
a, b, self.trans_a, self.trans_b, False, self.mode, self.spdims, self.block, c_lut, c_num_locks, c_width, c_packs,
|
a, b, self.trans_a, self.trans_b, False, self.mode, self.spdims, self.block, c_lut, c_num_locks, c_width, c_packs,
|
||||||
|
Reference in New Issue
Block a user