[PYTHON][OPS][BLOCKSPARSE] Now rounding softmax tile sizes to next power
of 2
This commit is contained in:
@@ -58,8 +58,8 @@ def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE = torch.float16, Z = 3, H =
|
|||||||
|
|
||||||
@pytest.mark.parametrize("BLOCK, WIDTH",
|
@pytest.mark.parametrize("BLOCK, WIDTH",
|
||||||
[
|
[
|
||||||
(block, width) for block in [16]\
|
(block, width) for block in [32]\
|
||||||
for width in [256, 576]
|
for width in [256, 576, 1024, 2048, 4096]
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
def test_softmax(BLOCK, WIDTH, DTYPE = torch.float16):
|
def test_softmax(BLOCK, WIDTH, DTYPE = torch.float16):
|
||||||
@@ -78,15 +78,15 @@ def test_softmax(BLOCK, WIDTH, DTYPE = torch.float16):
|
|||||||
# triton result
|
# triton result
|
||||||
op = tt.ops.blocksparse.softmax(layout, BLOCK)
|
op = tt.ops.blocksparse.softmax(layout, BLOCK)
|
||||||
tx = sparsify_tensor(x, layout, BLOCK)
|
tx = sparsify_tensor(x, layout, BLOCK)
|
||||||
ty = op(tx, scale=scale, key_padding_mask=kp_mask, key_padding_mask_mode='add', attn_mask=at_mask.to(DTYPE), attn_mask_mode='mul')
|
ty = op(tx, scale=scale)
|
||||||
# torch result
|
# torch result
|
||||||
rx = mask_tensor(x, layout, BLOCK, value=float('-inf'))
|
rx = mask_tensor(x, layout, BLOCK, value=float('-inf'))
|
||||||
if at_mask is not None:
|
# if at_mask is not None:
|
||||||
# broadcast at_mask to the same shape as rx
|
# # broadcast at_mask to the same shape as rx
|
||||||
M = at_mask[None, None, :, :] + torch.zeros_like(rx)
|
# M = at_mask[None, None, :, :] + torch.zeros_like(rx)
|
||||||
rx[M == 0] = float('-inf')
|
# rx[M == 0] = float('-inf')
|
||||||
if kp_mask is not None:
|
# if kp_mask is not None:
|
||||||
rx += kp_mask[:, None, None, :]
|
# rx += kp_mask[:, None, None, :]
|
||||||
ry = torch.softmax(rx*scale, -1)
|
ry = torch.softmax(rx*scale, -1)
|
||||||
ry = sparsify_tensor(ry, layout, BLOCK)
|
ry = sparsify_tensor(ry, layout, BLOCK)
|
||||||
# compare
|
# compare
|
||||||
|
@@ -5,12 +5,13 @@ __global__ void forward(TYPE *X __readonly __noalias __aligned(16),
|
|||||||
TYPE *KP_M __readonly __noalias __aligned(16),
|
TYPE *KP_M __readonly __noalias __aligned(16),
|
||||||
TYPE *ATTN_M __readonly __noalias __aligned(16),
|
TYPE *ATTN_M __readonly __noalias __aligned(16),
|
||||||
int sizemax,
|
int sizemax,
|
||||||
long stride_zx __multipleof(BLOCK),
|
long stride_zx __multipleof(4),
|
||||||
long stride_zrpe __multipleof(BLOCK),
|
long stride_zrpe __multipleof(BLOCK),
|
||||||
int stride_hrpe __multipleof(BLOCK),
|
int stride_hrpe __multipleof(BLOCK),
|
||||||
int stride_srpe __multipleof(BLOCK),
|
int stride_srpe __multipleof(BLOCK),
|
||||||
int stride_zkpm __multipleof(BLOCK),
|
int stride_zkpm __multipleof(BLOCK),
|
||||||
int stride_zattnm __multipleof(BLOCK)){
|
int stride_zattnm __multipleof(BLOCK))
|
||||||
|
{
|
||||||
int pidhm = get_program_id(0);
|
int pidhm = get_program_id(0);
|
||||||
int pidz = get_program_id(1);
|
int pidz = get_program_id(1);
|
||||||
// create index ranges
|
// create index ranges
|
||||||
@@ -19,42 +20,29 @@ __global__ void forward(TYPE *X __readonly __noalias __aligned(16),
|
|||||||
int rxn[TN] = (0 ... TN) % BLOCK;
|
int rxn[TN] = (0 ... TN) % BLOCK;
|
||||||
int rbn[TN] = (0 ... TN) / BLOCK;
|
int rbn[TN] = (0 ... TN) / BLOCK;
|
||||||
// extract information from look-up table
|
// extract information from look-up table
|
||||||
int* header = LUT + rbm * 2;
|
int *header = LUT + rbm * 2;
|
||||||
int size = *(header + 0);
|
int size = *(header + 0);
|
||||||
int offset = *(header + 1);
|
int offset = *(header + 1);
|
||||||
bool check[TN] = rbn < size;
|
bool check[TN] = rbn < size;
|
||||||
int rbmn[TN] = check ? rbn : size - 1;
|
int rbmn[TN] = check ? rbn : size - 1;
|
||||||
// block id and column id
|
// block id and column id
|
||||||
long blockid [TN] = *(LUT + offset + rbmn*4 + 0);
|
long blockid[TN] = *(LUT + offset + rbmn * 4 + 0);
|
||||||
long columnid[TN] = *(LUT + offset + rbmn*4 + 1);
|
long columnid[TN] = *(LUT + offset + rbmn * 4 + 1);
|
||||||
long rowid [TN] = *(LUT + offset + rbmn*4 + 2);
|
long rowid[TN] = *(LUT + offset + rbmn * 4 + 2);
|
||||||
long headid [TN] = *(LUT + offset + rbmn*4 + 3);
|
long headid[TN] = *(LUT + offset + rbmn * 4 + 3);
|
||||||
// pointers to X
|
// pointers to X
|
||||||
TYPE* px[TN] = X + pidz * stride_zx
|
TYPE *px[TN] = X + pidz * stride_zx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn;
|
||||||
+ blockid * BLOCK * BLOCK
|
|
||||||
+ rxm * BLOCK
|
|
||||||
+ rxn;
|
|
||||||
#ifdef APPLY_RPE
|
#ifdef APPLY_RPE
|
||||||
// pointers to relative position embedding
|
// pointers to relative position embedding
|
||||||
TYPE* prpe[TN] = RPE + pidz * stride_zrpe
|
TYPE *prpe[TN] = RPE + pidz * stride_zrpe + headid * stride_hrpe + columnid * BLOCK + rowid * BLOCK * stride_srpe + rxm * stride_srpe + rxn;
|
||||||
+ headid * stride_hrpe
|
|
||||||
+ columnid * BLOCK
|
|
||||||
+ rowid * BLOCK * stride_srpe
|
|
||||||
+ rxm * stride_srpe
|
|
||||||
+ rxn;
|
|
||||||
#endif
|
#endif
|
||||||
#ifdef APPLY_KP_MASK
|
#ifdef APPLY_KP_MASK
|
||||||
// pointers to key padding mask
|
// pointers to key padding mask
|
||||||
TYPE* pkp_m[TN] = KP_M + pidz * stride_zkpm
|
TYPE *pkp_m[TN] = KP_M + pidz * stride_zkpm + columnid * BLOCK + rxn;
|
||||||
+ columnid * BLOCK
|
|
||||||
+ rxn;
|
|
||||||
#endif
|
#endif
|
||||||
#ifdef APPLY_ATTN_MASK
|
#ifdef APPLY_ATTN_MASK
|
||||||
// pointers to attention mask
|
// pointers to attention mask
|
||||||
TYPE* pattn_m[TN] = ATTN_M + columnid * BLOCK
|
TYPE *pattn_m[TN] = ATTN_M + columnid * BLOCK + rowid * BLOCK * stride_zattnm + rxm * stride_zattnm + rxn;
|
||||||
+ rowid * BLOCK * stride_zattnm
|
|
||||||
+ rxm * stride_zattnm
|
|
||||||
+ rxn;
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
// load input
|
// load input
|
||||||
@@ -82,10 +70,10 @@ __global__ void forward(TYPE *X __readonly __noalias __aligned(16),
|
|||||||
float Fattn_m[TN] = attn_m;
|
float Fattn_m[TN] = attn_m;
|
||||||
#endif
|
#endif
|
||||||
#ifdef KP_MASK_MUL
|
#ifdef KP_MASK_MUL
|
||||||
Fkp_m = (Fkp_m == 0) ? (float[TN])-INFINITY : 0;
|
Fkp_m = (Fkp_m == 0) ? (float[TN]) - INFINITY : 0;
|
||||||
#endif
|
#endif
|
||||||
#ifdef ATTN_MASK_MUL
|
#ifdef ATTN_MASK_MUL
|
||||||
Fattn_m = (Fattn_m == 0) ? (float[TN])-INFINITY : 0;
|
Fattn_m = (Fattn_m == 0) ? (float[TN]) - INFINITY : 0;
|
||||||
#endif
|
#endif
|
||||||
float Fx[TN] = x;
|
float Fx[TN] = x;
|
||||||
#ifdef APPLY_SCALE
|
#ifdef APPLY_SCALE
|
||||||
@@ -106,16 +94,17 @@ __global__ void forward(TYPE *X __readonly __noalias __aligned(16),
|
|||||||
// write-back in half/float
|
// write-back in half/float
|
||||||
TYPE y[TN] = Fy;
|
TYPE y[TN] = Fy;
|
||||||
TYPE ysum = Fysum;
|
TYPE ysum = Fysum;
|
||||||
*?(check)px = y / ysum;
|
*? (check)px = y / ysum;
|
||||||
}
|
}
|
||||||
|
|
||||||
__global__ void backward(TYPE * X __readonly __noalias __aligned(16),
|
__global__ void backward(TYPE *X __readonly __noalias __aligned(16),
|
||||||
float scale,
|
float scale,
|
||||||
TYPE* DX __readonly __noalias __aligned(16),
|
TYPE *DX __readonly __noalias __aligned(16),
|
||||||
int* LUT,
|
int *LUT,
|
||||||
int sizemax,
|
int sizemax,
|
||||||
long stride_zx __multipleof(BLOCK),
|
long stride_zx __multipleof(BLOCK),
|
||||||
long stride_zdx __multipleof(BLOCK)) {
|
long stride_zdx __multipleof(BLOCK))
|
||||||
|
{
|
||||||
int pidhm = get_program_id(0);
|
int pidhm = get_program_id(0);
|
||||||
int pidz = get_program_id(1);
|
int pidz = get_program_id(1);
|
||||||
// create index ranges
|
// create index ranges
|
||||||
@@ -124,28 +113,22 @@ __global__ void backward(TYPE * X __readonly __noalias __aligned(16),
|
|||||||
int rxn[TN] = (0 ... TN) % BLOCK;
|
int rxn[TN] = (0 ... TN) % BLOCK;
|
||||||
int rbn[TN] = (0 ... TN) / BLOCK;
|
int rbn[TN] = (0 ... TN) / BLOCK;
|
||||||
// extract information from look-up table
|
// extract information from look-up table
|
||||||
int* header = LUT + rbm * 2;
|
int *header = LUT + rbm * 2;
|
||||||
int size = *(header + 0);
|
int size = *(header + 0);
|
||||||
int offset = *(header + 1);
|
int offset = *(header + 1);
|
||||||
// bounds checking on lut
|
// bounds checking on lut
|
||||||
bool check[TN] = rbn < size;
|
bool check[TN] = rbn < size;
|
||||||
int rbmn[TN] = check ? rbn : size - 1;
|
int rbmn[TN] = check ? rbn : size - 1;
|
||||||
// initialize pointers to block-sparse input
|
// initialize pointers to block-sparse input
|
||||||
long blockid[TN] = *(LUT + offset + rbmn*4);
|
long blockid[TN] = *(LUT + offset + rbmn * 4);
|
||||||
TYPE* px[TN] = X + pidz * stride_zx
|
TYPE *px[TN] = X + pidz * stride_zx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn;
|
||||||
+ blockid * BLOCK * BLOCK
|
TYPE *pdx[TN] = DX + pidz * stride_zdx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn;
|
||||||
+ rxm * BLOCK
|
|
||||||
+ rxn;
|
|
||||||
TYPE* pdx[TN] = DX + pidz * stride_zdx
|
|
||||||
+ blockid * BLOCK * BLOCK
|
|
||||||
+ rxm * BLOCK
|
|
||||||
+ rxn;
|
|
||||||
// compute fused softmax backward
|
// compute fused softmax backward
|
||||||
TYPE x[TN] = check ? *px : 0;
|
TYPE x[TN] = check ? *px : 0;
|
||||||
TYPE dx[TN] = check ? *pdx : 0;
|
TYPE dx[TN] = check ? *pdx : 0;
|
||||||
float Fdx[TN] = dx;
|
float Fdx[TN] = dx;
|
||||||
float Fx[TN] = x;
|
float Fx[TN] = x;
|
||||||
float Fxdx[TN] = Fdx*Fx;
|
float Fxdx[TN] = Fdx * Fx;
|
||||||
float Fxdxsum = Fxdx[+];
|
float Fxdxsum = Fxdx[+];
|
||||||
float Fy[TN] = Fx * (Fdx - Fxdxsum) * scale;
|
float Fy[TN] = Fx * (Fdx - Fxdxsum) * scale;
|
||||||
TYPE y[TN] = Fy;
|
TYPE y[TN] = Fy;
|
||||||
|
@@ -15,6 +15,17 @@ bwd_kernels = dict()
|
|||||||
class _softmax(torch.autograd.Function):
|
class _softmax(torch.autograd.Function):
|
||||||
|
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def next_power_of_2(n):
|
||||||
|
n -= 1
|
||||||
|
n |= n >> 1
|
||||||
|
n |= n >> 2
|
||||||
|
n |= n >> 4
|
||||||
|
n |= n >> 8
|
||||||
|
n |= n >> 16
|
||||||
|
n += 1
|
||||||
|
return n
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def make_lut(layout, block, device):
|
def make_lut(layout, block, device):
|
||||||
_empty = torch.tensor([], dtype=torch.int64, device=layout.device)
|
_empty = torch.tensor([], dtype=torch.int64, device=layout.device)
|
||||||
@@ -43,8 +54,7 @@ class _softmax(torch.autograd.Function):
|
|||||||
raise NotImplementedError('Reductions larger than 32768 elements '\
|
raise NotImplementedError('Reductions larger than 32768 elements '\
|
||||||
'are not yet implemented')
|
'are not yet implemented')
|
||||||
num_warps = 4 if max_k < 512 else (8 if max_k < 2048 else 16)
|
num_warps = 4 if max_k < 512 else (8 if max_k < 2048 else 16)
|
||||||
pad = num_warps * 32 * 2
|
TN = _softmax.next_power_of_2(max_k)
|
||||||
TN = (int(max_k) + pad-1)//pad * pad
|
|
||||||
# just-in-time compile kernel
|
# just-in-time compile kernel
|
||||||
key = (block, device, dtype, num_warps, TN, apply_scale, apply_rpe, apply_kp_mask, apply_attn_mask, kp_mask_mode, attn_mask_mode)
|
key = (block, device, dtype, num_warps, TN, apply_scale, apply_rpe, apply_kp_mask, apply_attn_mask, kp_mask_mode, attn_mask_mode)
|
||||||
if key not in cache:
|
if key not in cache:
|
||||||
@@ -108,11 +118,8 @@ class _softmax(torch.autograd.Function):
|
|||||||
grid = lambda opt: [triton.cdiv(spdims[0] * spdims[1] * block, opt.TM), M]
|
grid = lambda opt: [triton.cdiv(spdims[0] * spdims[1] * block, opt.TM), M]
|
||||||
|
|
||||||
# run kernel
|
# run kernel
|
||||||
time[0] = kernel(x.data_ptr(), scale, lut.data_ptr(), rpe.data_ptr(), key_padding_mask.data_ptr(), attn_mask.data_ptr(),\
|
kernel(x.data_ptr(), scale, lut.data_ptr(), rpe.data_ptr(), key_padding_mask.data_ptr(), attn_mask.data_ptr(),
|
||||||
maxlut,\
|
maxlut, x.stride(0), stride_zrpe, stride_hrpe, stride_srpe, stride_zkpm, stride_zattnm,
|
||||||
x.stride(0),\
|
|
||||||
stride_zrpe, stride_hrpe, stride_srpe,\
|
|
||||||
stride_zkpm, stride_zattnm,\
|
|
||||||
grid=grid)
|
grid=grid)
|
||||||
# save to context
|
# save to context
|
||||||
ctx.mark_dirty(x)
|
ctx.mark_dirty(x)
|
||||||
|
Reference in New Issue
Block a user