some cleaning

This commit is contained in:
Philippe Tillet
2019-07-12 20:10:15 -07:00
parent 7512c7ebed
commit 54617b4e51

View File

@@ -283,11 +283,21 @@ void shift(restrict read_only align(16) )" + a_ty_ + R"( *A,
/* A offsets */ /* A offsets */
if(op_ == FPROP){ if(op_ == FPROP){
if(true){
result += R"(
int32 rawh[TM] = rxa / NB;
int32 rab[TM] = rxa % NB;
int32 raw[TM] = rawh % CW;
int32 rah[TM] = rawh / CW;)";
}
else{
result += R"(
int32 rabh[TM] = rxa / CW;
int32 raw[TM] = rxa % CW;
int32 rah[TM] = rabh % CH;
int32 rab[TM] = rabh / CH;)";
}
result += R"( result += R"(
int32 rawh[TM] = rxa / NB;
int32 rab[TM] = rxa % NB;
int32 raw[TM] = rawh % CW;
int32 rah[TM] = rawh / CW;
raw = raw * stride_w; raw = raw * stride_w;
rah = rah * stride_h; rah = rah * stride_h;
int32 offxa[TM] = rab*lda_b + raw*lda_w + rah*lda_h; int32 offxa[TM] = rab*lda_b + raw*lda_w + rah*lda_h;
@@ -310,11 +320,21 @@ if(op_ == FPROP){
int32 offa1[TM, TK] = interior ? offa_interior : offa_exterior;)"; int32 offa1[TM, TK] = interior ? offa_interior : offa_exterior;)";
} }
if(op_ == BPROP){ if(op_ == BPROP){
if(true){
result += R"(
int32 rawh[TM] = rxa / NB;
int32 rab[TM] = rxa % NB;
int32 raw[TM] = rawh % CW;
int32 rah[TM] = rawh / CW;)";
}
else{
result += R"(
int32 rabh[TM] = rxa / CW;
int32 raw[TM] = rxa % CW;
int32 rah[TM] = rabh % CH;
int32 rab[TM] = rabh / CH;)";
}
result += R"( result += R"(
int32 rawh[TM] = rxa / NB;
int32 rab[TM] = rxa % NB;
int32 raw[TM] = (rawh % CW);
int32 rah[TM] = (rawh / CW);
int32 offxa[TM] = rab*lda_b + raw*lda_w + rah*lda_h; int32 offxa[TM] = rab*lda_b + raw*lda_w + rah*lda_h;
int32 offa0[TM, TK] = offxa[:, newaxis]; int32 offa0[TM, TK] = offxa[:, newaxis];
int32 offa1[TM, TK] = rka[newaxis, :] * lda_c;)"; int32 offa1[TM, TK] = rka[newaxis, :] * lda_c;)";
@@ -325,12 +345,22 @@ if(op_ == WGRAD && layout_ == CHWN){
int32 offa1[TK, TM] = rka[:, newaxis];)"; int32 offa1[TK, TM] = rka[:, newaxis];)";
} }
if(op_ == WGRAD && layout_ == NCHW){ if(op_ == WGRAD && layout_ == NCHW){
if(true){
result += R"(
int32 rawh[TM] = rka / NB;
int32 rab[TM] = rka % NB;
int32 raw[TM] = rawh % CW;
int32 rah[TM] = rawh / CW;)";
}
else{
result += R"(
int32 rabh[TM] = rka / CW;
int32 raw[TM] = rka % CW;
int32 rah[TM] = rabh % CH;
int32 rab[TM] = rabh / CH;)";
}
result += R"( result += R"(
int32 offa0[TK, TM] = rxa[newaxis, :] * lda_c; int32 offa0[TK, TM] = rxa[newaxis, :] * lda_c;
int32 rawh[TK] = rka / NB;
int32 rab[TK] = rka % NB;
int32 raw[TK] = (rawh % CW);
int32 rah[TK] = (rawh / CW);
int32 offxa[TK] = rab*lda_b + raw*lda_w + rah*lda_h; int32 offxa[TK] = rab*lda_b + raw*lda_w + rah*lda_h;
int32 offa1[TK, TM] = offxa[:, newaxis];)"; int32 offa1[TK, TM] = offxa[:, newaxis];)";
} }
@@ -347,14 +377,24 @@ if(op_ == BPROP){
int32 offb1[TK, TN] = rkb[:, newaxis];)"; int32 offb1[TK, TN] = rkb[:, newaxis];)";
} }
if(op_ == WGRAD){ if(op_ == WGRAD){
if(true){
result += R"(
int32 rbwh[TM] = rkb / NB;
int32 rbb[TM] = rkb % NB;
int32 rbw[TM] = rbwh % CW;
int32 rbh[TM] = rbwh / CW;)";
}
else{
result += R"(
int32 rbbh[TM] = rkb / CW;
int32 rbw[TM] = rkb % CW;
int32 rbh[TM] = rbbh % CH;
int32 rbb[TM] = rbbh / CH;)";
}
result += R"( result += R"(
__constant__ int32* pd[TN] = delta_a + ryb; __constant__ int32* pd[TN] = delta_a + ryb;
int32 d[TN] = *pd; int32 d[TN] = *pd;
int32 shift[TK, TN] = d[newaxis, :]; int32 shift[TK, TN] = d[newaxis, :];
int32 rbwh[TK] = rkb / NB;
int32 rbb[TK] = rkb % NB;
int32 rbw[TK] = rbwh % CW;
int32 rbh[TK] = rbwh / CW;
rbw = rbw * stride_w; rbw = rbw * stride_w;
rbh = rbh * stride_h; rbh = rbh * stride_h;
int32 offkb[TK] = rbb*ldb_b + rbw*ldb_w + rbh*ldb_h; int32 offkb[TK] = rbb*ldb_b + rbw*ldb_w + rbh*ldb_h;
@@ -406,12 +446,23 @@ if(op_ == WGRAD && layout_ == CHWN){
pa = pa + TK;)"; pa = pa + TK;)";
} }
if(op_ == WGRAD && layout_ == NCHW){ if(op_ == WGRAD && layout_ == NCHW){
result += R"(
rka = rka + TK;)";
if(true){
result += R"(
int32 rawh[TM] = rka / NB;
int32 rab[TM] = rka % NB;
int32 raw[TM] = rawh % CW;
int32 rah[TM] = rawh / CW;)";
}
else{
result += R"(
int32 rabh[TM] = rka / CW;
int32 raw[TM] = rka % CW;
int32 rah[TM] = rabh % CH;
int32 rab[TM] = rabh / CH;)";
}
result += R"( result += R"(
rka = rka + TK;
rawh = rka / NB;
rab = rka % NB;
raw = (rawh % CW);
rah = (rawh / CW);
offxa = rab*lda_b + raw*lda_w + rah*lda_h; offxa = rab*lda_b + raw*lda_w + rah*lda_h;
pa = A + offa0 + offxa[:, newaxis];)"; pa = A + offa0 + offxa[:, newaxis];)";
} }
@@ -420,12 +471,23 @@ if(op_ == WGRAD && layout_ == NCHW){
/* Increment B pointers */ /* Increment B pointers */
if(op_ == WGRAD){ if(op_ == WGRAD){
result += R"( result += R"(
rkb = rkb + TK; rkb = rkb + TK;)";
rbwh = rkb / NB; if(true){
rbb = rkb % NB; result += R"(
rbw = rbwh % CW; int32 rbwh[TM] = rkb / NB;
rbh = rbwh / CW; int32 rbb[TM] = rkb % NB;
int32 rbw[TM] = rbwh % CW;
int32 rbh[TM] = rbwh / CW;)";
}
else{
result += R"(
int32 rbbh[TM] = rkb / CW;
int32 rbw[TM] = rkb % CW;
int32 rbh[TM] = rbbh % CH;
int32 rbb[TM] = rbbh / CH;)";
}
result += R"(
rbw = rbw * stride_w; rbw = rbw * stride_w;
rbh = rbh * stride_h; rbh = rbh * stride_h;
offkb = rbb*ldb_b + rbw*ldb_w + rbh*ldb_h; offkb = rbb*ldb_b + rbw*ldb_w + rbh*ldb_h;
@@ -459,19 +521,41 @@ if(op_ == BPROP){
/* C offsets */ /* C offsets */
if(op_ == BPROP){ if(op_ == BPROP){
if(true){
result += R"(
int32 rcwh[TM] = rxc / NB;
int32 rcb[TM] = rxc % NB;
int32 rcw[TM] = rcwh % CW;
int32 rch[TM] = rcwh / CW;)";
}
else{
result += R"(
int32 rcbh[TM] = rxc / CW;
int32 rcw[TM] = rxc % CW;
int32 rch[TM] = rcbh % CH;
int32 rcb[TM] = rcbh / CH;)";
}
result += R"( result += R"(
int32 rcwh[TM] = rxc / NB; rcw = rcw * stride_w;
int32 rcb[TM] = rxc % NB; rch = rch * stride_h;
int32 rcw[TM] = (rcwh % CW) * stride_w;
int32 rch[TM] = (rcwh / CW) * stride_h;
int32 offxc[TM] = rcb*ldc_b + rcw*ldc_w + rch*ldc_h;)"; int32 offxc[TM] = rcb*ldc_b + rcw*ldc_w + rch*ldc_h;)";
} }
if(op_ == FPROP){ if(op_ == FPROP){
if(true){
result += R"(
int32 rcwh[TM] = rxc / NB;
int32 rcb[TM] = rxc % NB;
int32 rcw[TM] = rcwh % CW;
int32 rch[TM] = rcwh / CW;)";
}
else{
result += R"(
int32 rcbh[TM] = rxc / CW;
int32 rcw[TM] = rxc % CW;
int32 rch[TM] = rcbh % CH;
int32 rcb[TM] = rcbh / CH;)";
}
result += R"( result += R"(
int32 rcwh[TM] = rxc / NB;
int32 rcb[TM] = rxc % NB;
int32 rcw[TM] = (rcwh % CW);
int32 rch[TM] = (rcwh / CW);
int32 offxc[TM] = rcb*ldc_b + rcw*ldc_w + rch*ldc_h;)"; int32 offxc[TM] = rcb*ldc_b + rcw*ldc_w + rch*ldc_h;)";
} }
if(op_ == WGRAD){ if(op_ == WGRAD){