[PYTHON][OPS] Convolution: Some cleaning of Triton-C kernel

This commit is contained in:
Philippe Tillet
2019-11-01 11:20:00 -04:00
parent f4bbbbe5e4
commit 50a52df489
2 changed files with 23 additions and 24 deletions

View File

@@ -2,7 +2,7 @@ import torch
import triton import triton
N, C, K = 32, 8, 32 N, C, K = 32, 8, 32
H, W = 4, 4 H, W = 16, 16
R, S = 3, 3 R, S = 3, 3
torch.manual_seed(0) torch.manual_seed(0)
a = torch.randn(N, C, H, W).cuda() a = torch.randn(N, C, H, W).cuda()
@@ -11,6 +11,5 @@ b = torch.ones(C, R, S, K).cuda()
rc = torch.nn.functional.conv2d(a, b.permute(3, 0, 1, 2)) rc = torch.nn.functional.conv2d(a, b.permute(3, 0, 1, 2))
tc = triton.ops.conv(a, b) tc = triton.ops.conv(a, b)
print((rc - tc).abs().max()) print((rc - tc).abs().max())
print((tc[:,:,0,0] - rc[:,:,0,0]).abs())
#print((rc[:30,:30,:,:] - tc[:30, :30, :, :]).abs().max()) #print((rc[:30,:30,:,:] - tc[:30, :30, :, :]).abs().max())
#print(tc[31, 31,:,:]) #print(tc[31, 31,:,:])

View File

@@ -21,56 +21,57 @@ void convnd(A_TYPE *A,
int off_uh, int off_uw, int off_uh, int off_uw,
int off_uah, int off_uaw, int off_uah, int off_uaw,
int off_uch, int off_ucw, int off_uch, int off_ucw,
int* a_delta, int* inc_a){ int* ADELTA, int* ADIFF){
// range of indices along the reduction axis // range of indices along the reduction axis
int rka[TK] = 0 ... TK; int rxa[TM] = get_program_id(0) * TM + 0 ... TM;
int rkb[TK] = 0 ... TK; int ryb[TN] = get_program_id(1) * TN + 0 ... TN;
int rk[TK] = 0 ... TK;
// initialize accumulator // initialize accumulator
float c[TM, TN] = 0; float c[TM, TN] = 0;
// pointers for A // pointers for A
int rxa[TM] = get_program_id(0) * TM + 0 ... TM;
int rabh[TM] = rxa / CW; int rabh[TM] = rxa / CW;
int raw[TM] = rxa % CW; int raw[TM] = rxa % CW;
int rab[TM] = rabh / CH; int rab[TM] = rabh / CH;
int rah[TM] = rabh % CH; int rah[TM] = rabh % CH;
rah = rah * UPAW - off_uah; rah = rah * UPAW - off_uah;
raw = raw * UPAH - off_uaw; raw = raw * UPAH - off_uaw;
int racr[TK] = rka / BW; int racr[TK] = rk / BW;
int ras[TK] = rka % BW; int ras[TK] = rk % BW;
int rac[TK] = racr / BH; int rac[TK] = racr / BH;
int rar[TK] = racr % BH; int rar[TK] = racr % BH;
rar = UPAR * rar; rar = UPAR * rar;
ras = UPAS * ras; ras = UPAS * ras;
int ra0[TM] = rab*lda_n + rah*lda_h + raw*lda_w; int ram[TM] = rab*lda_n + rah*lda_h + raw*lda_w;
int ra1[TK] = rac*lda_c + rar*lda_h + ras*lda_w; int rak[TK] = rac*lda_c + rar*lda_h + ras*lda_w;
A_TYPE* pa[TM, TK] = A + ra0[:, newaxis] + ra1[newaxis, :]; A_TYPE* pa[TM, TK] = A + ram[:, newaxis] + rak[newaxis, :];
// pointers for B // pointers for B
int rbn[TN] = get_program_id(1) * TN + 0 ... TN; int rbk[TK] = rk;
B_TYPE* pb[TK, TN] = B + rbn[newaxis, :] * ldb_k + rkb[:, newaxis] * ldb_s; int rbn[TN] = ryb;
B_TYPE* pb[TK, TN] = B + rbn[newaxis, :] * ldb_k + rbk[:, newaxis] * ldb_s;
// pointers for A look-up table // pointers for A look-up table
int offda[TK] = rka % LUT_SIZE; int rklut[TK] = rk % LUT_SIZE;
int* pincd[TK] = inc_a + offda; int* padiff[TK] = ADIFF + rklut;
int* pda[TK] = a_delta + offda + off_uw * LUT_SIZE + off_uh * LUT_SIZE * upsample_w; int* padelta[TK] = ADELTA + rklut + off_uw * LUT_SIZE + off_uh * LUT_SIZE * upsample_w;
int da[TK] = *pda; int adiff[TK] = *padiff;
int incd[TK] = *pincd; int adelta[TK] = *padelta;
// reduction loop // reduction loop
A_TYPE a[TM, TK] = *pa; A_TYPE a[TM, TK] = *pa;
B_TYPE b[TK, TN] = *pb; B_TYPE b[TK, TN] = *pb;
for(int k = K; k > 0; k = k - TK){ for(int k = K; k > 0; k = k - TK){
c += a @ b; c += a @ b;
pa += da[newaxis, :]; pa += adelta[newaxis, :];
pb += TK * ldb_s; pb += TK * ldb_s;
// increment A look-up table // increment A look-up table
pda = pda + incd; padelta = padelta + adiff;
da = *pda; adelta = *padelta;
pincd = pincd + incd; padiff = padiff + adiff;
incd = *pincd; adiff = *padiff;
// pre-fetches // pre-fetches
bool checka[TM, TK] = k > TK; bool checka[TM, TK] = k > TK;
bool checkb[TK, TN] = k > TK; bool checkb[TK, TN] = k > TK;
@@ -78,7 +79,6 @@ void convnd(A_TYPE *A,
b = checkb ? *pb : 0; b = checkb ? *pb : 0;
} }
// write back // write back
int rxc[TM] = get_program_id(0) * TM + 0 ... TM; int rxc[TM] = get_program_id(0) * TM + 0 ... TM;
int rc1[TN] = get_program_id(1) * TN + 0 ... TN; int rc1[TN] = get_program_id(1) * TN + 0 ... TN;