[dnn/shift] optimizations for NCHW layout
This commit is contained in:
@@ -33,7 +33,7 @@ int main() {
|
||||
triton::dnn::shift shift(B, C, 1, H, W, 1, R, S, F, 1, 1,
|
||||
shift_h.data(), shift_w.data(),
|
||||
numeric_t_str, numeric_t_str,
|
||||
op, false);
|
||||
op, false, triton::dnn::shift::NCHW);
|
||||
// host buffers
|
||||
std::vector<NumericT> ha(B*C*H*W);
|
||||
std::vector<NumericT> hb(C*F);
|
||||
|
@@ -24,7 +24,7 @@ base::base(const std::string& name)
|
||||
|
||||
void base::enqueue(driver::stream *stream, std::vector<driver::buffer *> args) {
|
||||
static std::map<base*, std::unique_ptr<triton::jit>, cmp_recompile> m_jit;
|
||||
bool autotune = false;
|
||||
bool autotune = true;
|
||||
driver::context* ctx = stream->context();
|
||||
triton::jit* jit;
|
||||
/* the current template has not already been compiled */
|
||||
|
@@ -252,6 +252,7 @@ void shift::triton_c_src(std::ostream &os) const {
|
||||
}
|
||||
std::string AS = AS0 + ", " + AS1;
|
||||
std::string BS = BS0 + ", " + BS1;
|
||||
bool is_chwn = layout_ == CHWN;
|
||||
|
||||
std::string result =
|
||||
R"(
|
||||
@@ -283,7 +284,7 @@ void shift(restrict read_only align(16) )" + a_ty_ + R"( *A,
|
||||
|
||||
/* A offsets */
|
||||
if(op_ == FPROP){
|
||||
if(true){
|
||||
if(is_chwn){
|
||||
result += R"(
|
||||
int32 rawh[TM] = rxa / NB;
|
||||
int32 rab[TM] = rxa % NB;
|
||||
@@ -320,7 +321,7 @@ if(op_ == FPROP){
|
||||
int32 offa1[TM, TK] = interior ? offa_interior : offa_exterior;)";
|
||||
}
|
||||
if(op_ == BPROP){
|
||||
if(true){
|
||||
if(is_chwn){
|
||||
result += R"(
|
||||
int32 rawh[TM] = rxa / NB;
|
||||
int32 rab[TM] = rxa % NB;
|
||||
@@ -345,19 +346,19 @@ if(op_ == WGRAD && layout_ == CHWN){
|
||||
int32 offa1[TK, TM] = rka[:, newaxis];)";
|
||||
}
|
||||
if(op_ == WGRAD && layout_ == NCHW){
|
||||
if(true){
|
||||
if(is_chwn){
|
||||
result += R"(
|
||||
int32 rawh[TM] = rka / NB;
|
||||
int32 rab[TM] = rka % NB;
|
||||
int32 raw[TM] = rawh % CW;
|
||||
int32 rah[TM] = rawh / CW;)";
|
||||
int32 rawh[TK] = rka / NB;
|
||||
int32 rab[TK] = rka % NB;
|
||||
int32 raw[TK] = rawh % CW;
|
||||
int32 rah[TK] = rawh / CW;)";
|
||||
}
|
||||
else{
|
||||
result += R"(
|
||||
int32 rabh[TM] = rka / CW;
|
||||
int32 raw[TM] = rka % CW;
|
||||
int32 rah[TM] = rabh % CH;
|
||||
int32 rab[TM] = rabh / CH;)";
|
||||
int32 rabh[TK] = rka / CW;
|
||||
int32 raw[TK] = rka % CW;
|
||||
int32 rah[TK] = rabh % CH;
|
||||
int32 rab[TK] = rabh / CH;)";
|
||||
}
|
||||
result += R"(
|
||||
int32 offa0[TK, TM] = rxa[newaxis, :] * lda_c;
|
||||
@@ -377,19 +378,19 @@ if(op_ == BPROP){
|
||||
int32 offb1[TK, TN] = rkb[:, newaxis];)";
|
||||
}
|
||||
if(op_ == WGRAD){
|
||||
if(true){
|
||||
if(is_chwn){
|
||||
result += R"(
|
||||
int32 rbwh[TM] = rkb / NB;
|
||||
int32 rbb[TM] = rkb % NB;
|
||||
int32 rbw[TM] = rbwh % CW;
|
||||
int32 rbh[TM] = rbwh / CW;)";
|
||||
int32 rbwh[TK] = rkb / NB;
|
||||
int32 rbb[TK] = rkb % NB;
|
||||
int32 rbw[TK] = rbwh % CW;
|
||||
int32 rbh[TK] = rbwh / CW;)";
|
||||
}
|
||||
else{
|
||||
result += R"(
|
||||
int32 rbbh[TM] = rkb / CW;
|
||||
int32 rbw[TM] = rkb % CW;
|
||||
int32 rbh[TM] = rbbh % CH;
|
||||
int32 rbb[TM] = rbbh / CH;)";
|
||||
int32 rbbh[TK] = rkb / CW;
|
||||
int32 rbw[TK] = rkb % CW;
|
||||
int32 rbh[TK] = rbbh % CH;
|
||||
int32 rbb[TK] = rbbh / CH;)";
|
||||
}
|
||||
result += R"(
|
||||
__constant__ int32* pd[TN] = delta_a + ryb;
|
||||
@@ -448,19 +449,19 @@ if(op_ == WGRAD && layout_ == CHWN){
|
||||
if(op_ == WGRAD && layout_ == NCHW){
|
||||
result += R"(
|
||||
rka = rka + TK;)";
|
||||
if(true){
|
||||
if(is_chwn){
|
||||
result += R"(
|
||||
int32 rawh[TM] = rka / NB;
|
||||
int32 rab[TM] = rka % NB;
|
||||
int32 raw[TM] = rawh % CW;
|
||||
int32 rah[TM] = rawh / CW;)";
|
||||
int32 rawh[TK] = rka / NB;
|
||||
int32 rab[TK] = rka % NB;
|
||||
int32 raw[TK] = rawh % CW;
|
||||
int32 rah[TK] = rawh / CW;)";
|
||||
}
|
||||
else{
|
||||
result += R"(
|
||||
int32 rabh[TM] = rka / CW;
|
||||
int32 raw[TM] = rka % CW;
|
||||
int32 rah[TM] = rabh % CH;
|
||||
int32 rab[TM] = rabh / CH;)";
|
||||
int32 rabh[TK] = rka / CW;
|
||||
int32 raw[TK] = rka % CW;
|
||||
int32 rah[TK] = rabh % CH;
|
||||
int32 rab[TK] = rabh / CH;)";
|
||||
}
|
||||
result += R"(
|
||||
offxa = rab*lda_b + raw*lda_w + rah*lda_h;
|
||||
@@ -473,19 +474,19 @@ if(op_ == WGRAD && layout_ == NCHW){
|
||||
if(op_ == WGRAD){
|
||||
result += R"(
|
||||
rkb = rkb + TK;)";
|
||||
if(true){
|
||||
if(is_chwn){
|
||||
result += R"(
|
||||
int32 rbwh[TM] = rkb / NB;
|
||||
int32 rbb[TM] = rkb % NB;
|
||||
int32 rbw[TM] = rbwh % CW;
|
||||
int32 rbh[TM] = rbwh / CW;)";
|
||||
int32 rbwh[TK] = rkb / NB;
|
||||
int32 rbb[TK] = rkb % NB;
|
||||
int32 rbw[TK] = rbwh % CW;
|
||||
int32 rbh[TK] = rbwh / CW;)";
|
||||
}
|
||||
else{
|
||||
result += R"(
|
||||
int32 rbbh[TM] = rkb / CW;
|
||||
int32 rbw[TM] = rkb % CW;
|
||||
int32 rbh[TM] = rbbh % CH;
|
||||
int32 rbb[TM] = rbbh / CH;)";
|
||||
int32 rbbh[TK] = rkb / CW;
|
||||
int32 rbw[TK] = rkb % CW;
|
||||
int32 rbh[TK] = rbbh % CH;
|
||||
int32 rbb[TK] = rbbh / CH;)";
|
||||
}
|
||||
result += R"(
|
||||
rbw = rbw * stride_w;
|
||||
@@ -521,7 +522,7 @@ if(op_ == BPROP){
|
||||
|
||||
/* C offsets */
|
||||
if(op_ == BPROP){
|
||||
if(true){
|
||||
if(is_chwn){
|
||||
result += R"(
|
||||
int32 rcwh[TM] = rxc / NB;
|
||||
int32 rcb[TM] = rxc % NB;
|
||||
@@ -541,7 +542,7 @@ if(op_ == BPROP){
|
||||
int32 offxc[TM] = rcb*ldc_b + rcw*ldc_w + rch*ldc_h;)";
|
||||
}
|
||||
if(op_ == FPROP){
|
||||
if(true){
|
||||
if(is_chwn){
|
||||
result += R"(
|
||||
int32 rcwh[TM] = rxc / NB;
|
||||
int32 rcb[TM] = rxc % NB;
|
||||
|
Reference in New Issue
Block a user