From ece7beea3cf52e7b06b9469b6e8430bb316d44d1 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Wed, 15 May 2019 14:57:31 -0400 Subject: [PATCH] [dnn/conv]: now using look-up table for wgrad computation as well --- examples/cpp/conv.cpp | 17 +- examples/cpp/dot.cpp | 2 +- include/triton/dnn/conv.h | 338 ++++++++++++++++++++++---------------- 3 files changed, 201 insertions(+), 156 deletions(-) diff --git a/examples/cpp/conv.cpp b/examples/cpp/conv.cpp index 76c8cd096..f836edcb4 100644 --- a/examples/cpp/conv.cpp +++ b/examples/cpp/conv.cpp @@ -38,12 +38,6 @@ int main() { stream->write(db, true, 0, hb); stream->write(dc, true, 0, hc); stream->synchronize(); - // look-up table - std::vector h_delta, h_masks; - if(ty != triton::dnn::conv::WGRAD){ - configuration.build_deltas(h_delta); - configuration.build_masks(h_masks); - } // benchmark a given convolution kernel auto benchmark = [&](triton::driver::kernel* kernel, triton::jit::launch_information info) { @@ -51,12 +45,7 @@ int main() { unsigned TN = info.global_range_size[1]; unsigned nthreads = info.num_threads; std::array grid = configuration.get_grid(TM, TN); - if(ty != triton::dnn::conv::WGRAD){ - triton::driver::buffer* delta = jit.get_buffer("delta"); - triton::driver::buffer* masks = jit.get_buffer("masks"); - stream->write(delta, false, 0, h_delta.size()*4, h_delta.data()); - stream->write(masks, false, 0, h_masks.size()*4, h_masks.data()); - } + configuration.init(stream, jit); stream->synchronize(); configuration.set_arg(kernel, da, db, dc); stream->enqueue(kernel, grid, {nthreads, 1, 1}); @@ -66,7 +55,7 @@ int main() { return configuration.get_nflops() / ts * 1e-3; }; std::string src = configuration.src(); -// jit.autotune("conv", src.c_str(), benchmark); + jit.autotune("conv", src.c_str(), benchmark); jit.add_module("conv", src.c_str(), configuration.default_params()); triton::driver::kernel* kernel = jit.get_function("conv"); triton::jit::launch_information info = jit.get_launch_info("conv"); @@ -74,7 +63,7 @@ int main() { stream->read(dc, true, 0, hc); configuration.cpu_ref(rc.data(), ha.data(), hb.data()); for(size_t i = 0; i < hc.size(); i++){ - if(std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-4){ + if(std::isnan(hc[i]) || std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-4){ std::cout << i << " " << hc[i] << " " << rc[i] << std::endl; exit(EXIT_FAILURE); } diff --git a/examples/cpp/dot.cpp b/examples/cpp/dot.cpp index 0c735d9f4..3dde373ef 100644 --- a/examples/cpp/dot.cpp +++ b/examples/cpp/dot.cpp @@ -68,7 +68,7 @@ int main() { stream->read(dc, true, 0, hc); simple_gemm(AT, BT, rc, ha, hb, M, N, K); for(size_t i = 0; i < M*N; i++) - if(std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-4){ + if(!std::isnan(hc[i]) && std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-4){ std::cout << i << " " << hc[i] << " " << rc[i] << std::endl; exit(EXIT_FAILURE); } diff --git a/include/triton/dnn/conv.h b/include/triton/dnn/conv.h index b2e5cd3dc..0afa77088 100644 --- a/include/triton/dnn/conv.h +++ b/include/triton/dnn/conv.h @@ -4,6 +4,7 @@ #include #include "triton/driver/stream.h" #include "triton/driver/kernel.h" +#include "triton/jit.h" namespace triton{ namespace dnn{ @@ -46,6 +47,9 @@ public: // swap b and c for wgrad if(ty_ == WGRAD){ shapes_b_.swap(shapes_c_); + std::swap(BD_, CD_); + std::swap(BH_, CH_); + std::swap(BW_, CW_); } // leading dimensions auto set_ld = [](const std::vector& shapes, @@ -62,6 +66,8 @@ public: set_ld(shapes_b_, ld_b_); set_ld(shapes_c_, ld_c_); // equivalent matmul + b_trans_ = ty_ != BPROP; + b_lut_ = ty_ == WGRAD; if(ty_ == WGRAD){ M_ = shapes_c_[0]*shapes_c_[1]*shapes_c_[2]*shapes_c_[3]; N_ = shapes_c_[4]; @@ -73,11 +79,20 @@ public: K_ = shapes_b_[0]*shapes_b_[1]*shapes_b_[2]*shapes_b_[3]; } // look-up table info - Fs_ = shapes_b_[1]*shapes_b_[2]*shapes_b_[3]; - if(ty_ == BPROP) - Fs_ *= shapes_b_[4]; + if(ty_ == FPROP) + Fs_ = shapes_b_[1]*shapes_b_[2]*shapes_b_[3]; + else + Fs_ = K_; TK_ = 8; Luts_ = (TK_ + Fs_ - 1) / Fs_ * Fs_; + build_deltas(); + build_masks(); + size_t cst_size = h_b_deltas_.size()*4; + is_b_deltas_cst_ = cst_size < 65536; + cst_size += h_a_deltas_.size()*4; + is_a_deltas_cst = cst_size < 65536; + cst_size += h_masks_.size()*4; + is_mask_cst_ = cst_size < 65536; } size_t a_size() { @@ -99,14 +114,14 @@ public: return shapes_c_; } - void build_deltas(std::vector& deltas){ - if(ty_ == WGRAD) - throw std::runtime_error("no look-up table necessary for wgrad"); - deltas.resize(Luts_ + upsample_d_*upsample_h_*upsample_w_*Luts_); + void build_deltas(){ + h_a_deltas_.resize(Luts_ + upsample_d_*upsample_h_*upsample_w_*Luts_); + if(b_lut_) + h_b_deltas_.resize(Luts_); auto unpack = [&](int32_t ltrs){ - int32_t l = (ty_ == BPROP) ? ltrs % NF_ : ltrs / Fs_; - int32_t trs = (ty_ == BPROP) ? ltrs / NF_ : ltrs % Fs_; + int32_t l = (ty_ == BPROP) ? ltrs % NF_ : ltrs / (BD_*BH_*BW_); + int32_t trs = (ty_ == BPROP) ? ltrs / NF_ : ltrs % (BD_*BH_*BW_); int32_t tr = trs / BW_; int32_t s = trs % BW_; int32_t t = tr / BH_; @@ -119,7 +134,7 @@ public: }; for(size_t i = 0; i < Luts_; ++i) - deltas[i] = (((i + TK_) % Luts_) - i); + h_a_deltas_[i] = (((i + TK_) % Luts_) - i); size_t Ds0 = Luts_; size_t Ds1 = upsample_w_; @@ -128,7 +143,7 @@ public: for(size_t pd = 0; pd < Ds3; ++pd) for(size_t ph = 0; ph < Ds2; ++ph) for(size_t pw = 0; pw < Ds1; ++pw){ - int32_t* deltas_ptr = &deltas[Luts_ + pw*Ds0 + ph*Ds0*Ds1 + pd*Ds0*Ds1*Ds2]; + int32_t* deltas_ptr = &h_a_deltas_[Luts_ + pw*Ds0 + ph*Ds0*Ds1 + pd*Ds0*Ds1*Ds2]; // cumulative increments for(size_t i = 0; i < Ds0; ++i) { // unpack @@ -145,18 +160,31 @@ public: int32_t rdiff = (nextr + ph)/upsample_h_ - (r + ph)/upsample_h_; int32_t sdiff = (nexts + pw)/upsample_w_ - (s + pw)/upsample_w_; // delta pointers - deltas_ptr[i] = cdiff*ld_a_[1] + tdiff*ld_a_[2] + rdiff*ld_a_[3] + sdiff*ld_a_[4]; + if(ty_ == WGRAD) + deltas_ptr[i] = cdiff*ld_a_[0] + tdiff*ld_a_[2] + rdiff*ld_a_[3] + sdiff*ld_a_[4]; + else + deltas_ptr[i] = cdiff*ld_a_[1] + tdiff*ld_a_[2] + rdiff*ld_a_[3] + sdiff*ld_a_[4]; + } + } + + if(ty_ == WGRAD){ + for(size_t i = 0; i < Ds0; ++i) { + int32_t c, t, r, s; + int32_t nextc, nextt, nextr, nexts; + std::tie(c, t, r, s) = unpack(i); + std::tie(nextc, nextt, nextr, nexts) = unpack(i + TK_); + int32_t cdiff = nextc - c, tdiff = nextt - t, rdiff = nextr - r, sdiff = nexts - s; + h_b_deltas_[i] = cdiff*ld_b_[0] + tdiff*ld_b_[2] + rdiff*ld_b_[3] + sdiff*ld_b_[4]; } } } - void build_masks(std::vector& masks){ - if(ty_ == WGRAD) - throw std::runtime_error("no look-up table necessary for wgrad"); - masks.resize(Luts_ + (2*pad_h_+1)*(2*pad_w_+1)*(2*pad_d_+1)*Luts_); + void build_masks(){ + h_masks_.resize(Luts_ + (2*pad_h_+1)*(2*pad_w_+1)*(2*pad_d_+1)*Luts_); + auto unpack = [&](int32_t ltrs){ - int32_t l = (ty_ == BPROP) ? ltrs % NF_ : ltrs / Fs_; - int32_t trs = (ty_ == BPROP) ? ltrs / NF_ : ltrs % Fs_; + int32_t l = (ty_ == BPROP) ? ltrs % NF_ : ltrs / (BD_*BH_*BW_); + int32_t trs = (ty_ == BPROP) ? ltrs / NF_ : ltrs % (BD_*BH_*BW_); int32_t tr = trs / BW_; int32_t s = trs % BW_; int32_t t = tr / BH_; @@ -174,7 +202,7 @@ public: for(size_t pd = 0; pd < Ms3; ++pd) for(size_t ph = 0; ph < Ms2; ++ph) for(size_t pw = 0; pw < Ms1; ++pw){ - int32_t* masks_ptr = &masks[Luts_ + pw*Ms0 + ph*Ms0*Ms1 + pd*Ms0*Ms1*Ms2]; + int32_t* masks_ptr = &h_masks_[Luts_ + pw*Ms0 + ph*Ms0*Ms1 + pd*Ms0*Ms1*Ms2]; for(size_t i = 0; i < Ms0; ++i){ int32_t l, t, r, s; int32_t mask = 0x0; @@ -189,7 +217,7 @@ public: } } for(size_t i = 0; i < Luts_; ++i) - masks[i] = 0x0; + h_masks_[i] = 0x0; } std::array get_grid(size_t TM, size_t TN){ @@ -200,6 +228,27 @@ public: return 2.*M_*N_*K_; } + void init(driver::stream *stream, triton::jit &jit) { + auto init_lut = [&](bool is_cst, const char *name, std::vector host) -> triton::driver::buffer*{ + if(host.empty()) + return nullptr; + size_t nbytes = host.size()*4; + // get buffer + triton::driver::buffer* buffer; + if(is_cst) + buffer = jit.get_buffer(name); + else + buffer = triton::driver::buffer::create(stream->context(), nbytes); + // copy + stream->write(buffer, false, 0, nbytes, host.data()); + return buffer; + }; + + d_a_deltas_ = init_lut(is_a_deltas_cst, "delta", h_a_deltas_); + d_b_deltas_ = init_lut(is_b_deltas_cst_, "b_delta", h_b_deltas_); + d_masks_ = init_lut(is_mask_cst_, "masks", h_masks_); + } + void set_arg(driver::kernel *kernel, driver::buffer *a, driver::buffer *b, driver::buffer *c) { @@ -211,70 +260,107 @@ public: kernel->setArg(5, K_); kernel->setArg(6, AH_); kernel->setArg(7, AW_); + kernel->setArg(8, BH_); + kernel->setArg(9, BW_); + kernel->setArg(10, CH_); + kernel->setArg(11, CW_); + // A arguments if(ty_ == WGRAD){ - kernel->setArg(8, CH_); - kernel->setArg(9, CW_); - kernel->setArg(10, BH_); - kernel->setArg(11, BW_); + kernel->setArg(12, ld_a_[1]); + kernel->setArg(13, ld_a_[0]); } else{ - kernel->setArg(8, BH_); - kernel->setArg(9, BW_); - kernel->setArg(10, CH_); - kernel->setArg(11, CW_); + kernel->setArg(12, ld_a_[0]); + kernel->setArg(13, ld_a_[1]); } - kernel->setArg(12, ld_a_[0]); - kernel->setArg(13, ld_a_[1]); kernel->setArg(14, ld_a_[2]); kernel->setArg(15, ld_a_[3]); kernel->setArg(16, ld_a_[4]); - kernel->setArg(17, ld_b_[0]); - kernel->setArg(18, ld_b_[1]); - kernel->setArg(19, ld_b_[2]); - kernel->setArg(20, ld_b_[3]); - kernel->setArg(21, ld_b_[4]); - kernel->setArg(22, ld_c_[0]); - kernel->setArg(23, ld_c_[1]); - kernel->setArg(24, ld_c_[2]); - kernel->setArg(25, ld_c_[3]); - kernel->setArg(26, ld_c_[4]); + // B arguments + if(ty_ == WGRAD){ + kernel->setArg(17, ld_b_[0]); + kernel->setArg(18, ld_b_[2]); + kernel->setArg(19, ld_b_[3]); + kernel->setArg(20, ld_b_[4]); + kernel->setArg(21, ld_b_[1]); + } + else{ + kernel->setArg(17, ld_b_[0]); + kernel->setArg(18, ld_b_[1]); + kernel->setArg(19, ld_b_[2]); + kernel->setArg(20, ld_b_[3]); + kernel->setArg(21, ld_b_[4]); + } + // C arguments + if(ty_ == WGRAD){ + kernel->setArg(22, ld_c_[0]); + kernel->setArg(23, ld_c_[4]); + kernel->setArg(24, ld_c_[1]); + kernel->setArg(25, ld_c_[2]); + kernel->setArg(26, ld_c_[3]); + } + else{ + kernel->setArg(22, ld_c_[0]); + kernel->setArg(23, ld_c_[1]); + kernel->setArg(24, ld_c_[2]); + kernel->setArg(25, ld_c_[3]); + kernel->setArg(26, ld_c_[4]); + } kernel->setArg(27, pad_h_); kernel->setArg(28, pad_w_); + size_t idx = 29; + if(!is_a_deltas_cst) + kernel->setArg(idx++, d_a_deltas_); + if(!is_b_deltas_cst_) + kernel->setArg(idx++, d_b_deltas_); + if(!is_mask_cst_) + kernel->setArg(idx++, d_masks_); } std::vector default_params() { - if(ty_ == FPROP) + if(ty_==FPROP) return {16, 2, 64, 32, 2, 64, 16, 8, 2, 2, 8, 1, 8, 4}; else if(ty_ == BPROP) return {32, 2, 64, 32, 64, 32, 4, 2, 2, 4, 2, 8, 4, 2}; - else - return {8, 2, 16, 8, 2, 16, 8, 2, 8, 8}; + else if(ty_ == WGRAD) + return {32, 2, 64, 32, 2, 64, 16, 8, 2, 2, 4, 2, 8}; } - std::string xprop() { - bool trans_b = ty_ == FPROP; - std::string BS = trans_b ?"[TN,TK]" : "[TK, TN]"; - std::string bcb0 = trans_b ?"[:, newaxis]" : "[newaxis, :]"; - std::string bcb1 = trans_b ?"[newaxis, :]" : "[:, newaxis]"; - std::string ldb0 = trans_b ?"*ldb_s" : ""; - std::string ldb1 = trans_b ?"" : "*ldb_c"; - std::string useb = trans_b ?"trans(b)" : "b"; - std::string flipr = trans_b?"" : "BH - 1 -"; - std::string flips = trans_b?"" : "BW - 1 -"; - std::string ax = trans_b?"crs" : "rsc"; - std::vector redax = {"BH", "BW", "N"}; - if(trans_b) + std::string src() { + bool is_wgrad = ty_ == WGRAD; + std::string BS = b_trans_ ? "[TN,TK]" : "[TK, TN]"; + std::string bcb0 = b_trans_ ? "[:, newaxis]" : "[newaxis, :]"; + std::string bcb1 = b_trans_ ? "[newaxis, :]" : "[:, newaxis]"; + std::string ldb0 = b_trans_ ? "*ldb_s" : ""; + std::string ldb1 = b_trans_ ? "*ldb_k" : "*ldb_c"; + std::string useb = b_trans_ ? "trans(b)" : "b"; + std::string flipr = b_trans_ ? "" : "BH - 1 -"; + std::string flips = b_trans_ ? "" : "BW - 1 -"; + std::string ax = b_trans_ ? "crs" : "rsc"; + std::vector redax; + if(b_trans_) redax = {"C", "BH", "BW"}; + else + redax = {"BH", "BW", "N"}; + std::string inc_pb = is_wgrad ? "db[newaxis, :]" : "TK" + ldb0; + std::string a_delta_mem = is_a_deltas_cst ? "__constant__" : ""; + std::string b_delta_mem = is_b_deltas_cst_? "__constant__" : ""; + std::string masks_mem = is_mask_cst_? "__constant__" : ""; std::string res = R"( const tunable int32 TM = {16, 32, 64}; const tunable int32 TN = {16, 32, 64}; const tunable int32 TK = {8}; - - __constant__ int32* delta = alloc_const int32[1024]; - __constant__ int32* masks = alloc_const int32[4096]; + )"; + if(is_a_deltas_cst) + res += "__constant__ int32* delta = alloc_const int32[" + std::to_string(h_a_deltas_.size()) + "];\n"; + if(is_wgrad && is_b_deltas_cst_) + res += "__constant__ int32* b_delta = alloc_const int32[" + std::to_string(h_b_deltas_.size()) + "];\n"; + if(is_mask_cst_) + res += "__constant__ int32* masks = alloc_const int32[" + std::to_string(h_masks_.size()) + "];\n"; + res += R"( void conv(read_only restrict fp32 *a, read_only restrict fp32 *b, @@ -286,13 +372,20 @@ public: int32 lda_n, int32 lda_c, int32 lda_d, int32 lda_h, int32 lda_w, int32 ldb_c, int32 ldb_t, int32 ldb_r, int32 ldb_s, int32 ldb_k, int32 ldc_n, int32 ldc_k, int32 ldc_m, int32 ldc_p, int32 ldc_q, - int32 pad_h, int32 pad_w){ + int32 pad_h, int32 pad_w)"; + if(!is_a_deltas_cst) + res += ", int32* delta\n"; + if(is_wgrad && !is_b_deltas_cst_) + res += ", int32* b_delta\n"; + if(!is_mask_cst_) + res += ", int32* masks\n"; + res += R"(){ int32 rxa[TM] = get_global_range[TM](0); int32 rb0[TN] = get_global_range[TN](1); int32 rka[TK] = 0 ... TK; int32 rkb[TK] = 0 ... TK; fp32 C[TM, TN] = 0; - int32 Fs = )" + std::to_string(Fs_) + R"(; + int32 ldlut = )" + std::to_string(Fs_) + R"(; int32 rabh[TM] = rxa / CW; int32 raw[TM] = rxa % CW - pad_w; int32 rab[TM] = rabh / CH; @@ -305,16 +398,31 @@ public: rar = )" + flipr + R"( rar; ras = )" + flips + R"( ras; int32 ra1[TK] = rac*lda_c + rar*lda_h + ras*lda_w; - fp32* pa[TM, TK] = a + ra1[newaxis, :] + ra0[:, newaxis]; - fp32* pb)" + BS + " = b + rkb" + bcb1 + ldb0 + " + rb0" + bcb0 + ldb1 + R"(; - __constant__ int32* pincd[TK] = delta + rka; - __constant__ int32* pd[TK] = delta + Fs + rka; + fp32* pa[TM, TK] = a + ra1[newaxis, :] + ra0[:, newaxis];)"; + if(ty_ == WGRAD){ + res += R"( + int32 rbcr[TK] = rkb / BW; + int32 rbs[TK] = rkb % BW; + int32 rbc[TK] = rbcr / BH; + int32 rbr[TK] = rbcr % BH; + int32 rb1[TK] = rbc*ldb_c + rbr*ldb_r + ras*ldb_s; + )" + b_delta_mem + R"( int32* pdb[TK] = b_delta + rkb; + int32 db[TK] = *pdb;)"; + } + else{ + res += R"( + int32 rb1[TK] = rkb;)"; + } + res += R"( + fp32* pb)" + BS + " = b + rb1" + bcb1 + ldb0 + " + rb0" + bcb0 + ldb1 + R"(; + )" + a_delta_mem + R"( int32* pincd[TK] = delta + rka; + )" + a_delta_mem + R"( int32* pd[TK] = delta + ldlut + rka; int32 d[TK] = *pd; int32 incd[TK] = *pincd; int32 maskh[TM] = pad_h + min(rah, 0) + max(rah + BH - AH, 0); int32 maskw[TM] = pad_w + min(raw, 0) + max(raw + BW - AW, 0); - __constant__ int32* pm[TM] = masks + Fs + maskw*Fs + maskh*Fs*(2*pad_w + 1); - __constant__ int32* pincm[TM] = delta; + )" + masks_mem + R"( int32* pm[TM] = masks + ldlut + maskw*ldlut + maskh*ldlut*(2*pad_w + 1); + )" + a_delta_mem + R"( int32* pincm[TM] = delta; int32 incm[TM] = *pincm; int32 checka0[TM] = *pm; int32 checka1[TK] = 1 << rka; @@ -324,9 +432,15 @@ public: for(int32 k = K; k > 0; k = k - TK){ C = dot(a, )" + useb + R"(, C); pa = pa + d[newaxis, :]; - pb = pb + TK)" + ldb0 + R"(; + pb = pb + )" + inc_pb + R"(; b = *pb; - pd = pd + incd; + pd = pd + incd;)"; + if(ty_ == WGRAD){ + res += R"( + pdb = pdb + incd; + db = *pdb;)"; + } + res += R"( pincd = pincd + incd; d = *pd; incd = *pincd; @@ -342,86 +456,17 @@ public: int32 rc1[TN] = get_global_range[TN](1); int32 rcn[TM] = rxc / (CH*CW); int32 rcpq[TM] = rxc % (CH*CW); - int32 rc0[TM] = rcn * ldc_n + rcpq; + int32 rc0[TM] = rcn * ldc_n + rcpq * ldc_q; fp32* pc[TM, TN] = c + rc1[newaxis, :]*ldc_k + rc0[:, newaxis]; int1 checkc0[TM] = rxc < M; int1 checkc1[TN] = rc1 < N; int1 checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :]; @checkc *pc = C; })"; + return res; } - // C = A * B - // where A is N,C,AH,AW - // B is N,K,BH,BW - // C is C,CH,CW,K - std::string wgrad() { - std::string res = - R"( - const tunable int32 TM = {16, 32, 64}; - const tunable int32 TN = {16, 32, 64}; - const tunable int32 TK = {8}; - - void conv(read_only restrict fp32 *a, - read_only restrict fp32 *b, - fp32 *c, - int32 M, int32 N, int32 K, - int32 AH, int32 AW, - int32 BH, int32 BW, - int32 CH, int32 CW, - int32 lda_n, int32 lda_c, int32 lda_d, int32 lda_h, int32 lda_w, - int32 ldb_n, int32 ldb_k, int32 ldb_m, int32 ldb_p, int32 ldb_q, - int32 ldc_c, int32 ldc_t, int32 ldc_r, int32 ldc_s, int32 ldc_k, - int32 pad_h, int32 pad_w){ - int32 rxa[TM] = get_global_range[TM](0); - int32 ryb[TN] = get_global_range[TN](1); - int32 rk[TK] = 0 ... TK; - fp32 C[TM, TN] = 0; - int32 racr[TM] = rxa / CW; - int32 raw_base[TM] = rxa % CW - pad_w; - int32 rac[TM] = racr / CH; - int32 rah_base[TM] = racr % CH - pad_h; - fp32* pa_base[TM, TK] = a + rac[:, newaxis]*lda_c; - fp32* pb_base[TN, TK] = b + ryb[:, newaxis]*ldb_k; - for(int32 k = K; k > 0; k = k - TK){ - int32 rknp[TK] = rk / BW; - int32 rkq[TK] = rk % BW; - int32 rkn[TK] = rknp / BH; - int32 rkp[TK] = rknp % BH; - int32 rah[TM, TK] = rah_base[:, newaxis] + rkp[newaxis, :]; - int32 raw[TM, TK] = raw_base[:, newaxis] + rkq[newaxis, :]; - int1 checka[TM, TK] = (rah >= 0) && (rah < AH) && (raw >= 0) && (raw < AW); - fp32* pa[TM, TK] = pa_base + rah*lda_h + raw*lda_w + rkn*lda_n; - fp32* pb[TN, TK] = pb_base + rkp*ldb_p + rkq*ldb_q + rkn*ldb_n; - fp32 A[TM, TK] = checka ? *pa : 0; - fp32 B[TN, TK] = *pb; - C = dot(A, trans(B), C); - rk = rk + TK; - } - int32 rxc[TM] = get_global_range[TM](0); - int32 ryc[TN] = get_global_range[TN](1); - int32 rccr[TM] = rxc / CW; - int32 rcs[TM] = rxa % CW; - int32 rcc[TM] = racr / CH; - int32 rcr[TM] = racr % CH; - int32 rc0[TM] = rcc*ldc_c + rcr*ldc_r + rcs*ldc_s; - fp32* pc[TM, TN] = c + rc0[:, newaxis] + ryc[newaxis, :]*ldc_k; - int1 checkc0[TM] = rxc < M; - int1 checkc1[TN] = ryc < N; - int1 checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :]; - @checkc *pc = C; - })"; - return res; - } - - std::string src() { - if(ty_ == FPROP || ty_ == BPROP) - return xprop(); - else - return wgrad(); - } - template void cpu_xprop(OUT_DTYPE* C, IN_DTYPE* A, IN_DTYPE* B) { @@ -552,9 +597,20 @@ private: // memory stride for C std::vector shapes_c_; std::vector ld_c_; + // constant memory + std::vector h_a_deltas_; + std::vector h_b_deltas_; + std::vector h_masks_; + driver::buffer* d_a_deltas_; + driver::buffer* d_b_deltas_; + driver::buffer* d_masks_; + bool is_a_deltas_cst; + bool is_b_deltas_cst_; + bool is_mask_cst_; // type type ty_; - bool is_bprop_; + bool b_trans_; + bool b_lut_; }; }