diff --git a/examples/cpp/conv.cpp b/examples/cpp/conv.cpp index df5b22803..8cddeb588 100644 --- a/examples/cpp/conv.cpp +++ b/examples/cpp/conv.cpp @@ -10,13 +10,13 @@ int main() { // initialize default compute device auto context = triton::driver::backend::contexts::get_default(); triton::jit jit(context); - triton::dnn::conv::type ty = triton::dnn::conv::BPROP; + triton::dnn::conv::type ty = triton::dnn::conv::WGRAD; // initialization int32_t B = 4, NF = 32; int32_t D = 1, H = 24, W = 240; int32_t NC = 32, T = 1, R = 3, S = 3; int32_t pad_d = 0, pad_h = 1, pad_w = 1; - triton::dnn::conv configuration(B, NC, H, W, R, S, NF, 1, 1, pad_h, pad_w, ty); + triton::dnn::conv configuration(B, NC, D, H, W, T, R, S, NF, 1, 1, 1, pad_d, pad_h, pad_w, ty); // convolution configuration std::vector hc(configuration.c_size()); std::vector rc(configuration.c_size()); @@ -40,8 +40,10 @@ int main() { stream->synchronize(); // look-up table std::vector h_delta, h_masks; - configuration.build_deltas(h_delta); - configuration.build_masks(h_masks); + if(ty != triton::dnn::conv::WGRAD){ + configuration.build_deltas(h_delta); + configuration.build_masks(h_masks); + } // benchmark a given convolution kernel auto benchmark = [&](triton::driver::kernel* kernel, triton::jit::launch_information info) { @@ -49,10 +51,12 @@ int main() { unsigned TN = info.global_range_size[1]; unsigned nthreads = info.num_threads; std::array grid = configuration.get_grid(TM, TN); - triton::driver::buffer* delta = jit.get_buffer("delta"); - triton::driver::buffer* masks = jit.get_buffer("masks"); - stream->write(delta, false, 0, h_delta.size()*4, h_delta.data()); - stream->write(masks, false, 0, h_masks.size()*4, h_masks.data()); + if(ty != triton::dnn::conv::WGRAD){ + triton::driver::buffer* delta = jit.get_buffer("delta"); + triton::driver::buffer* masks = jit.get_buffer("masks"); + stream->write(delta, false, 0, h_delta.size()*4, h_delta.data()); + stream->write(masks, false, 0, h_masks.size()*4, h_masks.data()); + } stream->synchronize(); configuration.set_arg(kernel, da, db, dc); stream->enqueue(kernel, grid, {nthreads, 1, 1}); @@ -69,11 +73,11 @@ int main() { std::cout << "Performance: " << benchmark(kernel, info) << " TFLOPS " << std::endl; stream->read(dc, true, 0, hc); configuration.cpu_ref(rc.data(), ha.data(), hb.data()); -// std::cout << c[0] << std::endl; - for(size_t i = 0; i < hc.size(); i++) + for(size_t i = 0; i < hc.size(); i++){ if(std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-4){ std::cout << i << " " << hc[i] << " " << rc[i] << std::endl; exit(EXIT_FAILURE); - } + } + } std::cout << "Pass!" << std::endl; } diff --git a/include/triton/dnn/conv.h b/include/triton/dnn/conv.h index 85bb1e038..c29bb925b 100644 --- a/include/triton/dnn/conv.h +++ b/include/triton/dnn/conv.h @@ -1,5 +1,7 @@ #include #include +#include +#include #include "triton/driver/stream.h" #include "triton/driver/kernel.h" @@ -15,74 +17,91 @@ public: }; - conv(int B, int NC, int H, int W, int R, int S, int NF, - int upsample_h, int upsample_w, - int pad_h, int pad_w, + conv(int B, int NC, + int D, int H, int W, + int T, int R, int S, int NF, + int upsample_d, int upsample_h, int upsample_w, + int pad_d, int pad_h, int pad_w, type ty = FPROP) - : B_(B), NC_(NC), D_(1), H_(H), W_(W), T_(1), R_(R), S_(S), NF_(NF), - upsample_d_(1), upsample_h_(upsample_h), upsample_w_(upsample_w), + : NB_(B), NC_(NC), AD_(D), AH_(H), AW_(W), BD_(T), BH_(R), BW_(S), NF_(NF), + upsample_d_(upsample_d), upsample_h_(upsample_h), upsample_w_(upsample_w), stride_d_(1), stride_h_(1), stride_w_(1), - pad_d_(0), pad_h_(pad_h), pad_w_(pad_w), + pad_d_(pad_d), pad_h_(pad_h), pad_w_(pad_w), ty_(ty) { - RD_ = (D_*upsample_d_ - T_ + 1 + 2*pad_d_ + stride_d_ - 1)/stride_d_; - RH_ = (H_*upsample_h_ - R_ + 1 + 2*pad_h_ + stride_h_ - 1)/stride_h_; - RW_ = (W_*upsample_w_ - S_ + 1 + 2*pad_w_ + stride_w_ - 1)/stride_w_; - // memory strides for data - stride_a_w_ = 1; - stride_a_h_ = W_*stride_a_w_; - stride_a_d_ = H_*stride_a_h_; - stride_a_c_ = D_*stride_a_d_; - stride_a_n_ = NC_*stride_a_c_; - // memory stride for activations - stride_c_q_ = 1; - stride_c_p_ = RW_*stride_c_q_; - stride_c_m_ = RH_*stride_c_p_; - stride_c_k_ = RD_*stride_c_m_; - stride_c_n_ = NF_*stride_c_k_; + CD_ = (AD_*upsample_d_ - BD_ + 1 + 2*pad_d_ + stride_d_ - 1)/stride_d_; + CH_ = (AH_*upsample_h_ - BH_ + 1 + 2*pad_h_ + stride_h_ - 1)/stride_h_; + CW_ = (AW_*upsample_w_ - BW_ + 1 + 2*pad_w_ + stride_w_ - 1)/stride_w_; + // shapes + shapes_a_ = {NB_, NC_, AD_, AH_, AW_}; + shapes_b_ = {NC_, BD_, BH_, BW_, NF_}; + shapes_c_ = {NB_, NF_, CD_, CH_, CW_}; // swap a and c for bprop if(ty_ == BPROP){ - std::swap(stride_a_n_, stride_c_n_); - std::swap(stride_a_c_, stride_c_k_); - std::swap(stride_a_h_, stride_c_p_); - std::swap(stride_a_w_, stride_c_q_); - std::swap(D_, RD_); - std::swap(H_, RH_); - std::swap(W_, RW_); - std::swap(NF_, NC_); - pad_d_ = (RD_ - D_ + T_ - 1) / 2; - pad_h_ = (RH_ - H_ + R_ - 1) / 2; - pad_w_ = (RW_ - W_ + S_ - 1) / 2; + pad_d_ = (CD_ - AD_ + BD_ - 1) / 2; + pad_h_ = (CH_ - AH_ + BH_ - 1) / 2; + pad_w_ = (CW_ - AW_ + BW_ - 1) / 2; + shapes_a_.swap(shapes_c_); } + // swap b and c for wgrad + if(ty_ == WGRAD){ + shapes_b_.swap(shapes_c_); + } + // leading dimensions + auto set_ld = [](const std::vector& shapes, + std::vector& ld) { + size_t size = shapes.size(); + ld.resize(size); + ld[4] = 1; + ld[3] = shapes[4]*ld[4]; + ld[2] = shapes[3]*ld[3]; + ld[1] = shapes[2]*ld[2]; + ld[0] = shapes[1]*ld[1]; + }; + set_ld(shapes_a_, ld_a_); + set_ld(shapes_b_, ld_b_); + set_ld(shapes_c_, ld_c_); // equivalent matmul - M_ = B_*RD_*RH_*RW_; - N_ = NF_; - K_ = NC_*T_*R_*S_; + if(ty_ == WGRAD){ + M_ = shapes_c_[0]*shapes_c_[1]*shapes_c_[2]*shapes_c_[3]; + N_ = shapes_c_[4]; + K_ = shapes_b_[0]*shapes_b_[2]*shapes_b_[3]*shapes_b_[4]; + } + else{ + M_ = shapes_c_[0]*shapes_c_[2]*shapes_c_[3]*shapes_c_[4]; + N_ = shapes_c_[1]; + K_ = shapes_b_[0]*shapes_b_[1]*shapes_b_[2]*shapes_b_[3]; + } // look-up table info - Fs_ = T_*R_*S_; + Fs_ = shapes_b_[1]*shapes_b_[2]*shapes_b_[3]; TK_ = 8; Luts_ = (TK_ + Fs_ - 1) / Fs_ * Fs_; } size_t a_size() { - return B_*NC_*D_*H_*W_; + return std::accumulate(shapes_a_.begin(), shapes_a_.end(), + 1, std::multiplies()); } size_t b_size() { - return NC_*NF_*T_*R_*S_; + return std::accumulate(shapes_b_.begin(), shapes_b_.end(), + 1, std::multiplies()); } size_t c_size() { - return B_*NF_*RD_*RH_*RW_; + return std::accumulate(shapes_c_.begin(), shapes_c_.end(), + 1, std::multiplies()); } void build_deltas(std::vector& deltas){ + if(ty_ == WGRAD) + throw std::runtime_error("no look-up table necessary for wgrad"); deltas.resize(Luts_ + upsample_d_*upsample_h_*upsample_w_*Luts_); auto unpack = [&](int32_t trs){ - int32_t tr = trs / S_; - int32_t s = trs - tr*S_; - int32_t t = tr / R_; - int32_t r = tr - t*R_; + int32_t tr = trs / BW_; + int32_t s = trs - tr*BW_; + int32_t t = tr / BH_; + int32_t r = tr - t*BH_; return std::make_tuple(t, r, s); }; for(size_t i = 0; i < Luts_; ++i) @@ -112,18 +131,20 @@ public: int32_t rdiff = (nextr + ph)/upsample_h_ - (r + ph)/upsample_h_; int32_t sdiff = (nexts + pw)/upsample_w_ - (s + pw)/upsample_w_; // delta pointers - deltas_ptr[i] = cdiff*stride_a_c_ + sdiff*stride_a_w_ + rdiff*stride_a_h_ + tdiff*stride_a_d_; + deltas_ptr[i] = cdiff*ld_a_[1] + tdiff*ld_a_[2] + rdiff*ld_a_[3] + sdiff*ld_a_[4]; } } } void build_masks(std::vector& masks){ + if(ty_ == WGRAD) + throw std::runtime_error("no look-up table necessary for wgrad"); masks.resize(Luts_ + (2*pad_h_+1)*(2*pad_w_+1)*(2*pad_d_+1)*Luts_); auto unpack = [&](int32_t trs){ - int32_t tr = trs / S_; - int32_t s = trs - tr*S_; - int32_t t = tr / R_; - int32_t r = tr - t*R_; + int32_t tr = trs / BW_; + int32_t s = trs - tr*BW_; + int32_t t = tr / BH_; + int32_t r = tr - t*BH_; return std::make_tuple(t, r, s); }; size_t Ms0 = Luts_; @@ -139,9 +160,9 @@ public: int32_t mask = 0x0; for(size_t j = 0; j < TK_; ++j){ std::tie(t, r, s) = unpack((i + j) % Fs_); - bool in_bounds_d = (t + pd) >= pad_d_ && (t + pd) < (T_ + pad_d_); - bool in_bounds_h = (r + ph) >= pad_h_ && (r + ph) < (R_ + pad_h_); - bool in_bounds_w = (s + pw) >= pad_w_ && (s + pw) < (S_ + pad_w_); + bool in_bounds_d = (t + pd) >= pad_d_ && (t + pd) < (BD_ + pad_d_); + bool in_bounds_h = (r + ph) >= pad_h_ && (r + ph) < (BH_ + pad_h_); + bool in_bounds_w = (s + pw) >= pad_w_ && (s + pw) < (BW_ + pad_w_); mask |= (in_bounds_d && in_bounds_h && in_bounds_w) << j; } masks_ptr[i] = mask; @@ -168,46 +189,40 @@ public: kernel->setArg(3, M_); kernel->setArg(4, N_); kernel->setArg(5, K_); - kernel->setArg(6, B_); - kernel->setArg(7, H_); - kernel->setArg(8, W_); - kernel->setArg(9, NF_); - kernel->setArg(10, RH_); - kernel->setArg(11, RW_); - kernel->setArg(12, NC_); - kernel->setArg(13, R_); - kernel->setArg(14, S_); - kernel->setArg(15, stride_a_n_); - kernel->setArg(16, stride_a_c_); - kernel->setArg(17, stride_a_h_); - kernel->setArg(18, stride_a_w_); - kernel->setArg(19, stride_c_n_); - kernel->setArg(20, stride_c_k_); - kernel->setArg(21, stride_c_p_); - kernel->setArg(22, stride_c_q_); - kernel->setArg(23, pad_h_); - kernel->setArg(24, pad_w_); + kernel->setArg(6, AH_); + kernel->setArg(7, AW_); + kernel->setArg(8, BH_); + kernel->setArg(9, BW_); + kernel->setArg(10, CH_); + kernel->setArg(11, CW_); + kernel->setArg(12, ld_a_[0]); + kernel->setArg(13, ld_a_[1]); + kernel->setArg(14, ld_a_[2]); + kernel->setArg(15, ld_a_[3]); + kernel->setArg(16, ld_a_[4]); + kernel->setArg(17, ld_b_[0]); + kernel->setArg(18, ld_b_[1]); + kernel->setArg(19, ld_b_[2]); + kernel->setArg(20, ld_b_[3]); + kernel->setArg(21, ld_b_[4]); + kernel->setArg(22, ld_c_[0]); + kernel->setArg(23, ld_c_[1]); + kernel->setArg(24, ld_c_[2]); + kernel->setArg(25, ld_c_[3]); + kernel->setArg(26, ld_c_[4]); + kernel->setArg(27, pad_h_); + kernel->setArg(28, pad_w_); } std::vector default_params() { -// if(ty_ == FPROP) + if(ty_ == FPROP || ty_ == BPROP) return {16, 2, 64, 32, 2, 64, 16, 8, 2, 2, 8, 1, 8, 4}; -// else -// return {16, 2, 64, 16, 32, 16, 4, 2, 2, 4, 2, 8, 4, 2}; + else + return {8, 2, 16, 8, 2, 16, 8, 2, 8, 8}; } - std::string src() { - std::string bs0 = "TN", bs1 = "TK"; - std::string ldb0 = "*NF", ldb1 = ""; - std::string bcb0 = "[:, newaxis]", bcb1 = "[newaxis, :]"; - std::string b = "b"; - if(ty_ == BPROP){ - std::swap(bs0, bs1); - std::swap(ldb0, ldb1); - std::swap(bcb0, bcb1); - b = "trans(b)"; - } + std::string xprop() { std::string res = R"( const tunable int32 TM = {16, 32, 64}; @@ -221,36 +236,37 @@ public: read_only restrict fp32 *b, fp32 *c, int32 M, int32 N, int32 K, - int32 B, int32 H, int32 W, - int32 NF, int32 RH, int32 RW, - int32 NC, int32 R, int32 S, - int32 lda_n, int32 lda_c, int32 lda_h, int32 lda_w, - int32 ldc_n, int32 ldc_k, int32 ldc_p, int32 ldc_q, + int32 AH, int32 AW, + int32 BH, int32 BW, + int32 CH, int32 CW, + int32 lda_n, int32 lda_c, int32 lda_d, int32 lda_h, int32 lda_w, + int32 ldb_c, int32 ldb_t, int32 ldb_r, int32 ldb_s, int32 ldb_k, + int32 ldc_n, int32 ldc_k, int32 ldc_m, int32 ldc_p, int32 ldc_q, int32 pad_h, int32 pad_w){ int32 rxa[TM] = get_global_range[TM](0); int32 rb0[TN] = get_global_range[TN](1); int32 rka[TK] = 0 ... TK; int32 rb1[TK] = 0 ... TK; fp32 C[TM, TN] = 0; - int32 rabh[TM] = rxa / RW; - int32 raw[TM] = rxa % RW - pad_w; - int32 rab[TM] = rabh / RH; - int32 rah[TM] = rabh % RH - pad_h; + int32 rabh[TM] = rxa / CW; + int32 raw[TM] = rxa % CW - pad_w; + int32 rab[TM] = rabh / CH; + int32 rah[TM] = rabh % CH - pad_h; int32 ra0[TM] = rab*lda_n + rah*lda_h + raw*lda_w; - int32 racr[TK] = rka / S; - int32 ras[TK] = rka % S; - int32 rac[TK] = racr / R; - int32 rar[TK] = racr % R; + int32 racr[TK] = rka / BW; + int32 ras[TK] = rka % BW; + int32 rac[TK] = racr / BH; + int32 rar[TK] = racr % BH; int32 ra1[TK] = rac*lda_c + rar*lda_h + ras*lda_w; fp32* pa[TM, TK] = a + ra1[newaxis, :] + ra0[:, newaxis]; - fp32* pb[TN, TK] = b + rb1[newaxis, :]*NF + rb0[:, newaxis]; + fp32* pb[TN, TK] = b + rb1[newaxis, :]*ldb_s + rb0[:, newaxis]; __constant__ int32* pincd[TK] = delta + rka; - __constant__ int32* pd[TK] = delta + R*S + rka; + __constant__ int32* pd[TK] = delta + BH*BW + rka; int32 d[TK] = *pd; int32 incd[TK] = *pincd; - int32 maskh[TM] = pad_h + min(rah, 0) + max(rah + R - H, 0); - int32 maskw[TM] = pad_w + min(raw, 0) + max(raw + S - W, 0); - __constant__ int32* pm[TM] = masks + R*S + maskw*R*S + maskh*R*S*(2*pad_w + 1); + int32 maskh[TM] = pad_h + min(rah, 0) + max(rah + BH - AH, 0); + int32 maskw[TM] = pad_w + min(raw, 0) + max(raw + BW - AW, 0); + __constant__ int32* pm[TM] = masks + BH*BW + maskw*BH*BW + maskh*BH*BW*(2*pad_w + 1); __constant__ int32* pincm[TM] = delta; int32 incm[TM] = *pincm; int32 checka0[TM] = *pm; @@ -260,7 +276,7 @@ public: fp32 b[TN, TK] = *pb; for(int32 k = K; k > 0; k = k - TK){ C = dot(a, trans(b), C); - pb = pb + TK*NF; + pb = pb + TK*ldb_s; pa = pa + d[newaxis, :]; b = *pb; pd = pd + incd; @@ -276,8 +292,8 @@ public: } int32 rxc[TM] = get_global_range[TM](0); int32 rc1[TN] = get_global_range[TN](1); - int32 rcn[TM] = rxc / (RH*RW); - int32 rcpq[TM] = rxc % (RH*RW); + int32 rcn[TM] = rxc / (CH*CW); + int32 rcpq[TM] = rxc % (CH*CW); int32 rc0[TM] = rcn * ldc_n + rcpq; fp32* pc[TM, TN] = c + rc1[newaxis, :]*ldc_k + rc0[:, newaxis]; int1 checkc0[TM] = rxc < M; @@ -288,62 +304,169 @@ public: return res; } + // C = A * B + // where A is N,C,AH,AW + // B is N,K,BH,BW + // C is C,CH,CW,K + std::string wgrad() { + std::string res = + R"( + const tunable int32 TM = {16, 32, 64}; + const tunable int32 TN = {16, 32, 64}; + const tunable int32 TK = {8}; + + void conv(read_only restrict fp32 *a, + read_only restrict fp32 *b, + fp32 *c, + int32 M, int32 N, int32 K, + int32 AH, int32 AW, + int32 CH, int32 CW, + int32 BH, int32 BW, + int32 lda_n, int32 lda_c, int32 lda_d, int32 lda_h, int32 lda_w, + int32 ldb_n, int32 ldb_k, int32 ldb_m, int32 ldb_p, int32 ldb_q, + int32 ldc_c, int32 ldc_t, int32 ldc_r, int32 ldc_s, int32 ldc_k, + int32 pad_h, int32 pad_w){ + int32 rxa[TM] = get_global_range[TM](0); + int32 ryb[TN] = get_global_range[TN](1); + int32 rk[TK] = 0 ... TK; + fp32 C[TM, TN] = 0; + int32 racr[TM] = rxa / CW; + int32 raw_base[TM] = rxa % CW - pad_w; + int32 rac[TM] = racr / CH; + int32 rah_base[TM] = racr % CH - pad_h; + fp32* pa_base[TM, TK] = a + rac[:, newaxis]*lda_c; + fp32* pb_base[TN, TK] = b + ryb[:, newaxis]*ldb_k; + for(int32 k = K; k > 0; k = k - TK){ + int32 rknp[TK] = rk / BW; + int32 rkq[TK] = rk % BW; + int32 rkn[TK] = rknp / BH; + int32 rkp[TK] = rknp % BH; + int32 rah[TM, TK] = rah_base[:, newaxis] + rkp[newaxis, :]; + int32 raw[TM, TK] = raw_base[:, newaxis] + rkq[newaxis, :]; + int1 checka[TM, TK] = (rah >= 0) && (rah < AH) && (raw >= 0) && (raw < AW); + fp32* pa[TM, TK] = pa_base + rah*lda_h + raw*lda_w + rkn*lda_n; + fp32* pb[TN, TK] = pb_base + rkp*ldb_p + rkq*ldb_q + rkn*ldb_n; + fp32 A[TM, TK] = checka ? *pa : 0; + fp32 B[TN, TK] = *pb; + C = dot(A, trans(B), C); + rk = rk + TK; + } + int32 rxc[TM] = get_global_range[TM](0); + int32 ryc[TN] = get_global_range[TN](1); + int32 rccr[TM] = rxc / CW; + int32 rcs[TM] = rxa % CW; + int32 rcc[TM] = racr / CH; + int32 rcr[TM] = racr % CH; + int32 rc0[TM] = rcc*ldc_c + rcr*ldc_r + rcs*ldc_s; + fp32* pc[TM, TN] = c + rc0[:, newaxis] + ryc[newaxis, :]*ldc_k; + int1 checkc0[TM] = rxc < M; + int1 checkc1[TN] = ryc < N; + int1 checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :]; + @checkc *pc = C; + })"; + return res; + } + + std::string src() { + if(ty_ == FPROP || ty_ == BPROP) + return xprop(); + else + return wgrad(); + } + + template + void cpu_xprop(OUT_DTYPE* C, IN_DTYPE* A, IN_DTYPE* B) + { + IN_DTYPE acc; + for(int32_t n = 0; n < shapes_c_[0]; ++n) + for(int32_t k = 0; k < shapes_c_[1] ; ++k) + for(int32_t cd = 0 ; cd < shapes_c_[2]; ++cd) + for(int32_t ch = 0 ; ch < shapes_c_[3]; ++ch) + for(int32_t cw = 0; cw < shapes_c_[4]; ++cw) + { + acc = 0; + int32_t d = cd*stride_d_ - pad_d_; + int32_t h = ch*stride_h_ - pad_h_; + int32_t w = cw*stride_w_ - pad_w_; + for(int32_t c = 0; c < shapes_b_[0]; ++c) + for(int32_t bd = 0; bd < shapes_b_[1]; ++bd) + for(int32_t bh = 0; bh < shapes_b_[2]; ++bh) + for(int32_t bw = 0; bw < shapes_b_[3]; ++bw){ + int32_t ad = d + bd; + int32_t ah = h + bh; + int32_t aw = w + bw; + bool in_bounds = (ad >= 0 && ad < shapes_a_[2] && + ah >= 0 && ah < shapes_a_[3] && + aw >= 0 && aw < shapes_a_[4]); + IN_DTYPE a = 0; + if(in_bounds) + a = A[n*ld_a_[0] + c*ld_a_[1] + ad*ld_a_[2] + ah*ld_a_[3] + aw*ld_a_[4]]; + IN_DTYPE b = B[c*ld_b_[0] + bd*ld_b_[1] + bh*ld_b_[2] + bw*ld_b_[3] + k*ld_b_[4]]; + acc = std::fma(a, b, acc); + } + C[n*ld_c_[0] + k*ld_c_[1] + cd*ld_c_[2] + ch*ld_c_[3] + cw*ld_c_[4]] = acc; + } + } + + template + void cpu_wgrad(OUT_DTYPE* C, IN_DTYPE* A, IN_DTYPE* B) + { + IN_DTYPE acc; + for(int32_t c = 0 ; c < shapes_c_[0]; ++c) + for(int32_t cd = 0; cd < shapes_c_[1]; ++cd) + for(int32_t ch = 0; ch < shapes_c_[2]; ++ch) + for(int32_t cw = 0; cw < shapes_c_[3]; ++cw) + for(int32_t k = 0 ; k < shapes_c_[4]; ++k) + { + acc = 0; + int32_t d = cd*stride_d_ - pad_d_; + int32_t h = ch*stride_h_ - pad_h_; + int32_t w = cw*stride_w_ - pad_w_; + for(int32_t n = 0; n < shapes_b_[0]; ++n) + for(int32_t bd = 0; bd < shapes_b_[2]; ++bd) + for(int32_t bh = 0; bh < shapes_b_[3]; ++bh) + for(int32_t bw = 0; bw < shapes_b_[4]; ++bw){ + int32_t ad = d + bd; + int32_t ah = h + bh; + int32_t aw = w + bw; + bool in_bounds = (ad >= 0 && ad < shapes_a_[2] && + ah >= 0 && ah < shapes_a_[3] && + aw >= 0 && aw < shapes_a_[4]); + IN_DTYPE a = 0; + if(in_bounds) + a = A[n*ld_a_[0] + c*ld_a_[1] + ad*ld_a_[2] + ah*ld_a_[3] + aw*ld_a_[4]]; + IN_DTYPE b = B[n*ld_b_[0] + k*ld_b_[1] + bd*ld_b_[2] + bh*ld_b_[3] + bw*ld_b_[4]]; + acc = std::fma(a, b, acc); + } + C[c*ld_c_[0] + cd*ld_c_[1] + ch*ld_c_[2] + cw*ld_c_[3] + k*ld_c_[4]] = acc; + } + } + template void cpu_ref(OUT_DTYPE* C, IN_DTYPE* A, IN_DTYPE* B) { - auto idx = [&](int32_t x, int32_t y, int32_t z, int32_t w, int32_t u, - int32_t /*s0*/, int32_t s1, int32_t s2, int32_t s3, int32_t s4) - { return u + w*s4 + z*s4*s3 + y*s4*s3*s2 + x*s4*s3*s2*s1; }; - - IN_DTYPE accs[1]; - float tmp[1]; - for(int32_t m = 0 ; m < RD_; ++m) - for(int32_t p = 0 ; p < RH_; ++p) - for(int32_t q = 0; q < RW_; ++q) - for(int32_t n = 0; n < B_; ++n) - for(int32_t k = 0; k < NF_ ; ++k) - { - for(int32_t i = 0; i < 1; ++i) - accs[i] = 0; - int32_t mm = m*stride_d_ - pad_d_; - int32_t pp = p*stride_h_ - pad_h_; - int32_t qq = q*stride_w_ - pad_w_; - for(int32_t kk = 0; kk < 1; ++kk) - for(int32_t c = 0; c < NC_; ++c) - for(int32_t t = 0; t < T_; ++t) - for(int32_t r = 0; r < R_; ++r) - for(int32_t s = 0; s < S_; ++s){ - int32_t d = mm + t; - int32_t h = pp + r; - int32_t w = qq + s; - bool in_bounds = (d >= 0 && h >= 0 && w >= 0 && d < D_ && h < H_ && w < W_); - IN_DTYPE a = in_bounds?A[idx(n, c, d, h, w, B_, NC_, D_, H_, W_)]:0; - IN_DTYPE b = B[idx(c, t, r, s, k*1 + kk, NC_, T_, R_, S_, NF_*1)]; - accs[kk] = std::fma(a, b, accs[kk]); - } - for(int32_t kk = 0; kk < 1; ++kk){ - tmp[kk] = accs[kk]; - } - C[idx(n, k, m, p, q, B_, NF_, RD_, RH_, RW_)] = tmp[0]; - } + if(ty_ == FPROP || ty_ == BPROP) + cpu_xprop(C, A, B); + else + cpu_wgrad(C, A, B); } private: // image size - int32_t B_; + int32_t NB_; int32_t NC_; - int32_t D_; - int32_t H_; - int32_t W_; + int32_t AD_; + int32_t AH_; + int32_t AW_; // filter size - int32_t T_; - int32_t R_; - int32_t S_; + int32_t BD_; + int32_t BH_; + int32_t BW_; int32_t NF_; // activation size - int32_t RD_; - int32_t RH_; - int32_t RW_; + int32_t CD_; + int32_t CH_; + int32_t CW_; // upsampling int32_t upsample_d_; int32_t upsample_h_; @@ -364,18 +487,15 @@ private: int32_t Fs_; int32_t TK_; int32_t Luts_; - // memory strides for data - int32_t stride_a_w_; - int32_t stride_a_h_; - int32_t stride_a_d_; - int32_t stride_a_c_; - int32_t stride_a_n_; - // memory stride for activations - int32_t stride_c_q_; - int32_t stride_c_p_; - int32_t stride_c_m_; - int32_t stride_c_k_; - int32_t stride_c_n_; + // memory strides for A + std::vector shapes_a_; + std::vector ld_a_; + // memory strides for B + std::vector shapes_b_; + std::vector ld_b_; + // memory stride for C + std::vector shapes_c_; + std::vector ld_c_; // type type ty_; bool is_bprop_; diff --git a/lib/codegen/selection.cpp b/lib/codegen/selection.cpp index 7927e5400..6b638abaa 100644 --- a/lib/codegen/selection.cpp +++ b/lib/codegen/selection.cpp @@ -812,7 +812,9 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> & std::swap(b_idx[0], b_idx[1]); Value *a = TA->get_value(a_idx); Value *b = TB->get_value(b_idx); +// res = builder.CreateCall(f_mul_add, {ConstantFP::get(a->getType(), 1), ConstantFP::get(b->getType(), 1), res}); res = builder.CreateCall(f_mul_add, {a, b, res}); + } result->set_value(idx, res); });