diff --git a/python/triton/ops/conv.c b/python/triton/ops/conv.c index f2c9e899a..6a8877895 100644 --- a/python/triton/ops/conv.c +++ b/python/triton/ops/conv.c @@ -11,7 +11,8 @@ __global__ void conv(TYPE *A __noalias __readonly, // memory strides int lda_z, int lda_ci, int lda_h, int lda_w, int ldb_ci, int ldb_r, int ldb_s, int ldb_co, - int ldc_z, int ldc_co, int ldc_p, int ldc_q) { + int ldc_z, int ldc_co, int ldc_p, int ldc_q) +{ // prologue int ridx = get_program_id(0); int ridy = get_program_id(1); @@ -47,19 +48,13 @@ __global__ void conv(TYPE *A __noalias __readonly, int rw[TM, TK] = rw_0[:, newaxis] + rs [newaxis, :]; // pointers to lhs - int offa[TM, TK] = rz[:, newaxis] * lda_z + - rci [newaxis, :] * lda_ci + - rh * lda_h + - rw * 1; + int offa[TM, TK] = rz[:, newaxis] * lda_z + rci [newaxis, :] * lda_ci + + rh * lda_h + rw * 1; TYPE *pa[TM, TK] = A + offa; int *padelta[TK] = ADELTA + rk; // pointers to rhs - int offb[TK, TN] = rci[:, newaxis] * ldb_ci + - rr - [:, newaxis] * ldb_r + - rs - [:, newaxis] * ldb_s + - rn [newaxis, :] * 1; + int offb[TK, TN] = rci[:, newaxis] * ldb_ci + rr[:, newaxis] * ldb_r + + rs[:, newaxis] * ldb_s + rn [newaxis, :] * 1; TYPE *pb[TK, TN] = B + offb; // prefetches operands @@ -72,7 +67,8 @@ __global__ void conv(TYPE *A __noalias __readonly, // reduction loop float acc[TM, TN] = 0; - for (int k = K; k > 0; k -= TK) { + for (int k = K; k > 0; k -= TK) + { acc += a @b; // increment A int adelta[TK] = *padelta; @@ -103,12 +99,8 @@ __global__ void conv(TYPE *A __noalias __readonly, rzp = rm / QQ; rp = rzp % PP; rz = rzp / PP; - int offc[TM, TN] = rz[:, newaxis] * ldc_z + - rn [newaxis, :] * ldc_co + - rp - [:, newaxis] * ldc_p + - rq - [:, newaxis] * 1; + int offc[TM, TN] = rz[:, newaxis] * ldc_z + rn [newaxis, :] * ldc_co + + rp[:, newaxis] * ldc_p + rq[:, newaxis] * 1; TYPE *pc[TM, TN] = C + offc; bool checkc[TM, TN] = rm[:, newaxis] < M && rn [newaxis, :] < N;