[dnn/conv]: skeleton for NCHW layout

This commit is contained in:
Philippe Tillet
2019-07-11 20:34:38 -07:00
parent 207e021973
commit fe8caf12f0
2 changed files with 83 additions and 18 deletions

View File

@@ -44,6 +44,11 @@ public:
WGRAD
};
enum layout_t {
NCHW,
CHWN
};
private:
// initialize and enqueue
void init_impl(driver::stream *stream, driver::cu_module *module);
@@ -154,6 +159,8 @@ private:
// transpose
bool AT_;
bool BT_;
// layout
layout_t layout_;
};
}

View File

@@ -22,7 +22,8 @@ shift::shift(int B, int C,
stride_d_(1), stride_h_(stride_h), stride_w_(stride_w),
shift_h_(shift_h), shift_w_(shift_w),
a_ty_(a_ty), b_ty_(b_ty),
ty_(ty), bias_(bias) {
ty_(ty), bias_(bias),
layout_(CHWN){
// max number of channels
TK_ = 16;
MAX_C_ = 8192 + TK_;
@@ -31,22 +32,48 @@ shift::shift(int B, int C,
CH_ = AH_ / stride_h_;
CW_ = AW_ / stride_w_;
// A memory strides: [C, H, W, B]
lda_n_ = 1;
lda_w_ = B_;
lda_h_ = B_*AW_;
lda_c_ = B_*AW_*AH_;
switch(layout_){
case CHWN: {
lda_n_ = 1;
lda_w_ = B_;
lda_h_ = B_*AW_;
lda_c_ = B_*AW_*AH_;
break;
}
case NCHW: {
lda_w_ = 1;
lda_h_ = AW_;
lda_c_ = AW_*AH_;
lda_n_ = AW_*AH_*C_;
break;
}
default:
throw std::runtime_error("unsupported input layout");
}
// B memory strides: [C, F]
ldb_n_ = 1;
ldb_h_ = 1;
ldb_w_ = 1;
ldb_c_ = F_;
// C memory strides: [F, H, W, B]
ldc_n_ = 1;
ldc_w_ = B_;
ldc_h_ = B_*CW_;
ldc_f_ = B_*CW_*CH_;
// C shapes
shapes_c_ = {F, CH_, CW_, B};
switch(layout_){
case CHWN: {
ldc_n_ = 1;
ldc_w_ = B_;
ldc_h_ = B_*CW_;
ldc_f_ = B_*CW_*CH_;
break;
}
case NCHW: {
ldc_w_ = 1;
ldc_h_ = CW_;
ldc_f_ = CW_*CH_;
ldc_n_ = CW_*CH_*F_;
break;
}
default:
throw std::runtime_error("unsupported input layout");
}
// Equivalent matmul
M_ = B_*CH_*CW_;
N_ = F_;
@@ -54,8 +81,15 @@ shift::shift(int B, int C,
// transpose
AT_ = false;
BT_ = true;
// C shapes
if(layout_ == CHWN)
shapes_c_ = {F, CH_, CW_, B};
if(layout_ == NCHW)
shapes_c_ = {B, F, CH_, CW_};
// Weight gradient
if(ty_ == WGRAD){
// b <-> c
// b <-> a
std::swap(ldb_n_, ldc_n_);
std::swap(ldb_w_, ldc_w_);
std::swap(ldb_h_, ldc_h_);
@@ -72,6 +106,7 @@ shift::shift(int B, int C,
}
// Input gradient
if(ty_ == BPROP){
// a <-> c
std::swap(lda_n_, ldc_n_);
std::swap(lda_w_, ldc_w_);
std::swap(lda_h_, ldc_h_);
@@ -79,7 +114,10 @@ shift::shift(int B, int C,
std::swap(K_, N_);
AT_ = false;
BT_ = false;
shapes_c_ = {C, AH_, AW_, B};
if(layout_ == CHWN)
shapes_c_ = {C, AH_, AW_, B};
if(layout_ == NCHW)
shapes_c_ = {B, C, AH_, AW_};
}
}
@@ -251,11 +289,21 @@ if(ty_ == BPROP){
int32 offa0[TM, TK] = offxa[:, newaxis];
int32 offa1[TM, TK] = rka[newaxis, :] * lda_c;)";
}
if(ty_ == WGRAD){
if(ty_ == WGRAD && layout_ == CHWN){
os << R"(
int32 offa0[TK, TM] = rxa[newaxis, :] * lda_c;
int32 offa1[TK, TM] = rka[:, newaxis];)";
}
if(ty_ == WGRAD && layout_ == NCHW){
os << R"(
int32 offa0[TK, TM] = rxa[newaxis, :] * lda_c;
int32 rawh[TK] = rka / NB;
int32 rab[TK] = rka % NB;
int32 raw[TK] = (rawh % CW);
int32 rah[TK] = (rawh / CW);
int32 offxa[TK] = rab*lda_b + raw*lda_w + rah*lda_h;
int32 offa1[TK, TM] = offxa[:, newaxis];)";
}
/* B offsets */
if(ty_ == FPROP){
@@ -301,7 +349,7 @@ if(ty_ == WGRAD){
/* Increment A pointers */
if(ty_ == FPROP){
os << R"(
os << R"(
pd = pd + TK;
d = *pd;
offa_interior = d[newaxis, :];
@@ -311,14 +359,24 @@ if(ty_ == FPROP){
}
if(ty_ == BPROP){
os << R"(
pa = pa + TK * lda_c;)";
pa = pa + TK * lda_c;)";
}
if(ty_ == WGRAD){
os << R"(
if(ty_ == WGRAD && layout_ == CHWN){
os << R"(
pa = pa + TK;)";
}
if(ty_ == WGRAD && layout_ == NCHW){
os << R"(
@checka a = *pa;)";
rka = rka + TK;
rawh = rka / NB;
rab = rka % NB;
raw = (rawh % CW);
rah = (rawh / CW);
offxa = rab*lda_b + raw*lda_w + rah*lda_h;
pa = A + offa0 + offxa[:, newaxis];)";
}
os << R"(
@checka a = *pa;)";
/* Increment B pointers */
if(ty_ == WGRAD){