From 54617b4e516d5fc45a9c193db00ca5057a305d23 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Fri, 12 Jul 2019 20:10:15 -0700 Subject: [PATCH] some cleaning --- lib/dnn/shift.cpp | 154 +++++++++++++++++++++++++++++++++++----------- 1 file changed, 119 insertions(+), 35 deletions(-) diff --git a/lib/dnn/shift.cpp b/lib/dnn/shift.cpp index f2502db70..72e8395ba 100644 --- a/lib/dnn/shift.cpp +++ b/lib/dnn/shift.cpp @@ -283,11 +283,21 @@ void shift(restrict read_only align(16) )" + a_ty_ + R"( *A, /* A offsets */ if(op_ == FPROP){ + if(true){ + result += R"( + int32 rawh[TM] = rxa / NB; + int32 rab[TM] = rxa % NB; + int32 raw[TM] = rawh % CW; + int32 rah[TM] = rawh / CW;)"; + } + else{ + result += R"( + int32 rabh[TM] = rxa / CW; + int32 raw[TM] = rxa % CW; + int32 rah[TM] = rabh % CH; + int32 rab[TM] = rabh / CH;)"; + } result += R"( - int32 rawh[TM] = rxa / NB; - int32 rab[TM] = rxa % NB; - int32 raw[TM] = rawh % CW; - int32 rah[TM] = rawh / CW; raw = raw * stride_w; rah = rah * stride_h; int32 offxa[TM] = rab*lda_b + raw*lda_w + rah*lda_h; @@ -310,11 +320,21 @@ if(op_ == FPROP){ int32 offa1[TM, TK] = interior ? offa_interior : offa_exterior;)"; } if(op_ == BPROP){ + if(true){ + result += R"( + int32 rawh[TM] = rxa / NB; + int32 rab[TM] = rxa % NB; + int32 raw[TM] = rawh % CW; + int32 rah[TM] = rawh / CW;)"; + } + else{ + result += R"( + int32 rabh[TM] = rxa / CW; + int32 raw[TM] = rxa % CW; + int32 rah[TM] = rabh % CH; + int32 rab[TM] = rabh / CH;)"; + } result += R"( - int32 rawh[TM] = rxa / NB; - int32 rab[TM] = rxa % NB; - int32 raw[TM] = (rawh % CW); - int32 rah[TM] = (rawh / CW); int32 offxa[TM] = rab*lda_b + raw*lda_w + rah*lda_h; int32 offa0[TM, TK] = offxa[:, newaxis]; int32 offa1[TM, TK] = rka[newaxis, :] * lda_c;)"; @@ -325,12 +345,22 @@ if(op_ == WGRAD && layout_ == CHWN){ int32 offa1[TK, TM] = rka[:, newaxis];)"; } if(op_ == WGRAD && layout_ == NCHW){ + if(true){ + result += R"( + int32 rawh[TM] = rka / NB; + int32 rab[TM] = rka % NB; + int32 raw[TM] = rawh % CW; + int32 rah[TM] = rawh / CW;)"; + } + else{ + result += R"( + int32 rabh[TM] = rka / CW; + int32 raw[TM] = rka % CW; + int32 rah[TM] = rabh % CH; + int32 rab[TM] = rabh / CH;)"; + } result += R"( int32 offa0[TK, TM] = rxa[newaxis, :] * lda_c; - int32 rawh[TK] = rka / NB; - int32 rab[TK] = rka % NB; - int32 raw[TK] = (rawh % CW); - int32 rah[TK] = (rawh / CW); int32 offxa[TK] = rab*lda_b + raw*lda_w + rah*lda_h; int32 offa1[TK, TM] = offxa[:, newaxis];)"; } @@ -347,14 +377,24 @@ if(op_ == BPROP){ int32 offb1[TK, TN] = rkb[:, newaxis];)"; } if(op_ == WGRAD){ + if(true){ + result += R"( + int32 rbwh[TM] = rkb / NB; + int32 rbb[TM] = rkb % NB; + int32 rbw[TM] = rbwh % CW; + int32 rbh[TM] = rbwh / CW;)"; + } + else{ + result += R"( + int32 rbbh[TM] = rkb / CW; + int32 rbw[TM] = rkb % CW; + int32 rbh[TM] = rbbh % CH; + int32 rbb[TM] = rbbh / CH;)"; + } result += R"( __constant__ int32* pd[TN] = delta_a + ryb; int32 d[TN] = *pd; int32 shift[TK, TN] = d[newaxis, :]; - int32 rbwh[TK] = rkb / NB; - int32 rbb[TK] = rkb % NB; - int32 rbw[TK] = rbwh % CW; - int32 rbh[TK] = rbwh / CW; rbw = rbw * stride_w; rbh = rbh * stride_h; int32 offkb[TK] = rbb*ldb_b + rbw*ldb_w + rbh*ldb_h; @@ -406,12 +446,23 @@ if(op_ == WGRAD && layout_ == CHWN){ pa = pa + TK;)"; } if(op_ == WGRAD && layout_ == NCHW){ + result += R"( + rka = rka + TK;)"; + if(true){ + result += R"( + int32 rawh[TM] = rka / NB; + int32 rab[TM] = rka % NB; + int32 raw[TM] = rawh % CW; + int32 rah[TM] = rawh / CW;)"; + } + else{ + result += R"( + int32 rabh[TM] = rka / CW; + int32 raw[TM] = rka % CW; + int32 rah[TM] = rabh % CH; + int32 rab[TM] = rabh / CH;)"; + } result += R"( - rka = rka + TK; - rawh = rka / NB; - rab = rka % NB; - raw = (rawh % CW); - rah = (rawh / CW); offxa = rab*lda_b + raw*lda_w + rah*lda_h; pa = A + offa0 + offxa[:, newaxis];)"; } @@ -420,12 +471,23 @@ if(op_ == WGRAD && layout_ == NCHW){ /* Increment B pointers */ if(op_ == WGRAD){ - result += R"( - rkb = rkb + TK; - rbwh = rkb / NB; - rbb = rkb % NB; - rbw = rbwh % CW; - rbh = rbwh / CW; + result += R"( + rkb = rkb + TK;)"; + if(true){ + result += R"( + int32 rbwh[TM] = rkb / NB; + int32 rbb[TM] = rkb % NB; + int32 rbw[TM] = rbwh % CW; + int32 rbh[TM] = rbwh / CW;)"; + } + else{ + result += R"( + int32 rbbh[TM] = rkb / CW; + int32 rbw[TM] = rkb % CW; + int32 rbh[TM] = rbbh % CH; + int32 rbb[TM] = rbbh / CH;)"; + } + result += R"( rbw = rbw * stride_w; rbh = rbh * stride_h; offkb = rbb*ldb_b + rbw*ldb_w + rbh*ldb_h; @@ -459,19 +521,41 @@ if(op_ == BPROP){ /* C offsets */ if(op_ == BPROP){ + if(true){ + result += R"( + int32 rcwh[TM] = rxc / NB; + int32 rcb[TM] = rxc % NB; + int32 rcw[TM] = rcwh % CW; + int32 rch[TM] = rcwh / CW;)"; + } + else{ + result += R"( + int32 rcbh[TM] = rxc / CW; + int32 rcw[TM] = rxc % CW; + int32 rch[TM] = rcbh % CH; + int32 rcb[TM] = rcbh / CH;)"; + } result += R"( - int32 rcwh[TM] = rxc / NB; - int32 rcb[TM] = rxc % NB; - int32 rcw[TM] = (rcwh % CW) * stride_w; - int32 rch[TM] = (rcwh / CW) * stride_h; + rcw = rcw * stride_w; + rch = rch * stride_h; int32 offxc[TM] = rcb*ldc_b + rcw*ldc_w + rch*ldc_h;)"; } if(op_ == FPROP){ +if(true){ + result += R"( + int32 rcwh[TM] = rxc / NB; + int32 rcb[TM] = rxc % NB; + int32 rcw[TM] = rcwh % CW; + int32 rch[TM] = rcwh / CW;)"; +} +else{ + result += R"( + int32 rcbh[TM] = rxc / CW; + int32 rcw[TM] = rxc % CW; + int32 rch[TM] = rcbh % CH; + int32 rcb[TM] = rcbh / CH;)"; +} result += R"( - int32 rcwh[TM] = rxc / NB; - int32 rcb[TM] = rxc % NB; - int32 rcw[TM] = (rcwh % CW); - int32 rch[TM] = (rcwh / CW); int32 offxc[TM] = rcb*ldc_b + rcw*ldc_w + rch*ldc_h;)"; } if(op_ == WGRAD){