[dnn/conv] removed divergent paths in LUT computations
This commit is contained in:
@@ -46,7 +46,6 @@ public:
|
||||
|
||||
// source
|
||||
std::string src(){
|
||||
bool is_wgrad = ty_ == WGRAD;
|
||||
std::string BS = b_trans_ ? "[TN,TK]" : "[TK, TN]";
|
||||
std::string bcb0 = b_trans_ ? "[:, newaxis]" : "[newaxis, :]";
|
||||
std::string bcb1 = b_trans_ ? "[newaxis, :]" : "[:, newaxis]";
|
||||
@@ -61,7 +60,7 @@ public:
|
||||
redax = {"C", "BH", "BW"};
|
||||
else
|
||||
redax = {"BH", "BW", "N"};
|
||||
std::string inc_pb = is_wgrad ? "db[newaxis, :]" : "TK" + ldb0;
|
||||
std::string inc_pb = b_lut_ ? "db[newaxis, :]" : "TK" + ldb0;
|
||||
std::string a_delta_mem = is_a_deltas_cst ? "__constant__" : "";
|
||||
std::string b_delta_mem = is_b_deltas_cst_? "__constant__" : "";
|
||||
std::string masks_mem = is_mask_cst_? "__constant__" : "";
|
||||
@@ -74,7 +73,7 @@ public:
|
||||
)";
|
||||
if(is_a_deltas_cst)
|
||||
res += "__constant__ int32* delta = alloc_const int32[" + std::to_string(h_a_deltas_.size()) + "];\n";
|
||||
if(is_wgrad && is_b_deltas_cst_)
|
||||
if(b_lut_ && is_b_deltas_cst_)
|
||||
res += "__constant__ int32* b_delta = alloc_const int32[" + std::to_string(h_b_deltas_.size()) + "];\n";
|
||||
if(is_mask_cst_)
|
||||
res += "__constant__ int32* masks = alloc_const int32[" + std::to_string(h_masks_.size()) + "];\n";
|
||||
@@ -96,7 +95,7 @@ public:
|
||||
int32 upsample_h, int32 upsample_w)";
|
||||
if(!is_a_deltas_cst)
|
||||
res += ", int32* delta";
|
||||
if(is_wgrad && !is_b_deltas_cst_)
|
||||
if(b_lut_ && !is_b_deltas_cst_)
|
||||
res += ", int32* b_delta";
|
||||
if(!is_mask_cst_)
|
||||
res += ", int32* masks";
|
||||
@@ -122,7 +121,7 @@ public:
|
||||
ras = )" + flips + R"( ras;
|
||||
int32 ra1[TK] = rac*lda_c + rar*lda_h + ras*lda_w;
|
||||
fp32* pa[TM, TK] = a + ra1[newaxis, :] + ra0[:, newaxis];)";
|
||||
if(ty_ == WGRAD){
|
||||
if(b_lut_){
|
||||
res += R"(
|
||||
int32 rbcr[TK] = rkb / BW;
|
||||
int32 rbs[TK] = rkb % BW;
|
||||
@@ -158,7 +157,7 @@ public:
|
||||
pb = pb + )" + inc_pb + R"(;
|
||||
b = *pb;
|
||||
pd = pd + incd;)";
|
||||
if(ty_ == WGRAD){
|
||||
if(b_lut_){
|
||||
res += R"(
|
||||
pdb = pdb + TK;
|
||||
db = *pdb;)";
|
||||
|
@@ -73,12 +73,12 @@ conv::conv(int B, int NC,
|
||||
// equivalent matmul
|
||||
b_trans_ = ty_ != BPROP;
|
||||
b_lut_ = ty_ == WGRAD;
|
||||
if(ty_ == WGRAD){
|
||||
if(ty_ == WGRAD) {
|
||||
M_ = shapes_c_[0]*shapes_c_[1]*shapes_c_[2]*shapes_c_[3];
|
||||
N_ = shapes_c_[4];
|
||||
K_ = shapes_b_[0]*shapes_b_[2]*shapes_b_[3]*shapes_b_[4];
|
||||
}
|
||||
else{
|
||||
else {
|
||||
M_ = shapes_c_[0]*shapes_c_[2]*shapes_c_[3]*shapes_c_[4];
|
||||
N_ = shapes_c_[1];
|
||||
K_ = shapes_b_[0]*shapes_b_[1]*shapes_b_[2]*shapes_b_[3];
|
||||
@@ -120,14 +120,14 @@ void conv::build_deltas(){
|
||||
if(b_lut_)
|
||||
h_b_deltas_.resize(Luts_);
|
||||
|
||||
auto unpack = [&](int32_t ltrs){
|
||||
int32_t l = (ty_ == BPROP) ? ltrs % NF_ : ltrs / (BD_*BH_*BW_);
|
||||
int32_t trs = (ty_ == BPROP) ? ltrs / NF_ : ltrs % (BD_*BH_*BW_);
|
||||
auto unpack = [&](int32_t ltrs) {
|
||||
int32_t l = (!b_trans_) ? ltrs % NF_ : ltrs / (BD_*BH_*BW_);
|
||||
int32_t trs = (!b_trans_) ? ltrs / NF_ : ltrs % (BD_*BH_*BW_);
|
||||
int32_t tr = trs / BW_;
|
||||
int32_t s = trs % BW_;
|
||||
int32_t t = tr / BH_;
|
||||
int32_t r = tr % BH_;
|
||||
if(ty_ == BPROP){
|
||||
if(!b_trans_){
|
||||
r = BH_ - 1 - r;
|
||||
s = BW_ - 1 - s;
|
||||
}
|
||||
@@ -143,7 +143,7 @@ void conv::build_deltas(){
|
||||
size_t Ds3 = upsample_d_;
|
||||
for(size_t pd = 0; pd < Ds3; ++pd)
|
||||
for(size_t ph = 0; ph < Ds2; ++ph)
|
||||
for(size_t pw = 0; pw < Ds1; ++pw){
|
||||
for(size_t pw = 0; pw < Ds1; ++pw) {
|
||||
int32_t* deltas_ptr = &h_a_deltas_[Luts_ + pw*Ds0 + ph*Ds0*Ds1 + pd*Ds0*Ds1*Ds2];
|
||||
// cumulative increments
|
||||
for(size_t i = 0; i < Ds0; ++i) {
|
||||
@@ -161,21 +161,21 @@ void conv::build_deltas(){
|
||||
int32_t rdiff = (nextr + ph)/upsample_h_ - (r + ph)/upsample_h_;
|
||||
int32_t sdiff = (nexts + pw)/upsample_w_ - (s + pw)/upsample_w_;
|
||||
// delta pointers
|
||||
if(ty_ == WGRAD)
|
||||
deltas_ptr[i] = cdiff*ld_a_[0] + tdiff*ld_a_[2] + rdiff*ld_a_[3] + sdiff*ld_a_[4];
|
||||
else
|
||||
deltas_ptr[i] = cdiff*ld_a_[1] + tdiff*ld_a_[2] + rdiff*ld_a_[3] + sdiff*ld_a_[4];
|
||||
deltas_ptr[i] = cdiff*ld_a_[a_inner_idx_] + tdiff*ld_a_[a_pix_idx_] + rdiff*ld_a_[a_pix_idx_ + 1] + sdiff*ld_a_[a_pix_idx_ + 2];
|
||||
}
|
||||
}
|
||||
|
||||
if(ty_ == WGRAD){
|
||||
if(b_lut_) {
|
||||
for(size_t i = 0; i < Ds0; ++i) {
|
||||
int32_t c, t, r, s;
|
||||
int32_t nextc, nextt, nextr, nexts;
|
||||
std::tie(c, t, r, s) = unpack(i);
|
||||
std::tie(nextc, nextt, nextr, nexts) = unpack(i + TK_);
|
||||
int32_t cdiff = nextc - c, tdiff = nextt - t, rdiff = nextr - r, sdiff = nexts - s;
|
||||
h_b_deltas_[i] = cdiff*ld_b_[0] + tdiff*ld_b_[2] + rdiff*ld_b_[3] + sdiff*ld_b_[4];
|
||||
int32_t cdiff = nextc - c;
|
||||
int32_t tdiff = nextt - t;
|
||||
int32_t rdiff = nextr - r;
|
||||
int32_t sdiff = nexts - s;
|
||||
h_b_deltas_[i] = cdiff*ld_b_[b_inner_idx_] + tdiff*ld_b_[b_pix_idx_] + rdiff*ld_b_[b_pix_idx_ + 1] + sdiff*ld_b_[b_pix_idx_ + 2];
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -184,13 +184,13 @@ void conv::build_masks(){
|
||||
h_masks_.resize(Luts_ + (2*pad_h_+1)*(2*pad_w_+1)*(2*pad_d_+1)*Luts_);
|
||||
|
||||
auto unpack = [&](int32_t ltrs){
|
||||
int32_t l = (ty_ == BPROP) ? ltrs % NF_ : ltrs / (BD_*BH_*BW_);
|
||||
int32_t trs = (ty_ == BPROP) ? ltrs / NF_ : ltrs % (BD_*BH_*BW_);
|
||||
int32_t l = (!b_trans_) ? ltrs % NF_ : ltrs / (BD_*BH_*BW_);
|
||||
int32_t trs = (!b_trans_) ? ltrs / NF_ : ltrs % (BD_*BH_*BW_);
|
||||
int32_t tr = trs / BW_;
|
||||
int32_t s = trs % BW_;
|
||||
int32_t t = tr / BH_;
|
||||
int32_t r = tr % BH_;
|
||||
if(ty_ == BPROP){
|
||||
if(!b_trans_){
|
||||
r = BH_ - 1 - r;
|
||||
s = BW_ - 1 - s;
|
||||
}
|
||||
@@ -338,7 +338,7 @@ void conv::cpu_xprop(OUT_DTYPE* C, IN_DTYPE* A, IN_DTYPE* B)
|
||||
if(in_bounds)
|
||||
a = A[n*ld_a_[0] + ac*ld_a_[1] + ad*ld_a_[2] + ah*ld_a_[3] + aw*ld_a_[4]];
|
||||
IN_DTYPE b;
|
||||
if(ty_==FPROP)
|
||||
if(b_trans_)
|
||||
b = B[ac*ld_b_[0] + bd*ld_b_[1] + bh*ld_b_[2] + bw*ld_b_[3] + cf*ld_b_[4]];
|
||||
else{
|
||||
int32_t bdd = shapes_b_[1] - 1 - bd;
|
||||
|
Reference in New Issue
Block a user