diff --git a/examples/cpp/conv.cpp b/examples/cpp/conv.cpp index 5d1f095b0..0af076612 100644 --- a/examples/cpp/conv.cpp +++ b/examples/cpp/conv.cpp @@ -25,6 +25,7 @@ int main() { int32_t M = B*RD*RH*RW; int32_t N = NF; int32_t K = NC*T*R*S; + // convolution configuration std::vector hc(B*RH*RW*NF); std::vector rc(B*RH*RW*NF); std::vector ha(B*NC*H*W); @@ -57,8 +58,9 @@ int main() { int32_t stride_o_k = RD*stride_o_m; int32_t stride_o_n = NF*stride_o_k; // look-up table + triton::dnn::conv configuration(B, NC, H, W, R, S, NF, 1, 1, 0, 0); std::vector h_delta, h_masks; - triton::dnn::conv::init_cst(stride_i_d, stride_i_h, stride_i_w, stride_i_c, pad_d, pad_h, pad_w, T, R, S, h_delta, h_masks); + configuration.build_lut(h_delta, h_masks); // benchmark a given convolution kernel auto benchmark = [&](triton::driver::kernel* kernel, triton::jit::launch_information info) { diff --git a/include/triton/dnn/conv.h b/include/triton/dnn/conv.h index b76726575..e3fb91d43 100644 --- a/include/triton/dnn/conv.h +++ b/include/triton/dnn/conv.h @@ -12,81 +12,106 @@ public: WGRAD }; - static void build_lut(int TK, - int stride_d, int stride_h, int stride_w, int stride_c, - int pad_d, int pad_h, int pad_w, - int T, int R, int S, - std::vector& res, std::vector& masks) { - /* convolution parameters */ - int F = T * R * S; - int Nlut = (TK + F - 1) / F * F; - int upsample_w = 1; - int upsample_h = 1; - int upsample_d = 1; + + conv(int B, int NC, int H, int W, int R, int S, int NF, + int upsample_h, int upsample_w, + int pad_h, int pad_w) + : B_(B), NC_(NC), D_(1), H_(H), W_(W), T_(1), R_(R), S_(S), NF_(NF), + upsample_d_(1), upsample_h_(upsample_h), upsample_w_(upsample_w), + pad_d_(0), pad_h_(pad_h), pad_w_(pad_w) + { + RD_ = (D_*upsample_d_ - T_ + 1 + 2*pad_d_ + stride_d_ - 1)/stride_d_; + RH_ = (H_*upsample_h_ - R_ + 1 + 2*pad_h_ + stride_h_ - 1)/stride_h_; + RW_ = (W_*upsample_w_ - S_ + 1 + 2*pad_w_ + stride_w_ - 1)/stride_w_; + M_ = B*RD_*RH_*RW_; + N_ = NF; + K_ = NC*T_*R_*S_; + Fs_ = T_*R_*S_; + TK_ = 8; + Luts_ = (TK_ + Fs_ - 1) / Fs_ * Fs_; + // memory strides for data + stride_a_w_ = 1; + stride_a_h_ = W_*stride_a_w_; + stride_a_d_ = H_*stride_a_h_; + stride_a_c_ = D_*stride_a_d_; + stride_a_n_ = NC_*stride_a_c_; + // memory stride for activations + stride_c_q_ = 1; + stride_c_p_ = RW_*stride_c_q_; + stride_c_m_ = RH_*stride_c_p_; + stride_c_k_ = RD_*stride_c_m_; + stride_c_n_ = NF_*stride_c_k_; + } + + + void build_lut(std::vector& delta, std::vector& masks) { + delta.resize(Luts_ + upsample_d_*upsample_h_*upsample_w_*Luts_); + masks.resize(Luts_ + (2*pad_h_+1)*(2*pad_w_+1)*(2*pad_d_+1)*Luts_); + /* unpack index wrt filters */ auto unpack = [&](int32_t trs){ - int32_t tr = trs / S; - int32_t s = trs - tr*S; - int32_t t = tr / R; - int32_t r = tr - t*R; + int32_t tr = trs / S_; + int32_t s = trs - tr*S_; + int32_t t = tr / R_; + int32_t r = tr - t*R_; return std::make_tuple(t, r, s); }; /* increments */ - for(size_t i = 0; i < Nlut; ++i) - res[i] = (((i + TK) % Nlut) - i); + for(size_t i = 0; i < Luts_; ++i) + delta[i] = (((i + TK_) % Luts_) - i); /* deltas */ - size_t Ds0 = Nlut; - size_t Ds1 = upsample_w; - size_t Ds2 = upsample_h; - size_t Ds3 = upsample_d; + size_t Ds0 = Luts_; + size_t Ds1 = upsample_w_; + size_t Ds2 = upsample_h_; + size_t Ds3 = upsample_d_; for(size_t pd = 0; pd < Ds3; ++pd) for(size_t ph = 0; ph < Ds2; ++ph) for(size_t pw = 0; pw < Ds1; ++pw){ - int32_t* deltas_ptr = &res[Nlut + pw*Ds0 + ph*Ds0*Ds1 + pd*Ds0*Ds1*Ds2]; + int32_t* deltas_ptr = &delta[Luts_ + pw*Ds0 + ph*Ds0*Ds1 + pd*Ds0*Ds1*Ds2]; // cumulative increments for(size_t i = 0; i < Ds0; ++i){ int32_t ctrs = i; - int32_t c = ctrs / F; + int32_t c = ctrs / Fs_; int32_t t, r, s; - std::tie(t, r, s) = unpack(ctrs % F); + std::tie(t, r, s) = unpack(ctrs % Fs_); // next indices - int32_t nextctrs = ctrs + TK; - int32_t nextc = nextctrs / F; + int32_t nextctrs = ctrs + TK_; + int32_t nextc = nextctrs / Fs_; int32_t nextt, nextr, nexts; - std::tie(nextt, nextr, nexts) = unpack(nextctrs % F); + std::tie(nextt, nextr, nexts) = unpack(nextctrs % Fs_); // diffs int32_t cdiff = nextc - c; - int32_t tdiff = (nextt + pd)/upsample_d - (t + pd)/upsample_d; - int32_t rdiff = (nextr + ph)/upsample_h - (r + ph)/upsample_h; - int32_t sdiff = (nexts + pw)/upsample_w - (s + pw)/upsample_w; + int32_t tdiff = (nextt + pd)/upsample_d_ - (t + pd)/upsample_d_; + int32_t rdiff = (nextr + ph)/upsample_h_ - (r + ph)/upsample_h_; + int32_t sdiff = (nexts + pw)/upsample_w_ - (s + pw)/upsample_w_; // delta pointers - deltas_ptr[i] = cdiff*stride_c + sdiff*stride_w + rdiff*stride_h + tdiff*stride_d; + deltas_ptr[i] = cdiff*stride_a_c_ + sdiff*stride_a_w_ + rdiff*stride_a_h_ + tdiff*stride_a_d_; } } /* Masks */ - size_t Ms0 = Nlut; - size_t Ms1 = 2*pad_w + 1; - size_t Ms2 = 2*pad_h + 1; - size_t Ms3 = 2*pad_d + 1; + size_t Ms0 = Luts_; + size_t Ms1 = 2*pad_w_ + 1; + size_t Ms2 = 2*pad_h_ + 1; + size_t Ms3 = 2*pad_d_ + 1; for(size_t pd = 0; pd < Ms3; ++pd) for(size_t ph = 0; ph < Ms2; ++ph) for(size_t pw = 0; pw < Ms1; ++pw){ - int32_t* masks_ptr = &masks[Nlut + pw*Ms0 + ph*Ms0*Ms1 + pd*Ms0*Ms1*Ms2]; + int32_t* masks_ptr = &masks[Luts_ + pw*Ms0 + ph*Ms0*Ms1 + pd*Ms0*Ms1*Ms2]; for(size_t i = 0; i < Ms0; ++i){ int32_t t, r, s; int32_t mask = 0x0; - for(size_t j = 0; j < TK; ++j){ - std::tie(t, r, s) = unpack((i + j) % F); - bool in_bounds_d = (t + pd) >= pad_d && (t + pd) < (T + pad_d); - bool in_bounds_h = (r + ph) >= pad_h && (r + ph) < (R + pad_h); - bool in_bounds_w = (s + pw) >= pad_w && (s + pw) < (S + pad_w); + for(size_t j = 0; j < TK_; ++j){ + std::tie(t, r, s) = unpack((i + j) % Fs_); + bool in_bounds_d = (t + pd) >= pad_d_ && (t + pd) < (T_ + pad_d_); + bool in_bounds_h = (r + ph) >= pad_h_ && (r + ph) < (R_ + pad_h_); + bool in_bounds_w = (s + pw) >= pad_w_ && (s + pw) < (S_ + pad_w_); mask |= (in_bounds_d && in_bounds_h && in_bounds_w) << j; } masks_ptr[i] = mask; } } - for(size_t i = 0; i < Nlut; ++i) + for(size_t i = 0; i < Luts_; ++i) masks[i] = 0x0; } @@ -95,20 +120,6 @@ public: return {16, 2, 64, 32, 2, 64, 16, 8, 2, 2, 8, 1, 8, 4 }; } - static void init_cst(int stride_d, int stride_h, int stride_w, int stride_c, - int pad_d, int pad_h, int pad_w, - int T, int R, int S, - std::vector &h_delta, std::vector &h_masks) { - int upsample_d = 1; - int upsample_h = 1; - int upsample_w = 1; - int TK = 8; - int F = T * R * S; - int nlut = (TK + F - 1) / F * F; - h_delta.resize(nlut + upsample_d*upsample_h*upsample_w*nlut); - h_masks.resize(nlut + (2*pad_h+1)*(2*pad_w+1)*(2*pad_d+1)*nlut); - build_lut(TK, stride_d, stride_h, stride_w, stride_c, pad_d, pad_h, pad_w, T, R, S, h_delta, h_masks); - } static std::string src(type ty = FPROP) { @@ -191,6 +202,56 @@ public: })"; return res; } + +private: + // image size + int B_; + int NC_; + int D_; + int H_; + int W_; + // filter size + int T_; + int R_; + int S_; + int NF_; + // activation size + int RD_; + int RH_; + int RW_; + // upsampling + int upsample_d_; + int upsample_h_; + int upsample_w_; + // padding + int pad_d_; + int pad_h_; + int pad_w_; + // striding + int stride_d_; + int stride_h_; + int stride_w_; + // equivalent matmul + int M_; + int N_; + int K_; + // helpers + int Fs_; + int TK_; + int Luts_; + // memory strides for data + int32_t stride_a_w_; + int32_t stride_a_h_; + int32_t stride_a_d_; + int32_t stride_a_c_; + int32_t stride_a_n_; + // memory stride for activations + int32_t stride_c_q_; + int32_t stride_c_p_; + int32_t stride_c_m_; + int32_t stride_c_k_; + int32_t stride_c_n_; + }; }