Files
triton/python/triton/ops/blocksparse/matmul.py
Philippe Tillet 79d098450f [PYTHON][TESTS][DOC] Various improvement of the API and code quality:
* Simplified `triton.kernel` API to achieve lower latency:
  > .data_ptr() must now be passed as kernel argument. No more implicit
conversion from torch.tensor
  > compilation options are now constant attributes, i.e., opt.d('VAR')
becomes opt.VAR
  > torch.device must now be passed explicitly to triton.kernel (no
longer inferred from torch.tensor arguments)
* C++ tests moved to `python/tests/`
* C++ tutorial created in `tutorials/`
* Python tutorial created in python/tutorials/
* Version changed to 1.0alpha
* No longer copying C++ headers into the Python package
* added python/triton/ops/ package for pre-written Triton ops
2021-01-29 17:27:16 -05:00

467 lines
18 KiB
Python

import triton
import triton._C.libtriton as libtriton
import torch
import os
import math
src = triton.read(os.path.join(os.path.dirname(__file__), 'matmul.c'))
##############
# MAIN API #
##############
class _matmul(torch.autograd.Function):
sdd_cache = dict()
dsd_cache = dict()
dds_cache = dict()
locks = dict()
# Given an array sizes representing reduction size for each
# column of a block-mode matrix multiplication,
# performs load-balancing to achieve more smaller reductions
# between `seg_size` elements
@staticmethod
def load_balance(sizes, block):
# segment size
# heuristics taken from OpenAI blocksparse code
# https://github.com/openai/blocksparse/blob/master/blocksparse/matmul.py#L95
max_size = sizes.max()
min_size = sizes[sizes != 0].min()
#if max_size > min_size * 2.0:
# seg_max = max(triton.cdiv(max_size, 4), min_size*2)
#else:
# seg_max = max_size
seg_max = max_size
seg_min = max(triton.cdiv(seg_max, 4), 4)
# split reduction into segments
div = sizes // seg_max
rem = sizes % seg_max
packs = div + (sizes < seg_min).long() + (rem >= seg_min).long()
width = packs.sum()
segments = torch.empty(width, dtype=sizes.dtype)
column = torch.empty_like(segments)
lockid = torch.zeros_like(segments)
maxid = torch.zeros_like(segments)
nlocks = 0
current = 0
col_idx = 0
for i in range(len(sizes)):
d, r = div[i], rem[i]
isempty = sizes[i] < seg_min
last = current + d + (r >= seg_min) + isempty
# column id
column[current:last] = col_idx
# lock id
if d > 1 or (d == 1 and r >= seg_min):
nlocks += 1
lockid[current:last] = nlocks
maxid[current:last] = last - current
# segment size
segments[current:current+d] = seg_max
if r < seg_min and not isempty:
segments[current+d-1] += r
if r >= seg_min or isempty:
segments[current+d] = r
current = last
col_idx += 1
offsets = torch.zeros_like(segments)
offsets[1:] = torch.cumsum(segments[:-1], dim=0)
return segments, column, lockid, maxid, offsets
@staticmethod
def get_locks(size, dev):
if dev not in _matmul.locks or \
size > _matmul.locks[dev].size(0):
_matmul.locks[dev] = torch.zeros(size, dtype=torch.int32, device=dev)
return _matmul.locks[dev]
##########################
# SPARSE = DENSE x DENSE #
##########################
@staticmethod
def make_sdd_lut(layout, block, dtype, device):
start_width = 64 // block
superblocks = libtriton.superblock(layout.type(torch.int32), start_width)
luts, widths, packs = [], [], []
for size, nnz in superblocks:
width = nnz.shape[0] // (size*size)
h = nnz[:, 0]
i = nnz[:, 1]
j = nnz[:, 2]
b = nnz[:, 3]
lut = torch.stack((h, i, j, b), dim=1).view(-1).contiguous()
luts.append(lut.type(torch.int32).to(device))
widths.append(width)
packs.append(size)
# create locks
return luts, None, widths, packs
@staticmethod
def _sdd_matmul(a, b, trans_a, trans_b, trans_c,
spdims, block, luts, num_locks, widths, packs):
if trans_c:
a, b = b, a
trans_a, trans_b = not trans_b, not trans_a
AS0 = a.size(0)
AS1 = a.size(1)
AS2 = a.size(3 if trans_a else 2)
AS3 = a.size(2 if trans_a else 3)
BS0 = b.size(0)
BS1 = b.size(1)
BS2 = b.size(3 if trans_b else 2)
BS3 = b.size(2 if trans_b else 3)
dtype = a.dtype
device = a.device
is_16_multiple = AS3 % 16 == 0
is_32_multiple = AS3 % 32 == 0
is_64_multiple = AS3 % 64 == 0
if not is_16_multiple:
raise ValueError('Reduction size for SDD must be a multiple of 16')
# create kernel
total_width = sum([width*pack*pack for width,pack in zip(widths, packs)])
c = torch.empty((AS0, total_width, block, block), dtype=dtype, device=device)
for lut, width, pack in zip(luts, widths, packs):
num_lock = 1
key = (block, device, a.dtype, b.dtype, trans_a, trans_b, trans_c, pack, is_32_multiple, is_64_multiple)
if key not in _matmul.sdd_cache:
F32TK = [8, 16]
#F16TK = [16]
#F16TK += [32] if is_32_multiple else []
#F16TK += [64] if is_64_multiple else []
F16TK = [64]
TK = {torch.float32: F32TK,
torch.float16: F16TK}[dtype]
defines = {'TM': block*pack, 'TN': block*pack, 'TMN': block*block*pack*pack, 'BLOCK': block,
'TK': TK, 'TYPE': dtype,
'STRIDE_AM': '1' if trans_a else 'lda',
'STRIDE_AK': 'lda' if trans_a else '1',
'STRIDE_BN': 'ldb' if trans_b else '1',
'STRIDE_BK': '1' if trans_b else 'ldb',
'STRIDE_CM': 'ldc', 'STRIDE_CN': '1',
'SDD': True, 'TZ': 1, 'NAME': 'sdd_kernel'}
_matmul.sdd_cache[key] = triton.kernel(src, device=device, defines=defines, num_warps=[1, 2, 4])
kernel = _matmul.sdd_cache[key]
# create output
locks = _matmul.get_locks(2*width*AS0*num_lock, a.device)
# maximum grid size is 65535
# so operation might be decomposed into multiple
# kernel calls
max_width = 49152
for off_width in range(0, width, max_width):
kernel(a.data_ptr(), b.data_ptr(), c.data_ptr(),
a.stride(2), b.stride(2), block,
a.stride(0), b.stride(0), c.stride(0),
a.stride(1), b.stride(1), c.stride(0),
AS2, AS2, AS3, off_width, lut.data_ptr(), locks.data_ptr(), num_lock,
grid = lambda opt: [opt.TZ, min(max_width, width - off_width), AS0])
# save for backward pass
return c
##########################
# DENSE = DENSE x SPARSE #
# DENSE = SPARSE x DENSE #
##########################
# Given a binary layout of 0s and 1s,
# Construct look-up table for efficient execution on GPUs
@staticmethod
def make_dxx_lut(layout, block, step, trans, device, transform = lambda idx: idx):
# load-balancing
_empty = torch.tensor([], dtype=torch.int64, device=layout.device)
segments = _empty.clone()
column = _empty.clone()
depth = _empty.clone()
lockid = _empty.clone()
maxid = _empty.clone()
offsets = _empty.clone()
current_offset = 0
current_maxid = 0
for z in range(layout.size(0)):
if trans:
sizes = torch.sum(layout[z, :, :], 1)
else:
sizes = torch.sum(layout[z, :, :], 0)
z_segments, z_column, z_lockid, z_maxid, z_offsets = _matmul.load_balance(sizes, block)
z_depth = z * torch.ones_like(z_segments)
z_lockid[z_lockid > 0] += current_maxid
current_maxid = z_lockid.max()
# concatenate depth
segments = torch.cat((segments, z_segments))
column = torch.cat((column, z_column))
depth = torch.cat((depth, z_depth))
maxid = torch.cat((maxid, z_maxid))
offsets = torch.cat((offsets, current_offset + z_offsets))
lockid = torch.cat((lockid, z_lockid))
current_offset += layout[z, :, :].sum()
segments *= step
# pointer increments
if trans:
nnz = layout.nonzero(as_tuple=False)
else:
nnz = layout.transpose(1, 2).nonzero(as_tuple=False)
num_blocks = nnz.size(0)
offsets = torch.min(offsets, (num_blocks - 1)*torch.ones_like(offsets))
idx = transform(nnz[:, 2]*block)
xincs = idx.clone()
xincs[1:] -= idx[:-1]
# divide block into multiple steps
div = block // step
xincs = xincs.view(-1, 1).repeat(1, div)
xincs[:, 1:] = step
xincs[:, 0 ] -= (div-1)*step
# first increment for each reduction is actually the offset
xincs[offsets[segments>0], 0] = idx[offsets[segments>0]]
xincs = xincs.view(-1)
# block-mode input increments
if trans:
widx = torch.arange(num_blocks)
else:
widx = _empty.clone()
current_offset = 0
for z in range(layout.size(0)):
layoutw = layout[z, :, :].clone()
msum = layoutw.sum()
layoutw[layoutw > 0] = 1 + torch.arange(msum)
widx = torch.cat((widx, current_offset + layoutw.T[layoutw.T > 0] - 1))
current_offset += msum
widx = widx
wincs = widx*block*block
wincs[1:] -= widx[:-1]*block*block
wincs = wincs.view(-1, 1).repeat(1, div)
if trans:
wincs[:, 1:] = step
wincs[:, 0] -= (div-1)*step
else:
wincs[:, 1:] = step*block
wincs[:, 0] -= (div - 1)*step*block
wincs[offsets[segments>0], 0] = widx[offsets[segments>0]]
wincs = wincs.view(-1)
# adjust offset and segment size
offsets *= 2*div
segments *= div
# create header
width = column.size(0)
offsets += 6*width
header = torch.stack((offsets, segments, column, depth, lockid, maxid), dim=1).view(-1).contiguous()
incs = torch.stack((xincs, wincs), dim=1).view(-1).contiguous()
incs = torch.cat((incs, torch.zeros(2, device=incs.device, dtype=incs.dtype)))
# create lut
lut = torch.cat((header, incs))
lut = lut.type(torch.int32).to(device)
# create locks
num_locks = max(1, lockid.max())
return lut, num_locks, width, None
@staticmethod
def _dds_matmul(a, b, trans_a, trans_b, trans_c,
spdims, block, lut, num_locks, width, packs):
# shapes / dtypes
AS0 = a.size(0)
AS1 = a.size(1)
AS2 = a.size(3 if trans_a else 2)
AS3 = a.size(2 if trans_a else 3)
BS0 = spdims[0]
BS1 = block * spdims[2 if trans_b else 1]
BS2 = block * spdims[1 if trans_b else 2]
dtype = a.dtype
# kernel
key = (block, a.device, a.dtype, b.dtype, trans_a, trans_b, trans_c)
if key not in _matmul.dds_cache:
TM = [64, 128] if dtype == torch.float32 else [64, 128, 256]
TK = [8] if dtype == torch.float32 else [16]
defines = {'TM': TM, 'TN': block, 'TK': TK,
'BLOCK': block,
'TYPE': dtype,
'STRIDE_AM': 1 if trans_a else 'lda',
'STRIDE_AK': 'lda' if trans_a else 1,
'STRIDE_BN': block if trans_b else 1,
'STRIDE_BK': 1 if trans_b else block,
'STRIDE_CM': '1' if trans_c else 'ldc',
'STRIDE_CN': 'ldc' if trans_c else '1',
'NAME': 'dds_kernel',
'DDS': True}
_matmul.dds_cache[key] = triton.kernel(src, device=a.device, defines=defines, num_warps=[4])
kernel = _matmul.dds_cache[key]
# output
CS0 = AS0
CS1 = AS1
CS2 = BS2 if trans_c else AS2
CS3 = AS2 if trans_c else BS2
locks = _matmul.get_locks(2*AS0*AS2//32*num_locks, a.device)
c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device)
kernel(a.data_ptr(), b.data_ptr(), c.data_ptr(),
a.stride(2), block, c.stride(2),
a.stride(0), b.stride(0), c.stride(0),
a.stride(1), b.stride(1), c.stride(1),
AS2, BS2, 0, 0, lut.data_ptr(), locks.data_ptr(), num_locks,
grid = lambda opt: [width, triton.cdiv(AS2, opt.TM), AS0])
return c
@staticmethod
def _dsd_matmul(a, b, trans_a, trans_b, trans_c,
spdims, block, lut, num_locks, width, packs):
# shapes / dtypes
AS0 = spdims[0]
AS1 = block * spdims[2 if trans_a else 1]
AS2 = block * spdims[1 if trans_a else 2]
BS0 = b.size(0)
BS1 = b.size(1)
BS2 = b.size(3 if trans_b else 2)
BS3 = b.size(2 if trans_b else 3)
dtype = a.dtype
# kernel
key = (block, a.device, a.dtype, b.dtype, trans_a, trans_b, trans_c)
if key not in _matmul.dsd_cache:
TN = [64, 128] if dtype == torch.float32 else [64, 128]
TK = [8] if dtype == torch.float32 else [16]
defines = {'TM': block, 'TN': TN, 'TK': TK,
'BLOCK': block,
'TYPE': dtype,
'STRIDE_AM': 1 if trans_a else block,
'STRIDE_AK': block if trans_a else 1,
'STRIDE_BN': 'ldb' if trans_b else '1',
'STRIDE_BK': '1' if trans_b else 'ldb',
'STRIDE_CM': '1' if trans_c else 'ldc',
'STRIDE_CN': 'ldc' if trans_c else '1',
'NAME': 'dsd_kernel',
'DSD': True}
_matmul.dsd_cache[key] = triton.kernel(src, device=a.device, defines=defines, num_warps=[4])
kernel = _matmul.dsd_cache[key]
# output
CS0 = BS0
CS1 = BS1
CS2 = BS3 if trans_c else AS1
CS3 = AS1 if trans_c else BS3
locks = _matmul.get_locks(2*BS0*BS3//32*num_locks, a.device)
c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device)
kernel(a.data_ptr(), b.data_ptr(), c.data_ptr(),
block, b.stride(2), c.stride(2),
a.stride(0), b.stride(0), c.stride(0),
a.stride(1), b.stride(1), c.stride(1),
BS3, AS1, 0, 0, lut.data_ptr(), locks.data_ptr(), num_locks,
grid = lambda opt: [width, triton.cdiv(BS3, opt.TN), BS0])
return c
fn = {'sdd': _sdd_matmul.__get__(object),
'dsd': _dsd_matmul.__get__(object),
'dds': _dds_matmul.__get__(object)}
@staticmethod
def forward(ctx, a, b, trans_a, trans_b, trans_c,
mode, spdims, block,
c_lut, c_num_locks, c_width, c_packs,
da_lut, da_num_locks, da_width, da_packs,
db_lut, db_num_locks, db_width, db_packs):
c = _matmul.fn[mode](a, b, trans_a, trans_b, trans_c, spdims, block,
c_lut, c_num_locks, c_width, c_packs)
# save for backward
ctx.save_for_backward(a, b)
ctx.da_num_locks = da_num_locks
ctx.da_lut = da_lut
ctx.da_width = da_width
ctx.da_packs = da_packs
ctx.db_lut = db_lut
ctx.db_num_locks = db_num_locks
ctx.db_width = db_width
ctx.db_packs = db_packs
ctx.mode = mode
ctx.spdims = spdims
ctx.block = block
ctx.trans_a = trans_a
ctx.trans_b = trans_b
return c
@staticmethod
def backward(ctx, dc):
# saved for backward
a, b = ctx.saved_tensors
mode = ctx.mode
# gradients w.r.t. a
if ctx.needs_input_grad[0]:
mode_da = mode[1] + mode[0] + mode[2]
da = _matmul.fn[mode_da](dc, b, False, not ctx.trans_b, ctx.trans_a, ctx.spdims, ctx.block,
ctx.da_lut, ctx.da_num_locks, ctx.da_width, ctx.da_packs)
# gradients w.r.t. b
if ctx.needs_input_grad[1]:
mode_db = mode[2] + mode[1] + mode[0]
db = _matmul.fn[mode_db](a, dc, not ctx.trans_a, False, ctx.trans_b, ctx.spdims, ctx.block,
ctx.db_lut, ctx.db_num_locks, ctx.db_width, ctx.db_packs)
return da, db, None, None, None,\
None, None, None, None,\
None, None, None, None, None, None,\
None, None, None, None, None, None,\
None, None, None, None, None, None
class matmul:
def make_lut(self, dtype, device):
key = (dtype, device)
if key in self.lut_cache:
return self.lut_cache[key]
# C look-up table
layout, block = self.layout, self.block
step = 8 if dtype == torch.float32 else 16
if self.mode == 'sdd':
c_lut, c_num_locks, c_width, c_packs = _matmul.make_sdd_lut(layout, block, dtype, device)
elif self.mode == 'dsd':
c_lut, c_num_locks, c_width, c_packs = _matmul.make_dxx_lut(layout, block, step, not self.trans_a, device)
elif self.mode == 'dds':
c_lut, c_num_locks, c_width, c_packs = _matmul.make_dxx_lut(layout, block, step, self.trans_b, device)
# DA look-up table
if self.mode == 'sdd':
da_lut, da_num_locks, da_width, da_packs = _matmul.make_dxx_lut(layout, block, step, True, device)
elif self.mode == 'dsd':
da_lut, da_num_locks, da_width, da_packs = _matmul.make_sdd_lut(layout, block, dtype, device)
elif self.mode == 'dds':
da_lut, da_num_locks, da_width, da_packs = _matmul.make_dxx_lut(layout, block, step, not self.trans_b, device)
# DB look-up table
if self.mode == 'sdd':
db_lut, db_num_locks, db_width, db_packs = _matmul.make_dxx_lut(layout, block, step, False, device)
elif self.mode == 'dsd':
db_lut, db_num_locks, db_width, db_packs = _matmul.make_dxx_lut(layout, block, step, self.trans_a, device)
elif self.mode == 'dds':
db_lut, db_num_locks, db_width, db_packs = _matmul.make_sdd_lut(layout, block, dtype, device)
self.lut_cache[key] = (c_lut, c_num_locks, c_width, c_packs,\
da_lut, da_num_locks, da_width, da_packs,\
db_lut, db_num_locks, db_width, db_packs)
return self.lut_cache[key]
def __init__(self, layout, block, mode, trans_a = False, trans_b = False):
if mode not in ['sdd', 'dsd', 'dds']:
raise NotImplementedError('Supported modes are: sdd, dsd, dds')
# look-up table cache
self.lut_cache = dict()
# attributes
self.trans_a = trans_a
self.trans_b = trans_b
self.mode = mode
self.spdims = layout.shape
self.block = block
self.layout = layout
# pad shapes of a tensor to make it
# compatible with kernel calls
@staticmethod
def _pad_shape(x, is_sparse):
max_dim = 3 if is_sparse else 4
for i in range(max_dim - x.dim()):
x = x.unsqueeze(0)
return x
def __call__(self, a, b):
c_lut, c_num_locks, c_width, c_packs,\
da_lut, da_num_locks, da_width, da_packs,\
db_lut, db_num_locks, db_width, db_packs = self.make_lut(a.dtype, a.device)
# pad shapes with ones
a = matmul._pad_shape(a, self.mode == 'dsd')
b = matmul._pad_shape(b, self.mode == 'dds')
# execute
c = _matmul.apply(a, b, self.trans_a, self.trans_b, False,
self.mode, self.spdims, self.block,
c_lut, c_num_locks, c_width, c_packs,
da_lut, da_num_locks, da_width, da_packs,
db_lut, db_num_locks, db_width, db_packs)
return c