[PYTHON] Added shape & device checks for inputs to sparse matmul op (#93)
This commit is contained in:
committed by
Philippe Tillet
parent
bfc0a7587d
commit
d7f87929fa
@@ -197,7 +197,7 @@ class _matmul(torch.autograd.Function):
|
||||
# performs load-balancing to achieve more smaller reductions
|
||||
# between `seg_size` elements
|
||||
@staticmethod
|
||||
def load_balance(sizes, block):
|
||||
def load_balance(sizes):
|
||||
# segment size
|
||||
# heuristics taken from OpenAI blocksparse code
|
||||
# https://github.com/openai/blocksparse/blob/master/blocksparse/matmul.py#L95
|
||||
@@ -273,33 +273,41 @@ class _matmul(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def _sdd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, luts, num_locks, widths, packs):
|
||||
|
||||
# (A * B)^T = (B^T * A^T)
|
||||
if trans_c:
|
||||
a, b = b, a
|
||||
trans_a, trans_b = not trans_b, not trans_a
|
||||
AS0 = a.size(0)
|
||||
AS2 = a.size(3 if trans_a else 2)
|
||||
AS3 = a.size(2 if trans_a else 3)
|
||||
|
||||
# Shape check
|
||||
a_dim = -2 if trans_a else -1
|
||||
b_dim = -1 if trans_b else -2
|
||||
a_inner, b_inner = a.shape[a_dim], b.shape[b_dim]
|
||||
if a_inner != b_inner:
|
||||
raise ValueError(f"Size of tensor A along the {_dim_to_name(a_dim)} dim ({a_inner}) must match size "
|
||||
f"of tensor B along the {_dim_to_name(b_dim)} dim ({b_inner})")
|
||||
if a_inner % 16 != 0:
|
||||
raise ValueError('Reduction size for SDD must be a multiple of 16')
|
||||
|
||||
batch_size = a.size(0)
|
||||
a_outer = a.size(3 if trans_a else 2)
|
||||
dtype = a.dtype
|
||||
device = a.device
|
||||
is_16_multiple = AS3 % 16 == 0
|
||||
if not is_16_multiple:
|
||||
raise ValueError('Reduction size for SDD must be a multiple of 16')
|
||||
|
||||
# create kernel
|
||||
total_width = sum([width * pack * pack for width, pack in zip(widths, packs)])
|
||||
c = torch.zeros((AS0, total_width, block, block), dtype=dtype, device=device)
|
||||
c = torch.zeros((batch_size, total_width, block, block), dtype=dtype, device=device)
|
||||
for lut, width, pack in zip(luts, widths, packs):
|
||||
num_lock = 1
|
||||
meta = {'TM': block * pack, 'TN': block * pack, 'BLOCK': block, 'TK': 32, 'TZ': 1, \
|
||||
meta = {'TM': block * pack, 'TN': block * pack, 'BLOCK': block, 'TK': 32, 'TZ': 1,
|
||||
'SDD': True, 'DSD': False, 'DDS': False}
|
||||
# create output
|
||||
locks = _matmul.get_locks(2 * width * AS0 * num_lock, a.device)
|
||||
locks = _matmul.get_locks(2 * width * batch_size * num_lock, a.device)
|
||||
# maximum grid size is 65535
|
||||
# so operation might be decomposed into multiple
|
||||
# kernel calls
|
||||
max_width = 49152
|
||||
for off_width in range(0, width, max_width):
|
||||
grid = lambda meta: [meta['TZ'], min(max_width, width - off_width), AS0]
|
||||
grid = lambda meta: [meta['TZ'], min(max_width, width - off_width), batch_size]
|
||||
_kernel[grid](
|
||||
a,
|
||||
b,
|
||||
@@ -316,9 +324,9 @@ class _matmul(torch.autograd.Function):
|
||||
c.stride(0),
|
||||
c.stride(2),
|
||||
c.stride(3),
|
||||
AS2,
|
||||
AS2,
|
||||
AS3,
|
||||
a_outer,
|
||||
a_outer,
|
||||
a_inner,
|
||||
off_width,
|
||||
lut,
|
||||
locks,
|
||||
@@ -353,7 +361,7 @@ class _matmul(torch.autograd.Function):
|
||||
sizes = torch.sum(layout[z, :, :], 1)
|
||||
else:
|
||||
sizes = torch.sum(layout[z, :, :], 0)
|
||||
z_segments, z_column, z_lockid, z_maxid, z_offsets = _matmul.load_balance(sizes, block)
|
||||
z_segments, z_column, z_lockid, z_maxid, z_offsets = _matmul.load_balance(sizes)
|
||||
z_depth = z * torch.ones_like(z_segments)
|
||||
z_lockid[z_lockid > 0] += current_maxid
|
||||
current_maxid = z_lockid.max()
|
||||
@@ -433,7 +441,7 @@ class _matmul(torch.autograd.Function):
|
||||
BS2 = block * spdims[1 if trans_b else 2]
|
||||
dtype = a.dtype
|
||||
# kernel
|
||||
meta = {'TN': block, 'TM': 128, 'TK': 16, 'BLOCK': block, 'TZ': 1,\
|
||||
meta = {'TN': block, 'TM': 128, 'TK': 16, 'BLOCK': block, 'TZ': 1,
|
||||
'SDD': False, 'DSD': False, 'DDS': True}
|
||||
# output
|
||||
CS0 = AS0
|
||||
@@ -480,7 +488,7 @@ class _matmul(torch.autograd.Function):
|
||||
BS3 = b.size(2 if trans_b else 3)
|
||||
dtype = a.dtype
|
||||
# kernel
|
||||
meta = {'TM': block, 'TN': 128, 'TK': 16, 'BLOCK': block, 'TZ': 1,\
|
||||
meta = {'TM': block, 'TN': 128, 'TK': 16, 'BLOCK': block, 'TZ': 1,
|
||||
'SDD': False, 'DSD': True, 'DDS': False}
|
||||
# output
|
||||
CS0 = BS0
|
||||
@@ -599,8 +607,8 @@ class matmul:
|
||||
db_lut, db_num_locks, db_width, db_packs = _matmul.make_dxx_lut(layout, block, step, self.trans_a, device)
|
||||
elif self.mode == 'dds':
|
||||
db_lut, db_num_locks, db_width, db_packs = _matmul.make_sdd_lut(layout, block, device)
|
||||
self.lut_cache[key] = (c_lut, c_num_locks, c_width, c_packs,\
|
||||
da_lut, da_num_locks, da_width, da_packs,\
|
||||
self.lut_cache[key] = (c_lut, c_num_locks, c_width, c_packs,
|
||||
da_lut, da_num_locks, da_width, da_packs,
|
||||
db_lut, db_num_locks, db_width, db_packs)
|
||||
return self.lut_cache[key]
|
||||
|
||||
@@ -610,32 +618,98 @@ class matmul:
|
||||
# look-up table cache
|
||||
self.lut_cache = dict()
|
||||
# attributes
|
||||
self.block = block
|
||||
self.mode = mode
|
||||
self.trans_a = trans_a
|
||||
self.trans_b = trans_b
|
||||
self.mode = mode
|
||||
self.spdims = layout.shape
|
||||
self.block = block
|
||||
|
||||
layout_dim = layout.ndim
|
||||
assert layout_dim in (2, 3), "Layout should be a 2 or 3 dimensional tensor of 0s and 1s"
|
||||
|
||||
if not mode == 'sdd':
|
||||
# Dims to be reduced on the 'inside' of the matmul, either -1 or -2
|
||||
trans_dense, trans_sparse, sparse_inner = (trans_b, trans_a, -1) if mode == 'dsd' else (trans_a, trans_b, -2)
|
||||
self.dense_inner_dim = -((sparse_inner % 2) + 1) if not trans_dense else sparse_inner
|
||||
sparse_inner = sparse_inner if not trans_sparse else -((sparse_inner % 2) + 1)
|
||||
|
||||
# Inner dim of the dense input should be equal to the inner dim of the sparse input
|
||||
self.dense_inner_size = layout.shape[sparse_inner] * block
|
||||
# Expected shape for sparse inputs
|
||||
self.sparse_shape = (layout.sum().item(), block, block)
|
||||
|
||||
# Support using the same layout across attention heads etc.
|
||||
if layout_dim == 2:
|
||||
layout = layout.unsqueeze(0)
|
||||
|
||||
layout = layout.long() # Above code assumes the layout tensor is an integral type
|
||||
self.layout = layout
|
||||
|
||||
# Kernel assumes that all tensors are 4 dimensional
|
||||
@staticmethod
|
||||
def _pad_shape(x):
|
||||
# Add extra batch dimensions if needed
|
||||
for i in range(4 - x.ndim):
|
||||
x = x.unsqueeze(0)
|
||||
|
||||
return x
|
||||
self.spdims = layout.shape
|
||||
|
||||
def __call__(self, a, b):
|
||||
c_lut, c_num_locks, c_width, c_packs,\
|
||||
da_lut, da_num_locks, da_width, da_packs,\
|
||||
db_lut, db_num_locks, db_width, db_packs = self.make_lut(a.dtype, a.device)
|
||||
# pad shapes with ones
|
||||
a = matmul._pad_shape(a)
|
||||
b = matmul._pad_shape(b)
|
||||
|
||||
# If we don't check for invalid shapes, devices, & dtypes here, they will lead to undefined behavior
|
||||
# and potential illegal memory accesses
|
||||
original_dims = max(a.ndim, b.ndim)
|
||||
a, b = self._validate_inputs(a, b)
|
||||
|
||||
# execute
|
||||
c = _matmul.apply(
|
||||
a, b, self.trans_a, self.trans_b, False, self.mode, self.spdims, self.block, c_lut, c_num_locks, c_width, c_packs,
|
||||
da_lut, da_num_locks, da_width, da_packs, db_lut, db_num_locks, db_width, db_packs
|
||||
a, b, self.trans_a, self.trans_b, False, self.mode, self.spdims, self.block, c_lut, c_num_locks, c_width,
|
||||
c_packs, da_lut, da_num_locks, da_width, da_packs, db_lut, db_num_locks, db_width, db_packs
|
||||
)
|
||||
# This removes any leading singleton dimensions we may have added to the tensor that weren't in the input
|
||||
dims_to_trim = c.ndim - original_dims
|
||||
for _ in range(dims_to_trim):
|
||||
c = c.squeeze(0)
|
||||
|
||||
return c
|
||||
|
||||
def _validate_inputs(self, a, b):
|
||||
if a.device != b.device:
|
||||
raise ValueError(f"Inputs must be on the same device; got {a.device} for tensor A "
|
||||
f"and {b.device} for tensor B")
|
||||
if not a.is_cuda:
|
||||
raise ValueError("Only GPU devices are supported for now")
|
||||
|
||||
# When autocast is enabled, torch.matmul autocasts to float16, so we do the same here
|
||||
if torch.is_autocast_enabled():
|
||||
a, b = a.half(), b.half()
|
||||
elif a.dtype != b.dtype:
|
||||
raise ValueError(f"Inputs must be the same dtype; got {a.dtype} for A and {b.dtype} for B")
|
||||
|
||||
mode, trans_a, trans_b = self.mode, self.trans_a, self.trans_b
|
||||
if mode != 'sdd':
|
||||
# One input is sparse
|
||||
dense, dense_name, sparse, sparse_name = (a, 'A', b, 'B') if mode == 'dds' else (b, 'B', a, 'A')
|
||||
dense_inner = dense.shape[self.dense_inner_dim]
|
||||
if dense_inner != self.dense_inner_size:
|
||||
raise ValueError(f"Expected tensor {dense_name} to have size {self.dense_inner_size} at dim "
|
||||
f"{self.dense_inner_dim % dense.ndim}, got {dense_inner}.")
|
||||
|
||||
if sparse.shape[-len(self.sparse_shape):] != self.sparse_shape:
|
||||
raise ValueError(f"Expected tensor with trailing dimensions of shape {self.sparse_shape} for argument "
|
||||
f"{sparse_name}, got {sparse.shape}")
|
||||
|
||||
def add_extra_dims(x):
|
||||
# Add extra leading singleton dimensions if needed
|
||||
dims_needed = 4 - x.ndim
|
||||
if dims_needed > 0:
|
||||
singletons = [1] * dims_needed
|
||||
x = x.view(*singletons, *x.shape)
|
||||
elif dims_needed < 0:
|
||||
raise ValueError("Tensors with more than 4 dimensions are not currently supported")
|
||||
|
||||
return x
|
||||
|
||||
# Pad shapes with leading singleton dimensions
|
||||
a = add_extra_dims(a)
|
||||
b = add_extra_dims(b)
|
||||
|
||||
return a, b
|
||||
|
||||
def _dim_to_name(x):
|
||||
# assert x in (-1, -2)
|
||||
return "last" if x == -1 else "second to last"
|
||||
|
Reference in New Issue
Block a user