From 54f888a27030b5f0afe02b5e5e6c81e1b0a76f6e Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Wed, 8 May 2019 10:09:30 -0400 Subject: [PATCH] [dnn/conv] some minor fixes --- examples/cpp/conv.cpp | 74 ++------- include/triton/dnn/conv.h | 340 ++++++++++++++++++++++++++------------ 2 files changed, 247 insertions(+), 167 deletions(-) diff --git a/examples/cpp/conv.cpp b/examples/cpp/conv.cpp index 0af076612..4d55babe7 100644 --- a/examples/cpp/conv.cpp +++ b/examples/cpp/conv.cpp @@ -9,22 +9,19 @@ int main() { // initialize default compute device auto context = triton::driver::backend::contexts::get_default(); - // initialize just-in-time compiler triton::jit jit(context); + triton::dnn::conv::type ty = triton::dnn::conv::BPROP; // initialization int32_t B = 4, NF = 32; int32_t D = 1, H = 24, W = 240; - int32_t NC = 64, T = 1, R = 3, S = 3; + int32_t NC = 32, T = 1, R = 3, S = 3; int32_t pad_d = 0, pad_h = 1, pad_w = 1; int32_t stride_d = 1, stride_h = 1, stride_w = 1; int32_t upsample_d = 1, upsample_h = 1, upsample_w = 1; int32_t RD = (D*upsample_d - T + 1 + 2*pad_d + stride_d - 1)/stride_d; int32_t RH = (H*upsample_h - R + 1 + 2*pad_h + stride_h - 1)/stride_h; int32_t RW = (W*upsample_w - S + 1 + 2*pad_w + stride_w - 1)/stride_w; - // equivalent matmul dimensions - int32_t M = B*RD*RH*RW; - int32_t N = NF; - int32_t K = NC*T*R*S; + triton::dnn::conv configuration(B, NC, H, W, R, S, NF, 1, 1, pad_h, pad_w, ty); // convolution configuration std::vector hc(B*RH*RW*NF); std::vector rc(B*RH*RW*NF); @@ -36,7 +33,8 @@ int main() { for(size_t i = 0; i < hb.size(); i++) hb[i] = (float)rand()/RAND_MAX; for(size_t i = 0; i < hc.size(); i++) - hc[i] = 0; + hc[i] = (float)rand()/RAND_MAX; + rc = hc; triton::driver::buffer* dc = triton::driver::buffer::create(context, hc.size()*4); triton::driver::buffer* da = triton::driver::buffer::create(context, ha.size()*4); triton::driver::buffer* db = triton::driver::buffer::create(context, hb.size()*4); @@ -45,80 +43,38 @@ int main() { stream->write(db, true, 0, hb); stream->write(dc, true, 0, hc); stream->synchronize(); - // memory strides for data - int32_t stride_i_w = 1; - int32_t stride_i_h = W*stride_i_w; - int32_t stride_i_d = H*stride_i_h; - int32_t stride_i_c = D*stride_i_d; - int32_t stride_i_n = NC*stride_i_c; - // memory stride for activations - int32_t stride_o_q = 1; - int32_t stride_o_p = RW*stride_o_q; - int32_t stride_o_m = RH*stride_o_p; - int32_t stride_o_k = RD*stride_o_m; - int32_t stride_o_n = NF*stride_o_k; // look-up table - triton::dnn::conv configuration(B, NC, H, W, R, S, NF, 1, 1, 0, 0); std::vector h_delta, h_masks; - configuration.build_lut(h_delta, h_masks); + configuration.build_deltas(h_delta); + configuration.build_masks(h_masks); // benchmark a given convolution kernel auto benchmark = [&](triton::driver::kernel* kernel, triton::jit::launch_information info) { - // launch info unsigned TM = info.global_range_size[0]; unsigned TN = info.global_range_size[1]; - // initialize constant memory + unsigned nthreads = info.num_threads; + std::array grid = configuration.get_grid(TM, TN); triton::driver::buffer* delta = jit.get_buffer("delta"); triton::driver::buffer* masks = jit.get_buffer("masks"); stream->write(delta, false, 0, h_delta.size()*4, h_delta.data()); stream->write(masks, false, 0, h_masks.size()*4, h_masks.data()); stream->synchronize(); - // launch info - unsigned nthreads = info.num_threads; - std::array grid = {(M + TM - 1)/TM, (N + TN - 1)/TN, 1}; - // set arguments - kernel->setArg(0, da); - kernel->setArg(1, db); - kernel->setArg(2, dc); - kernel->setArg(3, M); - kernel->setArg(4, N); - kernel->setArg(5, K); - kernel->setArg(6, B); - kernel->setArg(7, H); - kernel->setArg(8, W); - kernel->setArg(9, NF); - kernel->setArg(10, RH); - kernel->setArg(11, RW); - kernel->setArg(12, NC); - kernel->setArg(13, R); - kernel->setArg(14, S); - kernel->setArg(15, stride_i_n); - kernel->setArg(16, stride_i_c); - kernel->setArg(17, stride_i_h); - kernel->setArg(18, stride_i_w); - kernel->setArg(19, stride_o_n); - kernel->setArg(20, stride_o_k); - kernel->setArg(21, stride_o_p); - kernel->setArg(22, stride_o_q); - kernel->setArg(23, pad_h); - kernel->setArg(24, pad_w); - // dry run + configuration.set_arg(kernel, da, db, dc); stream->enqueue(kernel, grid, {nthreads, 1, 1}); stream->synchronize(); - // benchmark double ts = bench([&](){stream->enqueue(kernel, grid, {nthreads, 1, 1});}, [&](){ stream->synchronize(); }, *context->device()); - return 2.*M*N*K / ts * 1e-3; + return configuration.get_nflops() / ts * 1e-3; }; - std::string src = triton::dnn::conv::src(); + std::string src = configuration.src(); // jit.autotune("conv", src.c_str(), benchmark); - jit.add_module("conv", src.c_str(), triton::dnn::conv::default_params()); + jit.add_module("conv", src.c_str(), configuration.default_params()); triton::driver::kernel* kernel = jit.get_function("conv"); triton::jit::launch_information info = jit.get_launch_info("conv"); std::cout << "Performance: " << benchmark(kernel, info) << " TFLOPS " << std::endl; stream->read(dc, true, 0, hc); - cpp_conv_nchw(NC, B, NF, D, H, W, T, R, S, pad_d, pad_h, pad_w, stride_d, stride_h, stride_w, RD, RH, RW, rc, ha, hb); - for(size_t i = 0; i < M*N; i++) + configuration.cpu_ref(rc.data(), ha.data(), hb.data()); + for(size_t i = 0; i < hc.size(); i++) if(std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-4){ std::cout << i << " " << hc[i] << " " << rc[i] << std::endl; exit(EXIT_FAILURE); diff --git a/include/triton/dnn/conv.h b/include/triton/dnn/conv.h index e3fb91d43..11e222f3f 100644 --- a/include/triton/dnn/conv.h +++ b/include/triton/dnn/conv.h @@ -1,5 +1,7 @@ #include #include +#include "triton/driver/stream.h" +#include "triton/driver/kernel.h" namespace triton{ namespace dnn{ @@ -15,10 +17,13 @@ public: conv(int B, int NC, int H, int W, int R, int S, int NF, int upsample_h, int upsample_w, - int pad_h, int pad_w) + int pad_h, int pad_w, + type ty = FPROP) : B_(B), NC_(NC), D_(1), H_(H), W_(W), T_(1), R_(R), S_(S), NF_(NF), upsample_d_(1), upsample_h_(upsample_h), upsample_w_(upsample_w), - pad_d_(0), pad_h_(pad_h), pad_w_(pad_w) + stride_d_(1), stride_h_(1), stride_w_(1), + pad_d_(0), pad_h_(pad_h), pad_w_(pad_w), + ty_(ty) { RD_ = (D_*upsample_d_ - T_ + 1 + 2*pad_d_ + stride_d_ - 1)/stride_d_; RH_ = (H_*upsample_h_ - R_ + 1 + 2*pad_h_ + stride_h_ - 1)/stride_h_; @@ -26,9 +31,6 @@ public: M_ = B*RD_*RH_*RW_; N_ = NF; K_ = NC*T_*R_*S_; - Fs_ = T_*R_*S_; - TK_ = 8; - Luts_ = (TK_ + Fs_ - 1) / Fs_ * Fs_; // memory strides for data stride_a_w_ = 1; stride_a_h_ = W_*stride_a_w_; @@ -41,88 +43,160 @@ public: stride_c_m_ = RH_*stride_c_p_; stride_c_k_ = RD_*stride_c_m_; stride_c_n_ = NF_*stride_c_k_; + // swap a and c for bprop + if(ty_ == BPROP){ + std::swap(stride_a_n_, stride_c_n_); + std::swap(stride_a_c_, stride_c_k_); + std::swap(stride_a_h_, stride_c_p_); + std::swap(stride_a_w_, stride_c_q_); + std::swap(D_, RD_); + std::swap(H_, RH_); + std::swap(W_, RW_); + pad_d_ = (RD_ - D_ + T_ - 1) / 2; + pad_h_ = (RH_ - H_ + R_ - 1) / 2; + pad_w_ = (RW_ - W_ + S_ - 1) / 2; + } + // look-up table info + Fs_ = T_*R_*S_; + TK_ = 8; + Luts_ = (TK_ + Fs_ - 1) / Fs_ * Fs_; } - - void build_lut(std::vector& delta, std::vector& masks) { - delta.resize(Luts_ + upsample_d_*upsample_h_*upsample_w_*Luts_); - masks.resize(Luts_ + (2*pad_h_+1)*(2*pad_w_+1)*(2*pad_d_+1)*Luts_); - - /* unpack index wrt filters */ - auto unpack = [&](int32_t trs){ - int32_t tr = trs / S_; - int32_t s = trs - tr*S_; - int32_t t = tr / R_; - int32_t r = tr - t*R_; - return std::make_tuple(t, r, s); - }; - /* increments */ - for(size_t i = 0; i < Luts_; ++i) - delta[i] = (((i + TK_) % Luts_) - i); - /* deltas */ - size_t Ds0 = Luts_; - size_t Ds1 = upsample_w_; - size_t Ds2 = upsample_h_; - size_t Ds3 = upsample_d_; - for(size_t pd = 0; pd < Ds3; ++pd) - for(size_t ph = 0; ph < Ds2; ++ph) - for(size_t pw = 0; pw < Ds1; ++pw){ - int32_t* deltas_ptr = &delta[Luts_ + pw*Ds0 + ph*Ds0*Ds1 + pd*Ds0*Ds1*Ds2]; - // cumulative increments - for(size_t i = 0; i < Ds0; ++i){ - int32_t ctrs = i; - int32_t c = ctrs / Fs_; - int32_t t, r, s; - std::tie(t, r, s) = unpack(ctrs % Fs_); - // next indices - int32_t nextctrs = ctrs + TK_; - int32_t nextc = nextctrs / Fs_; - int32_t nextt, nextr, nexts; - std::tie(nextt, nextr, nexts) = unpack(nextctrs % Fs_); - // diffs - int32_t cdiff = nextc - c; - int32_t tdiff = (nextt + pd)/upsample_d_ - (t + pd)/upsample_d_; - int32_t rdiff = (nextr + ph)/upsample_h_ - (r + ph)/upsample_h_; - int32_t sdiff = (nexts + pw)/upsample_w_ - (s + pw)/upsample_w_; - // delta pointers - deltas_ptr[i] = cdiff*stride_a_c_ + sdiff*stride_a_w_ + rdiff*stride_a_h_ + tdiff*stride_a_d_; - } + void build_deltas(std::vector& deltas){ + deltas.resize(Luts_ + upsample_d_*upsample_h_*upsample_w_*Luts_); + auto unpack = [&](int32_t trs){ + int32_t tr = trs / S_; + int32_t s = trs - tr*S_; + int32_t t = tr / R_; + int32_t r = tr - t*R_; + return std::make_tuple(t, r, s); + }; + for(size_t i = 0; i < Luts_; ++i) + deltas[i] = (((i + TK_) % Luts_) - i); + size_t Ds0 = Luts_; + size_t Ds1 = upsample_w_; + size_t Ds2 = upsample_h_; + size_t Ds3 = upsample_d_; + for(size_t pd = 0; pd < Ds3; ++pd) + for(size_t ph = 0; ph < Ds2; ++ph) + for(size_t pw = 0; pw < Ds1; ++pw){ + int32_t* deltas_ptr = &deltas[Luts_ + pw*Ds0 + ph*Ds0*Ds1 + pd*Ds0*Ds1*Ds2]; + // cumulative increments + for(size_t i = 0; i < Ds0; ++i){ + int32_t ctrs = i; + int32_t c = ctrs / Fs_; + int32_t t, r, s; + std::tie(t, r, s) = unpack(ctrs % Fs_); + // next indices + int32_t nextctrs = ctrs + TK_; + int32_t nextc = nextctrs / Fs_; + int32_t nextt, nextr, nexts; + std::tie(nextt, nextr, nexts) = unpack(nextctrs % Fs_); + // diffs + int32_t cdiff = nextc - c; + int32_t tdiff = (nextt + pd)/upsample_d_ - (t + pd)/upsample_d_; + int32_t rdiff = (nextr + ph)/upsample_h_ - (r + ph)/upsample_h_; + int32_t sdiff = (nexts + pw)/upsample_w_ - (s + pw)/upsample_w_; + // delta pointers + deltas_ptr[i] = cdiff*stride_a_c_ + sdiff*stride_a_w_ + rdiff*stride_a_h_ + tdiff*stride_a_d_; } + } + } - /* Masks */ - size_t Ms0 = Luts_; - size_t Ms1 = 2*pad_w_ + 1; - size_t Ms2 = 2*pad_h_ + 1; - size_t Ms3 = 2*pad_d_ + 1; - for(size_t pd = 0; pd < Ms3; ++pd) - for(size_t ph = 0; ph < Ms2; ++ph) - for(size_t pw = 0; pw < Ms1; ++pw){ - int32_t* masks_ptr = &masks[Luts_ + pw*Ms0 + ph*Ms0*Ms1 + pd*Ms0*Ms1*Ms2]; - for(size_t i = 0; i < Ms0; ++i){ - int32_t t, r, s; - int32_t mask = 0x0; - for(size_t j = 0; j < TK_; ++j){ - std::tie(t, r, s) = unpack((i + j) % Fs_); - bool in_bounds_d = (t + pd) >= pad_d_ && (t + pd) < (T_ + pad_d_); - bool in_bounds_h = (r + ph) >= pad_h_ && (r + ph) < (R_ + pad_h_); - bool in_bounds_w = (s + pw) >= pad_w_ && (s + pw) < (S_ + pad_w_); - mask |= (in_bounds_d && in_bounds_h && in_bounds_w) << j; - } - masks_ptr[i] = mask; - } + void build_masks(std::vector& masks){ + masks.resize(Luts_ + (2*pad_h_+1)*(2*pad_w_+1)*(2*pad_d_+1)*Luts_); + auto unpack = [&](int32_t trs){ + int32_t tr = trs / S_; + int32_t s = trs - tr*S_; + int32_t t = tr / R_; + int32_t r = tr - t*R_; + return std::make_tuple(t, r, s); + }; + size_t Ms0 = Luts_; + size_t Ms1 = 2*pad_w_ + 1; + size_t Ms2 = 2*pad_h_ + 1; + size_t Ms3 = 2*pad_d_ + 1; + for(size_t pd = 0; pd < Ms3; ++pd) + for(size_t ph = 0; ph < Ms2; ++ph) + for(size_t pw = 0; pw < Ms1; ++pw){ + int32_t* masks_ptr = &masks[Luts_ + pw*Ms0 + ph*Ms0*Ms1 + pd*Ms0*Ms1*Ms2]; + for(size_t i = 0; i < Ms0; ++i){ + int32_t t, r, s; + int32_t mask = 0x0; + for(size_t j = 0; j < TK_; ++j){ + std::tie(t, r, s) = unpack((i + j) % Fs_); + bool in_bounds_d = (t + pd) >= pad_d_ && (t + pd) < (T_ + pad_d_); + bool in_bounds_h = (r + ph) >= pad_h_ && (r + ph) < (R_ + pad_h_); + bool in_bounds_w = (s + pw) >= pad_w_ && (s + pw) < (S_ + pad_w_); + mask |= (in_bounds_d && in_bounds_h && in_bounds_w) << j; + } + masks_ptr[i] = mask; } - for(size_t i = 0; i < Luts_; ++i) - masks[i] = 0x0; - + } + for(size_t i = 0; i < Luts_; ++i) + masks[i] = 0x0; } - static std::vector default_params() { - return {16, 2, 64, 32, 2, 64, 16, 8, 2, 2, 8, 1, 8, 4 }; + std::array get_grid(size_t TM, size_t TN){ + return {(M_ + TM - 1)/TM, (N_ + TN - 1)/TN, 1}; + } + + size_t get_nflops(){ + return 2.*M_*N_*K_; + } + + void set_arg(driver::kernel *kernel, + driver::buffer *a, driver::buffer *b, driver::buffer *c) + { + + if(ty_ == BPROP) + std::swap(a, c); + kernel->setArg(0, a); + kernel->setArg(1, b); + kernel->setArg(2, c); + kernel->setArg(3, M_); + kernel->setArg(4, N_); + kernel->setArg(5, K_); + kernel->setArg(6, B_); + kernel->setArg(7, H_); + kernel->setArg(8, W_); + kernel->setArg(9, NF_); + kernel->setArg(10, RH_); + kernel->setArg(11, RW_); + kernel->setArg(12, NC_); + kernel->setArg(13, R_); + kernel->setArg(14, S_); + kernel->setArg(15, stride_a_n_); + kernel->setArg(16, stride_a_c_); + kernel->setArg(17, stride_a_h_); + kernel->setArg(18, stride_a_w_); + kernel->setArg(19, stride_c_n_); + kernel->setArg(20, stride_c_k_); + kernel->setArg(21, stride_c_p_); + kernel->setArg(22, stride_c_q_); + kernel->setArg(23, pad_h_); + kernel->setArg(24, pad_w_); + } + + std::vector default_params() { + if(ty_ == FPROP) + return {16, 2, 64, 32, 2, 64, 16, 8, 2, 2, 8, 1, 8, 4}; + else + return {16, 2, 64, 16, 32, 16, 4, 2, 2, 4, 2, 8, 4, 2}; } - static std::string src(type ty = FPROP) { - + std::string src() { + std::string bs0 = "TN", bs1 = "TK"; + std::string ldb0 = "*NF", ldb1 = ""; + std::string bcb0 = "[:, newaxis]", bcb1 = "[newaxis, :]"; + std::string b = "b"; + if(ty_ == BPROP){ + std::swap(bs0, bs1); + std::swap(ldb0, ldb1); + std::swap(bcb0, bcb1); + b = "trans(b)"; + } std::string res = R"( const tunable int32 TM = {16, 32, 64}; @@ -158,7 +232,7 @@ public: int32 rar[TK] = racr % R; int32 ra1[TK] = rac*lda_c + rar*lda_h + ras*lda_w; fp32* pa[TM, TK] = a + ra1[newaxis, :] + ra0[:, newaxis]; - fp32* pb[TN, TK] = b + rb1[newaxis, :]*NF + rb0[:, newaxis]; + fp32* pb[)" + bs0 + ", " + bs1 + R"(] = b + rb1)" + bcb1 + ldb0 + " + rb0" + bcb0 + ldb1 + R"(; __constant__ int32* pincd[TK] = delta + rka; __constant__ int32* pd[TK] = delta + R*S + rka; int32 d[TK] = *pd; @@ -172,10 +246,10 @@ public: int32 checka1[TK] = 1 << rka; int1 checka[TM, TK] = (checka0[:, newaxis] & checka1[newaxis, :]) > 0; fp32 a[TM, TK] = checka ? *pa : 0; - fp32 b[TN, TK] = *pb; + fp32 b[)" + bs0 + ", " + bs1 + R"(] = *pb; for(int32 k = K; k > 0; k = k - TK){ C = dot(a, trans(b), C); - pb = pb + TK*NF; + pb = pb + TK)" + ldb0 + R"(; pa = pa + d[newaxis, :]; b = *pb; pd = pd + incd; @@ -203,42 +277,90 @@ public: return res; } + template + void cpu_ref(OUT_DTYPE* C, IN_DTYPE* A, IN_DTYPE* B) + { + auto idx = [&](int32_t x, int32_t y, int32_t z, int32_t w, int32_t u, + int32_t /*s0*/, int32_t s1, int32_t s2, int32_t s3, int32_t s4) + { return u + w*s4 + z*s4*s3 + y*s4*s3*s2 + x*s4*s3*s2*s1; }; + + if(ty_==BPROP){ + std::swap(A, C); + } + std::cout << A[0] << std::endl; + IN_DTYPE accs[1]; + float tmp[1]; + for(int32_t m = 0 ; m < RD_; ++m) + for(int32_t p = 0 ; p < RH_; ++p) + for(int32_t q = 0; q < RW_; ++q) + for(int32_t n = 0; n < B_; ++n) + for(int32_t k = 0; k < NF_ ; ++k) + { + for(int32_t i = 0; i < 1; ++i) + accs[i] = 0; + int32_t mm = m*stride_d_ - pad_d_; + int32_t pp = p*stride_h_ - pad_h_; + int32_t qq = q*stride_w_ - pad_w_; + for(int32_t kk = 0; kk < 1; ++kk) + for(int32_t c = 0; c < NC_; ++c) + for(int32_t t = 0; t < T_; ++t) + for(int32_t r = 0; r < R_; ++r) + for(int32_t s = 0; s < S_; ++s){ + int32_t d = mm + t; + int32_t h = pp + r; + int32_t w = qq + s; + bool in_bounds = (d >= 0 && h >= 0 && w >= 0 && d < D_ && h < H_ && w < W_); + IN_DTYPE a = in_bounds?A[idx(n, c, d, h, w, B_, NC_, D_, H_, W_)]:0; + IN_DTYPE b; + if(ty_==FPROP) + b = B[idx(c, t, r, s, k*1 + kk, NC_, T_, R_, S_, NF_*1)]; + else + b = B[idx(c, t, s, r, k*1 + kk, NC_, T_, R_, S_, NF_*1)]; + accs[kk] = std::fma(a, b, accs[kk]); + } + for(int32_t kk = 0; kk < 1; ++kk){ + tmp[kk] = accs[kk]; + } + C[idx(n, k, m, p, q, B_, NF_, RD_, RH_, RW_)] = tmp[0]; + } + } + private: // image size - int B_; - int NC_; - int D_; - int H_; - int W_; + int32_t B_; + int32_t NC_; + int32_t D_; + int32_t H_; + int32_t W_; // filter size - int T_; - int R_; - int S_; - int NF_; + int32_t T_; + int32_t R_; + int32_t S_; + int32_t NF_; // activation size - int RD_; - int RH_; - int RW_; + int32_t RD_; + int32_t RH_; + int32_t RW_; // upsampling - int upsample_d_; - int upsample_h_; - int upsample_w_; + int32_t upsample_d_; + int32_t upsample_h_; + int32_t upsample_w_; // padding - int pad_d_; - int pad_h_; - int pad_w_; + int32_t pad_d_; + int32_t pad_h_; + int32_t pad_w_; // striding - int stride_d_; - int stride_h_; - int stride_w_; + int32_t stride_d_; + int32_t stride_h_; + int32_t stride_w_; // equivalent matmul - int M_; - int N_; - int K_; + int32_t M_; + int32_t N_; + int32_t K_; // helpers - int Fs_; - int TK_; - int Luts_; + int32_t Fs_; + int32_t TK_; + int32_t Luts_; // memory strides for data int32_t stride_a_w_; int32_t stride_a_h_; @@ -251,7 +373,9 @@ private: int32_t stride_c_m_; int32_t stride_c_k_; int32_t stride_c_n_; - + // type + type ty_; + bool is_bprop_; }; }