[PYTHON][OPS] Convolution: Some cleaning of Triton-C kernel
This commit is contained in:
@@ -2,7 +2,7 @@ import torch
|
|||||||
import triton
|
import triton
|
||||||
|
|
||||||
N, C, K = 32, 8, 32
|
N, C, K = 32, 8, 32
|
||||||
H, W = 4, 4
|
H, W = 16, 16
|
||||||
R, S = 3, 3
|
R, S = 3, 3
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
a = torch.randn(N, C, H, W).cuda()
|
a = torch.randn(N, C, H, W).cuda()
|
||||||
@@ -11,6 +11,5 @@ b = torch.ones(C, R, S, K).cuda()
|
|||||||
rc = torch.nn.functional.conv2d(a, b.permute(3, 0, 1, 2))
|
rc = torch.nn.functional.conv2d(a, b.permute(3, 0, 1, 2))
|
||||||
tc = triton.ops.conv(a, b)
|
tc = triton.ops.conv(a, b)
|
||||||
print((rc - tc).abs().max())
|
print((rc - tc).abs().max())
|
||||||
print((tc[:,:,0,0] - rc[:,:,0,0]).abs())
|
|
||||||
#print((rc[:30,:30,:,:] - tc[:30, :30, :, :]).abs().max())
|
#print((rc[:30,:30,:,:] - tc[:30, :30, :, :]).abs().max())
|
||||||
#print(tc[31, 31,:,:])
|
#print(tc[31, 31,:,:])
|
@@ -21,56 +21,57 @@ void convnd(A_TYPE *A,
|
|||||||
int off_uh, int off_uw,
|
int off_uh, int off_uw,
|
||||||
int off_uah, int off_uaw,
|
int off_uah, int off_uaw,
|
||||||
int off_uch, int off_ucw,
|
int off_uch, int off_ucw,
|
||||||
int* a_delta, int* inc_a){
|
int* ADELTA, int* ADIFF){
|
||||||
|
|
||||||
// range of indices along the reduction axis
|
// range of indices along the reduction axis
|
||||||
int rka[TK] = 0 ... TK;
|
int rxa[TM] = get_program_id(0) * TM + 0 ... TM;
|
||||||
int rkb[TK] = 0 ... TK;
|
int ryb[TN] = get_program_id(1) * TN + 0 ... TN;
|
||||||
|
int rk[TK] = 0 ... TK;
|
||||||
|
|
||||||
// initialize accumulator
|
// initialize accumulator
|
||||||
float c[TM, TN] = 0;
|
float c[TM, TN] = 0;
|
||||||
|
|
||||||
// pointers for A
|
// pointers for A
|
||||||
int rxa[TM] = get_program_id(0) * TM + 0 ... TM;
|
|
||||||
int rabh[TM] = rxa / CW;
|
int rabh[TM] = rxa / CW;
|
||||||
int raw[TM] = rxa % CW;
|
int raw[TM] = rxa % CW;
|
||||||
int rab[TM] = rabh / CH;
|
int rab[TM] = rabh / CH;
|
||||||
int rah[TM] = rabh % CH;
|
int rah[TM] = rabh % CH;
|
||||||
rah = rah * UPAW - off_uah;
|
rah = rah * UPAW - off_uah;
|
||||||
raw = raw * UPAH - off_uaw;
|
raw = raw * UPAH - off_uaw;
|
||||||
int racr[TK] = rka / BW;
|
int racr[TK] = rk / BW;
|
||||||
int ras[TK] = rka % BW;
|
int ras[TK] = rk % BW;
|
||||||
int rac[TK] = racr / BH;
|
int rac[TK] = racr / BH;
|
||||||
int rar[TK] = racr % BH;
|
int rar[TK] = racr % BH;
|
||||||
rar = UPAR * rar;
|
rar = UPAR * rar;
|
||||||
ras = UPAS * ras;
|
ras = UPAS * ras;
|
||||||
int ra0[TM] = rab*lda_n + rah*lda_h + raw*lda_w;
|
int ram[TM] = rab*lda_n + rah*lda_h + raw*lda_w;
|
||||||
int ra1[TK] = rac*lda_c + rar*lda_h + ras*lda_w;
|
int rak[TK] = rac*lda_c + rar*lda_h + ras*lda_w;
|
||||||
A_TYPE* pa[TM, TK] = A + ra0[:, newaxis] + ra1[newaxis, :];
|
A_TYPE* pa[TM, TK] = A + ram[:, newaxis] + rak[newaxis, :];
|
||||||
|
|
||||||
// pointers for B
|
// pointers for B
|
||||||
int rbn[TN] = get_program_id(1) * TN + 0 ... TN;
|
int rbk[TK] = rk;
|
||||||
B_TYPE* pb[TK, TN] = B + rbn[newaxis, :] * ldb_k + rkb[:, newaxis] * ldb_s;
|
int rbn[TN] = ryb;
|
||||||
|
B_TYPE* pb[TK, TN] = B + rbn[newaxis, :] * ldb_k + rbk[:, newaxis] * ldb_s;
|
||||||
|
|
||||||
// pointers for A look-up table
|
// pointers for A look-up table
|
||||||
int offda[TK] = rka % LUT_SIZE;
|
int rklut[TK] = rk % LUT_SIZE;
|
||||||
int* pincd[TK] = inc_a + offda;
|
int* padiff[TK] = ADIFF + rklut;
|
||||||
int* pda[TK] = a_delta + offda + off_uw * LUT_SIZE + off_uh * LUT_SIZE * upsample_w;
|
int* padelta[TK] = ADELTA + rklut + off_uw * LUT_SIZE + off_uh * LUT_SIZE * upsample_w;
|
||||||
int da[TK] = *pda;
|
int adiff[TK] = *padiff;
|
||||||
int incd[TK] = *pincd;
|
int adelta[TK] = *padelta;
|
||||||
|
|
||||||
// reduction loop
|
// reduction loop
|
||||||
A_TYPE a[TM, TK] = *pa;
|
A_TYPE a[TM, TK] = *pa;
|
||||||
B_TYPE b[TK, TN] = *pb;
|
B_TYPE b[TK, TN] = *pb;
|
||||||
for(int k = K; k > 0; k = k - TK){
|
for(int k = K; k > 0; k = k - TK){
|
||||||
c += a @ b;
|
c += a @ b;
|
||||||
pa += da[newaxis, :];
|
pa += adelta[newaxis, :];
|
||||||
pb += TK * ldb_s;
|
pb += TK * ldb_s;
|
||||||
// increment A look-up table
|
// increment A look-up table
|
||||||
pda = pda + incd;
|
padelta = padelta + adiff;
|
||||||
da = *pda;
|
adelta = *padelta;
|
||||||
pincd = pincd + incd;
|
padiff = padiff + adiff;
|
||||||
incd = *pincd;
|
adiff = *padiff;
|
||||||
// pre-fetches
|
// pre-fetches
|
||||||
bool checka[TM, TK] = k > TK;
|
bool checka[TM, TK] = k > TK;
|
||||||
bool checkb[TK, TN] = k > TK;
|
bool checkb[TK, TN] = k > TK;
|
||||||
@@ -78,7 +79,6 @@ void convnd(A_TYPE *A,
|
|||||||
b = checkb ? *pb : 0;
|
b = checkb ? *pb : 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// write back
|
// write back
|
||||||
int rxc[TM] = get_program_id(0) * TM + 0 ... TM;
|
int rxc[TM] = get_program_id(0) * TM + 0 ... TM;
|
||||||
int rc1[TN] = get_program_id(1) * TN + 0 ... TN;
|
int rc1[TN] = get_program_id(1) * TN + 0 ... TN;
|
||||||
|
Reference in New Issue
Block a user