[dnn/conv] Added the option to have look-up table for filters for all

operations
This commit is contained in:
Philippe Tillet
2019-05-22 19:03:33 -04:00
parent f8291af7ef
commit 3f3eb1c2a4
3 changed files with 13 additions and 13 deletions

View File

@@ -14,7 +14,7 @@ int main() {
// initialization
int32_t B = 4, NF = 32;
int32_t D = 1, H = 56, W = 56;
int32_t NC = 32, T = 1, R = 3, S = 3;
int32_t NC = 16, T = 1, R = 3, S = 3;
int32_t pad_d = 0, pad_h = 0, pad_w = 0;
int32_t stride_d = 1, stride_h = 1, stride_w = 1;
int32_t upsample_d = 1, upsample_h = 1, upsample_w = 1;

View File

@@ -60,7 +60,7 @@ public:
redax = {"C", "BH", "BW"};
else
redax = {"BH", "BW", "N"};
std::string inc_pb = b_lut_ ? "db[newaxis, :]" : "TK" + ldb0;
std::string inc_pb = b_lut_ ? "db" + bcb1 : "TK" + ldb0;
std::string a_delta_mem = is_a_deltas_cst ? "__constant__" : "";
std::string b_delta_mem = is_b_deltas_cst_? "__constant__" : "";
std::string masks_mem = is_mask_cst_? "__constant__" : "";
@@ -133,13 +133,13 @@ public:
}
else{
res += R"(
int32 rb1[TK] = rkb;)";
int32 rb1[TK] = rkb)" + ldb0 + ";";
}
res += R"(
fp32* pb)" + BS + " = b + rb1" + bcb1 + ldb0 + " + rb0" + bcb0 + ldb1 + R"(;
fp32* pb)" + BS + " = b + rb1" + bcb1 + " + rb0" + bcb0 + ldb1 + R"(;
)" + a_delta_mem + R"( int32* pincd[TK] = delta + rka;
)" + a_delta_mem + R"( int32* pd[TK] = delta + ldlut + rka;
int32 d[TK] = *pd;
)" + a_delta_mem + R"( int32* pda[TK] = delta + ldlut + rka;
int32 da[TK] = *pda;
int32 incd[TK] = *pincd;
int32 maskh[TM] = pad_h + min(rah, 0) + max(rah + BH - AH, 0);
int32 maskw[TM] = pad_w + min(raw, 0) + max(raw + BW - AW, 0);
@@ -153,18 +153,18 @@ public:
fp32 b)" + BS + R"( = *pb;
for(int32 k = K; k > 0; k = k - TK){
C = dot(a, )" + useb + R"(, C);
pa = pa + d[newaxis, :];
pa = pa + da[newaxis, :];
pb = pb + )" + inc_pb + R"(;
b = *pb;
pd = pd + incd;)";
pda = pda + incd;)";
if(b_lut_){
res += R"(
pdb = pdb + TK;
pdb = pdb + incd;
db = *pdb;)";
}
res += R"(
pincd = pincd + incd;
d = *pd;
da = *pda;
incd = *pincd;
pm = pm + incm;
pincm = pincm + incm;

View File

@@ -301,12 +301,12 @@ void conv::set_arg(driver::kernel *kernel,
}
std::vector<unsigned> conv::default_params() {
if(ty_==FPROP)
if(b_lut_)
return {32, 2, 64, 32, 2, 64, 16, 8, 2, 2, 4, 2, 8};
else if(ty_ == FPROP)
return {16, 2, 64, 32, 2, 64, 16, 8, 2, 2, 8, 1, 8, 4};
else if(ty_ == BPROP)
return {32, 2, 64, 32, 64, 32, 4, 2, 2, 4, 2, 8, 4, 2};
else if(ty_ == WGRAD)
return {32, 2, 64, 32, 2, 64, 16, 8, 2, 2, 4, 2, 8};
}