diff --git a/python/examples/conv.py b/python/examples/conv.py index dff03488a..43f0f5d91 100644 --- a/python/examples/conv.py +++ b/python/examples/conv.py @@ -2,7 +2,7 @@ import torch import triton N, C, K = 32, 8, 32 -H, W = 4, 4 +H, W = 16, 16 R, S = 3, 3 torch.manual_seed(0) a = torch.randn(N, C, H, W).cuda() @@ -11,6 +11,5 @@ b = torch.ones(C, R, S, K).cuda() rc = torch.nn.functional.conv2d(a, b.permute(3, 0, 1, 2)) tc = triton.ops.conv(a, b) print((rc - tc).abs().max()) -print((tc[:,:,0,0] - rc[:,:,0,0]).abs()) #print((rc[:30,:30,:,:] - tc[:30, :30, :, :]).abs().max()) #print(tc[31, 31,:,:]) \ No newline at end of file diff --git a/python/triton/ops/conv.py b/python/triton/ops/conv.py index 4bf290258..8a2678f2a 100644 --- a/python/triton/ops/conv.py +++ b/python/triton/ops/conv.py @@ -21,56 +21,57 @@ void convnd(A_TYPE *A, int off_uh, int off_uw, int off_uah, int off_uaw, int off_uch, int off_ucw, - int* a_delta, int* inc_a){ + int* ADELTA, int* ADIFF){ // range of indices along the reduction axis - int rka[TK] = 0 ... TK; - int rkb[TK] = 0 ... TK; + int rxa[TM] = get_program_id(0) * TM + 0 ... TM; + int ryb[TN] = get_program_id(1) * TN + 0 ... TN; + int rk[TK] = 0 ... TK; // initialize accumulator float c[TM, TN] = 0; // pointers for A - int rxa[TM] = get_program_id(0) * TM + 0 ... TM; int rabh[TM] = rxa / CW; int raw[TM] = rxa % CW; int rab[TM] = rabh / CH; int rah[TM] = rabh % CH; rah = rah * UPAW - off_uah; raw = raw * UPAH - off_uaw; - int racr[TK] = rka / BW; - int ras[TK] = rka % BW; + int racr[TK] = rk / BW; + int ras[TK] = rk % BW; int rac[TK] = racr / BH; int rar[TK] = racr % BH; rar = UPAR * rar; ras = UPAS * ras; - int ra0[TM] = rab*lda_n + rah*lda_h + raw*lda_w; - int ra1[TK] = rac*lda_c + rar*lda_h + ras*lda_w; - A_TYPE* pa[TM, TK] = A + ra0[:, newaxis] + ra1[newaxis, :]; + int ram[TM] = rab*lda_n + rah*lda_h + raw*lda_w; + int rak[TK] = rac*lda_c + rar*lda_h + ras*lda_w; + A_TYPE* pa[TM, TK] = A + ram[:, newaxis] + rak[newaxis, :]; // pointers for B - int rbn[TN] = get_program_id(1) * TN + 0 ... TN; - B_TYPE* pb[TK, TN] = B + rbn[newaxis, :] * ldb_k + rkb[:, newaxis] * ldb_s; + int rbk[TK] = rk; + int rbn[TN] = ryb; + B_TYPE* pb[TK, TN] = B + rbn[newaxis, :] * ldb_k + rbk[:, newaxis] * ldb_s; // pointers for A look-up table - int offda[TK] = rka % LUT_SIZE; - int* pincd[TK] = inc_a + offda; - int* pda[TK] = a_delta + offda + off_uw * LUT_SIZE + off_uh * LUT_SIZE * upsample_w; - int da[TK] = *pda; - int incd[TK] = *pincd; + int rklut[TK] = rk % LUT_SIZE; + int* padiff[TK] = ADIFF + rklut; + int* padelta[TK] = ADELTA + rklut + off_uw * LUT_SIZE + off_uh * LUT_SIZE * upsample_w; + int adiff[TK] = *padiff; + int adelta[TK] = *padelta; // reduction loop A_TYPE a[TM, TK] = *pa; B_TYPE b[TK, TN] = *pb; for(int k = K; k > 0; k = k - TK){ c += a @ b; - pa += da[newaxis, :]; + pa += adelta[newaxis, :]; pb += TK * ldb_s; // increment A look-up table - pda = pda + incd; - da = *pda; - pincd = pincd + incd; - incd = *pincd; + padelta = padelta + adiff; + adelta = *padelta; + padiff = padiff + adiff; + adiff = *padiff; // pre-fetches bool checka[TM, TK] = k > TK; bool checkb[TK, TN] = k > TK; @@ -78,7 +79,6 @@ void convnd(A_TYPE *A, b = checkb ? *pb : 0; } - // write back int rxc[TM] = get_program_id(0) * TM + 0 ... TM; int rc1[TN] = get_program_id(1) * TN + 0 ... TN;