[dnn/shift] optimizations for NCHW layout

This commit is contained in:
Philippe Tillet
2019-07-12 20:22:32 -07:00
parent 54617b4e51
commit fe42cb7142
3 changed files with 43 additions and 42 deletions

View File

@@ -33,7 +33,7 @@ int main() {
triton::dnn::shift shift(B, C, 1, H, W, 1, R, S, F, 1, 1,
shift_h.data(), shift_w.data(),
numeric_t_str, numeric_t_str,
op, false);
op, false, triton::dnn::shift::NCHW);
// host buffers
std::vector<NumericT> ha(B*C*H*W);
std::vector<NumericT> hb(C*F);

View File

@@ -24,7 +24,7 @@ base::base(const std::string& name)
void base::enqueue(driver::stream *stream, std::vector<driver::buffer *> args) {
static std::map<base*, std::unique_ptr<triton::jit>, cmp_recompile> m_jit;
bool autotune = false;
bool autotune = true;
driver::context* ctx = stream->context();
triton::jit* jit;
/* the current template has not already been compiled */

View File

@@ -252,6 +252,7 @@ void shift::triton_c_src(std::ostream &os) const {
}
std::string AS = AS0 + ", " + AS1;
std::string BS = BS0 + ", " + BS1;
bool is_chwn = layout_ == CHWN;
std::string result =
R"(
@@ -283,7 +284,7 @@ void shift(restrict read_only align(16) )" + a_ty_ + R"( *A,
/* A offsets */
if(op_ == FPROP){
if(true){
if(is_chwn){
result += R"(
int32 rawh[TM] = rxa / NB;
int32 rab[TM] = rxa % NB;
@@ -320,7 +321,7 @@ if(op_ == FPROP){
int32 offa1[TM, TK] = interior ? offa_interior : offa_exterior;)";
}
if(op_ == BPROP){
if(true){
if(is_chwn){
result += R"(
int32 rawh[TM] = rxa / NB;
int32 rab[TM] = rxa % NB;
@@ -345,19 +346,19 @@ if(op_ == WGRAD && layout_ == CHWN){
int32 offa1[TK, TM] = rka[:, newaxis];)";
}
if(op_ == WGRAD && layout_ == NCHW){
if(true){
if(is_chwn){
result += R"(
int32 rawh[TM] = rka / NB;
int32 rab[TM] = rka % NB;
int32 raw[TM] = rawh % CW;
int32 rah[TM] = rawh / CW;)";
int32 rawh[TK] = rka / NB;
int32 rab[TK] = rka % NB;
int32 raw[TK] = rawh % CW;
int32 rah[TK] = rawh / CW;)";
}
else{
result += R"(
int32 rabh[TM] = rka / CW;
int32 raw[TM] = rka % CW;
int32 rah[TM] = rabh % CH;
int32 rab[TM] = rabh / CH;)";
int32 rabh[TK] = rka / CW;
int32 raw[TK] = rka % CW;
int32 rah[TK] = rabh % CH;
int32 rab[TK] = rabh / CH;)";
}
result += R"(
int32 offa0[TK, TM] = rxa[newaxis, :] * lda_c;
@@ -377,19 +378,19 @@ if(op_ == BPROP){
int32 offb1[TK, TN] = rkb[:, newaxis];)";
}
if(op_ == WGRAD){
if(true){
if(is_chwn){
result += R"(
int32 rbwh[TM] = rkb / NB;
int32 rbb[TM] = rkb % NB;
int32 rbw[TM] = rbwh % CW;
int32 rbh[TM] = rbwh / CW;)";
int32 rbwh[TK] = rkb / NB;
int32 rbb[TK] = rkb % NB;
int32 rbw[TK] = rbwh % CW;
int32 rbh[TK] = rbwh / CW;)";
}
else{
result += R"(
int32 rbbh[TM] = rkb / CW;
int32 rbw[TM] = rkb % CW;
int32 rbh[TM] = rbbh % CH;
int32 rbb[TM] = rbbh / CH;)";
int32 rbbh[TK] = rkb / CW;
int32 rbw[TK] = rkb % CW;
int32 rbh[TK] = rbbh % CH;
int32 rbb[TK] = rbbh / CH;)";
}
result += R"(
__constant__ int32* pd[TN] = delta_a + ryb;
@@ -448,19 +449,19 @@ if(op_ == WGRAD && layout_ == CHWN){
if(op_ == WGRAD && layout_ == NCHW){
result += R"(
rka = rka + TK;)";
if(true){
if(is_chwn){
result += R"(
int32 rawh[TM] = rka / NB;
int32 rab[TM] = rka % NB;
int32 raw[TM] = rawh % CW;
int32 rah[TM] = rawh / CW;)";
int32 rawh[TK] = rka / NB;
int32 rab[TK] = rka % NB;
int32 raw[TK] = rawh % CW;
int32 rah[TK] = rawh / CW;)";
}
else{
result += R"(
int32 rabh[TM] = rka / CW;
int32 raw[TM] = rka % CW;
int32 rah[TM] = rabh % CH;
int32 rab[TM] = rabh / CH;)";
int32 rabh[TK] = rka / CW;
int32 raw[TK] = rka % CW;
int32 rah[TK] = rabh % CH;
int32 rab[TK] = rabh / CH;)";
}
result += R"(
offxa = rab*lda_b + raw*lda_w + rah*lda_h;
@@ -473,19 +474,19 @@ if(op_ == WGRAD && layout_ == NCHW){
if(op_ == WGRAD){
result += R"(
rkb = rkb + TK;)";
if(true){
if(is_chwn){
result += R"(
int32 rbwh[TM] = rkb / NB;
int32 rbb[TM] = rkb % NB;
int32 rbw[TM] = rbwh % CW;
int32 rbh[TM] = rbwh / CW;)";
int32 rbwh[TK] = rkb / NB;
int32 rbb[TK] = rkb % NB;
int32 rbw[TK] = rbwh % CW;
int32 rbh[TK] = rbwh / CW;)";
}
else{
result += R"(
int32 rbbh[TM] = rkb / CW;
int32 rbw[TM] = rkb % CW;
int32 rbh[TM] = rbbh % CH;
int32 rbb[TM] = rbbh / CH;)";
int32 rbbh[TK] = rkb / CW;
int32 rbw[TK] = rkb % CW;
int32 rbh[TK] = rbbh % CH;
int32 rbb[TK] = rbbh / CH;)";
}
result += R"(
rbw = rbw * stride_w;
@@ -521,7 +522,7 @@ if(op_ == BPROP){
/* C offsets */
if(op_ == BPROP){
if(true){
if(is_chwn){
result += R"(
int32 rcwh[TM] = rxc / NB;
int32 rcb[TM] = rxc % NB;
@@ -541,7 +542,7 @@ if(op_ == BPROP){
int32 offxc[TM] = rcb*ldc_b + rcw*ldc_w + rch*ldc_h;)";
}
if(op_ == FPROP){
if(true){
if(is_chwn){
result += R"(
int32 rcwh[TM] = rxc / NB;
int32 rcb[TM] = rxc % NB;