[dnn/conv]: skeleton for NCHW layout
This commit is contained in:
@@ -44,6 +44,11 @@ public:
|
||||
WGRAD
|
||||
};
|
||||
|
||||
enum layout_t {
|
||||
NCHW,
|
||||
CHWN
|
||||
};
|
||||
|
||||
private:
|
||||
// initialize and enqueue
|
||||
void init_impl(driver::stream *stream, driver::cu_module *module);
|
||||
@@ -154,6 +159,8 @@ private:
|
||||
// transpose
|
||||
bool AT_;
|
||||
bool BT_;
|
||||
// layout
|
||||
layout_t layout_;
|
||||
};
|
||||
|
||||
}
|
||||
|
@@ -22,7 +22,8 @@ shift::shift(int B, int C,
|
||||
stride_d_(1), stride_h_(stride_h), stride_w_(stride_w),
|
||||
shift_h_(shift_h), shift_w_(shift_w),
|
||||
a_ty_(a_ty), b_ty_(b_ty),
|
||||
ty_(ty), bias_(bias) {
|
||||
ty_(ty), bias_(bias),
|
||||
layout_(CHWN){
|
||||
// max number of channels
|
||||
TK_ = 16;
|
||||
MAX_C_ = 8192 + TK_;
|
||||
@@ -31,22 +32,48 @@ shift::shift(int B, int C,
|
||||
CH_ = AH_ / stride_h_;
|
||||
CW_ = AW_ / stride_w_;
|
||||
// A memory strides: [C, H, W, B]
|
||||
lda_n_ = 1;
|
||||
lda_w_ = B_;
|
||||
lda_h_ = B_*AW_;
|
||||
lda_c_ = B_*AW_*AH_;
|
||||
switch(layout_){
|
||||
case CHWN: {
|
||||
lda_n_ = 1;
|
||||
lda_w_ = B_;
|
||||
lda_h_ = B_*AW_;
|
||||
lda_c_ = B_*AW_*AH_;
|
||||
break;
|
||||
}
|
||||
case NCHW: {
|
||||
lda_w_ = 1;
|
||||
lda_h_ = AW_;
|
||||
lda_c_ = AW_*AH_;
|
||||
lda_n_ = AW_*AH_*C_;
|
||||
break;
|
||||
}
|
||||
default:
|
||||
throw std::runtime_error("unsupported input layout");
|
||||
}
|
||||
// B memory strides: [C, F]
|
||||
ldb_n_ = 1;
|
||||
ldb_h_ = 1;
|
||||
ldb_w_ = 1;
|
||||
ldb_c_ = F_;
|
||||
// C memory strides: [F, H, W, B]
|
||||
ldc_n_ = 1;
|
||||
ldc_w_ = B_;
|
||||
ldc_h_ = B_*CW_;
|
||||
ldc_f_ = B_*CW_*CH_;
|
||||
// C shapes
|
||||
shapes_c_ = {F, CH_, CW_, B};
|
||||
switch(layout_){
|
||||
case CHWN: {
|
||||
ldc_n_ = 1;
|
||||
ldc_w_ = B_;
|
||||
ldc_h_ = B_*CW_;
|
||||
ldc_f_ = B_*CW_*CH_;
|
||||
break;
|
||||
}
|
||||
case NCHW: {
|
||||
ldc_w_ = 1;
|
||||
ldc_h_ = CW_;
|
||||
ldc_f_ = CW_*CH_;
|
||||
ldc_n_ = CW_*CH_*F_;
|
||||
break;
|
||||
}
|
||||
default:
|
||||
throw std::runtime_error("unsupported input layout");
|
||||
}
|
||||
// Equivalent matmul
|
||||
M_ = B_*CH_*CW_;
|
||||
N_ = F_;
|
||||
@@ -54,8 +81,15 @@ shift::shift(int B, int C,
|
||||
// transpose
|
||||
AT_ = false;
|
||||
BT_ = true;
|
||||
// C shapes
|
||||
if(layout_ == CHWN)
|
||||
shapes_c_ = {F, CH_, CW_, B};
|
||||
if(layout_ == NCHW)
|
||||
shapes_c_ = {B, F, CH_, CW_};
|
||||
// Weight gradient
|
||||
if(ty_ == WGRAD){
|
||||
// b <-> c
|
||||
// b <-> a
|
||||
std::swap(ldb_n_, ldc_n_);
|
||||
std::swap(ldb_w_, ldc_w_);
|
||||
std::swap(ldb_h_, ldc_h_);
|
||||
@@ -72,6 +106,7 @@ shift::shift(int B, int C,
|
||||
}
|
||||
// Input gradient
|
||||
if(ty_ == BPROP){
|
||||
// a <-> c
|
||||
std::swap(lda_n_, ldc_n_);
|
||||
std::swap(lda_w_, ldc_w_);
|
||||
std::swap(lda_h_, ldc_h_);
|
||||
@@ -79,7 +114,10 @@ shift::shift(int B, int C,
|
||||
std::swap(K_, N_);
|
||||
AT_ = false;
|
||||
BT_ = false;
|
||||
shapes_c_ = {C, AH_, AW_, B};
|
||||
if(layout_ == CHWN)
|
||||
shapes_c_ = {C, AH_, AW_, B};
|
||||
if(layout_ == NCHW)
|
||||
shapes_c_ = {B, C, AH_, AW_};
|
||||
}
|
||||
}
|
||||
|
||||
@@ -251,11 +289,21 @@ if(ty_ == BPROP){
|
||||
int32 offa0[TM, TK] = offxa[:, newaxis];
|
||||
int32 offa1[TM, TK] = rka[newaxis, :] * lda_c;)";
|
||||
}
|
||||
if(ty_ == WGRAD){
|
||||
if(ty_ == WGRAD && layout_ == CHWN){
|
||||
os << R"(
|
||||
int32 offa0[TK, TM] = rxa[newaxis, :] * lda_c;
|
||||
int32 offa1[TK, TM] = rka[:, newaxis];)";
|
||||
}
|
||||
if(ty_ == WGRAD && layout_ == NCHW){
|
||||
os << R"(
|
||||
int32 offa0[TK, TM] = rxa[newaxis, :] * lda_c;
|
||||
int32 rawh[TK] = rka / NB;
|
||||
int32 rab[TK] = rka % NB;
|
||||
int32 raw[TK] = (rawh % CW);
|
||||
int32 rah[TK] = (rawh / CW);
|
||||
int32 offxa[TK] = rab*lda_b + raw*lda_w + rah*lda_h;
|
||||
int32 offa1[TK, TM] = offxa[:, newaxis];)";
|
||||
}
|
||||
|
||||
/* B offsets */
|
||||
if(ty_ == FPROP){
|
||||
@@ -301,7 +349,7 @@ if(ty_ == WGRAD){
|
||||
|
||||
/* Increment A pointers */
|
||||
if(ty_ == FPROP){
|
||||
os << R"(
|
||||
os << R"(
|
||||
pd = pd + TK;
|
||||
d = *pd;
|
||||
offa_interior = d[newaxis, :];
|
||||
@@ -311,14 +359,24 @@ if(ty_ == FPROP){
|
||||
}
|
||||
if(ty_ == BPROP){
|
||||
os << R"(
|
||||
pa = pa + TK * lda_c;)";
|
||||
pa = pa + TK * lda_c;)";
|
||||
}
|
||||
if(ty_ == WGRAD){
|
||||
os << R"(
|
||||
if(ty_ == WGRAD && layout_ == CHWN){
|
||||
os << R"(
|
||||
pa = pa + TK;)";
|
||||
}
|
||||
if(ty_ == WGRAD && layout_ == NCHW){
|
||||
os << R"(
|
||||
@checka a = *pa;)";
|
||||
rka = rka + TK;
|
||||
rawh = rka / NB;
|
||||
rab = rka % NB;
|
||||
raw = (rawh % CW);
|
||||
rah = (rawh / CW);
|
||||
offxa = rab*lda_b + raw*lda_w + rah*lda_h;
|
||||
pa = A + offa0 + offxa[:, newaxis];)";
|
||||
}
|
||||
os << R"(
|
||||
@checka a = *pa;)";
|
||||
|
||||
/* Increment B pointers */
|
||||
if(ty_ == WGRAD){
|
||||
|
Reference in New Issue
Block a user