[PYTHON][OPS] Convolution: Some cleaning of Triton-C kernel

This commit is contained in:
Philippe Tillet
2019-11-01 11:20:00 -04:00
parent f4bbbbe5e4
commit 50a52df489
2 changed files with 23 additions and 24 deletions

View File

@@ -2,7 +2,7 @@ import torch
import triton
N, C, K = 32, 8, 32
H, W = 4, 4
H, W = 16, 16
R, S = 3, 3
torch.manual_seed(0)
a = torch.randn(N, C, H, W).cuda()
@@ -11,6 +11,5 @@ b = torch.ones(C, R, S, K).cuda()
rc = torch.nn.functional.conv2d(a, b.permute(3, 0, 1, 2))
tc = triton.ops.conv(a, b)
print((rc - tc).abs().max())
print((tc[:,:,0,0] - rc[:,:,0,0]).abs())
#print((rc[:30,:30,:,:] - tc[:30, :30, :, :]).abs().max())
#print(tc[31, 31,:,:])

View File

@@ -21,56 +21,57 @@ void convnd(A_TYPE *A,
int off_uh, int off_uw,
int off_uah, int off_uaw,
int off_uch, int off_ucw,
int* a_delta, int* inc_a){
int* ADELTA, int* ADIFF){
// range of indices along the reduction axis
int rka[TK] = 0 ... TK;
int rkb[TK] = 0 ... TK;
int rxa[TM] = get_program_id(0) * TM + 0 ... TM;
int ryb[TN] = get_program_id(1) * TN + 0 ... TN;
int rk[TK] = 0 ... TK;
// initialize accumulator
float c[TM, TN] = 0;
// pointers for A
int rxa[TM] = get_program_id(0) * TM + 0 ... TM;
int rabh[TM] = rxa / CW;
int raw[TM] = rxa % CW;
int rab[TM] = rabh / CH;
int rah[TM] = rabh % CH;
rah = rah * UPAW - off_uah;
raw = raw * UPAH - off_uaw;
int racr[TK] = rka / BW;
int ras[TK] = rka % BW;
int racr[TK] = rk / BW;
int ras[TK] = rk % BW;
int rac[TK] = racr / BH;
int rar[TK] = racr % BH;
rar = UPAR * rar;
ras = UPAS * ras;
int ra0[TM] = rab*lda_n + rah*lda_h + raw*lda_w;
int ra1[TK] = rac*lda_c + rar*lda_h + ras*lda_w;
A_TYPE* pa[TM, TK] = A + ra0[:, newaxis] + ra1[newaxis, :];
int ram[TM] = rab*lda_n + rah*lda_h + raw*lda_w;
int rak[TK] = rac*lda_c + rar*lda_h + ras*lda_w;
A_TYPE* pa[TM, TK] = A + ram[:, newaxis] + rak[newaxis, :];
// pointers for B
int rbn[TN] = get_program_id(1) * TN + 0 ... TN;
B_TYPE* pb[TK, TN] = B + rbn[newaxis, :] * ldb_k + rkb[:, newaxis] * ldb_s;
int rbk[TK] = rk;
int rbn[TN] = ryb;
B_TYPE* pb[TK, TN] = B + rbn[newaxis, :] * ldb_k + rbk[:, newaxis] * ldb_s;
// pointers for A look-up table
int offda[TK] = rka % LUT_SIZE;
int* pincd[TK] = inc_a + offda;
int* pda[TK] = a_delta + offda + off_uw * LUT_SIZE + off_uh * LUT_SIZE * upsample_w;
int da[TK] = *pda;
int incd[TK] = *pincd;
int rklut[TK] = rk % LUT_SIZE;
int* padiff[TK] = ADIFF + rklut;
int* padelta[TK] = ADELTA + rklut + off_uw * LUT_SIZE + off_uh * LUT_SIZE * upsample_w;
int adiff[TK] = *padiff;
int adelta[TK] = *padelta;
// reduction loop
A_TYPE a[TM, TK] = *pa;
B_TYPE b[TK, TN] = *pb;
for(int k = K; k > 0; k = k - TK){
c += a @ b;
pa += da[newaxis, :];
pa += adelta[newaxis, :];
pb += TK * ldb_s;
// increment A look-up table
pda = pda + incd;
da = *pda;
pincd = pincd + incd;
incd = *pincd;
padelta = padelta + adiff;
adelta = *padelta;
padiff = padiff + adiff;
adiff = *padiff;
// pre-fetches
bool checka[TM, TK] = k > TK;
bool checkb[TK, TN] = k > TK;
@@ -78,7 +79,6 @@ void convnd(A_TYPE *A,
b = checkb ? *pb : 0;
}
// write back
int rxc[TM] = get_program_id(0) * TM + 0 ... TM;
int rc1[TN] = get_program_id(1) * TN + 0 ... TN;