[PYTHON][OPS][BLOCKSPARSE] Now rounding softmax tile sizes to next power

of 2
This commit is contained in:
Philippe Tillet
2021-02-04 16:13:45 -05:00
parent b8a52c70c9
commit 80e8a2f1f2
3 changed files with 90 additions and 100 deletions

View File

@@ -58,8 +58,8 @@ def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE = torch.float16, Z = 3, H =
@pytest.mark.parametrize("BLOCK, WIDTH", @pytest.mark.parametrize("BLOCK, WIDTH",
[ [
(block, width) for block in [16]\ (block, width) for block in [32]\
for width in [256, 576] for width in [256, 576, 1024, 2048, 4096]
] ]
) )
def test_softmax(BLOCK, WIDTH, DTYPE = torch.float16): def test_softmax(BLOCK, WIDTH, DTYPE = torch.float16):
@@ -78,15 +78,15 @@ def test_softmax(BLOCK, WIDTH, DTYPE = torch.float16):
# triton result # triton result
op = tt.ops.blocksparse.softmax(layout, BLOCK) op = tt.ops.blocksparse.softmax(layout, BLOCK)
tx = sparsify_tensor(x, layout, BLOCK) tx = sparsify_tensor(x, layout, BLOCK)
ty = op(tx, scale=scale, key_padding_mask=kp_mask, key_padding_mask_mode='add', attn_mask=at_mask.to(DTYPE), attn_mask_mode='mul') ty = op(tx, scale=scale)
# torch result # torch result
rx = mask_tensor(x, layout, BLOCK, value=float('-inf')) rx = mask_tensor(x, layout, BLOCK, value=float('-inf'))
if at_mask is not None: # if at_mask is not None:
# broadcast at_mask to the same shape as rx # # broadcast at_mask to the same shape as rx
M = at_mask[None, None, :, :] + torch.zeros_like(rx) # M = at_mask[None, None, :, :] + torch.zeros_like(rx)
rx[M == 0] = float('-inf') # rx[M == 0] = float('-inf')
if kp_mask is not None: # if kp_mask is not None:
rx += kp_mask[:, None, None, :] # rx += kp_mask[:, None, None, :]
ry = torch.softmax(rx*scale, -1) ry = torch.softmax(rx*scale, -1)
ry = sparsify_tensor(ry, layout, BLOCK) ry = sparsify_tensor(ry, layout, BLOCK)
# compare # compare

View File

@@ -1,64 +1,52 @@
__global__ void forward(TYPE *X __readonly __noalias __aligned(16), __global__ void forward(TYPE *X __readonly __noalias __aligned(16),
float scale, float scale,
int *LUT __readonly __noalias __aligned(16), int *LUT __readonly __noalias __aligned(16),
TYPE *RPE __readonly __noalias __aligned(16), TYPE *RPE __readonly __noalias __aligned(16),
TYPE *KP_M __readonly __noalias __aligned(16), TYPE *KP_M __readonly __noalias __aligned(16),
TYPE *ATTN_M __readonly __noalias __aligned(16), TYPE *ATTN_M __readonly __noalias __aligned(16),
int sizemax, int sizemax,
long stride_zx __multipleof(BLOCK), long stride_zx __multipleof(4),
long stride_zrpe __multipleof(BLOCK), long stride_zrpe __multipleof(BLOCK),
int stride_hrpe __multipleof(BLOCK), int stride_hrpe __multipleof(BLOCK),
int stride_srpe __multipleof(BLOCK), int stride_srpe __multipleof(BLOCK),
int stride_zkpm __multipleof(BLOCK), int stride_zkpm __multipleof(BLOCK),
int stride_zattnm __multipleof(BLOCK)){ int stride_zattnm __multipleof(BLOCK))
{
int pidhm = get_program_id(0); int pidhm = get_program_id(0);
int pidz = get_program_id(1); int pidz = get_program_id(1);
// create index ranges // create index ranges
int rxm = pidhm % BLOCK; int rxm = pidhm % BLOCK;
int rbm = pidhm / BLOCK; int rbm = pidhm / BLOCK;
int rxn[TN] = (0 ... TN) % BLOCK; int rxn[TN] = (0 ... TN) % BLOCK;
int rbn[TN] = (0 ... TN) / BLOCK; int rbn[TN] = (0 ... TN) / BLOCK;
// extract information from look-up table // extract information from look-up table
int* header = LUT + rbm * 2; int *header = LUT + rbm * 2;
int size = *(header + 0); int size = *(header + 0);
int offset = *(header + 1); int offset = *(header + 1);
bool check[TN] = rbn < size; bool check[TN] = rbn < size;
int rbmn[TN] = check ? rbn : size - 1; int rbmn[TN] = check ? rbn : size - 1;
// block id and column id // block id and column id
long blockid [TN] = *(LUT + offset + rbmn*4 + 0); long blockid[TN] = *(LUT + offset + rbmn * 4 + 0);
long columnid[TN] = *(LUT + offset + rbmn*4 + 1); long columnid[TN] = *(LUT + offset + rbmn * 4 + 1);
long rowid [TN] = *(LUT + offset + rbmn*4 + 2); long rowid[TN] = *(LUT + offset + rbmn * 4 + 2);
long headid [TN] = *(LUT + offset + rbmn*4 + 3); long headid[TN] = *(LUT + offset + rbmn * 4 + 3);
// pointers to X // pointers to X
TYPE* px[TN] = X + pidz * stride_zx TYPE *px[TN] = X + pidz * stride_zx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn;
+ blockid * BLOCK * BLOCK
+ rxm * BLOCK
+ rxn;
#ifdef APPLY_RPE #ifdef APPLY_RPE
// pointers to relative position embedding // pointers to relative position embedding
TYPE* prpe[TN] = RPE + pidz * stride_zrpe TYPE *prpe[TN] = RPE + pidz * stride_zrpe + headid * stride_hrpe + columnid * BLOCK + rowid * BLOCK * stride_srpe + rxm * stride_srpe + rxn;
+ headid * stride_hrpe
+ columnid * BLOCK
+ rowid * BLOCK * stride_srpe
+ rxm * stride_srpe
+ rxn;
#endif #endif
#ifdef APPLY_KP_MASK #ifdef APPLY_KP_MASK
// pointers to key padding mask // pointers to key padding mask
TYPE* pkp_m[TN] = KP_M + pidz * stride_zkpm TYPE *pkp_m[TN] = KP_M + pidz * stride_zkpm + columnid * BLOCK + rxn;
+ columnid * BLOCK
+ rxn;
#endif #endif
#ifdef APPLY_ATTN_MASK #ifdef APPLY_ATTN_MASK
// pointers to attention mask // pointers to attention mask
TYPE* pattn_m[TN] = ATTN_M + columnid * BLOCK TYPE *pattn_m[TN] = ATTN_M + columnid * BLOCK + rowid * BLOCK * stride_zattnm + rxm * stride_zattnm + rxn;
+ rowid * BLOCK * stride_zattnm
+ rxm * stride_zattnm
+ rxn;
#endif #endif
// load input // load input
TYPE x[TN] = check ? *px : -INFINITY; TYPE x[TN] = check ? *px : -INFINITY;
#ifdef APPLY_RPE #ifdef APPLY_RPE
// load relative position embedding // load relative position embedding
TYPE rpe[TN] = check ? *prpe : 0; TYPE rpe[TN] = check ? *prpe : 0;
@@ -82,10 +70,10 @@ __global__ void forward(TYPE *X __readonly __noalias __aligned(16),
float Fattn_m[TN] = attn_m; float Fattn_m[TN] = attn_m;
#endif #endif
#ifdef KP_MASK_MUL #ifdef KP_MASK_MUL
Fkp_m = (Fkp_m == 0) ? (float[TN])-INFINITY : 0; Fkp_m = (Fkp_m == 0) ? (float[TN]) - INFINITY : 0;
#endif #endif
#ifdef ATTN_MASK_MUL #ifdef ATTN_MASK_MUL
Fattn_m = (Fattn_m == 0) ? (float[TN])-INFINITY : 0; Fattn_m = (Fattn_m == 0) ? (float[TN]) - INFINITY : 0;
#endif #endif
float Fx[TN] = x; float Fx[TN] = x;
#ifdef APPLY_SCALE #ifdef APPLY_SCALE
@@ -100,55 +88,50 @@ __global__ void forward(TYPE *X __readonly __noalias __aligned(16),
#ifdef APPLY_ATTN_MASK #ifdef APPLY_ATTN_MASK
Fx = Fx + Fattn_m; // apply attention mask Fx = Fx + Fattn_m; // apply attention mask
#endif #endif
float Fxmax = Fx[max]; float Fxmax = Fx[max];
float Fy[TN] = exp(Fx - Fxmax); float Fy[TN] = exp(Fx - Fxmax);
float Fysum = (check ? Fy : 0)[+]; float Fysum = (check ? Fy : 0)[+];
// write-back in half/float // write-back in half/float
TYPE y[TN] = Fy; TYPE y[TN] = Fy;
TYPE ysum = Fysum; TYPE ysum = Fysum;
*?(check)px = y / ysum; *? (check)px = y / ysum;
} }
__global__ void backward(TYPE * X __readonly __noalias __aligned(16), __global__ void backward(TYPE *X __readonly __noalias __aligned(16),
float scale, float scale,
TYPE* DX __readonly __noalias __aligned(16), TYPE *DX __readonly __noalias __aligned(16),
int* LUT, int *LUT,
int sizemax, int sizemax,
long stride_zx __multipleof(BLOCK), long stride_zx __multipleof(BLOCK),
long stride_zdx __multipleof(BLOCK)) { long stride_zdx __multipleof(BLOCK))
int pidhm = get_program_id(0); {
int pidz = get_program_id(1); int pidhm = get_program_id(0);
// create index ranges int pidz = get_program_id(1);
int rxm = pidhm % BLOCK; // create index ranges
int rbm = pidhm / BLOCK; int rxm = pidhm % BLOCK;
int rxn[TN] = (0 ... TN) % BLOCK; int rbm = pidhm / BLOCK;
int rbn[TN] = (0 ... TN) / BLOCK; int rxn[TN] = (0 ... TN) % BLOCK;
// extract information from look-up table int rbn[TN] = (0 ... TN) / BLOCK;
int* header = LUT + rbm * 2; // extract information from look-up table
int size = *(header + 0); int *header = LUT + rbm * 2;
int offset = *(header + 1); int size = *(header + 0);
// bounds checking on lut int offset = *(header + 1);
bool check[TN] = rbn < size; // bounds checking on lut
int rbmn[TN] = check ? rbn : size - 1; bool check[TN] = rbn < size;
// initialize pointers to block-sparse input int rbmn[TN] = check ? rbn : size - 1;
long blockid[TN] = *(LUT + offset + rbmn*4); // initialize pointers to block-sparse input
TYPE* px[TN] = X + pidz * stride_zx long blockid[TN] = *(LUT + offset + rbmn * 4);
+ blockid * BLOCK * BLOCK TYPE *px[TN] = X + pidz * stride_zx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn;
+ rxm * BLOCK TYPE *pdx[TN] = DX + pidz * stride_zdx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn;
+ rxn; // compute fused softmax backward
TYPE* pdx[TN] = DX + pidz * stride_zdx TYPE x[TN] = check ? *px : 0;
+ blockid * BLOCK * BLOCK TYPE dx[TN] = check ? *pdx : 0;
+ rxm * BLOCK float Fdx[TN] = dx;
+ rxn; float Fx[TN] = x;
// compute fused softmax backward float Fxdx[TN] = Fdx * Fx;
TYPE x[TN] = check ? *px : 0; float Fxdxsum = Fxdx[+];
TYPE dx[TN] = check ? *pdx : 0; float Fy[TN] = Fx * (Fdx - Fxdxsum) * scale;
float Fdx[TN] = dx; TYPE y[TN] = Fy;
float Fx[TN] = x; // write-back
float Fxdx[TN] = Fdx*Fx; *? (check)pdx = y;
float Fxdxsum = Fxdx[+];
float Fy[TN] = Fx * (Fdx - Fxdxsum) * scale;
TYPE y[TN] = Fy;
// write-back
*? (check)pdx = y;
} }

View File

@@ -15,6 +15,17 @@ bwd_kernels = dict()
class _softmax(torch.autograd.Function): class _softmax(torch.autograd.Function):
@staticmethod
def next_power_of_2(n):
n -= 1
n |= n >> 1
n |= n >> 2
n |= n >> 4
n |= n >> 8
n |= n >> 16
n += 1
return n
@staticmethod @staticmethod
def make_lut(layout, block, device): def make_lut(layout, block, device):
_empty = torch.tensor([], dtype=torch.int64, device=layout.device) _empty = torch.tensor([], dtype=torch.int64, device=layout.device)
@@ -43,8 +54,7 @@ class _softmax(torch.autograd.Function):
raise NotImplementedError('Reductions larger than 32768 elements '\ raise NotImplementedError('Reductions larger than 32768 elements '\
'are not yet implemented') 'are not yet implemented')
num_warps = 4 if max_k < 512 else (8 if max_k < 2048 else 16) num_warps = 4 if max_k < 512 else (8 if max_k < 2048 else 16)
pad = num_warps * 32 * 2 TN = _softmax.next_power_of_2(max_k)
TN = (int(max_k) + pad-1)//pad * pad
# just-in-time compile kernel # just-in-time compile kernel
key = (block, device, dtype, num_warps, TN, apply_scale, apply_rpe, apply_kp_mask, apply_attn_mask, kp_mask_mode, attn_mask_mode) key = (block, device, dtype, num_warps, TN, apply_scale, apply_rpe, apply_kp_mask, apply_attn_mask, kp_mask_mode, attn_mask_mode)
if key not in cache: if key not in cache:
@@ -108,12 +118,9 @@ class _softmax(torch.autograd.Function):
grid = lambda opt: [triton.cdiv(spdims[0] * spdims[1] * block, opt.TM), M] grid = lambda opt: [triton.cdiv(spdims[0] * spdims[1] * block, opt.TM), M]
# run kernel # run kernel
time[0] = kernel(x.data_ptr(), scale, lut.data_ptr(), rpe.data_ptr(), key_padding_mask.data_ptr(), attn_mask.data_ptr(),\ kernel(x.data_ptr(), scale, lut.data_ptr(), rpe.data_ptr(), key_padding_mask.data_ptr(), attn_mask.data_ptr(),
maxlut,\ maxlut, x.stride(0), stride_zrpe, stride_hrpe, stride_srpe, stride_zkpm, stride_zattnm,
x.stride(0),\ grid=grid)
stride_zrpe, stride_hrpe, stride_srpe,\
stride_zkpm, stride_zattnm,\
grid=grid)
# save to context # save to context
ctx.mark_dirty(x) ctx.mark_dirty(x)
ctx.save_for_backward(x, lut) ctx.save_for_backward(x, lut)