[dnn/conv] optimizations of backpropagation with look-up tables
This commit is contained in:
@@ -10,7 +10,7 @@ int main() {
|
||||
// initialize default compute device
|
||||
auto context = triton::driver::backend::contexts::get_default();
|
||||
triton::jit jit(context);
|
||||
triton::dnn::conv::type ty = triton::dnn::conv::FPROP;
|
||||
triton::dnn::conv::type ty = triton::dnn::conv::BPROP;
|
||||
// initialization
|
||||
int32_t B = 4, NF = 32;
|
||||
int32_t D = 1, H = 24, W = 240;
|
||||
@@ -66,7 +66,7 @@ int main() {
|
||||
return configuration.get_nflops() / ts * 1e-3;
|
||||
};
|
||||
std::string src = configuration.src();
|
||||
// jit.autotune("conv", src.c_str(), benchmark);
|
||||
jit.autotune("conv", src.c_str(), benchmark);
|
||||
jit.add_module("conv", src.c_str(), configuration.default_params());
|
||||
triton::driver::kernel* kernel = jit.get_function("conv");
|
||||
triton::jit::launch_information info = jit.get_launch_info("conv");
|
||||
|
@@ -74,6 +74,8 @@ public:
|
||||
}
|
||||
// look-up table info
|
||||
Fs_ = shapes_b_[1]*shapes_b_[2]*shapes_b_[3];
|
||||
if(ty_ == BPROP)
|
||||
Fs_ *= shapes_b_[4];
|
||||
TK_ = 8;
|
||||
Luts_ = (TK_ + Fs_ - 1) / Fs_ * Fs_;
|
||||
}
|
||||
@@ -101,15 +103,24 @@ public:
|
||||
if(ty_ == WGRAD)
|
||||
throw std::runtime_error("no look-up table necessary for wgrad");
|
||||
deltas.resize(Luts_ + upsample_d_*upsample_h_*upsample_w_*Luts_);
|
||||
auto unpack = [&](int32_t trs){
|
||||
|
||||
auto unpack = [&](int32_t ltrs){
|
||||
int32_t l = (ty_ == BPROP) ? ltrs % NF_ : ltrs / Fs_;
|
||||
int32_t trs = (ty_ == BPROP) ? ltrs / NF_ : ltrs % Fs_;
|
||||
int32_t tr = trs / BW_;
|
||||
int32_t s = trs - tr*BW_;
|
||||
int32_t s = trs % BW_;
|
||||
int32_t t = tr / BH_;
|
||||
int32_t r = tr - t*BH_;
|
||||
return std::make_tuple(t, r, s);
|
||||
int32_t r = tr % BH_;
|
||||
if(ty_ == BPROP){
|
||||
r = BH_ - 1 - r;
|
||||
s = BW_ - 1 - s;
|
||||
}
|
||||
return std::make_tuple(l, t, r, s);
|
||||
};
|
||||
|
||||
for(size_t i = 0; i < Luts_; ++i)
|
||||
deltas[i] = (((i + TK_) % Luts_) - i);
|
||||
|
||||
size_t Ds0 = Luts_;
|
||||
size_t Ds1 = upsample_w_;
|
||||
size_t Ds2 = upsample_h_;
|
||||
@@ -119,17 +130,15 @@ public:
|
||||
for(size_t pw = 0; pw < Ds1; ++pw){
|
||||
int32_t* deltas_ptr = &deltas[Luts_ + pw*Ds0 + ph*Ds0*Ds1 + pd*Ds0*Ds1*Ds2];
|
||||
// cumulative increments
|
||||
for(size_t i = 0; i < Ds0; ++i){
|
||||
for(size_t i = 0; i < Ds0; ++i) {
|
||||
// unpack
|
||||
int32_t ctrs = i;
|
||||
int32_t c = ctrs / Fs_;
|
||||
int32_t t, r, s;
|
||||
std::tie(t, r, s) = unpack(ctrs % Fs_);
|
||||
|
||||
int32_t c, t, r, s;
|
||||
std::tie(c, t, r, s) = unpack(ctrs);
|
||||
// next indices
|
||||
int32_t nextctrs = ctrs + TK_;
|
||||
int32_t nextc = nextctrs / Fs_;
|
||||
int32_t nextt, nextr, nexts;
|
||||
std::tie(nextt, nextr, nexts) = unpack(nextctrs % Fs_);
|
||||
int32_t nextc, nextt, nextr, nexts;
|
||||
std::tie(nextc, nextt, nextr, nexts) = unpack(nextctrs);
|
||||
// diffs
|
||||
int32_t cdiff = nextc - c;
|
||||
int32_t tdiff = (nextt + pd)/upsample_d_ - (t + pd)/upsample_d_;
|
||||
@@ -145,12 +154,18 @@ public:
|
||||
if(ty_ == WGRAD)
|
||||
throw std::runtime_error("no look-up table necessary for wgrad");
|
||||
masks.resize(Luts_ + (2*pad_h_+1)*(2*pad_w_+1)*(2*pad_d_+1)*Luts_);
|
||||
auto unpack = [&](int32_t trs){
|
||||
auto unpack = [&](int32_t ltrs){
|
||||
int32_t l = (ty_ == BPROP) ? ltrs % NF_ : ltrs / Fs_;
|
||||
int32_t trs = (ty_ == BPROP) ? ltrs / NF_ : ltrs % Fs_;
|
||||
int32_t tr = trs / BW_;
|
||||
int32_t s = trs - tr*BW_;
|
||||
int32_t s = trs % BW_;
|
||||
int32_t t = tr / BH_;
|
||||
int32_t r = tr - t*BH_;
|
||||
return std::make_tuple(t, r, s);
|
||||
int32_t r = tr % BH_;
|
||||
if(ty_ == BPROP){
|
||||
r = BH_ - 1 - r;
|
||||
s = BW_ - 1 - s;
|
||||
}
|
||||
return std::make_tuple(l, t, r, s);
|
||||
};
|
||||
size_t Ms0 = Luts_;
|
||||
size_t Ms1 = 2*pad_w_ + 1;
|
||||
@@ -161,10 +176,10 @@ public:
|
||||
for(size_t pw = 0; pw < Ms1; ++pw){
|
||||
int32_t* masks_ptr = &masks[Luts_ + pw*Ms0 + ph*Ms0*Ms1 + pd*Ms0*Ms1*Ms2];
|
||||
for(size_t i = 0; i < Ms0; ++i){
|
||||
int32_t t, r, s;
|
||||
int32_t l, t, r, s;
|
||||
int32_t mask = 0x0;
|
||||
for(size_t j = 0; j < TK_; ++j){
|
||||
std::tie(t, r, s) = unpack((i + j) % Fs_);
|
||||
std::tie(l, t, r, s) = unpack(i + j);
|
||||
bool in_bounds_d = (t + pd) >= pad_d_ && (t + pd) < (BD_ + pad_d_);
|
||||
bool in_bounds_h = (r + ph) >= pad_h_ && (r + ph) < (BH_ + pad_h_);
|
||||
bool in_bounds_w = (s + pw) >= pad_w_ && (s + pw) < (BW_ + pad_w_);
|
||||
@@ -220,50 +235,29 @@ public:
|
||||
}
|
||||
|
||||
std::vector<unsigned> default_params() {
|
||||
if(ty_ == FPROP || ty_ == BPROP)
|
||||
if(ty_ == FPROP)
|
||||
return {16, 2, 64, 32, 2, 64, 16, 8, 2, 2, 8, 1, 8, 4};
|
||||
else if(ty_ == BPROP)
|
||||
return {32, 2, 64, 32, 64, 32, 4, 2, 2, 4, 2, 8, 4, 2};
|
||||
else
|
||||
return {8, 2, 16, 8, 2, 16, 8, 2, 8, 8};
|
||||
}
|
||||
|
||||
|
||||
std::string xprop() {
|
||||
|
||||
std::string declare_pb;
|
||||
if(ty_ == FPROP){
|
||||
declare_pb = R"(
|
||||
fp32* pb[TN, TK] = b + rkb[newaxis, :]*ldb_s + rb0[:, newaxis];
|
||||
)";
|
||||
}
|
||||
else{
|
||||
declare_pb = R"(
|
||||
fp32* pb_base[TN, TK] = b + rb0[:, newaxis]*ldb_c;
|
||||
int32 rbk[TK] = rkb / (BH*BW);
|
||||
int32 rbrs[TK] = rkb % (BH*BW);
|
||||
int32 rbs[TK] = BW - 1 - rbrs % BW;
|
||||
int32 rbr[TK] = BH - 1 - rbrs / BW;
|
||||
int32 rb1[TK] = rbk*ldb_k + rbr*ldb_r + rbs*ldb_s;
|
||||
fp32* pb[TN, TK] = pb_base + rb1[newaxis, :];
|
||||
)";
|
||||
}
|
||||
std::string increment_pb;
|
||||
if(ty_ == FPROP){
|
||||
increment_pb = R"(
|
||||
pb = pb + TK*ldb_s;
|
||||
)";
|
||||
}
|
||||
else{
|
||||
increment_pb = R"(
|
||||
rbrs = rbrs + TK;
|
||||
rkb = rkb + TK;
|
||||
rbk = rkb / (BH*BW);
|
||||
rbrs = rkb % (BH*BW);
|
||||
rbs = BW - 1 - rbrs % BW;
|
||||
rbr = BH - 1 - rbrs / BW;
|
||||
rb1 = rbk*ldb_k + rbr*ldb_r + rbs*ldb_s;
|
||||
pb = pb_base + rb1[newaxis, :];
|
||||
)";
|
||||
}
|
||||
bool trans_b = ty_ == FPROP;
|
||||
std::string BS = trans_b ?"[TN,TK]" : "[TK, TN]";
|
||||
std::string bcb0 = trans_b ?"[:, newaxis]" : "[newaxis, :]";
|
||||
std::string bcb1 = trans_b ?"[newaxis, :]" : "[:, newaxis]";
|
||||
std::string ldb0 = trans_b ?"*ldb_s" : "";
|
||||
std::string ldb1 = trans_b ?"" : "*ldb_c";
|
||||
std::string useb = trans_b ?"trans(b)" : "b";
|
||||
std::string flipr = trans_b?"" : "BH - 1 -";
|
||||
std::string flips = trans_b?"" : "BW - 1 -";
|
||||
std::string ax = trans_b?"crs" : "rsc";
|
||||
std::vector<std::string> redax = {"BH", "BW", "N"};
|
||||
if(trans_b)
|
||||
redax = {"C", "BH", "BW"};
|
||||
|
||||
std::string res =
|
||||
R"(
|
||||
@@ -271,8 +265,8 @@ public:
|
||||
const tunable int32 TN = {16, 32, 64};
|
||||
const tunable int32 TK = {8};
|
||||
|
||||
__constant__ int32* delta = alloc_const int32[18];
|
||||
__constant__ int32* masks = alloc_const int32[1024];
|
||||
__constant__ int32* delta = alloc_const int32[1024];
|
||||
__constant__ int32* masks = alloc_const int32[4096];
|
||||
|
||||
void conv(read_only restrict fp32 *a,
|
||||
read_only restrict fp32 *b,
|
||||
@@ -290,36 +284,39 @@ public:
|
||||
int32 rka[TK] = 0 ... TK;
|
||||
int32 rkb[TK] = 0 ... TK;
|
||||
fp32 C[TM, TN] = 0;
|
||||
int32 Fs = )" + std::to_string(Fs_) + R"(;
|
||||
int32 rabh[TM] = rxa / CW;
|
||||
int32 raw[TM] = rxa % CW - pad_w;
|
||||
int32 rab[TM] = rabh / CH;
|
||||
int32 rah[TM] = rabh % CH - pad_h;
|
||||
int32 ra0[TM] = rab*lda_n + rah*lda_h + raw*lda_w;
|
||||
int32 racr[TK] = rka / BW;
|
||||
int32 ras[TK] = rka % BW;
|
||||
int32 rac[TK] = racr / BH;
|
||||
int32 rar[TK] = racr % BH;
|
||||
int32 ra)" + ax[0] + ax[1] + "[TK] = rka / " + redax[2] + R"(;
|
||||
int32 ra)" + ax[2] + "[TK] = rka % " + redax[2] + R"(;
|
||||
int32 ra)" + ax[0] + "[TK] = ra" + ax[0] + ax[1] + " / " + redax[1] + R"(;
|
||||
int32 ra)" + ax[1] + "[TK] = ra" + ax[0] + ax[1] + " % " + redax[1] + R"(;
|
||||
rar = )" + flipr + R"( rar;
|
||||
ras = )" + flips + R"( ras;
|
||||
int32 ra1[TK] = rac*lda_c + rar*lda_h + ras*lda_w;
|
||||
fp32* pa[TM, TK] = a + ra1[newaxis, :] + ra0[:, newaxis];)"
|
||||
+ declare_pb + R"(
|
||||
fp32* pa[TM, TK] = a + ra1[newaxis, :] + ra0[:, newaxis];
|
||||
fp32* pb)" + BS + " = b + rkb" + bcb1 + ldb0 + " + rb0" + bcb0 + ldb1 + R"(;
|
||||
__constant__ int32* pincd[TK] = delta + rka;
|
||||
__constant__ int32* pd[TK] = delta + BH*BW + rka;
|
||||
__constant__ int32* pd[TK] = delta + Fs + rka;
|
||||
int32 d[TK] = *pd;
|
||||
int32 incd[TK] = *pincd;
|
||||
int32 maskh[TM] = pad_h + min(rah, 0) + max(rah + BH - AH, 0);
|
||||
int32 maskw[TM] = pad_w + min(raw, 0) + max(raw + BW - AW, 0);
|
||||
__constant__ int32* pm[TM] = masks + BH*BW + maskw*BH*BW + maskh*BH*BW*(2*pad_w + 1);
|
||||
__constant__ int32* pm[TM] = masks + Fs + maskw*Fs + maskh*Fs*(2*pad_w + 1);
|
||||
__constant__ int32* pincm[TM] = delta;
|
||||
int32 incm[TM] = *pincm;
|
||||
int32 checka0[TM] = *pm;
|
||||
int32 checka1[TK] = 1 << rka;
|
||||
int1 checka[TM, TK] = (checka0[:, newaxis] & checka1[newaxis, :]) > 0;
|
||||
fp32 a[TM, TK] = checka ? *pa : 0;
|
||||
fp32 b[TN, TK] = *pb;
|
||||
fp32 b)" + BS + R"( = *pb;
|
||||
for(int32 k = K; k > 0; k = k - TK){
|
||||
C = dot(a, trans(b), C);
|
||||
pa = pa + d[newaxis, :];)"
|
||||
+ increment_pb + R"(
|
||||
C = dot(a, )" + useb + R"(, C);
|
||||
pa = pa + d[newaxis, :];
|
||||
pb = pb + TK)" + ldb0 + R"(;
|
||||
b = *pb;
|
||||
pd = pd + incd;
|
||||
pincd = pincd + incd;
|
||||
@@ -448,9 +445,9 @@ public:
|
||||
if(ty_==FPROP)
|
||||
b = B[ac*ld_b_[0] + bd*ld_b_[1] + bh*ld_b_[2] + bw*ld_b_[3] + cf*ld_b_[4]];
|
||||
else{
|
||||
int32_t bdd = bd;
|
||||
int32_t bhh = bh;
|
||||
int32_t bww = bw;
|
||||
int32_t bdd = shapes_b_[1] - 1 - bd;
|
||||
int32_t bhh = shapes_b_[2] - 1 - bh;
|
||||
int32_t bww = shapes_b_[3] - 1 - bw;
|
||||
b = B[cf*ld_b_[0] + bdd*ld_b_[1] + bhh*ld_b_[2] + bww*ld_b_[3] + ac*ld_b_[4]];
|
||||
}
|
||||
acc = std::fma(a, b, acc);
|
||||
|
Reference in New Issue
Block a user