diff --git a/examples/cpp/conv.cpp b/examples/cpp/conv.cpp index 2bbec482a..179fca421 100644 --- a/examples/cpp/conv.cpp +++ b/examples/cpp/conv.cpp @@ -14,7 +14,7 @@ int main() { // initialization int32_t B = 4, NF = 32; int32_t D = 1, H = 56, W = 56; - int32_t NC = 32, T = 1, R = 3, S = 3; + int32_t NC = 16, T = 1, R = 3, S = 3; int32_t pad_d = 0, pad_h = 0, pad_w = 0; int32_t stride_d = 1, stride_h = 1, stride_w = 1; int32_t upsample_d = 1, upsample_h = 1, upsample_w = 1; diff --git a/include/triton/dnn/conv.h b/include/triton/dnn/conv.h index 05a3bd1ee..1c9127466 100644 --- a/include/triton/dnn/conv.h +++ b/include/triton/dnn/conv.h @@ -60,7 +60,7 @@ public: redax = {"C", "BH", "BW"}; else redax = {"BH", "BW", "N"}; - std::string inc_pb = b_lut_ ? "db[newaxis, :]" : "TK" + ldb0; + std::string inc_pb = b_lut_ ? "db" + bcb1 : "TK" + ldb0; std::string a_delta_mem = is_a_deltas_cst ? "__constant__" : ""; std::string b_delta_mem = is_b_deltas_cst_? "__constant__" : ""; std::string masks_mem = is_mask_cst_? "__constant__" : ""; @@ -133,13 +133,13 @@ public: } else{ res += R"( - int32 rb1[TK] = rkb;)"; + int32 rb1[TK] = rkb)" + ldb0 + ";"; } res += R"( - fp32* pb)" + BS + " = b + rb1" + bcb1 + ldb0 + " + rb0" + bcb0 + ldb1 + R"(; + fp32* pb)" + BS + " = b + rb1" + bcb1 + " + rb0" + bcb0 + ldb1 + R"(; )" + a_delta_mem + R"( int32* pincd[TK] = delta + rka; - )" + a_delta_mem + R"( int32* pd[TK] = delta + ldlut + rka; - int32 d[TK] = *pd; + )" + a_delta_mem + R"( int32* pda[TK] = delta + ldlut + rka; + int32 da[TK] = *pda; int32 incd[TK] = *pincd; int32 maskh[TM] = pad_h + min(rah, 0) + max(rah + BH - AH, 0); int32 maskw[TM] = pad_w + min(raw, 0) + max(raw + BW - AW, 0); @@ -153,18 +153,18 @@ public: fp32 b)" + BS + R"( = *pb; for(int32 k = K; k > 0; k = k - TK){ C = dot(a, )" + useb + R"(, C); - pa = pa + d[newaxis, :]; + pa = pa + da[newaxis, :]; pb = pb + )" + inc_pb + R"(; b = *pb; - pd = pd + incd;)"; + pda = pda + incd;)"; if(b_lut_){ res += R"( - pdb = pdb + TK; + pdb = pdb + incd; db = *pdb;)"; } res += R"( pincd = pincd + incd; - d = *pd; + da = *pda; incd = *pincd; pm = pm + incm; pincm = pincm + incm; diff --git a/lib/dnn/conv.cpp b/lib/dnn/conv.cpp index 05b9211a5..889a37f00 100644 --- a/lib/dnn/conv.cpp +++ b/lib/dnn/conv.cpp @@ -301,12 +301,12 @@ void conv::set_arg(driver::kernel *kernel, } std::vector conv::default_params() { - if(ty_==FPROP) + if(b_lut_) + return {32, 2, 64, 32, 2, 64, 16, 8, 2, 2, 4, 2, 8}; + else if(ty_ == FPROP) return {16, 2, 64, 32, 2, 64, 16, 8, 2, 2, 8, 1, 8, 4}; else if(ty_ == BPROP) return {32, 2, 64, 32, 64, 32, 4, 2, 2, 4, 2, 8, 4, 2}; - else if(ty_ == WGRAD) - return {32, 2, 64, 32, 2, 64, 16, 8, 2, 2, 4, 2, 8}; }