[dnn/conv] fixed formatting of generated Triton-C code
This commit is contained in:
@@ -12,9 +12,9 @@ int main() {
|
||||
triton::jit jit(context);
|
||||
triton::dnn::conv::type ty = triton::dnn::conv::WGRAD;
|
||||
// initialization
|
||||
int32_t B = 4, NF = 32;
|
||||
int32_t D = 1, H = 24, W = 240;
|
||||
int32_t NC = 32, T = 1, R = 3, S = 3;
|
||||
int32_t B = 32, NF = 128;
|
||||
int32_t D = 1, H = 56, W = 56;
|
||||
int32_t NC = 128, T = 1, R = 3, S = 3;
|
||||
int32_t pad_d = 0, pad_h = 1, pad_w = 1;
|
||||
triton::dnn::conv configuration(B, NC, D, H, W, T, R, S, NF, 1, 1, 1, pad_d, pad_h, pad_w, ty);
|
||||
// convolution configuration
|
||||
|
@@ -350,120 +350,119 @@ public:
|
||||
|
||||
std::string res =
|
||||
R"(
|
||||
const tunable int32 TM = {16, 32, 64};
|
||||
const tunable int32 TN = {16, 32, 64};
|
||||
const tunable int32 TK = {8};
|
||||
)";
|
||||
if(is_a_deltas_cst)
|
||||
res += "__constant__ int32* delta = alloc_const int32[" + std::to_string(h_a_deltas_.size()) + "];\n";
|
||||
if(is_wgrad && is_b_deltas_cst_)
|
||||
res += "__constant__ int32* b_delta = alloc_const int32[" + std::to_string(h_b_deltas_.size()) + "];\n";
|
||||
if(is_mask_cst_)
|
||||
res += "__constant__ int32* masks = alloc_const int32[" + std::to_string(h_masks_.size()) + "];\n";
|
||||
res += R"(
|
||||
|
||||
void conv(read_only restrict fp32 *a,
|
||||
read_only restrict fp32 *b,
|
||||
fp32 *c,
|
||||
int32 M, int32 N, int32 K,
|
||||
int32 AH, int32 AW,
|
||||
int32 BH, int32 BW,
|
||||
int32 CH, int32 CW,
|
||||
int32 lda_n, int32 lda_c, int32 lda_d, int32 lda_h, int32 lda_w,
|
||||
int32 ldb_c, int32 ldb_t, int32 ldb_r, int32 ldb_s, int32 ldb_k,
|
||||
int32 ldc_n, int32 ldc_k, int32 ldc_m, int32 ldc_p, int32 ldc_q,
|
||||
int32 pad_h, int32 pad_w)";
|
||||
if(!is_a_deltas_cst)
|
||||
res += ", int32* delta\n";
|
||||
if(is_wgrad && !is_b_deltas_cst_)
|
||||
res += ", int32* b_delta\n";
|
||||
if(!is_mask_cst_)
|
||||
res += ", int32* masks\n";
|
||||
res += R"(){
|
||||
int32 rxa[TM] = get_global_range[TM](0);
|
||||
int32 rb0[TN] = get_global_range[TN](1);
|
||||
int32 rka[TK] = 0 ... TK;
|
||||
int32 rkb[TK] = 0 ... TK;
|
||||
fp32 C[TM, TN] = 0;
|
||||
int32 ldlut = )" + std::to_string(Fs_) + R"(;
|
||||
int32 rabh[TM] = rxa / CW;
|
||||
int32 raw[TM] = rxa % CW - pad_w;
|
||||
int32 rab[TM] = rabh / CH;
|
||||
int32 rah[TM] = rabh % CH - pad_h;
|
||||
int32 ra0[TM] = rab*lda_n + rah*lda_h + raw*lda_w;
|
||||
int32 ra)" + ax[0] + ax[1] + "[TK] = rka / " + redax[2] + R"(;
|
||||
int32 ra)" + ax[2] + "[TK] = rka % " + redax[2] + R"(;
|
||||
int32 ra)" + ax[0] + "[TK] = ra" + ax[0] + ax[1] + " / " + redax[1] + R"(;
|
||||
int32 ra)" + ax[1] + "[TK] = ra" + ax[0] + ax[1] + " % " + redax[1] + R"(;
|
||||
rar = )" + flipr + R"( rar;
|
||||
ras = )" + flips + R"( ras;
|
||||
int32 ra1[TK] = rac*lda_c + rar*lda_h + ras*lda_w;
|
||||
fp32* pa[TM, TK] = a + ra1[newaxis, :] + ra0[:, newaxis];)";
|
||||
if(ty_ == WGRAD){
|
||||
res += R"(
|
||||
int32 rbcr[TK] = rkb / BW;
|
||||
int32 rbs[TK] = rkb % BW;
|
||||
int32 rbc[TK] = rbcr / BH;
|
||||
int32 rbr[TK] = rbcr % BH;
|
||||
int32 rb1[TK] = rbc*ldb_c + rbr*ldb_r + ras*ldb_s;
|
||||
)" + b_delta_mem + R"( int32* pdb[TK] = b_delta + rkb;
|
||||
int32 db[TK] = *pdb;)";
|
||||
}
|
||||
else{
|
||||
res += R"(
|
||||
int32 rb1[TK] = rkb;)";
|
||||
}
|
||||
res += R"(
|
||||
fp32* pb)" + BS + " = b + rb1" + bcb1 + ldb0 + " + rb0" + bcb0 + ldb1 + R"(;
|
||||
)" + a_delta_mem + R"( int32* pincd[TK] = delta + rka;
|
||||
)" + a_delta_mem + R"( int32* pd[TK] = delta + ldlut + rka;
|
||||
int32 d[TK] = *pd;
|
||||
int32 incd[TK] = *pincd;
|
||||
int32 maskh[TM] = pad_h + min(rah, 0) + max(rah + BH - AH, 0);
|
||||
int32 maskw[TM] = pad_w + min(raw, 0) + max(raw + BW - AW, 0);
|
||||
)" + masks_mem + R"( int32* pm[TM] = masks + ldlut + maskw*ldlut + maskh*ldlut*(2*pad_w + 1);
|
||||
)" + a_delta_mem + R"( int32* pincm[TM] = delta;
|
||||
int32 incm[TM] = *pincm;
|
||||
int32 checka0[TM] = *pm;
|
||||
int32 checka1[TK] = 1 << rka;
|
||||
int1 checka[TM, TK] = (checka0[:, newaxis] & checka1[newaxis, :]) > 0;
|
||||
fp32 a[TM, TK] = checka ? *pa : 0;
|
||||
fp32 b)" + BS + R"( = *pb;
|
||||
for(int32 k = K; k > 0; k = k - TK){
|
||||
C = dot(a, )" + useb + R"(, C);
|
||||
pa = pa + d[newaxis, :];
|
||||
pb = pb + )" + inc_pb + R"(;
|
||||
b = *pb;
|
||||
pd = pd + incd;)";
|
||||
if(ty_ == WGRAD){
|
||||
res += R"(
|
||||
pdb = pdb + incd;
|
||||
db = *pdb;)";
|
||||
}
|
||||
res += R"(
|
||||
pincd = pincd + incd;
|
||||
d = *pd;
|
||||
incd = *pincd;
|
||||
pm = pm + incm;
|
||||
pincm = pincm + incm;
|
||||
incm = *pincm;
|
||||
checka0 = *pm;
|
||||
checka = (checka0[:, newaxis] & checka1[newaxis, :]) > 0;
|
||||
checka = checka && (k > TK);
|
||||
a = checka ? *pa : 0;
|
||||
}
|
||||
int32 rxc[TM] = get_global_range[TM](0);
|
||||
int32 rc1[TN] = get_global_range[TN](1);
|
||||
int32 rcn[TM] = rxc / (CH*CW);
|
||||
int32 rcpq[TM] = rxc % (CH*CW);
|
||||
int32 rc0[TM] = rcn * ldc_n + rcpq * ldc_q;
|
||||
fp32* pc[TM, TN] = c + rc1[newaxis, :]*ldc_k + rc0[:, newaxis];
|
||||
int1 checkc0[TM] = rxc < M;
|
||||
int1 checkc1[TN] = rc1 < N;
|
||||
int1 checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];
|
||||
@checkc *pc = C;
|
||||
})";
|
||||
const tunable int32 TM = {16, 32, 64};
|
||||
const tunable int32 TN = {16, 32, 64};
|
||||
const tunable int32 TK = {8};
|
||||
)";
|
||||
if(is_a_deltas_cst)
|
||||
res += "__constant__ int32* delta = alloc_const int32[" + std::to_string(h_a_deltas_.size()) + "];\n";
|
||||
if(is_wgrad && is_b_deltas_cst_)
|
||||
res += "__constant__ int32* b_delta = alloc_const int32[" + std::to_string(h_b_deltas_.size()) + "];\n";
|
||||
if(is_mask_cst_)
|
||||
res += "__constant__ int32* masks = alloc_const int32[" + std::to_string(h_masks_.size()) + "];\n";
|
||||
res += R"(
|
||||
|
||||
void conv(read_only restrict fp32 *a,
|
||||
read_only restrict fp32 *b,
|
||||
fp32 *c,
|
||||
int32 M, int32 N, int32 K,
|
||||
int32 AH, int32 AW,
|
||||
int32 BH, int32 BW,
|
||||
int32 CH, int32 CW,
|
||||
int32 lda_n, int32 lda_c, int32 lda_d, int32 lda_h, int32 lda_w,
|
||||
int32 ldb_c, int32 ldb_t, int32 ldb_r, int32 ldb_s, int32 ldb_k,
|
||||
int32 ldc_n, int32 ldc_k, int32 ldc_m, int32 ldc_p, int32 ldc_q,
|
||||
int32 pad_h, int32 pad_w)";
|
||||
if(!is_a_deltas_cst)
|
||||
res += ", int32* delta";
|
||||
if(is_wgrad && !is_b_deltas_cst_)
|
||||
res += ", int32* b_delta";
|
||||
if(!is_mask_cst_)
|
||||
res += ", int32* masks";
|
||||
res += R"(){
|
||||
int32 rxa[TM] = get_global_range[TM](0);
|
||||
int32 rb0[TN] = get_global_range[TN](1);
|
||||
int32 rka[TK] = 0 ... TK;
|
||||
int32 rkb[TK] = 0 ... TK;
|
||||
fp32 C[TM, TN] = 0;
|
||||
int32 ldlut = )" + std::to_string(Fs_) + R"(;
|
||||
int32 rabh[TM] = rxa / CW;
|
||||
int32 raw[TM] = rxa % CW - pad_w;
|
||||
int32 rab[TM] = rabh / CH;
|
||||
int32 rah[TM] = rabh % CH - pad_h;
|
||||
int32 ra0[TM] = rab*lda_n + rah*lda_h + raw*lda_w;
|
||||
int32 ra)" + ax[0] + ax[1] + "[TK] = rka / " + redax[2] + R"(;
|
||||
int32 ra)" + ax[2] + "[TK] = rka % " + redax[2] + R"(;
|
||||
int32 ra)" + ax[0] + "[TK] = ra" + ax[0] + ax[1] + " / " + redax[1] + R"(;
|
||||
int32 ra)" + ax[1] + "[TK] = ra" + ax[0] + ax[1] + " % " + redax[1] + R"(;
|
||||
rar = )" + flipr + R"( rar;
|
||||
ras = )" + flips + R"( ras;
|
||||
int32 ra1[TK] = rac*lda_c + rar*lda_h + ras*lda_w;
|
||||
fp32* pa[TM, TK] = a + ra1[newaxis, :] + ra0[:, newaxis];)";
|
||||
if(ty_ == WGRAD){
|
||||
res += R"(
|
||||
int32 rbcr[TK] = rkb / BW;
|
||||
int32 rbs[TK] = rkb % BW;
|
||||
int32 rbc[TK] = rbcr / BH;
|
||||
int32 rbr[TK] = rbcr % BH;
|
||||
int32 rb1[TK] = rbc*ldb_c + rbr*ldb_r + ras*ldb_s;
|
||||
)" + b_delta_mem + R"( int32* pdb[TK] = b_delta + rkb;
|
||||
int32 db[TK] = *pdb;)";
|
||||
}
|
||||
else{
|
||||
res += R"(
|
||||
int32 rb1[TK] = rkb;)";
|
||||
}
|
||||
res += R"(
|
||||
fp32* pb)" + BS + " = b + rb1" + bcb1 + ldb0 + " + rb0" + bcb0 + ldb1 + R"(;
|
||||
)" + a_delta_mem + R"( int32* pincd[TK] = delta + rka;
|
||||
)" + a_delta_mem + R"( int32* pd[TK] = delta + ldlut + rka;
|
||||
int32 d[TK] = *pd;
|
||||
int32 incd[TK] = *pincd;
|
||||
int32 maskh[TM] = pad_h + min(rah, 0) + max(rah + BH - AH, 0);
|
||||
int32 maskw[TM] = pad_w + min(raw, 0) + max(raw + BW - AW, 0);
|
||||
)" + masks_mem + R"( int32* pm[TM] = masks + ldlut + maskw*ldlut + maskh*ldlut*(2*pad_w + 1);
|
||||
)" + a_delta_mem + R"( int32* pincm[TM] = delta;
|
||||
int32 incm[TM] = *pincm;
|
||||
int32 checka0[TM] = *pm;
|
||||
int32 checka1[TK] = 1 << rka;
|
||||
int1 checka[TM, TK] = (checka0[:, newaxis] & checka1[newaxis, :]) > 0;
|
||||
fp32 a[TM, TK] = checka ? *pa : 0;
|
||||
fp32 b)" + BS + R"( = *pb;
|
||||
for(int32 k = K; k > 0; k = k - TK){
|
||||
C = dot(a, )" + useb + R"(, C);
|
||||
pa = pa + d[newaxis, :];
|
||||
pb = pb + )" + inc_pb + R"(;
|
||||
b = *pb;
|
||||
pd = pd + incd;)";
|
||||
if(ty_ == WGRAD){
|
||||
res += R"(
|
||||
pdb = pdb + TK;
|
||||
db = *pdb;)";
|
||||
}
|
||||
res += R"(
|
||||
pincd = pincd + incd;
|
||||
d = *pd;
|
||||
incd = *pincd;
|
||||
pm = pm + incm;
|
||||
pincm = pincm + incm;
|
||||
incm = *pincm;
|
||||
checka0 = *pm;
|
||||
checka = (checka0[:, newaxis] & checka1[newaxis, :]) > 0;
|
||||
checka = checka && (k > TK);
|
||||
a = checka ? *pa : 0;
|
||||
}
|
||||
int32 rxc[TM] = get_global_range[TM](0);
|
||||
int32 rc1[TN] = get_global_range[TN](1);
|
||||
int32 rcn[TM] = rxc / (CH*CW);
|
||||
int32 rcpq[TM] = rxc % (CH*CW);
|
||||
int32 rc0[TM] = rcn * ldc_n + rcpq * ldc_q;
|
||||
fp32* pc[TM, TN] = c + rc1[newaxis, :]*ldc_k + rc0[:, newaxis];
|
||||
int1 checkc0[TM] = rxc < M;
|
||||
int1 checkc1[TN] = rc1 < N;
|
||||
int1 checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];
|
||||
@checkc *pc = C;
|
||||
})";
|
||||
return res;
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user