some cleaning
This commit is contained in:
@@ -283,11 +283,21 @@ void shift(restrict read_only align(16) )" + a_ty_ + R"( *A,
|
||||
|
||||
/* A offsets */
|
||||
if(op_ == FPROP){
|
||||
if(true){
|
||||
result += R"(
|
||||
int32 rawh[TM] = rxa / NB;
|
||||
int32 rab[TM] = rxa % NB;
|
||||
int32 raw[TM] = rawh % CW;
|
||||
int32 rah[TM] = rawh / CW;)";
|
||||
}
|
||||
else{
|
||||
result += R"(
|
||||
int32 rabh[TM] = rxa / CW;
|
||||
int32 raw[TM] = rxa % CW;
|
||||
int32 rah[TM] = rabh % CH;
|
||||
int32 rab[TM] = rabh / CH;)";
|
||||
}
|
||||
result += R"(
|
||||
int32 rawh[TM] = rxa / NB;
|
||||
int32 rab[TM] = rxa % NB;
|
||||
int32 raw[TM] = rawh % CW;
|
||||
int32 rah[TM] = rawh / CW;
|
||||
raw = raw * stride_w;
|
||||
rah = rah * stride_h;
|
||||
int32 offxa[TM] = rab*lda_b + raw*lda_w + rah*lda_h;
|
||||
@@ -310,11 +320,21 @@ if(op_ == FPROP){
|
||||
int32 offa1[TM, TK] = interior ? offa_interior : offa_exterior;)";
|
||||
}
|
||||
if(op_ == BPROP){
|
||||
if(true){
|
||||
result += R"(
|
||||
int32 rawh[TM] = rxa / NB;
|
||||
int32 rab[TM] = rxa % NB;
|
||||
int32 raw[TM] = rawh % CW;
|
||||
int32 rah[TM] = rawh / CW;)";
|
||||
}
|
||||
else{
|
||||
result += R"(
|
||||
int32 rabh[TM] = rxa / CW;
|
||||
int32 raw[TM] = rxa % CW;
|
||||
int32 rah[TM] = rabh % CH;
|
||||
int32 rab[TM] = rabh / CH;)";
|
||||
}
|
||||
result += R"(
|
||||
int32 rawh[TM] = rxa / NB;
|
||||
int32 rab[TM] = rxa % NB;
|
||||
int32 raw[TM] = (rawh % CW);
|
||||
int32 rah[TM] = (rawh / CW);
|
||||
int32 offxa[TM] = rab*lda_b + raw*lda_w + rah*lda_h;
|
||||
int32 offa0[TM, TK] = offxa[:, newaxis];
|
||||
int32 offa1[TM, TK] = rka[newaxis, :] * lda_c;)";
|
||||
@@ -325,12 +345,22 @@ if(op_ == WGRAD && layout_ == CHWN){
|
||||
int32 offa1[TK, TM] = rka[:, newaxis];)";
|
||||
}
|
||||
if(op_ == WGRAD && layout_ == NCHW){
|
||||
if(true){
|
||||
result += R"(
|
||||
int32 rawh[TM] = rka / NB;
|
||||
int32 rab[TM] = rka % NB;
|
||||
int32 raw[TM] = rawh % CW;
|
||||
int32 rah[TM] = rawh / CW;)";
|
||||
}
|
||||
else{
|
||||
result += R"(
|
||||
int32 rabh[TM] = rka / CW;
|
||||
int32 raw[TM] = rka % CW;
|
||||
int32 rah[TM] = rabh % CH;
|
||||
int32 rab[TM] = rabh / CH;)";
|
||||
}
|
||||
result += R"(
|
||||
int32 offa0[TK, TM] = rxa[newaxis, :] * lda_c;
|
||||
int32 rawh[TK] = rka / NB;
|
||||
int32 rab[TK] = rka % NB;
|
||||
int32 raw[TK] = (rawh % CW);
|
||||
int32 rah[TK] = (rawh / CW);
|
||||
int32 offxa[TK] = rab*lda_b + raw*lda_w + rah*lda_h;
|
||||
int32 offa1[TK, TM] = offxa[:, newaxis];)";
|
||||
}
|
||||
@@ -347,14 +377,24 @@ if(op_ == BPROP){
|
||||
int32 offb1[TK, TN] = rkb[:, newaxis];)";
|
||||
}
|
||||
if(op_ == WGRAD){
|
||||
if(true){
|
||||
result += R"(
|
||||
int32 rbwh[TM] = rkb / NB;
|
||||
int32 rbb[TM] = rkb % NB;
|
||||
int32 rbw[TM] = rbwh % CW;
|
||||
int32 rbh[TM] = rbwh / CW;)";
|
||||
}
|
||||
else{
|
||||
result += R"(
|
||||
int32 rbbh[TM] = rkb / CW;
|
||||
int32 rbw[TM] = rkb % CW;
|
||||
int32 rbh[TM] = rbbh % CH;
|
||||
int32 rbb[TM] = rbbh / CH;)";
|
||||
}
|
||||
result += R"(
|
||||
__constant__ int32* pd[TN] = delta_a + ryb;
|
||||
int32 d[TN] = *pd;
|
||||
int32 shift[TK, TN] = d[newaxis, :];
|
||||
int32 rbwh[TK] = rkb / NB;
|
||||
int32 rbb[TK] = rkb % NB;
|
||||
int32 rbw[TK] = rbwh % CW;
|
||||
int32 rbh[TK] = rbwh / CW;
|
||||
rbw = rbw * stride_w;
|
||||
rbh = rbh * stride_h;
|
||||
int32 offkb[TK] = rbb*ldb_b + rbw*ldb_w + rbh*ldb_h;
|
||||
@@ -406,12 +446,23 @@ if(op_ == WGRAD && layout_ == CHWN){
|
||||
pa = pa + TK;)";
|
||||
}
|
||||
if(op_ == WGRAD && layout_ == NCHW){
|
||||
result += R"(
|
||||
rka = rka + TK;)";
|
||||
if(true){
|
||||
result += R"(
|
||||
int32 rawh[TM] = rka / NB;
|
||||
int32 rab[TM] = rka % NB;
|
||||
int32 raw[TM] = rawh % CW;
|
||||
int32 rah[TM] = rawh / CW;)";
|
||||
}
|
||||
else{
|
||||
result += R"(
|
||||
int32 rabh[TM] = rka / CW;
|
||||
int32 raw[TM] = rka % CW;
|
||||
int32 rah[TM] = rabh % CH;
|
||||
int32 rab[TM] = rabh / CH;)";
|
||||
}
|
||||
result += R"(
|
||||
rka = rka + TK;
|
||||
rawh = rka / NB;
|
||||
rab = rka % NB;
|
||||
raw = (rawh % CW);
|
||||
rah = (rawh / CW);
|
||||
offxa = rab*lda_b + raw*lda_w + rah*lda_h;
|
||||
pa = A + offa0 + offxa[:, newaxis];)";
|
||||
}
|
||||
@@ -420,12 +471,23 @@ if(op_ == WGRAD && layout_ == NCHW){
|
||||
|
||||
/* Increment B pointers */
|
||||
if(op_ == WGRAD){
|
||||
result += R"(
|
||||
rkb = rkb + TK;
|
||||
rbwh = rkb / NB;
|
||||
rbb = rkb % NB;
|
||||
rbw = rbwh % CW;
|
||||
rbh = rbwh / CW;
|
||||
result += R"(
|
||||
rkb = rkb + TK;)";
|
||||
if(true){
|
||||
result += R"(
|
||||
int32 rbwh[TM] = rkb / NB;
|
||||
int32 rbb[TM] = rkb % NB;
|
||||
int32 rbw[TM] = rbwh % CW;
|
||||
int32 rbh[TM] = rbwh / CW;)";
|
||||
}
|
||||
else{
|
||||
result += R"(
|
||||
int32 rbbh[TM] = rkb / CW;
|
||||
int32 rbw[TM] = rkb % CW;
|
||||
int32 rbh[TM] = rbbh % CH;
|
||||
int32 rbb[TM] = rbbh / CH;)";
|
||||
}
|
||||
result += R"(
|
||||
rbw = rbw * stride_w;
|
||||
rbh = rbh * stride_h;
|
||||
offkb = rbb*ldb_b + rbw*ldb_w + rbh*ldb_h;
|
||||
@@ -459,19 +521,41 @@ if(op_ == BPROP){
|
||||
|
||||
/* C offsets */
|
||||
if(op_ == BPROP){
|
||||
if(true){
|
||||
result += R"(
|
||||
int32 rcwh[TM] = rxc / NB;
|
||||
int32 rcb[TM] = rxc % NB;
|
||||
int32 rcw[TM] = rcwh % CW;
|
||||
int32 rch[TM] = rcwh / CW;)";
|
||||
}
|
||||
else{
|
||||
result += R"(
|
||||
int32 rcbh[TM] = rxc / CW;
|
||||
int32 rcw[TM] = rxc % CW;
|
||||
int32 rch[TM] = rcbh % CH;
|
||||
int32 rcb[TM] = rcbh / CH;)";
|
||||
}
|
||||
result += R"(
|
||||
int32 rcwh[TM] = rxc / NB;
|
||||
int32 rcb[TM] = rxc % NB;
|
||||
int32 rcw[TM] = (rcwh % CW) * stride_w;
|
||||
int32 rch[TM] = (rcwh / CW) * stride_h;
|
||||
rcw = rcw * stride_w;
|
||||
rch = rch * stride_h;
|
||||
int32 offxc[TM] = rcb*ldc_b + rcw*ldc_w + rch*ldc_h;)";
|
||||
}
|
||||
if(op_ == FPROP){
|
||||
if(true){
|
||||
result += R"(
|
||||
int32 rcwh[TM] = rxc / NB;
|
||||
int32 rcb[TM] = rxc % NB;
|
||||
int32 rcw[TM] = rcwh % CW;
|
||||
int32 rch[TM] = rcwh / CW;)";
|
||||
}
|
||||
else{
|
||||
result += R"(
|
||||
int32 rcbh[TM] = rxc / CW;
|
||||
int32 rcw[TM] = rxc % CW;
|
||||
int32 rch[TM] = rcbh % CH;
|
||||
int32 rcb[TM] = rcbh / CH;)";
|
||||
}
|
||||
result += R"(
|
||||
int32 rcwh[TM] = rxc / NB;
|
||||
int32 rcb[TM] = rxc % NB;
|
||||
int32 rcw[TM] = (rcwh % CW);
|
||||
int32 rch[TM] = (rcwh / CW);
|
||||
int32 offxc[TM] = rcb*ldc_b + rcw*ldc_w + rch*ldc_h;)";
|
||||
}
|
||||
if(op_ == WGRAD){
|
||||
|
Reference in New Issue
Block a user