[PYTHON][OPS] Convolution: Some cleaning of Triton-C kernel
This commit is contained in:
@@ -2,7 +2,7 @@ import torch
|
||||
import triton
|
||||
|
||||
N, C, K = 32, 8, 32
|
||||
H, W = 4, 4
|
||||
H, W = 16, 16
|
||||
R, S = 3, 3
|
||||
torch.manual_seed(0)
|
||||
a = torch.randn(N, C, H, W).cuda()
|
||||
@@ -11,6 +11,5 @@ b = torch.ones(C, R, S, K).cuda()
|
||||
rc = torch.nn.functional.conv2d(a, b.permute(3, 0, 1, 2))
|
||||
tc = triton.ops.conv(a, b)
|
||||
print((rc - tc).abs().max())
|
||||
print((tc[:,:,0,0] - rc[:,:,0,0]).abs())
|
||||
#print((rc[:30,:30,:,:] - tc[:30, :30, :, :]).abs().max())
|
||||
#print(tc[31, 31,:,:])
|
@@ -21,56 +21,57 @@ void convnd(A_TYPE *A,
|
||||
int off_uh, int off_uw,
|
||||
int off_uah, int off_uaw,
|
||||
int off_uch, int off_ucw,
|
||||
int* a_delta, int* inc_a){
|
||||
int* ADELTA, int* ADIFF){
|
||||
|
||||
// range of indices along the reduction axis
|
||||
int rka[TK] = 0 ... TK;
|
||||
int rkb[TK] = 0 ... TK;
|
||||
int rxa[TM] = get_program_id(0) * TM + 0 ... TM;
|
||||
int ryb[TN] = get_program_id(1) * TN + 0 ... TN;
|
||||
int rk[TK] = 0 ... TK;
|
||||
|
||||
// initialize accumulator
|
||||
float c[TM, TN] = 0;
|
||||
|
||||
// pointers for A
|
||||
int rxa[TM] = get_program_id(0) * TM + 0 ... TM;
|
||||
int rabh[TM] = rxa / CW;
|
||||
int raw[TM] = rxa % CW;
|
||||
int rab[TM] = rabh / CH;
|
||||
int rah[TM] = rabh % CH;
|
||||
rah = rah * UPAW - off_uah;
|
||||
raw = raw * UPAH - off_uaw;
|
||||
int racr[TK] = rka / BW;
|
||||
int ras[TK] = rka % BW;
|
||||
int racr[TK] = rk / BW;
|
||||
int ras[TK] = rk % BW;
|
||||
int rac[TK] = racr / BH;
|
||||
int rar[TK] = racr % BH;
|
||||
rar = UPAR * rar;
|
||||
ras = UPAS * ras;
|
||||
int ra0[TM] = rab*lda_n + rah*lda_h + raw*lda_w;
|
||||
int ra1[TK] = rac*lda_c + rar*lda_h + ras*lda_w;
|
||||
A_TYPE* pa[TM, TK] = A + ra0[:, newaxis] + ra1[newaxis, :];
|
||||
int ram[TM] = rab*lda_n + rah*lda_h + raw*lda_w;
|
||||
int rak[TK] = rac*lda_c + rar*lda_h + ras*lda_w;
|
||||
A_TYPE* pa[TM, TK] = A + ram[:, newaxis] + rak[newaxis, :];
|
||||
|
||||
// pointers for B
|
||||
int rbn[TN] = get_program_id(1) * TN + 0 ... TN;
|
||||
B_TYPE* pb[TK, TN] = B + rbn[newaxis, :] * ldb_k + rkb[:, newaxis] * ldb_s;
|
||||
int rbk[TK] = rk;
|
||||
int rbn[TN] = ryb;
|
||||
B_TYPE* pb[TK, TN] = B + rbn[newaxis, :] * ldb_k + rbk[:, newaxis] * ldb_s;
|
||||
|
||||
// pointers for A look-up table
|
||||
int offda[TK] = rka % LUT_SIZE;
|
||||
int* pincd[TK] = inc_a + offda;
|
||||
int* pda[TK] = a_delta + offda + off_uw * LUT_SIZE + off_uh * LUT_SIZE * upsample_w;
|
||||
int da[TK] = *pda;
|
||||
int incd[TK] = *pincd;
|
||||
int rklut[TK] = rk % LUT_SIZE;
|
||||
int* padiff[TK] = ADIFF + rklut;
|
||||
int* padelta[TK] = ADELTA + rklut + off_uw * LUT_SIZE + off_uh * LUT_SIZE * upsample_w;
|
||||
int adiff[TK] = *padiff;
|
||||
int adelta[TK] = *padelta;
|
||||
|
||||
// reduction loop
|
||||
A_TYPE a[TM, TK] = *pa;
|
||||
B_TYPE b[TK, TN] = *pb;
|
||||
for(int k = K; k > 0; k = k - TK){
|
||||
c += a @ b;
|
||||
pa += da[newaxis, :];
|
||||
pa += adelta[newaxis, :];
|
||||
pb += TK * ldb_s;
|
||||
// increment A look-up table
|
||||
pda = pda + incd;
|
||||
da = *pda;
|
||||
pincd = pincd + incd;
|
||||
incd = *pincd;
|
||||
padelta = padelta + adiff;
|
||||
adelta = *padelta;
|
||||
padiff = padiff + adiff;
|
||||
adiff = *padiff;
|
||||
// pre-fetches
|
||||
bool checka[TM, TK] = k > TK;
|
||||
bool checkb[TK, TN] = k > TK;
|
||||
@@ -78,7 +79,6 @@ void convnd(A_TYPE *A,
|
||||
b = checkb ? *pb : 0;
|
||||
}
|
||||
|
||||
|
||||
// write back
|
||||
int rxc[TM] = get_program_id(0) * TM + 0 ... TM;
|
||||
int rc1[TN] = get_program_id(1) * TN + 0 ... TN;
|
||||
|
Reference in New Issue
Block a user