[PYTHON][OPS] Added block-sparse softmax
This commit is contained in:
@@ -23,7 +23,7 @@ def mask_tensor(x, mask, block, value = 0):
|
|||||||
for block in [16, 32, 64]
|
for block in [16, 32, 64]
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
def test_op(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE = torch.float16, Z = 3, H = 2, M = 128, N = 256, K = 384):
|
def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE = torch.float16, Z = 3, H = 2, M = 128, N = 256, K = 384):
|
||||||
# set seed
|
# set seed
|
||||||
torch.random.manual_seed(0)
|
torch.random.manual_seed(0)
|
||||||
# create inputs
|
# create inputs
|
||||||
@@ -48,3 +48,42 @@ def test_op(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE = torch.float16, Z = 3, H = 2,
|
|||||||
rtol, atol = {torch.float32: (1e-4, 1e-5),
|
rtol, atol = {torch.float32: (1e-4, 1e-5),
|
||||||
torch.float16: (1e-2, 1e-3)}[DTYPE]
|
torch.float16: (1e-2, 1e-3)}[DTYPE]
|
||||||
assert torch.allclose(rc, tc, rtol=rtol, atol=atol)
|
assert torch.allclose(rc, tc, rtol=rtol, atol=atol)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("BLOCK, WIDTH",
|
||||||
|
[
|
||||||
|
(block, width) for block in [16]\
|
||||||
|
for width in [256, 576]
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def test_softmax(BLOCK, WIDTH, DTYPE = torch.float16):
|
||||||
|
# set seed
|
||||||
|
torch.random.manual_seed(0)
|
||||||
|
Z, H, M, N = 2, 4, WIDTH, WIDTH
|
||||||
|
scale = 0.4
|
||||||
|
# create inputs
|
||||||
|
layout = torch.randint(2, (H, M//BLOCK, N//BLOCK))
|
||||||
|
x = torch.randn((Z, H, M, N), dtype=DTYPE, requires_grad=True, device='cuda')
|
||||||
|
at_mask = torch.randint(low=0, high=2, size=(N, N), \
|
||||||
|
dtype=torch.bool, requires_grad=False, device='cuda')
|
||||||
|
kp_mask = torch.randint(low=0, high=2, size=(Z, N), \
|
||||||
|
dtype=DTYPE, requires_grad=False, device='cuda')
|
||||||
|
kp_mask[kp_mask==1.] = float('-inf')
|
||||||
|
# triton result
|
||||||
|
op = tt.ops.blocksparse.softmax(layout, BLOCK)
|
||||||
|
tx = sparsify_tensor(x, layout, BLOCK)
|
||||||
|
ty = op(tx, scale=scale, key_padding_mask=kp_mask, key_padding_mask_mode='add', attn_mask=at_mask.to(DTYPE), attn_mask_mode='mul')
|
||||||
|
# torch result
|
||||||
|
rx = mask_tensor(x, layout, BLOCK, value=float('-inf'))
|
||||||
|
if at_mask is not None:
|
||||||
|
# broadcast at_mask to the same shape as rx
|
||||||
|
M = at_mask[None, None, :, :] + torch.zeros_like(rx)
|
||||||
|
rx[M == 0] = float('-inf')
|
||||||
|
if kp_mask is not None:
|
||||||
|
rx += kp_mask[:, None, None, :]
|
||||||
|
ry = torch.softmax(rx*scale, -1)
|
||||||
|
ry = sparsify_tensor(ry, layout, BLOCK)
|
||||||
|
# compare
|
||||||
|
rtol, atol = {torch.float32: (1e-4, 1e-5),
|
||||||
|
torch.float16: (1e-2, 1e-3)}[DTYPE]
|
||||||
|
assert torch.allclose(ry , ty, rtol=rtol, atol=atol)
|
||||||
|
@@ -1 +1,2 @@
|
|||||||
from .matmul import matmul
|
from .matmul import matmul
|
||||||
|
from .softmax import softmax
|
154
python/triton/ops/blocksparse/softmax.c
Normal file
154
python/triton/ops/blocksparse/softmax.c
Normal file
@@ -0,0 +1,154 @@
|
|||||||
|
__global__ void forward(TYPE *X __readonly __noalias __aligned(16),
|
||||||
|
float scale,
|
||||||
|
int *LUT __readonly __noalias __aligned(16),
|
||||||
|
TYPE *RPE __readonly __noalias __aligned(16),
|
||||||
|
TYPE *KP_M __readonly __noalias __aligned(16),
|
||||||
|
TYPE *ATTN_M __readonly __noalias __aligned(16),
|
||||||
|
int sizemax,
|
||||||
|
long stride_zx __multipleof(BLOCK),
|
||||||
|
long stride_zrpe __multipleof(BLOCK),
|
||||||
|
int stride_hrpe __multipleof(BLOCK),
|
||||||
|
int stride_srpe __multipleof(BLOCK),
|
||||||
|
int stride_zkpm __multipleof(BLOCK),
|
||||||
|
int stride_zattnm __multipleof(BLOCK)){
|
||||||
|
int pidhm = get_program_id(0);
|
||||||
|
int pidz = get_program_id(1);
|
||||||
|
// create index ranges
|
||||||
|
int rxm = pidhm % BLOCK;
|
||||||
|
int rbm = pidhm / BLOCK;
|
||||||
|
int rxn[TN] = (0 ... TN) % BLOCK;
|
||||||
|
int rbn[TN] = (0 ... TN) / BLOCK;
|
||||||
|
// extract information from look-up table
|
||||||
|
int* header = LUT + rbm * 2;
|
||||||
|
int size = *(header + 0);
|
||||||
|
int offset = *(header + 1);
|
||||||
|
bool check[TN] = rbn < size;
|
||||||
|
int rbmn[TN] = check ? rbn : size - 1;
|
||||||
|
// block id and column id
|
||||||
|
long blockid [TN] = *(LUT + offset + rbmn*4 + 0);
|
||||||
|
long columnid[TN] = *(LUT + offset + rbmn*4 + 1);
|
||||||
|
long rowid [TN] = *(LUT + offset + rbmn*4 + 2);
|
||||||
|
long headid [TN] = *(LUT + offset + rbmn*4 + 3);
|
||||||
|
// pointers to X
|
||||||
|
TYPE* px[TN] = X + pidz * stride_zx
|
||||||
|
+ blockid * BLOCK * BLOCK
|
||||||
|
+ rxm * BLOCK
|
||||||
|
+ rxn;
|
||||||
|
#ifdef APPLY_RPE
|
||||||
|
// pointers to relative position embedding
|
||||||
|
TYPE* prpe[TN] = RPE + pidz * stride_zrpe
|
||||||
|
+ headid * stride_hrpe
|
||||||
|
+ columnid * BLOCK
|
||||||
|
+ rowid * BLOCK * stride_srpe
|
||||||
|
+ rxm * stride_srpe
|
||||||
|
+ rxn;
|
||||||
|
#endif
|
||||||
|
#ifdef APPLY_KP_MASK
|
||||||
|
// pointers to key padding mask
|
||||||
|
TYPE* pkp_m[TN] = KP_M + pidz * stride_zkpm
|
||||||
|
+ columnid * BLOCK
|
||||||
|
+ rxn;
|
||||||
|
#endif
|
||||||
|
#ifdef APPLY_ATTN_MASK
|
||||||
|
// pointers to attention mask
|
||||||
|
TYPE* pattn_m[TN] = ATTN_M + columnid * BLOCK
|
||||||
|
+ rowid * BLOCK * stride_zattnm
|
||||||
|
+ rxm * stride_zattnm
|
||||||
|
+ rxn;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// load input
|
||||||
|
TYPE x[TN] = check ? *px : -INFINITY;
|
||||||
|
#ifdef APPLY_RPE
|
||||||
|
// load relative position embedding
|
||||||
|
TYPE rpe[TN] = check ? *prpe : 0;
|
||||||
|
#endif
|
||||||
|
#ifdef APPLY_KP_MASK
|
||||||
|
// load key-padding mask
|
||||||
|
TYPE kp_m[TN] = check ? *pkp_m : -INFINITY;
|
||||||
|
#endif
|
||||||
|
#ifdef APPLY_ATTN_MASK
|
||||||
|
// load attention mask
|
||||||
|
TYPE attn_m[TN] = check ? *pattn_m : -INFINITY;
|
||||||
|
#endif
|
||||||
|
// compute softmax in float
|
||||||
|
#ifdef APPLY_RPE
|
||||||
|
float Frpe[TN] = rpe;
|
||||||
|
#endif
|
||||||
|
#ifdef APPLY_KP_MASK
|
||||||
|
float Fkp_m[TN] = kp_m;
|
||||||
|
#endif
|
||||||
|
#ifdef APPLY_ATTN_MASK
|
||||||
|
float Fattn_m[TN] = attn_m;
|
||||||
|
#endif
|
||||||
|
#ifdef KP_MASK_MUL
|
||||||
|
Fkp_m = (Fkp_m == 0) ? (float[TN])-INFINITY : 0;
|
||||||
|
#endif
|
||||||
|
#ifdef ATTN_MASK_MUL
|
||||||
|
Fattn_m = (Fattn_m == 0) ? (float[TN])-INFINITY : 0;
|
||||||
|
#endif
|
||||||
|
float Fx[TN] = x;
|
||||||
|
#ifdef APPLY_SCALE
|
||||||
|
Fx = Fx * scale; // apply scale
|
||||||
|
#endif
|
||||||
|
#ifdef APPLY_RPE
|
||||||
|
Fx = Fx + Frpe; // apply relative position embedding
|
||||||
|
#endif
|
||||||
|
#ifdef APPLY_KP_MASK
|
||||||
|
Fx = Fx + Fkp_m; // apply key padding mask
|
||||||
|
#endif
|
||||||
|
#ifdef APPLY_ATTN_MASK
|
||||||
|
Fx = Fx + Fattn_m; // apply attention mask
|
||||||
|
#endif
|
||||||
|
float Fxmax = Fx[max];
|
||||||
|
float Fy[TN] = exp(Fx - Fxmax);
|
||||||
|
float Fysum = (check ? Fy : 0)[+];
|
||||||
|
// write-back in half/float
|
||||||
|
TYPE y[TN] = Fy;
|
||||||
|
TYPE ysum = Fysum;
|
||||||
|
*?(check)px = y / ysum;
|
||||||
|
}
|
||||||
|
|
||||||
|
__global__ void backward(TYPE * X __readonly __noalias __aligned(16),
|
||||||
|
float scale,
|
||||||
|
TYPE* DX __readonly __noalias __aligned(16),
|
||||||
|
int* LUT,
|
||||||
|
int sizemax,
|
||||||
|
long stride_zx __multipleof(BLOCK),
|
||||||
|
long stride_zdx __multipleof(BLOCK)) {
|
||||||
|
int pidhm = get_program_id(0);
|
||||||
|
int pidz = get_program_id(1);
|
||||||
|
// create index ranges
|
||||||
|
int rxm = pidhm % BLOCK;
|
||||||
|
int rbm = pidhm / BLOCK;
|
||||||
|
int rxn[TN] = (0 ... TN) % BLOCK;
|
||||||
|
int rbn[TN] = (0 ... TN) / BLOCK;
|
||||||
|
// extract information from look-up table
|
||||||
|
int* header = LUT + rbm * 2;
|
||||||
|
int size = *(header + 0);
|
||||||
|
int offset = *(header + 1);
|
||||||
|
// bounds checking on lut
|
||||||
|
bool check[TN] = rbn < size;
|
||||||
|
int rbmn[TN] = check ? rbn : size - 1;
|
||||||
|
// initialize pointers to block-sparse input
|
||||||
|
long blockid[TN] = *(LUT + offset + rbmn*4);
|
||||||
|
TYPE* px[TN] = X + pidz * stride_zx
|
||||||
|
+ blockid * BLOCK * BLOCK
|
||||||
|
+ rxm * BLOCK
|
||||||
|
+ rxn;
|
||||||
|
TYPE* pdx[TN] = DX + pidz * stride_zdx
|
||||||
|
+ blockid * BLOCK * BLOCK
|
||||||
|
+ rxm * BLOCK
|
||||||
|
+ rxn;
|
||||||
|
// compute fused softmax backward
|
||||||
|
TYPE x[TN] = check ? *px : 0;
|
||||||
|
TYPE dx[TN] = check ? *pdx : 0;
|
||||||
|
float Fdx[TN] = dx;
|
||||||
|
float Fx[TN] = x;
|
||||||
|
float Fxdx[TN] = Fdx*Fx;
|
||||||
|
float Fxdxsum = Fxdx[+];
|
||||||
|
float Fy[TN] = Fx * (Fdx - Fxdxsum) * scale;
|
||||||
|
TYPE y[TN] = Fy;
|
||||||
|
// write-back
|
||||||
|
*? (check)pdx = y;
|
||||||
|
}
|
177
python/triton/ops/blocksparse/softmax.py
Normal file
177
python/triton/ops/blocksparse/softmax.py
Normal file
@@ -0,0 +1,177 @@
|
|||||||
|
import triton
|
||||||
|
import torch
|
||||||
|
import os
|
||||||
|
|
||||||
|
fwd_src = triton.read(os.path.join(os.path.dirname(__file__), 'softmax.c'),
|
||||||
|
kernel_names=['forward'])
|
||||||
|
fwd_kernels = dict()
|
||||||
|
|
||||||
|
|
||||||
|
bwd_src = triton.read(os.path.join(os.path.dirname(__file__), 'softmax.c'),
|
||||||
|
kernel_names=['backward'])
|
||||||
|
bwd_kernels = dict()
|
||||||
|
|
||||||
|
|
||||||
|
class _softmax(torch.autograd.Function):
|
||||||
|
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def make_lut(layout, block, device):
|
||||||
|
_empty = torch.tensor([], dtype=torch.int64, device=layout.device)
|
||||||
|
sizes = _empty.clone()
|
||||||
|
# sizes along rows
|
||||||
|
for h in range(layout.shape[0]):
|
||||||
|
sizes = torch.cat((sizes, layout[h,:,:].sum(-1)))
|
||||||
|
# offsets in block format
|
||||||
|
offsets = torch.zeros_like(sizes)
|
||||||
|
offsets[1:] = torch.cumsum(sizes[:-1], dim=0)
|
||||||
|
# block indices
|
||||||
|
idx = torch.arange(layout.sum())
|
||||||
|
head = layout.nonzero(as_tuple=False)[:, 0]
|
||||||
|
rows = layout.nonzero(as_tuple=False)[:, 1]
|
||||||
|
columns = layout.nonzero(as_tuple=False)[:, 2]
|
||||||
|
core = torch.stack((idx, columns, rows, head), dim=1).view(-1)
|
||||||
|
# construct look-up table
|
||||||
|
offsets = offsets*4 + 2*sizes.numel()
|
||||||
|
header = torch.stack((sizes, offsets), dim=1).view(-1)
|
||||||
|
lut = torch.cat((header, core)).type(torch.int32).to(device)
|
||||||
|
return lut, int(sizes.max())
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def make_kernel(cache, src, max_k, device, dtype, block, apply_scale, apply_rpe, apply_kp_mask, apply_attn_mask, kp_mask_mode, attn_mask_mode):
|
||||||
|
if max_k >= 32768:
|
||||||
|
raise NotImplementedError('Reductions larger than 32768 elements '\
|
||||||
|
'are not yet implemented')
|
||||||
|
num_warps = 4 if max_k < 512 else (8 if max_k < 2048 else 16)
|
||||||
|
pad = num_warps * 32 * 2
|
||||||
|
TN = (int(max_k) + pad-1)//pad * pad
|
||||||
|
# just-in-time compile kernel
|
||||||
|
key = (block, device, dtype, num_warps, TN, apply_scale, apply_rpe, apply_kp_mask, apply_attn_mask, kp_mask_mode, attn_mask_mode)
|
||||||
|
if key not in cache:
|
||||||
|
defines = {'TM': [1], 'TN': [TN], 'TYPE': dtype, 'BLOCK': block,
|
||||||
|
'INFINITY': {torch.float32: 'F32_INFINITY',
|
||||||
|
torch.float16: 'F16_INFINITY'}[dtype]}
|
||||||
|
if apply_scale:
|
||||||
|
defines['APPLY_SCALE'] = True
|
||||||
|
if apply_rpe:
|
||||||
|
defines['APPLY_RPE'] = True
|
||||||
|
if apply_kp_mask:
|
||||||
|
defines['APPLY_KP_MASK'] = True
|
||||||
|
if kp_mask_mode == 'mul':
|
||||||
|
defines['KP_MASK_MUL'] = True
|
||||||
|
if apply_attn_mask:
|
||||||
|
defines['APPLY_ATTN_MASK'] = True
|
||||||
|
if attn_mask_mode == 'mul':
|
||||||
|
defines['ATTN_MASK_MUL'] = True
|
||||||
|
kernel = triton.kernel(src, device=device, defines=defines, num_warps=[num_warps])
|
||||||
|
cache[key] = kernel
|
||||||
|
return cache[key]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, x, scale, rpe, key_padding_mask, attn_mask, kp_mask_mode, attn_mask_mode,
|
||||||
|
spdims, block, lut, maxlut, bench, time):
|
||||||
|
apply_scale = False if scale == 1.0 else True
|
||||||
|
|
||||||
|
# handle None rpe
|
||||||
|
if rpe is None:
|
||||||
|
apply_rpe = False
|
||||||
|
stride_zrpe, stride_hrpe, stride_srpe = 0, 0, 0
|
||||||
|
rpe = torch.empty(0, dtype=x.dtype, device=x.device)
|
||||||
|
else:
|
||||||
|
apply_rpe = True
|
||||||
|
stride_zrpe, stride_hrpe, stride_srpe = rpe.stride(0), rpe.stride(1), rpe.stride(2)
|
||||||
|
|
||||||
|
# handle None key_padding_mask
|
||||||
|
if key_padding_mask is None:
|
||||||
|
apply_kp_mask = False
|
||||||
|
stride_zkpm = 0
|
||||||
|
key_padding_mask = torch.empty(0, dtype=x.dtype, device=x.device)
|
||||||
|
else:
|
||||||
|
apply_kp_mask = True
|
||||||
|
stride_zkpm = key_padding_mask.stride(0)
|
||||||
|
|
||||||
|
# handle None attention_mask
|
||||||
|
if attn_mask is None:
|
||||||
|
apply_attn_mask = False
|
||||||
|
stride_zattnm = 0
|
||||||
|
attn_mask = torch.empty(0, dtype=x.dtype, device=x.device)
|
||||||
|
else:
|
||||||
|
apply_attn_mask = True
|
||||||
|
stride_zattnm = attn_mask.stride(0)
|
||||||
|
|
||||||
|
|
||||||
|
# run kernel
|
||||||
|
kernel = _softmax.make_kernel(fwd_kernels, fwd_src, maxlut*block, x.device, x.dtype, block,
|
||||||
|
apply_scale, apply_rpe, apply_kp_mask, apply_attn_mask,
|
||||||
|
kp_mask_mode, attn_mask_mode)
|
||||||
|
M = x.shape[0]
|
||||||
|
grid = lambda opt: [triton.cdiv(spdims[0] * spdims[1] * block, opt.TM), M]
|
||||||
|
|
||||||
|
# run kernel
|
||||||
|
time[0] = kernel(x.data_ptr(), scale, lut.data_ptr(), rpe.data_ptr(), key_padding_mask.data_ptr(), attn_mask.data_ptr(),\
|
||||||
|
maxlut,\
|
||||||
|
x.stride(0),\
|
||||||
|
stride_zrpe, stride_hrpe, stride_srpe,\
|
||||||
|
stride_zkpm, stride_zattnm,\
|
||||||
|
grid=grid)
|
||||||
|
# save to context
|
||||||
|
ctx.mark_dirty(x)
|
||||||
|
ctx.save_for_backward(x, lut)
|
||||||
|
ctx.spdims = spdims
|
||||||
|
ctx.block = block
|
||||||
|
ctx.maxlut = maxlut
|
||||||
|
ctx.scale = scale
|
||||||
|
ctx.apply_scale = apply_scale
|
||||||
|
ctx.apply_rpe = apply_rpe
|
||||||
|
ctx.apply_kp_mask = apply_kp_mask
|
||||||
|
ctx.apply_attn_mask = apply_attn_mask
|
||||||
|
ctx.kp_mask_mode = kp_mask_mode
|
||||||
|
ctx.attn_mask_mode = attn_mask_mode
|
||||||
|
return x
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, dx):
|
||||||
|
# retrieve from context
|
||||||
|
x, lut = ctx.saved_tensors
|
||||||
|
# run kernel
|
||||||
|
kernel = _softmax.make_kernel(bwd_kernels, bwd_src, ctx.maxlut*ctx.block, x.device, x.dtype, ctx.block,
|
||||||
|
ctx.apply_scale, ctx.apply_rpe, ctx.apply_kp_mask, ctx.apply_attn_mask,
|
||||||
|
ctx.kp_mask_mode, ctx.attn_mask_mode)
|
||||||
|
M = x.shape[0]
|
||||||
|
grid = lambda opt: [triton.cdiv(ctx.spdims[0] * ctx.spdims[1] * ctx.block, opt.TM), M]
|
||||||
|
kernel(x.data_ptr(), ctx.scale, dx.data_ptr(), lut.data_ptr(), ctx.maxlut, x.stride(0), dx.stride(0), grid=grid)
|
||||||
|
return dx, None, None, None, None, None, None, None, None, None, None, None, None, None, None
|
||||||
|
|
||||||
|
class softmax:
|
||||||
|
|
||||||
|
apply_softmax = _softmax.apply
|
||||||
|
|
||||||
|
def make_lut(self, device):
|
||||||
|
key = (device, )
|
||||||
|
if key not in self.lut_cache:
|
||||||
|
self.lut_cache[key] = _softmax.make_lut(self.layout, self.block, device)
|
||||||
|
return self.lut_cache[key]
|
||||||
|
|
||||||
|
def __init__(self, layout, block, bench = False):
|
||||||
|
self.spdims = layout.shape
|
||||||
|
self.layout = layout
|
||||||
|
self.block = block
|
||||||
|
self.bench = bench
|
||||||
|
self.lut_cache = dict()
|
||||||
|
|
||||||
|
def __call__(self, x, scale = 1., rpe = None, key_padding_mask = None, attn_mask = None,
|
||||||
|
key_padding_mask_mode='add', attn_mask_mode='add'):
|
||||||
|
time_y = [None]
|
||||||
|
if rpe is not None and rpe.dtype != x.dtype:
|
||||||
|
raise ValueError('relative position embedding must be %s' % x.dtype)
|
||||||
|
if attn_mask is not None and attn_mask.dtype != x.dtype:
|
||||||
|
raise ValueError('Attention mask must be %s' % x.dtype)
|
||||||
|
if key_padding_mask is not None and key_padding_mask.dtype != x.dtype:
|
||||||
|
raise ValueError('Key padding mask must be %s' % x.dtype)
|
||||||
|
lut, maxlut = self.make_lut(x.device)
|
||||||
|
x = softmax.apply_softmax(x, scale, rpe, key_padding_mask, attn_mask,
|
||||||
|
key_padding_mask_mode, attn_mask_mode,
|
||||||
|
self.spdims, self.block,
|
||||||
|
lut,
|
||||||
|
maxlut, self.bench, time_y)
|
||||||
|
return x
|
Reference in New Issue
Block a user