Files
triton/python/triton/ops/blocksparse/softmax.py
Jared Kaplan 045ab5d62a [PYTHON] Add Blocksparse Attention Fwd/Bwd Test (#69)
Also includes small bugfix for block-sparse softmax
2021-07-27 12:38:49 -07:00

193 lines
7.5 KiB
Python

import triton
import torch
import os
fwd_src = triton.read(os.path.join(os.path.dirname(__file__), 'softmax.c'), kernel_names=['forward'])
fwd_kernels = dict()
bwd_src = triton.read(os.path.join(os.path.dirname(__file__), 'softmax.c'), kernel_names=['backward'])
bwd_kernels = dict()
class _softmax(torch.autograd.Function):
@staticmethod
def next_power_of_2(n):
n -= 1
n |= n >> 1
n |= n >> 2
n |= n >> 4
n |= n >> 8
n |= n >> 16
n += 1
return n
@staticmethod
def make_lut(layout, block, device):
_empty = torch.tensor([], dtype=torch.int64, device=layout.device)
sizes = _empty.clone()
# sizes along rows
for h in range(layout.shape[0]):
sizes = torch.cat((sizes, layout[h, :, :].sum(-1)))
# offsets in block format
offsets = torch.zeros_like(sizes)
offsets[1:] = torch.cumsum(sizes[:-1], dim=0)
# block indices
idx = torch.arange(layout.sum())
head = layout.nonzero(as_tuple=False)[:, 0]
rows = layout.nonzero(as_tuple=False)[:, 1]
columns = layout.nonzero(as_tuple=False)[:, 2]
core = torch.stack((idx, columns, rows, head), dim=1).view(-1)
# construct look-up table
offsets = offsets * 4 + 2 * sizes.numel()
header = torch.stack((sizes, offsets), dim=1).view(-1)
lut = torch.cat((header, core)).type(torch.int32).to(device)
return lut, int(sizes.max())
@staticmethod
def make_kernel(cache, src, max_k, device, dtype, block, apply_scale, apply_rpe, apply_kp_mask, apply_attn_mask,
kp_mask_mode, attn_mask_mode):
if max_k >= 32768:
raise NotImplementedError('Reductions larger than 32768 elements '\
'are not yet implemented')
num_warps = 4 if max_k < 512 else (8 if max_k < 2048 else 16)
TN = _softmax.next_power_of_2(max_k)
# just-in-time compile kernel
key = (block, device, dtype, num_warps, TN, apply_scale, apply_rpe, apply_kp_mask, apply_attn_mask,
kp_mask_mode, attn_mask_mode)
if key not in cache:
defines = {
'TM': 1, 'TN': TN, 'TYPE': dtype, 'BLOCK': block, 'INFINITY':
{torch.float32: 'F32_INFINITY', torch.float16: 'F16_INFINITY'}[dtype]
}
if apply_scale:
defines['APPLY_SCALE'] = True
if apply_rpe:
defines['APPLY_RPE'] = True
if apply_kp_mask:
defines['APPLY_KP_MASK'] = True
if kp_mask_mode == 'mul':
defines['KP_MASK_MUL'] = True
if apply_attn_mask:
defines['APPLY_ATTN_MASK'] = True
if attn_mask_mode == 'mul':
defines['ATTN_MASK_MUL'] = True
kernel = triton.kernel(src, device=device, defines=defines, num_warps=num_warps)
cache[key] = kernel
return cache[key]
@staticmethod
def forward(ctx, x, scale, rpe, key_padding_mask, attn_mask, kp_mask_mode, attn_mask_mode, spdims, block, lut,
maxlut, bench, time):
apply_scale = False if scale == 1.0 else True
# handle None rpe
if rpe is None:
apply_rpe = False
stride_zrpe, stride_hrpe, stride_srpe = 0, 0, 0
rpe = torch.empty(0, dtype=x.dtype, device=x.device)
else:
apply_rpe = True
stride_zrpe, stride_hrpe, stride_srpe = rpe.stride(0), rpe.stride(1), rpe.stride(2)
# handle None key_padding_mask
if key_padding_mask is None:
apply_kp_mask = False
stride_zkpm = 0
key_padding_mask = torch.empty(0, dtype=x.dtype, device=x.device)
else:
apply_kp_mask = True
stride_zkpm = key_padding_mask.stride(0)
# handle None attention_mask
if attn_mask is None:
apply_attn_mask = False
stride_zattnm = 0
attn_mask = torch.empty(0, dtype=x.dtype, device=x.device)
else:
apply_attn_mask = True
stride_zattnm = attn_mask.stride(0)
# run kernel
kernel = _softmax.make_kernel(fwd_kernels, fwd_src, maxlut * block, x.device, x.dtype, block, apply_scale,
apply_rpe, apply_kp_mask, apply_attn_mask, kp_mask_mode, attn_mask_mode)
M = x.shape[0]
grid = lambda opt: [spdims[0] * spdims[1] * block, M]
# run kernel
kernel(x.data_ptr(),
scale,
lut.data_ptr(),
rpe.data_ptr(),
key_padding_mask.data_ptr(),
attn_mask.data_ptr(),
maxlut,
x.stride(0),
stride_zrpe,
stride_hrpe,
stride_srpe,
stride_zkpm,
stride_zattnm,
grid=grid)
# save to context
ctx.mark_dirty(x)
ctx.save_for_backward(x, lut)
ctx.spdims = spdims
ctx.block = block
ctx.maxlut = maxlut
ctx.scale = scale
ctx.apply_scale = apply_scale
ctx.apply_rpe = apply_rpe
ctx.apply_kp_mask = apply_kp_mask
ctx.apply_attn_mask = apply_attn_mask
ctx.kp_mask_mode = kp_mask_mode
ctx.attn_mask_mode = attn_mask_mode
return x
@staticmethod
def backward(ctx, dx):
# retrieve from context
x, lut = ctx.saved_tensors
# run kernel
kernel = _softmax.make_kernel(bwd_kernels, bwd_src, ctx.maxlut * ctx.block, x.device, x.dtype, ctx.block,
ctx.apply_scale, ctx.apply_rpe, ctx.apply_kp_mask, ctx.apply_attn_mask,
ctx.kp_mask_mode, ctx.attn_mask_mode)
M = x.shape[0]
grid = lambda opt: [ctx.spdims[0] * ctx.spdims[1] * ctx.block, M]
kernel(x.data_ptr(), ctx.scale, dx.data_ptr(), lut.data_ptr(), ctx.maxlut, x.stride(0), dx.stride(0), grid=grid)
return dx, None, None, None, None, None, None, None, None, None, None, None, None, None, None
class softmax:
apply_softmax = _softmax.apply
def make_lut(self, device):
key = (device, )
if key not in self.lut_cache:
self.lut_cache[key] = _softmax.make_lut(self.layout, self.block, device)
return self.lut_cache[key]
def __init__(self, layout, block, bench=False):
self.spdims = layout.shape
self.layout = layout
self.block = block
self.bench = bench
self.lut_cache = dict()
def __call__(self,
x,
scale=1.,
rpe=None,
key_padding_mask=None,
attn_mask=None,
key_padding_mask_mode='add',
attn_mask_mode='add'):
time_y = [None]
if rpe is not None and rpe.dtype != x.dtype:
raise ValueError('relative position embedding must be %s' % x.dtype)
if attn_mask is not None and attn_mask.dtype != x.dtype:
raise ValueError('Attention mask must be %s' % x.dtype)
if key_padding_mask is not None and key_padding_mask.dtype != x.dtype:
raise ValueError('Key padding mask must be %s' % x.dtype)
lut, maxlut = self.make_lut(x.device)
x = softmax.apply_softmax(x, scale, rpe, key_padding_mask, attn_mask, key_padding_mask_mode, attn_mask_mode,
self.spdims, self.block, lut, maxlut, self.bench, time_y)
return x