From 80e8a2f1f2a76a25281b20112d3481c2310700ef Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Thu, 4 Feb 2021 16:13:45 -0500 Subject: [PATCH] [PYTHON][OPS][BLOCKSPARSE] Now rounding softmax tile sizes to next power of 2 --- python/tests/test_blocksparse.py | 18 +-- python/triton/ops/blocksparse/softmax.c | 149 ++++++++++------------- python/triton/ops/blocksparse/softmax.py | 23 ++-- 3 files changed, 90 insertions(+), 100 deletions(-) diff --git a/python/tests/test_blocksparse.py b/python/tests/test_blocksparse.py index b218e9543..f3686227b 100644 --- a/python/tests/test_blocksparse.py +++ b/python/tests/test_blocksparse.py @@ -58,8 +58,8 @@ def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE = torch.float16, Z = 3, H = @pytest.mark.parametrize("BLOCK, WIDTH", [ - (block, width) for block in [16]\ - for width in [256, 576] + (block, width) for block in [32]\ + for width in [256, 576, 1024, 2048, 4096] ] ) def test_softmax(BLOCK, WIDTH, DTYPE = torch.float16): @@ -78,15 +78,15 @@ def test_softmax(BLOCK, WIDTH, DTYPE = torch.float16): # triton result op = tt.ops.blocksparse.softmax(layout, BLOCK) tx = sparsify_tensor(x, layout, BLOCK) - ty = op(tx, scale=scale, key_padding_mask=kp_mask, key_padding_mask_mode='add', attn_mask=at_mask.to(DTYPE), attn_mask_mode='mul') + ty = op(tx, scale=scale) # torch result rx = mask_tensor(x, layout, BLOCK, value=float('-inf')) - if at_mask is not None: - # broadcast at_mask to the same shape as rx - M = at_mask[None, None, :, :] + torch.zeros_like(rx) - rx[M == 0] = float('-inf') - if kp_mask is not None: - rx += kp_mask[:, None, None, :] + # if at_mask is not None: + # # broadcast at_mask to the same shape as rx + # M = at_mask[None, None, :, :] + torch.zeros_like(rx) + # rx[M == 0] = float('-inf') + # if kp_mask is not None: + # rx += kp_mask[:, None, None, :] ry = torch.softmax(rx*scale, -1) ry = sparsify_tensor(ry, layout, BLOCK) # compare diff --git a/python/triton/ops/blocksparse/softmax.c b/python/triton/ops/blocksparse/softmax.c index 9b68f495b..625f4a6ac 100644 --- a/python/triton/ops/blocksparse/softmax.c +++ b/python/triton/ops/blocksparse/softmax.c @@ -1,64 +1,52 @@ -__global__ void forward(TYPE *X __readonly __noalias __aligned(16), +__global__ void forward(TYPE *X __readonly __noalias __aligned(16), float scale, - int *LUT __readonly __noalias __aligned(16), - TYPE *RPE __readonly __noalias __aligned(16), - TYPE *KP_M __readonly __noalias __aligned(16), + int *LUT __readonly __noalias __aligned(16), + TYPE *RPE __readonly __noalias __aligned(16), + TYPE *KP_M __readonly __noalias __aligned(16), TYPE *ATTN_M __readonly __noalias __aligned(16), - int sizemax, - long stride_zx __multipleof(BLOCK), + int sizemax, + long stride_zx __multipleof(4), long stride_zrpe __multipleof(BLOCK), int stride_hrpe __multipleof(BLOCK), int stride_srpe __multipleof(BLOCK), - int stride_zkpm __multipleof(BLOCK), - int stride_zattnm __multipleof(BLOCK)){ + int stride_zkpm __multipleof(BLOCK), + int stride_zattnm __multipleof(BLOCK)) +{ int pidhm = get_program_id(0); int pidz = get_program_id(1); // create index ranges - int rxm = pidhm % BLOCK; - int rbm = pidhm / BLOCK; + int rxm = pidhm % BLOCK; + int rbm = pidhm / BLOCK; int rxn[TN] = (0 ... TN) % BLOCK; int rbn[TN] = (0 ... TN) / BLOCK; // extract information from look-up table - int* header = LUT + rbm * 2; - int size = *(header + 0); - int offset = *(header + 1); + int *header = LUT + rbm * 2; + int size = *(header + 0); + int offset = *(header + 1); bool check[TN] = rbn < size; - int rbmn[TN] = check ? rbn : size - 1; + int rbmn[TN] = check ? rbn : size - 1; // block id and column id - long blockid [TN] = *(LUT + offset + rbmn*4 + 0); - long columnid[TN] = *(LUT + offset + rbmn*4 + 1); - long rowid [TN] = *(LUT + offset + rbmn*4 + 2); - long headid [TN] = *(LUT + offset + rbmn*4 + 3); + long blockid[TN] = *(LUT + offset + rbmn * 4 + 0); + long columnid[TN] = *(LUT + offset + rbmn * 4 + 1); + long rowid[TN] = *(LUT + offset + rbmn * 4 + 2); + long headid[TN] = *(LUT + offset + rbmn * 4 + 3); // pointers to X - TYPE* px[TN] = X + pidz * stride_zx - + blockid * BLOCK * BLOCK - + rxm * BLOCK - + rxn; + TYPE *px[TN] = X + pidz * stride_zx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn; #ifdef APPLY_RPE // pointers to relative position embedding - TYPE* prpe[TN] = RPE + pidz * stride_zrpe - + headid * stride_hrpe - + columnid * BLOCK - + rowid * BLOCK * stride_srpe - + rxm * stride_srpe - + rxn; + TYPE *prpe[TN] = RPE + pidz * stride_zrpe + headid * stride_hrpe + columnid * BLOCK + rowid * BLOCK * stride_srpe + rxm * stride_srpe + rxn; #endif #ifdef APPLY_KP_MASK // pointers to key padding mask - TYPE* pkp_m[TN] = KP_M + pidz * stride_zkpm - + columnid * BLOCK - + rxn; + TYPE *pkp_m[TN] = KP_M + pidz * stride_zkpm + columnid * BLOCK + rxn; #endif #ifdef APPLY_ATTN_MASK // pointers to attention mask - TYPE* pattn_m[TN] = ATTN_M + columnid * BLOCK - + rowid * BLOCK * stride_zattnm - + rxm * stride_zattnm - + rxn; + TYPE *pattn_m[TN] = ATTN_M + columnid * BLOCK + rowid * BLOCK * stride_zattnm + rxm * stride_zattnm + rxn; #endif - + // load input - TYPE x[TN] = check ? *px : -INFINITY; + TYPE x[TN] = check ? *px : -INFINITY; #ifdef APPLY_RPE // load relative position embedding TYPE rpe[TN] = check ? *prpe : 0; @@ -82,10 +70,10 @@ __global__ void forward(TYPE *X __readonly __noalias __aligned(16), float Fattn_m[TN] = attn_m; #endif #ifdef KP_MASK_MUL - Fkp_m = (Fkp_m == 0) ? (float[TN])-INFINITY : 0; + Fkp_m = (Fkp_m == 0) ? (float[TN]) - INFINITY : 0; #endif #ifdef ATTN_MASK_MUL - Fattn_m = (Fattn_m == 0) ? (float[TN])-INFINITY : 0; + Fattn_m = (Fattn_m == 0) ? (float[TN]) - INFINITY : 0; #endif float Fx[TN] = x; #ifdef APPLY_SCALE @@ -100,55 +88,50 @@ __global__ void forward(TYPE *X __readonly __noalias __aligned(16), #ifdef APPLY_ATTN_MASK Fx = Fx + Fattn_m; // apply attention mask #endif - float Fxmax = Fx[max]; + float Fxmax = Fx[max]; float Fy[TN] = exp(Fx - Fxmax); float Fysum = (check ? Fy : 0)[+]; // write-back in half/float TYPE y[TN] = Fy; TYPE ysum = Fysum; - *?(check)px = y / ysum; + *? (check)px = y / ysum; } -__global__ void backward(TYPE * X __readonly __noalias __aligned(16), - float scale, - TYPE* DX __readonly __noalias __aligned(16), - int* LUT, - int sizemax, - long stride_zx __multipleof(BLOCK), - long stride_zdx __multipleof(BLOCK)) { - int pidhm = get_program_id(0); - int pidz = get_program_id(1); - // create index ranges - int rxm = pidhm % BLOCK; - int rbm = pidhm / BLOCK; - int rxn[TN] = (0 ... TN) % BLOCK; - int rbn[TN] = (0 ... TN) / BLOCK; - // extract information from look-up table - int* header = LUT + rbm * 2; - int size = *(header + 0); - int offset = *(header + 1); - // bounds checking on lut - bool check[TN] = rbn < size; - int rbmn[TN] = check ? rbn : size - 1; - // initialize pointers to block-sparse input - long blockid[TN] = *(LUT + offset + rbmn*4); - TYPE* px[TN] = X + pidz * stride_zx - + blockid * BLOCK * BLOCK - + rxm * BLOCK - + rxn; - TYPE* pdx[TN] = DX + pidz * stride_zdx - + blockid * BLOCK * BLOCK - + rxm * BLOCK - + rxn; - // compute fused softmax backward - TYPE x[TN] = check ? *px : 0; - TYPE dx[TN] = check ? *pdx : 0; - float Fdx[TN] = dx; - float Fx[TN] = x; - float Fxdx[TN] = Fdx*Fx; - float Fxdxsum = Fxdx[+]; - float Fy[TN] = Fx * (Fdx - Fxdxsum) * scale; - TYPE y[TN] = Fy; - // write-back - *? (check)pdx = y; +__global__ void backward(TYPE *X __readonly __noalias __aligned(16), + float scale, + TYPE *DX __readonly __noalias __aligned(16), + int *LUT, + int sizemax, + long stride_zx __multipleof(BLOCK), + long stride_zdx __multipleof(BLOCK)) +{ + int pidhm = get_program_id(0); + int pidz = get_program_id(1); + // create index ranges + int rxm = pidhm % BLOCK; + int rbm = pidhm / BLOCK; + int rxn[TN] = (0 ... TN) % BLOCK; + int rbn[TN] = (0 ... TN) / BLOCK; + // extract information from look-up table + int *header = LUT + rbm * 2; + int size = *(header + 0); + int offset = *(header + 1); + // bounds checking on lut + bool check[TN] = rbn < size; + int rbmn[TN] = check ? rbn : size - 1; + // initialize pointers to block-sparse input + long blockid[TN] = *(LUT + offset + rbmn * 4); + TYPE *px[TN] = X + pidz * stride_zx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn; + TYPE *pdx[TN] = DX + pidz * stride_zdx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn; + // compute fused softmax backward + TYPE x[TN] = check ? *px : 0; + TYPE dx[TN] = check ? *pdx : 0; + float Fdx[TN] = dx; + float Fx[TN] = x; + float Fxdx[TN] = Fdx * Fx; + float Fxdxsum = Fxdx[+]; + float Fy[TN] = Fx * (Fdx - Fxdxsum) * scale; + TYPE y[TN] = Fy; + // write-back + *? (check)pdx = y; } diff --git a/python/triton/ops/blocksparse/softmax.py b/python/triton/ops/blocksparse/softmax.py index 45cd7bdf9..7d636bc27 100644 --- a/python/triton/ops/blocksparse/softmax.py +++ b/python/triton/ops/blocksparse/softmax.py @@ -15,6 +15,17 @@ bwd_kernels = dict() class _softmax(torch.autograd.Function): + @staticmethod + def next_power_of_2(n): + n -= 1 + n |= n >> 1 + n |= n >> 2 + n |= n >> 4 + n |= n >> 8 + n |= n >> 16 + n += 1 + return n + @staticmethod def make_lut(layout, block, device): _empty = torch.tensor([], dtype=torch.int64, device=layout.device) @@ -43,8 +54,7 @@ class _softmax(torch.autograd.Function): raise NotImplementedError('Reductions larger than 32768 elements '\ 'are not yet implemented') num_warps = 4 if max_k < 512 else (8 if max_k < 2048 else 16) - pad = num_warps * 32 * 2 - TN = (int(max_k) + pad-1)//pad * pad + TN = _softmax.next_power_of_2(max_k) # just-in-time compile kernel key = (block, device, dtype, num_warps, TN, apply_scale, apply_rpe, apply_kp_mask, apply_attn_mask, kp_mask_mode, attn_mask_mode) if key not in cache: @@ -108,12 +118,9 @@ class _softmax(torch.autograd.Function): grid = lambda opt: [triton.cdiv(spdims[0] * spdims[1] * block, opt.TM), M] # run kernel - time[0] = kernel(x.data_ptr(), scale, lut.data_ptr(), rpe.data_ptr(), key_padding_mask.data_ptr(), attn_mask.data_ptr(),\ - maxlut,\ - x.stride(0),\ - stride_zrpe, stride_hrpe, stride_srpe,\ - stride_zkpm, stride_zattnm,\ - grid=grid) + kernel(x.data_ptr(), scale, lut.data_ptr(), rpe.data_ptr(), key_padding_mask.data_ptr(), attn_mask.data_ptr(), + maxlut, x.stride(0), stride_zrpe, stride_hrpe, stride_srpe, stride_zkpm, stride_zattnm, + grid=grid) # save to context ctx.mark_dirty(x) ctx.save_for_backward(x, lut)