[dnn/conv] Added the option to have look-up table for filters for all
operations
This commit is contained in:
@@ -14,7 +14,7 @@ int main() {
|
||||
// initialization
|
||||
int32_t B = 4, NF = 32;
|
||||
int32_t D = 1, H = 56, W = 56;
|
||||
int32_t NC = 32, T = 1, R = 3, S = 3;
|
||||
int32_t NC = 16, T = 1, R = 3, S = 3;
|
||||
int32_t pad_d = 0, pad_h = 0, pad_w = 0;
|
||||
int32_t stride_d = 1, stride_h = 1, stride_w = 1;
|
||||
int32_t upsample_d = 1, upsample_h = 1, upsample_w = 1;
|
||||
|
@@ -60,7 +60,7 @@ public:
|
||||
redax = {"C", "BH", "BW"};
|
||||
else
|
||||
redax = {"BH", "BW", "N"};
|
||||
std::string inc_pb = b_lut_ ? "db[newaxis, :]" : "TK" + ldb0;
|
||||
std::string inc_pb = b_lut_ ? "db" + bcb1 : "TK" + ldb0;
|
||||
std::string a_delta_mem = is_a_deltas_cst ? "__constant__" : "";
|
||||
std::string b_delta_mem = is_b_deltas_cst_? "__constant__" : "";
|
||||
std::string masks_mem = is_mask_cst_? "__constant__" : "";
|
||||
@@ -133,13 +133,13 @@ public:
|
||||
}
|
||||
else{
|
||||
res += R"(
|
||||
int32 rb1[TK] = rkb;)";
|
||||
int32 rb1[TK] = rkb)" + ldb0 + ";";
|
||||
}
|
||||
res += R"(
|
||||
fp32* pb)" + BS + " = b + rb1" + bcb1 + ldb0 + " + rb0" + bcb0 + ldb1 + R"(;
|
||||
fp32* pb)" + BS + " = b + rb1" + bcb1 + " + rb0" + bcb0 + ldb1 + R"(;
|
||||
)" + a_delta_mem + R"( int32* pincd[TK] = delta + rka;
|
||||
)" + a_delta_mem + R"( int32* pd[TK] = delta + ldlut + rka;
|
||||
int32 d[TK] = *pd;
|
||||
)" + a_delta_mem + R"( int32* pda[TK] = delta + ldlut + rka;
|
||||
int32 da[TK] = *pda;
|
||||
int32 incd[TK] = *pincd;
|
||||
int32 maskh[TM] = pad_h + min(rah, 0) + max(rah + BH - AH, 0);
|
||||
int32 maskw[TM] = pad_w + min(raw, 0) + max(raw + BW - AW, 0);
|
||||
@@ -153,18 +153,18 @@ public:
|
||||
fp32 b)" + BS + R"( = *pb;
|
||||
for(int32 k = K; k > 0; k = k - TK){
|
||||
C = dot(a, )" + useb + R"(, C);
|
||||
pa = pa + d[newaxis, :];
|
||||
pa = pa + da[newaxis, :];
|
||||
pb = pb + )" + inc_pb + R"(;
|
||||
b = *pb;
|
||||
pd = pd + incd;)";
|
||||
pda = pda + incd;)";
|
||||
if(b_lut_){
|
||||
res += R"(
|
||||
pdb = pdb + TK;
|
||||
pdb = pdb + incd;
|
||||
db = *pdb;)";
|
||||
}
|
||||
res += R"(
|
||||
pincd = pincd + incd;
|
||||
d = *pd;
|
||||
da = *pda;
|
||||
incd = *pincd;
|
||||
pm = pm + incm;
|
||||
pincm = pincm + incm;
|
||||
|
@@ -301,12 +301,12 @@ void conv::set_arg(driver::kernel *kernel,
|
||||
}
|
||||
|
||||
std::vector<unsigned> conv::default_params() {
|
||||
if(ty_==FPROP)
|
||||
if(b_lut_)
|
||||
return {32, 2, 64, 32, 2, 64, 16, 8, 2, 2, 4, 2, 8};
|
||||
else if(ty_ == FPROP)
|
||||
return {16, 2, 64, 32, 2, 64, 16, 8, 2, 2, 8, 1, 8, 4};
|
||||
else if(ty_ == BPROP)
|
||||
return {32, 2, 64, 32, 64, 32, 4, 2, 2, 4, 2, 8, 4, 2};
|
||||
else if(ty_ == WGRAD)
|
||||
return {32, 2, 64, 32, 2, 64, 16, 8, 2, 2, 4, 2, 8};
|
||||
}
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user