Files
triton/python/triton/ops/blocksparse/softmax.c
2021-07-27 12:38:48 -07:00

155 lines
5.1 KiB
C

__global__ void forward(TYPE *X __readonly __noalias __aligned(16),
float scale,
int *LUT __readonly __noalias __aligned(16),
TYPE *RPE __readonly __noalias __aligned(16),
TYPE *KP_M __readonly __noalias __aligned(16),
TYPE *ATTN_M __readonly __noalias __aligned(16),
int sizemax,
long stride_zx __multipleof(BLOCK),
long stride_zrpe __multipleof(BLOCK),
int stride_hrpe __multipleof(BLOCK),
int stride_srpe __multipleof(BLOCK),
int stride_zkpm __multipleof(BLOCK),
int stride_zattnm __multipleof(BLOCK)){
int pidhm = get_program_id(0);
int pidz = get_program_id(1);
// create index ranges
int rxm = pidhm % BLOCK;
int rbm = pidhm / BLOCK;
int rxn[TN] = (0 ... TN) % BLOCK;
int rbn[TN] = (0 ... TN) / BLOCK;
// extract information from look-up table
int* header = LUT + rbm * 2;
int size = *(header + 0);
int offset = *(header + 1);
bool check[TN] = rbn < size;
int rbmn[TN] = check ? rbn : size - 1;
// block id and column id
long blockid [TN] = *(LUT + offset + rbmn*4 + 0);
long columnid[TN] = *(LUT + offset + rbmn*4 + 1);
long rowid [TN] = *(LUT + offset + rbmn*4 + 2);
long headid [TN] = *(LUT + offset + rbmn*4 + 3);
// pointers to X
TYPE* px[TN] = X + pidz * stride_zx
+ blockid * BLOCK * BLOCK
+ rxm * BLOCK
+ rxn;
#ifdef APPLY_RPE
// pointers to relative position embedding
TYPE* prpe[TN] = RPE + pidz * stride_zrpe
+ headid * stride_hrpe
+ columnid * BLOCK
+ rowid * BLOCK * stride_srpe
+ rxm * stride_srpe
+ rxn;
#endif
#ifdef APPLY_KP_MASK
// pointers to key padding mask
TYPE* pkp_m[TN] = KP_M + pidz * stride_zkpm
+ columnid * BLOCK
+ rxn;
#endif
#ifdef APPLY_ATTN_MASK
// pointers to attention mask
TYPE* pattn_m[TN] = ATTN_M + columnid * BLOCK
+ rowid * BLOCK * stride_zattnm
+ rxm * stride_zattnm
+ rxn;
#endif
// load input
TYPE x[TN] = check ? *px : -INFINITY;
#ifdef APPLY_RPE
// load relative position embedding
TYPE rpe[TN] = check ? *prpe : 0;
#endif
#ifdef APPLY_KP_MASK
// load key-padding mask
TYPE kp_m[TN] = check ? *pkp_m : -INFINITY;
#endif
#ifdef APPLY_ATTN_MASK
// load attention mask
TYPE attn_m[TN] = check ? *pattn_m : -INFINITY;
#endif
// compute softmax in float
#ifdef APPLY_RPE
float Frpe[TN] = rpe;
#endif
#ifdef APPLY_KP_MASK
float Fkp_m[TN] = kp_m;
#endif
#ifdef APPLY_ATTN_MASK
float Fattn_m[TN] = attn_m;
#endif
#ifdef KP_MASK_MUL
Fkp_m = (Fkp_m == 0) ? (float[TN])-INFINITY : 0;
#endif
#ifdef ATTN_MASK_MUL
Fattn_m = (Fattn_m == 0) ? (float[TN])-INFINITY : 0;
#endif
float Fx[TN] = x;
#ifdef APPLY_SCALE
Fx = Fx * scale; // apply scale
#endif
#ifdef APPLY_RPE
Fx = Fx + Frpe; // apply relative position embedding
#endif
#ifdef APPLY_KP_MASK
Fx = Fx + Fkp_m; // apply key padding mask
#endif
#ifdef APPLY_ATTN_MASK
Fx = Fx + Fattn_m; // apply attention mask
#endif
float Fxmax = Fx[max];
float Fy[TN] = exp(Fx - Fxmax);
float Fysum = (check ? Fy : 0)[+];
// write-back in half/float
TYPE y[TN] = Fy;
TYPE ysum = Fysum;
*?(check)px = y / ysum;
}
__global__ void backward(TYPE * X __readonly __noalias __aligned(16),
float scale,
TYPE* DX __readonly __noalias __aligned(16),
int* LUT,
int sizemax,
long stride_zx __multipleof(BLOCK),
long stride_zdx __multipleof(BLOCK)) {
int pidhm = get_program_id(0);
int pidz = get_program_id(1);
// create index ranges
int rxm = pidhm % BLOCK;
int rbm = pidhm / BLOCK;
int rxn[TN] = (0 ... TN) % BLOCK;
int rbn[TN] = (0 ... TN) / BLOCK;
// extract information from look-up table
int* header = LUT + rbm * 2;
int size = *(header + 0);
int offset = *(header + 1);
// bounds checking on lut
bool check[TN] = rbn < size;
int rbmn[TN] = check ? rbn : size - 1;
// initialize pointers to block-sparse input
long blockid[TN] = *(LUT + offset + rbmn*4);
TYPE* px[TN] = X + pidz * stride_zx
+ blockid * BLOCK * BLOCK
+ rxm * BLOCK
+ rxn;
TYPE* pdx[TN] = DX + pidz * stride_zdx
+ blockid * BLOCK * BLOCK
+ rxm * BLOCK
+ rxn;
// compute fused softmax backward
TYPE x[TN] = check ? *px : 0;
TYPE dx[TN] = check ? *pdx : 0;
float Fdx[TN] = dx;
float Fx[TN] = x;
float Fxdx[TN] = Fdx*Fx;
float Fxdxsum = Fxdx[+];
float Fy[TN] = Fx * (Fdx - Fxdxsum) * scale;
TYPE y[TN] = Fy;
// write-back
*? (check)pdx = y;
}