[dnn/conv]: skeleton for NCHW layout
This commit is contained in:
@@ -44,6 +44,11 @@ public:
|
|||||||
WGRAD
|
WGRAD
|
||||||
};
|
};
|
||||||
|
|
||||||
|
enum layout_t {
|
||||||
|
NCHW,
|
||||||
|
CHWN
|
||||||
|
};
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// initialize and enqueue
|
// initialize and enqueue
|
||||||
void init_impl(driver::stream *stream, driver::cu_module *module);
|
void init_impl(driver::stream *stream, driver::cu_module *module);
|
||||||
@@ -154,6 +159,8 @@ private:
|
|||||||
// transpose
|
// transpose
|
||||||
bool AT_;
|
bool AT_;
|
||||||
bool BT_;
|
bool BT_;
|
||||||
|
// layout
|
||||||
|
layout_t layout_;
|
||||||
};
|
};
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@@ -22,7 +22,8 @@ shift::shift(int B, int C,
|
|||||||
stride_d_(1), stride_h_(stride_h), stride_w_(stride_w),
|
stride_d_(1), stride_h_(stride_h), stride_w_(stride_w),
|
||||||
shift_h_(shift_h), shift_w_(shift_w),
|
shift_h_(shift_h), shift_w_(shift_w),
|
||||||
a_ty_(a_ty), b_ty_(b_ty),
|
a_ty_(a_ty), b_ty_(b_ty),
|
||||||
ty_(ty), bias_(bias) {
|
ty_(ty), bias_(bias),
|
||||||
|
layout_(CHWN){
|
||||||
// max number of channels
|
// max number of channels
|
||||||
TK_ = 16;
|
TK_ = 16;
|
||||||
MAX_C_ = 8192 + TK_;
|
MAX_C_ = 8192 + TK_;
|
||||||
@@ -31,22 +32,48 @@ shift::shift(int B, int C,
|
|||||||
CH_ = AH_ / stride_h_;
|
CH_ = AH_ / stride_h_;
|
||||||
CW_ = AW_ / stride_w_;
|
CW_ = AW_ / stride_w_;
|
||||||
// A memory strides: [C, H, W, B]
|
// A memory strides: [C, H, W, B]
|
||||||
lda_n_ = 1;
|
switch(layout_){
|
||||||
lda_w_ = B_;
|
case CHWN: {
|
||||||
lda_h_ = B_*AW_;
|
lda_n_ = 1;
|
||||||
lda_c_ = B_*AW_*AH_;
|
lda_w_ = B_;
|
||||||
|
lda_h_ = B_*AW_;
|
||||||
|
lda_c_ = B_*AW_*AH_;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case NCHW: {
|
||||||
|
lda_w_ = 1;
|
||||||
|
lda_h_ = AW_;
|
||||||
|
lda_c_ = AW_*AH_;
|
||||||
|
lda_n_ = AW_*AH_*C_;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
throw std::runtime_error("unsupported input layout");
|
||||||
|
}
|
||||||
// B memory strides: [C, F]
|
// B memory strides: [C, F]
|
||||||
ldb_n_ = 1;
|
ldb_n_ = 1;
|
||||||
ldb_h_ = 1;
|
ldb_h_ = 1;
|
||||||
ldb_w_ = 1;
|
ldb_w_ = 1;
|
||||||
ldb_c_ = F_;
|
ldb_c_ = F_;
|
||||||
// C memory strides: [F, H, W, B]
|
// C memory strides: [F, H, W, B]
|
||||||
ldc_n_ = 1;
|
switch(layout_){
|
||||||
ldc_w_ = B_;
|
case CHWN: {
|
||||||
ldc_h_ = B_*CW_;
|
ldc_n_ = 1;
|
||||||
ldc_f_ = B_*CW_*CH_;
|
ldc_w_ = B_;
|
||||||
// C shapes
|
ldc_h_ = B_*CW_;
|
||||||
shapes_c_ = {F, CH_, CW_, B};
|
ldc_f_ = B_*CW_*CH_;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case NCHW: {
|
||||||
|
ldc_w_ = 1;
|
||||||
|
ldc_h_ = CW_;
|
||||||
|
ldc_f_ = CW_*CH_;
|
||||||
|
ldc_n_ = CW_*CH_*F_;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
throw std::runtime_error("unsupported input layout");
|
||||||
|
}
|
||||||
// Equivalent matmul
|
// Equivalent matmul
|
||||||
M_ = B_*CH_*CW_;
|
M_ = B_*CH_*CW_;
|
||||||
N_ = F_;
|
N_ = F_;
|
||||||
@@ -54,8 +81,15 @@ shift::shift(int B, int C,
|
|||||||
// transpose
|
// transpose
|
||||||
AT_ = false;
|
AT_ = false;
|
||||||
BT_ = true;
|
BT_ = true;
|
||||||
|
// C shapes
|
||||||
|
if(layout_ == CHWN)
|
||||||
|
shapes_c_ = {F, CH_, CW_, B};
|
||||||
|
if(layout_ == NCHW)
|
||||||
|
shapes_c_ = {B, F, CH_, CW_};
|
||||||
// Weight gradient
|
// Weight gradient
|
||||||
if(ty_ == WGRAD){
|
if(ty_ == WGRAD){
|
||||||
|
// b <-> c
|
||||||
|
// b <-> a
|
||||||
std::swap(ldb_n_, ldc_n_);
|
std::swap(ldb_n_, ldc_n_);
|
||||||
std::swap(ldb_w_, ldc_w_);
|
std::swap(ldb_w_, ldc_w_);
|
||||||
std::swap(ldb_h_, ldc_h_);
|
std::swap(ldb_h_, ldc_h_);
|
||||||
@@ -72,6 +106,7 @@ shift::shift(int B, int C,
|
|||||||
}
|
}
|
||||||
// Input gradient
|
// Input gradient
|
||||||
if(ty_ == BPROP){
|
if(ty_ == BPROP){
|
||||||
|
// a <-> c
|
||||||
std::swap(lda_n_, ldc_n_);
|
std::swap(lda_n_, ldc_n_);
|
||||||
std::swap(lda_w_, ldc_w_);
|
std::swap(lda_w_, ldc_w_);
|
||||||
std::swap(lda_h_, ldc_h_);
|
std::swap(lda_h_, ldc_h_);
|
||||||
@@ -79,7 +114,10 @@ shift::shift(int B, int C,
|
|||||||
std::swap(K_, N_);
|
std::swap(K_, N_);
|
||||||
AT_ = false;
|
AT_ = false;
|
||||||
BT_ = false;
|
BT_ = false;
|
||||||
shapes_c_ = {C, AH_, AW_, B};
|
if(layout_ == CHWN)
|
||||||
|
shapes_c_ = {C, AH_, AW_, B};
|
||||||
|
if(layout_ == NCHW)
|
||||||
|
shapes_c_ = {B, C, AH_, AW_};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -251,11 +289,21 @@ if(ty_ == BPROP){
|
|||||||
int32 offa0[TM, TK] = offxa[:, newaxis];
|
int32 offa0[TM, TK] = offxa[:, newaxis];
|
||||||
int32 offa1[TM, TK] = rka[newaxis, :] * lda_c;)";
|
int32 offa1[TM, TK] = rka[newaxis, :] * lda_c;)";
|
||||||
}
|
}
|
||||||
if(ty_ == WGRAD){
|
if(ty_ == WGRAD && layout_ == CHWN){
|
||||||
os << R"(
|
os << R"(
|
||||||
int32 offa0[TK, TM] = rxa[newaxis, :] * lda_c;
|
int32 offa0[TK, TM] = rxa[newaxis, :] * lda_c;
|
||||||
int32 offa1[TK, TM] = rka[:, newaxis];)";
|
int32 offa1[TK, TM] = rka[:, newaxis];)";
|
||||||
}
|
}
|
||||||
|
if(ty_ == WGRAD && layout_ == NCHW){
|
||||||
|
os << R"(
|
||||||
|
int32 offa0[TK, TM] = rxa[newaxis, :] * lda_c;
|
||||||
|
int32 rawh[TK] = rka / NB;
|
||||||
|
int32 rab[TK] = rka % NB;
|
||||||
|
int32 raw[TK] = (rawh % CW);
|
||||||
|
int32 rah[TK] = (rawh / CW);
|
||||||
|
int32 offxa[TK] = rab*lda_b + raw*lda_w + rah*lda_h;
|
||||||
|
int32 offa1[TK, TM] = offxa[:, newaxis];)";
|
||||||
|
}
|
||||||
|
|
||||||
/* B offsets */
|
/* B offsets */
|
||||||
if(ty_ == FPROP){
|
if(ty_ == FPROP){
|
||||||
@@ -301,7 +349,7 @@ if(ty_ == WGRAD){
|
|||||||
|
|
||||||
/* Increment A pointers */
|
/* Increment A pointers */
|
||||||
if(ty_ == FPROP){
|
if(ty_ == FPROP){
|
||||||
os << R"(
|
os << R"(
|
||||||
pd = pd + TK;
|
pd = pd + TK;
|
||||||
d = *pd;
|
d = *pd;
|
||||||
offa_interior = d[newaxis, :];
|
offa_interior = d[newaxis, :];
|
||||||
@@ -311,14 +359,24 @@ if(ty_ == FPROP){
|
|||||||
}
|
}
|
||||||
if(ty_ == BPROP){
|
if(ty_ == BPROP){
|
||||||
os << R"(
|
os << R"(
|
||||||
pa = pa + TK * lda_c;)";
|
pa = pa + TK * lda_c;)";
|
||||||
}
|
}
|
||||||
if(ty_ == WGRAD){
|
if(ty_ == WGRAD && layout_ == CHWN){
|
||||||
os << R"(
|
os << R"(
|
||||||
pa = pa + TK;)";
|
pa = pa + TK;)";
|
||||||
}
|
}
|
||||||
|
if(ty_ == WGRAD && layout_ == NCHW){
|
||||||
os << R"(
|
os << R"(
|
||||||
@checka a = *pa;)";
|
rka = rka + TK;
|
||||||
|
rawh = rka / NB;
|
||||||
|
rab = rka % NB;
|
||||||
|
raw = (rawh % CW);
|
||||||
|
rah = (rawh / CW);
|
||||||
|
offxa = rab*lda_b + raw*lda_w + rah*lda_h;
|
||||||
|
pa = A + offa0 + offxa[:, newaxis];)";
|
||||||
|
}
|
||||||
|
os << R"(
|
||||||
|
@checka a = *pa;)";
|
||||||
|
|
||||||
/* Increment B pointers */
|
/* Increment B pointers */
|
||||||
if(ty_ == WGRAD){
|
if(ty_ == WGRAD){
|
||||||
|
Reference in New Issue
Block a user