[PYTHON] Added shape & device checks for inputs to sparse matmul op (#93)
This commit is contained in:
committed by
Philippe Tillet
parent
bfc0a7587d
commit
d7f87929fa
@@ -197,7 +197,7 @@ class _matmul(torch.autograd.Function):
|
|||||||
# performs load-balancing to achieve more smaller reductions
|
# performs load-balancing to achieve more smaller reductions
|
||||||
# between `seg_size` elements
|
# between `seg_size` elements
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def load_balance(sizes, block):
|
def load_balance(sizes):
|
||||||
# segment size
|
# segment size
|
||||||
# heuristics taken from OpenAI blocksparse code
|
# heuristics taken from OpenAI blocksparse code
|
||||||
# https://github.com/openai/blocksparse/blob/master/blocksparse/matmul.py#L95
|
# https://github.com/openai/blocksparse/blob/master/blocksparse/matmul.py#L95
|
||||||
@@ -273,33 +273,41 @@ class _matmul(torch.autograd.Function):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _sdd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, luts, num_locks, widths, packs):
|
def _sdd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, luts, num_locks, widths, packs):
|
||||||
|
# (A * B)^T = (B^T * A^T)
|
||||||
if trans_c:
|
if trans_c:
|
||||||
a, b = b, a
|
a, b = b, a
|
||||||
trans_a, trans_b = not trans_b, not trans_a
|
trans_a, trans_b = not trans_b, not trans_a
|
||||||
AS0 = a.size(0)
|
|
||||||
AS2 = a.size(3 if trans_a else 2)
|
# Shape check
|
||||||
AS3 = a.size(2 if trans_a else 3)
|
a_dim = -2 if trans_a else -1
|
||||||
|
b_dim = -1 if trans_b else -2
|
||||||
|
a_inner, b_inner = a.shape[a_dim], b.shape[b_dim]
|
||||||
|
if a_inner != b_inner:
|
||||||
|
raise ValueError(f"Size of tensor A along the {_dim_to_name(a_dim)} dim ({a_inner}) must match size "
|
||||||
|
f"of tensor B along the {_dim_to_name(b_dim)} dim ({b_inner})")
|
||||||
|
if a_inner % 16 != 0:
|
||||||
|
raise ValueError('Reduction size for SDD must be a multiple of 16')
|
||||||
|
|
||||||
|
batch_size = a.size(0)
|
||||||
|
a_outer = a.size(3 if trans_a else 2)
|
||||||
dtype = a.dtype
|
dtype = a.dtype
|
||||||
device = a.device
|
device = a.device
|
||||||
is_16_multiple = AS3 % 16 == 0
|
|
||||||
if not is_16_multiple:
|
|
||||||
raise ValueError('Reduction size for SDD must be a multiple of 16')
|
|
||||||
# create kernel
|
# create kernel
|
||||||
total_width = sum([width * pack * pack for width, pack in zip(widths, packs)])
|
total_width = sum([width * pack * pack for width, pack in zip(widths, packs)])
|
||||||
c = torch.zeros((AS0, total_width, block, block), dtype=dtype, device=device)
|
c = torch.zeros((batch_size, total_width, block, block), dtype=dtype, device=device)
|
||||||
for lut, width, pack in zip(luts, widths, packs):
|
for lut, width, pack in zip(luts, widths, packs):
|
||||||
num_lock = 1
|
num_lock = 1
|
||||||
meta = {'TM': block * pack, 'TN': block * pack, 'BLOCK': block, 'TK': 32, 'TZ': 1, \
|
meta = {'TM': block * pack, 'TN': block * pack, 'BLOCK': block, 'TK': 32, 'TZ': 1,
|
||||||
'SDD': True, 'DSD': False, 'DDS': False}
|
'SDD': True, 'DSD': False, 'DDS': False}
|
||||||
# create output
|
# create output
|
||||||
locks = _matmul.get_locks(2 * width * AS0 * num_lock, a.device)
|
locks = _matmul.get_locks(2 * width * batch_size * num_lock, a.device)
|
||||||
# maximum grid size is 65535
|
# maximum grid size is 65535
|
||||||
# so operation might be decomposed into multiple
|
# so operation might be decomposed into multiple
|
||||||
# kernel calls
|
# kernel calls
|
||||||
max_width = 49152
|
max_width = 49152
|
||||||
for off_width in range(0, width, max_width):
|
for off_width in range(0, width, max_width):
|
||||||
grid = lambda meta: [meta['TZ'], min(max_width, width - off_width), AS0]
|
grid = lambda meta: [meta['TZ'], min(max_width, width - off_width), batch_size]
|
||||||
_kernel[grid](
|
_kernel[grid](
|
||||||
a,
|
a,
|
||||||
b,
|
b,
|
||||||
@@ -316,9 +324,9 @@ class _matmul(torch.autograd.Function):
|
|||||||
c.stride(0),
|
c.stride(0),
|
||||||
c.stride(2),
|
c.stride(2),
|
||||||
c.stride(3),
|
c.stride(3),
|
||||||
AS2,
|
a_outer,
|
||||||
AS2,
|
a_outer,
|
||||||
AS3,
|
a_inner,
|
||||||
off_width,
|
off_width,
|
||||||
lut,
|
lut,
|
||||||
locks,
|
locks,
|
||||||
@@ -353,7 +361,7 @@ class _matmul(torch.autograd.Function):
|
|||||||
sizes = torch.sum(layout[z, :, :], 1)
|
sizes = torch.sum(layout[z, :, :], 1)
|
||||||
else:
|
else:
|
||||||
sizes = torch.sum(layout[z, :, :], 0)
|
sizes = torch.sum(layout[z, :, :], 0)
|
||||||
z_segments, z_column, z_lockid, z_maxid, z_offsets = _matmul.load_balance(sizes, block)
|
z_segments, z_column, z_lockid, z_maxid, z_offsets = _matmul.load_balance(sizes)
|
||||||
z_depth = z * torch.ones_like(z_segments)
|
z_depth = z * torch.ones_like(z_segments)
|
||||||
z_lockid[z_lockid > 0] += current_maxid
|
z_lockid[z_lockid > 0] += current_maxid
|
||||||
current_maxid = z_lockid.max()
|
current_maxid = z_lockid.max()
|
||||||
@@ -433,7 +441,7 @@ class _matmul(torch.autograd.Function):
|
|||||||
BS2 = block * spdims[1 if trans_b else 2]
|
BS2 = block * spdims[1 if trans_b else 2]
|
||||||
dtype = a.dtype
|
dtype = a.dtype
|
||||||
# kernel
|
# kernel
|
||||||
meta = {'TN': block, 'TM': 128, 'TK': 16, 'BLOCK': block, 'TZ': 1,\
|
meta = {'TN': block, 'TM': 128, 'TK': 16, 'BLOCK': block, 'TZ': 1,
|
||||||
'SDD': False, 'DSD': False, 'DDS': True}
|
'SDD': False, 'DSD': False, 'DDS': True}
|
||||||
# output
|
# output
|
||||||
CS0 = AS0
|
CS0 = AS0
|
||||||
@@ -480,7 +488,7 @@ class _matmul(torch.autograd.Function):
|
|||||||
BS3 = b.size(2 if trans_b else 3)
|
BS3 = b.size(2 if trans_b else 3)
|
||||||
dtype = a.dtype
|
dtype = a.dtype
|
||||||
# kernel
|
# kernel
|
||||||
meta = {'TM': block, 'TN': 128, 'TK': 16, 'BLOCK': block, 'TZ': 1,\
|
meta = {'TM': block, 'TN': 128, 'TK': 16, 'BLOCK': block, 'TZ': 1,
|
||||||
'SDD': False, 'DSD': True, 'DDS': False}
|
'SDD': False, 'DSD': True, 'DDS': False}
|
||||||
# output
|
# output
|
||||||
CS0 = BS0
|
CS0 = BS0
|
||||||
@@ -599,8 +607,8 @@ class matmul:
|
|||||||
db_lut, db_num_locks, db_width, db_packs = _matmul.make_dxx_lut(layout, block, step, self.trans_a, device)
|
db_lut, db_num_locks, db_width, db_packs = _matmul.make_dxx_lut(layout, block, step, self.trans_a, device)
|
||||||
elif self.mode == 'dds':
|
elif self.mode == 'dds':
|
||||||
db_lut, db_num_locks, db_width, db_packs = _matmul.make_sdd_lut(layout, block, device)
|
db_lut, db_num_locks, db_width, db_packs = _matmul.make_sdd_lut(layout, block, device)
|
||||||
self.lut_cache[key] = (c_lut, c_num_locks, c_width, c_packs,\
|
self.lut_cache[key] = (c_lut, c_num_locks, c_width, c_packs,
|
||||||
da_lut, da_num_locks, da_width, da_packs,\
|
da_lut, da_num_locks, da_width, da_packs,
|
||||||
db_lut, db_num_locks, db_width, db_packs)
|
db_lut, db_num_locks, db_width, db_packs)
|
||||||
return self.lut_cache[key]
|
return self.lut_cache[key]
|
||||||
|
|
||||||
@@ -610,32 +618,98 @@ class matmul:
|
|||||||
# look-up table cache
|
# look-up table cache
|
||||||
self.lut_cache = dict()
|
self.lut_cache = dict()
|
||||||
# attributes
|
# attributes
|
||||||
|
self.block = block
|
||||||
|
self.mode = mode
|
||||||
self.trans_a = trans_a
|
self.trans_a = trans_a
|
||||||
self.trans_b = trans_b
|
self.trans_b = trans_b
|
||||||
self.mode = mode
|
|
||||||
self.spdims = layout.shape
|
layout_dim = layout.ndim
|
||||||
self.block = block
|
assert layout_dim in (2, 3), "Layout should be a 2 or 3 dimensional tensor of 0s and 1s"
|
||||||
|
|
||||||
|
if not mode == 'sdd':
|
||||||
|
# Dims to be reduced on the 'inside' of the matmul, either -1 or -2
|
||||||
|
trans_dense, trans_sparse, sparse_inner = (trans_b, trans_a, -1) if mode == 'dsd' else (trans_a, trans_b, -2)
|
||||||
|
self.dense_inner_dim = -((sparse_inner % 2) + 1) if not trans_dense else sparse_inner
|
||||||
|
sparse_inner = sparse_inner if not trans_sparse else -((sparse_inner % 2) + 1)
|
||||||
|
|
||||||
|
# Inner dim of the dense input should be equal to the inner dim of the sparse input
|
||||||
|
self.dense_inner_size = layout.shape[sparse_inner] * block
|
||||||
|
# Expected shape for sparse inputs
|
||||||
|
self.sparse_shape = (layout.sum().item(), block, block)
|
||||||
|
|
||||||
|
# Support using the same layout across attention heads etc.
|
||||||
|
if layout_dim == 2:
|
||||||
|
layout = layout.unsqueeze(0)
|
||||||
|
|
||||||
|
layout = layout.long() # Above code assumes the layout tensor is an integral type
|
||||||
self.layout = layout
|
self.layout = layout
|
||||||
|
self.spdims = layout.shape
|
||||||
# Kernel assumes that all tensors are 4 dimensional
|
|
||||||
@staticmethod
|
|
||||||
def _pad_shape(x):
|
|
||||||
# Add extra batch dimensions if needed
|
|
||||||
for i in range(4 - x.ndim):
|
|
||||||
x = x.unsqueeze(0)
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
def __call__(self, a, b):
|
def __call__(self, a, b):
|
||||||
c_lut, c_num_locks, c_width, c_packs,\
|
c_lut, c_num_locks, c_width, c_packs,\
|
||||||
da_lut, da_num_locks, da_width, da_packs,\
|
da_lut, da_num_locks, da_width, da_packs,\
|
||||||
db_lut, db_num_locks, db_width, db_packs = self.make_lut(a.dtype, a.device)
|
db_lut, db_num_locks, db_width, db_packs = self.make_lut(a.dtype, a.device)
|
||||||
# pad shapes with ones
|
|
||||||
a = matmul._pad_shape(a)
|
# If we don't check for invalid shapes, devices, & dtypes here, they will lead to undefined behavior
|
||||||
b = matmul._pad_shape(b)
|
# and potential illegal memory accesses
|
||||||
|
original_dims = max(a.ndim, b.ndim)
|
||||||
|
a, b = self._validate_inputs(a, b)
|
||||||
|
|
||||||
# execute
|
# execute
|
||||||
c = _matmul.apply(
|
c = _matmul.apply(
|
||||||
a, b, self.trans_a, self.trans_b, False, self.mode, self.spdims, self.block, c_lut, c_num_locks, c_width, c_packs,
|
a, b, self.trans_a, self.trans_b, False, self.mode, self.spdims, self.block, c_lut, c_num_locks, c_width,
|
||||||
da_lut, da_num_locks, da_width, da_packs, db_lut, db_num_locks, db_width, db_packs
|
c_packs, da_lut, da_num_locks, da_width, da_packs, db_lut, db_num_locks, db_width, db_packs
|
||||||
)
|
)
|
||||||
|
# This removes any leading singleton dimensions we may have added to the tensor that weren't in the input
|
||||||
|
dims_to_trim = c.ndim - original_dims
|
||||||
|
for _ in range(dims_to_trim):
|
||||||
|
c = c.squeeze(0)
|
||||||
|
|
||||||
return c
|
return c
|
||||||
|
|
||||||
|
def _validate_inputs(self, a, b):
|
||||||
|
if a.device != b.device:
|
||||||
|
raise ValueError(f"Inputs must be on the same device; got {a.device} for tensor A "
|
||||||
|
f"and {b.device} for tensor B")
|
||||||
|
if not a.is_cuda:
|
||||||
|
raise ValueError("Only GPU devices are supported for now")
|
||||||
|
|
||||||
|
# When autocast is enabled, torch.matmul autocasts to float16, so we do the same here
|
||||||
|
if torch.is_autocast_enabled():
|
||||||
|
a, b = a.half(), b.half()
|
||||||
|
elif a.dtype != b.dtype:
|
||||||
|
raise ValueError(f"Inputs must be the same dtype; got {a.dtype} for A and {b.dtype} for B")
|
||||||
|
|
||||||
|
mode, trans_a, trans_b = self.mode, self.trans_a, self.trans_b
|
||||||
|
if mode != 'sdd':
|
||||||
|
# One input is sparse
|
||||||
|
dense, dense_name, sparse, sparse_name = (a, 'A', b, 'B') if mode == 'dds' else (b, 'B', a, 'A')
|
||||||
|
dense_inner = dense.shape[self.dense_inner_dim]
|
||||||
|
if dense_inner != self.dense_inner_size:
|
||||||
|
raise ValueError(f"Expected tensor {dense_name} to have size {self.dense_inner_size} at dim "
|
||||||
|
f"{self.dense_inner_dim % dense.ndim}, got {dense_inner}.")
|
||||||
|
|
||||||
|
if sparse.shape[-len(self.sparse_shape):] != self.sparse_shape:
|
||||||
|
raise ValueError(f"Expected tensor with trailing dimensions of shape {self.sparse_shape} for argument "
|
||||||
|
f"{sparse_name}, got {sparse.shape}")
|
||||||
|
|
||||||
|
def add_extra_dims(x):
|
||||||
|
# Add extra leading singleton dimensions if needed
|
||||||
|
dims_needed = 4 - x.ndim
|
||||||
|
if dims_needed > 0:
|
||||||
|
singletons = [1] * dims_needed
|
||||||
|
x = x.view(*singletons, *x.shape)
|
||||||
|
elif dims_needed < 0:
|
||||||
|
raise ValueError("Tensors with more than 4 dimensions are not currently supported")
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
# Pad shapes with leading singleton dimensions
|
||||||
|
a = add_extra_dims(a)
|
||||||
|
b = add_extra_dims(b)
|
||||||
|
|
||||||
|
return a, b
|
||||||
|
|
||||||
|
def _dim_to_name(x):
|
||||||
|
# assert x in (-1, -2)
|
||||||
|
return "last" if x == -1 else "second to last"
|
||||||
|
Reference in New Issue
Block a user