[dnn/conv] Added the option to have look-up table for filters for all

operations
This commit is contained in:
Philippe Tillet
2019-05-22 19:03:33 -04:00
parent f8291af7ef
commit 3f3eb1c2a4
3 changed files with 13 additions and 13 deletions

View File

@@ -14,7 +14,7 @@ int main() {
// initialization // initialization
int32_t B = 4, NF = 32; int32_t B = 4, NF = 32;
int32_t D = 1, H = 56, W = 56; int32_t D = 1, H = 56, W = 56;
int32_t NC = 32, T = 1, R = 3, S = 3; int32_t NC = 16, T = 1, R = 3, S = 3;
int32_t pad_d = 0, pad_h = 0, pad_w = 0; int32_t pad_d = 0, pad_h = 0, pad_w = 0;
int32_t stride_d = 1, stride_h = 1, stride_w = 1; int32_t stride_d = 1, stride_h = 1, stride_w = 1;
int32_t upsample_d = 1, upsample_h = 1, upsample_w = 1; int32_t upsample_d = 1, upsample_h = 1, upsample_w = 1;

View File

@@ -60,7 +60,7 @@ public:
redax = {"C", "BH", "BW"}; redax = {"C", "BH", "BW"};
else else
redax = {"BH", "BW", "N"}; redax = {"BH", "BW", "N"};
std::string inc_pb = b_lut_ ? "db[newaxis, :]" : "TK" + ldb0; std::string inc_pb = b_lut_ ? "db" + bcb1 : "TK" + ldb0;
std::string a_delta_mem = is_a_deltas_cst ? "__constant__" : ""; std::string a_delta_mem = is_a_deltas_cst ? "__constant__" : "";
std::string b_delta_mem = is_b_deltas_cst_? "__constant__" : ""; std::string b_delta_mem = is_b_deltas_cst_? "__constant__" : "";
std::string masks_mem = is_mask_cst_? "__constant__" : ""; std::string masks_mem = is_mask_cst_? "__constant__" : "";
@@ -133,13 +133,13 @@ public:
} }
else{ else{
res += R"( res += R"(
int32 rb1[TK] = rkb;)"; int32 rb1[TK] = rkb)" + ldb0 + ";";
} }
res += R"( res += R"(
fp32* pb)" + BS + " = b + rb1" + bcb1 + ldb0 + " + rb0" + bcb0 + ldb1 + R"(; fp32* pb)" + BS + " = b + rb1" + bcb1 + " + rb0" + bcb0 + ldb1 + R"(;
)" + a_delta_mem + R"( int32* pincd[TK] = delta + rka; )" + a_delta_mem + R"( int32* pincd[TK] = delta + rka;
)" + a_delta_mem + R"( int32* pd[TK] = delta + ldlut + rka; )" + a_delta_mem + R"( int32* pda[TK] = delta + ldlut + rka;
int32 d[TK] = *pd; int32 da[TK] = *pda;
int32 incd[TK] = *pincd; int32 incd[TK] = *pincd;
int32 maskh[TM] = pad_h + min(rah, 0) + max(rah + BH - AH, 0); int32 maskh[TM] = pad_h + min(rah, 0) + max(rah + BH - AH, 0);
int32 maskw[TM] = pad_w + min(raw, 0) + max(raw + BW - AW, 0); int32 maskw[TM] = pad_w + min(raw, 0) + max(raw + BW - AW, 0);
@@ -153,18 +153,18 @@ public:
fp32 b)" + BS + R"( = *pb; fp32 b)" + BS + R"( = *pb;
for(int32 k = K; k > 0; k = k - TK){ for(int32 k = K; k > 0; k = k - TK){
C = dot(a, )" + useb + R"(, C); C = dot(a, )" + useb + R"(, C);
pa = pa + d[newaxis, :]; pa = pa + da[newaxis, :];
pb = pb + )" + inc_pb + R"(; pb = pb + )" + inc_pb + R"(;
b = *pb; b = *pb;
pd = pd + incd;)"; pda = pda + incd;)";
if(b_lut_){ if(b_lut_){
res += R"( res += R"(
pdb = pdb + TK; pdb = pdb + incd;
db = *pdb;)"; db = *pdb;)";
} }
res += R"( res += R"(
pincd = pincd + incd; pincd = pincd + incd;
d = *pd; da = *pda;
incd = *pincd; incd = *pincd;
pm = pm + incm; pm = pm + incm;
pincm = pincm + incm; pincm = pincm + incm;

View File

@@ -301,12 +301,12 @@ void conv::set_arg(driver::kernel *kernel,
} }
std::vector<unsigned> conv::default_params() { std::vector<unsigned> conv::default_params() {
if(ty_==FPROP) if(b_lut_)
return {32, 2, 64, 32, 2, 64, 16, 8, 2, 2, 4, 2, 8};
else if(ty_ == FPROP)
return {16, 2, 64, 32, 2, 64, 16, 8, 2, 2, 8, 1, 8, 4}; return {16, 2, 64, 32, 2, 64, 16, 8, 2, 2, 8, 1, 8, 4};
else if(ty_ == BPROP) else if(ty_ == BPROP)
return {32, 2, 64, 32, 64, 32, 4, 2, 2, 4, 2, 8, 4, 2}; return {32, 2, 64, 32, 64, 32, 4, 2, 2, 4, 2, 8, 4, 2};
else if(ty_ == WGRAD)
return {32, 2, 64, 32, 2, 64, 16, 8, 2, 2, 4, 2, 8};
} }