diff --git a/examples/cpp/shift.cpp b/examples/cpp/shift.cpp index b244e8ec2..ed949c74c 100644 --- a/examples/cpp/shift.cpp +++ b/examples/cpp/shift.cpp @@ -1,9 +1,11 @@ #include #include +#include #include "triton/runtime/jit.h" #include "triton/driver/backend.h" #include "triton/driver/stream.h" #include "triton/tools/bench.hpp" +#include "triton/dnn/shift.h" // input layout: C, H, W, BS // filter layout: C, K @@ -36,96 +38,6 @@ void shift_conv(int32_t C, int32_t H, int32_t W, int32_t BS, } } -// K = channels -// M = batch * height * width -// N = number of feature maps - -const char* src = -R"( -const tunable int32 TM = {16, 32, 64, 128}; -const tunable int32 TN = {16, 32, 64, 128}; -const tunable int32 TK = {8}; - -__constant__ int32* delta = alloc_const int32[256]; -__constant__ int32* masks = alloc_const int32[8192]; - -void shift(restrict read_only fp32 *a, restrict read_only fp32 *b, fp32 *c, - int32 M, int32 N, int32 K, - int32 ABS, int32 AH, int32 AW, int32 AR, int32 AS){ - int32 rxa[TM] = get_global_range[TM](0); - int32 ryb[TN] = get_global_range[TN](1); - int32 rka[TK] = 0 ... TK; - int32 rkb[TK] = 0 ... TK; - fp32 C[TM, TN] = 0; - fp32* pxa[TM, TK] = a + rxa[:, newaxis]; - fp32* pb[TN, TK] = b + rkb[newaxis, :]*N + ryb[:, newaxis]; - __constant__ int32* pd[TK] = delta + rka; - int32 pad_h = AR/2; - int32 pad_w = AS/2; - int32 rawhc[TM] = rxa / ABS; - int32 raw[TM] = rawhc % AW - pad_w; - int32 rahc[TM] = rawhc / AW; - int32 rah[TM] = rahc % AH - pad_h; - int32 maskh[TM] = pad_h + min(rah, 0) + max(rah + AR - AH, 0); - int32 maskw[TM] = pad_w + min(raw, 0) + max(raw + AS - AW, 0); - __constant__ int32* pxm[TM] = masks + maskh*K + maskw*K*(2*pad_h + 1); - __constant__ int32* pm[TM, TK] = pxm[:, newaxis] + rka[newaxis, :]; - for(int32 k = K; k > 0; k = k - TK){ - int32 delta[TK] = *pd; - fp32 *pa[TM, TK] = pxa + delta[newaxis, :]; - int1 m[TM, TK] = *pm > 0; - fp32 a[TM, TK] = m ? *pa : 0; - fp32 b[TN, TK] = *pb; - C = dot(a, trans(b), C); - pb = pb + TK*N; - pd = pd + TK; - pm = pm + TK; - } - int32 rxc[TM] = get_global_range[TM](0); - int32 ryc[TN] = get_global_range[TN](1); - fp32* pc[TM, TN] = c + ryc[newaxis, :]*M + rxc[:, newaxis]; - int1 checkc0[TM] = rxc < M; - int1 checkc1[TN] = ryc < N; - int1 checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :]; - @checkc *pc = C; -} -)"; - -std::vector shift_deltas(// strides - int32_t stride_w, int32_t stride_h, int32_t stride_c, - // shift - int32_t C, - const std::vector& shift_h, - const std::vector& shift_w) { - std::vector res(C); - for(unsigned c = 0; c < C; c++){ - res[c] = c*stride_c; - res[c] += shift_h[c]*stride_h; - res[c] += shift_w[c]*stride_w; - } - return res; -} - -std::vector shift_masks(int32_t C, - const std::vector& shift_h, - const std::vector& shift_w, - int32_t R, int32_t S) { - size_t S0 = C; - size_t S1 = R; - size_t S2 = S; - std::vector res(S0*S1*S2); - for(size_t ph = 0; ph < S1; ++ph) - for(size_t pw = 0; pw < S2; ++pw){ - int32_t* ptr = &res[ph*S0 + pw*S0*S1]; - for(size_t i = 0; i < S0; ++i){ - bool in_bounds_h = shift_h[i] + ph >= 0 && shift_h[i] + ph < R; - bool in_bounds_w = shift_w[i] + pw >= 0 && shift_w[i] + pw < S; - ptr[i] = in_bounds_h && in_bounds_w; - } - } - return res; -} - int main() { // initialize default compute device auto context = triton::driver::backend::contexts::get_default(); @@ -136,20 +48,6 @@ int main() { int32_t BS = 4, F = 128; int32_t H = 32, W = 32; int32_t C = 128; - // equivalent matmul dimensions - int32_t M = BS*H*W; - int32_t N = F; - int32_t K = C; - std::cout << M << " " << N << " " << K << std::endl; - std::vector hc(BS*H*W*F); - std::vector rc(BS*H*W*F); - std::vector ha(BS*C*H*W); - std::vector hb(F*C); - // strides - int32_t stride_i_bs = 1; - int32_t stride_i_w = BS*stride_i_bs; - int32_t stride_i_h = W*stride_i_w; - int32_t stride_i_c = H*stride_i_h; // random shifts std::vector shift_h(C); std::vector shift_w(C); @@ -157,83 +55,63 @@ int main() { shift_h[c] = rand() % R - R/2; shift_w[c] = rand() % S - S/2; } - // initialize buffers - srand(0); - for(int c = 0 ; c < C; c++) - for(int h = 0 ; h < H; h++) - for(int w = 0 ; w < W; w++) - for(int bs = 0 ; bs < BS; bs++){ - float value = (float)rand() / RAND_MAX; - size_t idx = bs + w*stride_i_w + h*stride_i_h + c*stride_i_c; - ha[idx] = value; - } - for(size_t i = 0; i < hb.size(); i++) - hb[i] = (float)rand() / RAND_MAX; - for(size_t i = 0; i < hc.size(); i++) - hc[i] = 0; + // configuration + triton::dnn::shift shift(BS, C, 1, H, W, 1, R, S, F, shift_h, shift_w); + // host buffers + std::vector hc(shift.c_size()); + std::vector rc(shift.c_size()); + std::vector ha(shift.a_size()); + std::vector hb(shift.b_size()); + // device buffers triton::driver::buffer* dc = triton::driver::buffer::create(context, hc.size()*4); triton::driver::buffer* da = triton::driver::buffer::create(context, ha.size()*4); triton::driver::buffer* db = triton::driver::buffer::create(context, hb.size()*4); triton::driver::stream* stream = triton::driver::stream::create(context); + // initialize host + srand(0); + for(size_t i = 0; i < ha.size(); i++) + ha[i] = (float)rand() / RAND_MAX; + for(size_t i = 0; i < hb.size(); i++) + hb[i] = (float)rand() / RAND_MAX; + for(size_t i = 0; i < hc.size(); i++) + hc[i] = 0; + // initialize device stream->write(da, true, 0, ha); stream->write(db, true, 0, hb); stream->write(dc, true, 0, hc); stream->synchronize(); - std::vector h_delta = shift_deltas(stride_i_w, stride_i_h, stride_i_c, C, shift_h, shift_w); - std::vector h_masks = shift_masks(C, shift_h, shift_w, R, S); - // benchmark a given matrix multiplication kernel + // benchmark auto benchmark = [&](triton::driver::kernel* kernel, triton::jit::launch_information info) { + shift.init(stream, (triton::driver::cu_module*)kernel->module()); // launch info unsigned TM = info.global_range_size[0]; unsigned TN = info.global_range_size[1]; unsigned nthreads = info.num_threads; - // initialize constant memory - triton::driver::buffer* delta = ((triton::driver::cu_module*)kernel->module())->symbol("delta"); - triton::driver::buffer* masks = ((triton::driver::cu_module*)kernel->module())->symbol("masks"); - stream->write(delta, false, 0, h_delta.size()*4, h_delta.data()); - stream->write(masks, false, 0, h_masks.size()*4, h_masks.data()); - stream->synchronize(); // set argument - kernel->setArg(0, da); - kernel->setArg(1, db); - kernel->setArg(2, dc); - kernel->setArg(3, M); - kernel->setArg(4, N); - kernel->setArg(5, K); - kernel->setArg(6, BS); - kernel->setArg(7, H); - kernel->setArg(8, W); - kernel->setArg(9, R); - kernel->setArg(10, S); - // dry run - std::array grid = {(M + TM - 1)/TM, (N + TN - 1)/TN, 1}; - stream->enqueue(kernel, grid, {nthreads, 1, 1}); + shift.enqueue(stream, kernel, da, db, dc, TM, TN, nthreads); stream->synchronize(); // benchmark - double ts = triton::tools::bench([&](){stream->enqueue(kernel, grid, {nthreads, 1, 1});}, + double ts = triton::tools::bench([&](){shift.enqueue(stream, kernel, da, db, dc, TM, TN, nthreads);}, [&](){ stream->synchronize(); }, context->device()); - ts = ts * 1e-9; - double tflops = 2.*M*N*K / ts * 1e-12; - return tflops; + return shift.get_nflops() / ts * 1e-3; }; // shift std::vector params = { - 16, 2, 64, - 32, 2, 64, - 16, 8, 2, 2, - 8, 8, - 4 + 8, 2, 16, 8, 2, 32, 8, 4, 2, 2, 4, 2, 8, 4 }; - jit.autotune("shift", src, benchmark); - jit.add_module("shift", src, params); + std::ostringstream oss; + shift.src(oss); + std::string src = oss.str(); +// jit.autotune("shift", src.c_str(), benchmark); + jit.add_module("shift", src.c_str(), params); triton::driver::kernel* kernel = jit.get_function("shift"); triton::jit::launch_information info = jit.get_launch_info("shift"); std::cout << "Performance: " << benchmark(kernel, info) << " TFLOPS " << std::endl; stream->read(dc, true, 0, hc); - shift_conv(C, H, W, BS, F, rc, ha, hb, shift_h, shift_w); - for(size_t i = 0; i < M*N; i++) + shift.cpu_ref(rc.data(), ha.data(), hb.data()); + for(size_t i = 0; i < hc.size(); i++) if(std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-4){ std::cout << i << " " << hc[i] << " " << rc[i] << std::endl; exit(EXIT_FAILURE); diff --git a/include/triton/dnn/shift.h b/include/triton/dnn/shift.h new file mode 100644 index 000000000..6d6bda9de --- /dev/null +++ b/include/triton/dnn/shift.h @@ -0,0 +1,151 @@ +/* Copyright 2015-2017 Philippe Tillet +* +* Permission is hereby granted, free of charge, to any person obtaining +* a copy of this software and associated documentation files +* (the "Software"), to deal in the Software without restriction, +* including without limitation the rights to use, copy, modify, merge, +* publish, distribute, sublicense, and/or sell copies of the Software, +* and to permit persons to whom the Software is furnished to do so, +* subject to the following conditions: +* +* The above copyright notice and this permission notice shall be +* included in all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +*/ + +#ifndef TDL_INCLUDE_DNN_SHIFT_H +#define TDL_INCLUDE_DNN_SHIFT_H + +#include +#include +#include +#include +#include +#include "triton/driver/stream.h" +#include "triton/driver/kernel.h" + +namespace triton{ +namespace dnn{ + +class shift { + +public: + enum type { + FPROP + }; + +private: + void set_ld(const std::vector& shapes, + std::vector& ld); + +public: + + shift(int B, int NC, + int D, int H, int W, + int T, int R, int S, int NF, + const std::vector &shift_h, const std::vector &shift_w, + std::string a_ty = "fp32", std::string b_ty = "fp32", + type ty = FPROP, bool bias = false); + + // look-up table + void build_deltas(); + void build_masks(); + + // accessors + size_t a_size(); + size_t b_size(); + size_t c_size(); + std::vector c_shapes(); + + // device function + void init(driver::stream *stream, driver::cu_module *module); + void enqueue(driver::stream *stream, driver::kernel *kernel, + driver::buffer *a, driver::buffer *b, driver::buffer *c, + size_t TM, size_t TN, size_t nthreads); + + // utils + size_t get_nflops(); + + // source + void src(std::ostream &os); + + // cpu_ref + template + void cpu_ref(OUT_DTYPE* O, + const IN_DTYPE* I, + const IN_DTYPE* F) + { + OUT_DTYPE acc; + for(int32_t p = 0; p < AH_; ++p) + for(int32_t q = 0; q < AW_; ++q) + for(int32_t bs = 0; bs < NB_; ++bs) + for(int32_t k = 0; k < NF_; ++k) + { + acc = 0; + for(int32_t c = 0; c < NC_; ++c){ + int32_t h = p + shift_h_[c]; + int32_t w = q + shift_w_[c]; + bool in_bounds = (h >= 0 && w >= 0 && h < AH_ && w < AW_); + IN_DTYPE a = in_bounds?I[bs + w*NB_ + h*NB_*AW_ + c*NB_*AH_*AW_]:0; + IN_DTYPE b = F[k + c*NF_]; + acc = std::fma(a, b, acc); + } + O[bs + q*NB_ + p*NB_*AW_ + k*NB_*AH_*AW_] = acc; + } + } + +private: + // image size + int32_t NB_; + int32_t NC_; + int32_t AD_; + int32_t AH_; + int32_t AW_; + // filter size + int32_t BD_; + int32_t BH_; + int32_t BW_; + int32_t NF_; + // activation size + int32_t CD_; + int32_t CH_; + int32_t CW_; + // equivalent matmul + int32_t M_; + int32_t N_; + int32_t K_; + // shapes + std::vector shapes_a_; + std::vector shapes_b_; + std::vector shapes_c_; + // memory strides + std::vector ld_a_; + std::vector ld_b_; + std::vector ld_c_; + // shift values + std::vector shift_h_; + std::vector shift_w_; + // look-up tables + std::vector h_deltas_; + std::vector h_masks_; + driver::buffer* d_deltas_; + driver::buffer* d_masks_; + // data types + std::string a_ty_; + std::string b_ty_; + // convolution type + type ty_; + bool bias_; +}; + +} +} + +#endif diff --git a/lib/dnn/shift.cpp b/lib/dnn/shift.cpp new file mode 100644 index 000000000..d07809f1f --- /dev/null +++ b/lib/dnn/shift.cpp @@ -0,0 +1,176 @@ +#include "triton/dnn/shift.h" + + +namespace triton{ +namespace dnn{ + +void shift::set_ld(const std::vector& shapes, + std::vector& ld) { + size_t size = shapes.size(); + ld.resize(size); + ld[4] = 1; + ld[3] = shapes[4]*ld[4]; + ld[2] = shapes[3]*ld[3]; + ld[1] = shapes[2]*ld[2]; + ld[0] = shapes[1]*ld[1]; +} + +shift::shift(int B, int NC, + int D, int H, int W, + int T, int R, int S, + int NF, + const std::vector& shift_h, const std::vector& shift_w, + std::string a_ty, std::string b_ty, + type ty, bool bias) + : NB_(B), NC_(NC), + AD_(D), AH_(H), AW_(W), + BD_(T), BH_(R), BW_(S), + NF_(NF), + shift_h_(shift_h), shift_w_(shift_w), + a_ty_(a_ty), b_ty_(b_ty), + ty_(ty), bias_(bias) { + // equivalent matmul + M_ = NB_*AH_*AW_; + N_ = NF_; + K_ = NC_; + // shapes + // input layout: C, H, W, BS + // filter layout: C, K + // output layout: K, H, W, BS + shapes_a_ = {NC, H, W, B}; + shapes_b_ = {NC, NF}; + shapes_c_ = {NF, H, W, B}; + // memory strides + set_ld(shapes_a_, ld_a_); +} + +void shift::build_deltas() { + h_deltas_.resize(NC_); + for(unsigned c = 0; c < NC_; c++){ + h_deltas_[c] = c*ld_a_[0]; + h_deltas_[c] += shift_h_[c]*ld_a_[1]; + h_deltas_[c] += shift_w_[c]*ld_a_[2]; + } +} + +void shift::build_masks() { + size_t S0 = NC_; + size_t S1 = BH_; + size_t S2 = BW_; + h_masks_.resize(S0*S1*S2); + for(size_t ph = 0; ph < S1; ++ph) + for(size_t pw = 0; pw < S2; ++pw){ + int32_t* ptr = &h_masks_[ph*S0 + pw*S0*S1]; + for(size_t i = 0; i < S0; ++i){ + bool in_bounds_h = shift_h_[i] + ph >= 0 && shift_h_[i] + ph < BH_; + bool in_bounds_w = shift_w_[i] + pw >= 0 && shift_w_[i] + pw < BW_; + ptr[i] = in_bounds_h && in_bounds_w; + } + } +} + +size_t shift::a_size(){ + return std::accumulate(shapes_a_.begin(), shapes_a_.end(), + 1, std::multiplies()); +} + +size_t shift::b_size(){ + return std::accumulate(shapes_b_.begin(), shapes_b_.end(), + 1, std::multiplies()); +} + +size_t shift::c_size(){ + return std::accumulate(shapes_c_.begin(), shapes_c_.end(), + 1, std::multiplies()); +} + +std::vector shift::c_shapes(){ + return shapes_c_; +} + +size_t shift::get_nflops() { + return 2 * M_ * N_ * K_; +} + + +void shift::init(driver::stream *stream, driver::cu_module *module) { + triton::driver::buffer* delta = ((triton::driver::cu_module*)module)->symbol("delta"); + triton::driver::buffer* masks = ((triton::driver::cu_module*)module)->symbol("masks"); + stream->write(delta, false, 0, h_deltas_.size()*4, h_deltas_.data()); + stream->write(masks, false, 0, h_masks_.size()*4, h_masks_.data()); +} + +void shift::enqueue(driver::stream *stream, driver::kernel *kernel, + driver::buffer *a, driver::buffer *b, driver::buffer *c, + size_t TM, size_t TN, size_t nthreads) { + kernel->setArg(0, a); + kernel->setArg(1, b); + kernel->setArg(2, c); + kernel->setArg(3, M_); + kernel->setArg(4, N_); + kernel->setArg(5, K_); + kernel->setArg(6, NB_); + kernel->setArg(7, AH_); + kernel->setArg(8, AW_); + kernel->setArg(9, BH_); + kernel->setArg(10, BW_); + // dry run + std::array grid = {(M_ + TM - 1)/TM, (N_ + TN - 1)/TN, 1}; + stream->enqueue(kernel, grid, {nthreads, 1, 1}); +} + +void shift::src(std::ostream &os) { + os << +R"( +const tunable int32 TM = {16, 32, 64, 128}; +const tunable int32 TN = {16, 32, 64, 128}; +const tunable int32 TK = {8}; + +__constant__ int32* delta = alloc_const int32[256]; +__constant__ int32* masks = alloc_const int32[8192]; + +void shift(restrict read_only fp32 *a, restrict read_only fp32 *b, fp32 *c, + int32 M, int32 N, int32 K, + int32 ABS, int32 AH, int32 AW, int32 AR, int32 AS){ + int32 rxa[TM] = get_global_range[TM](0); + int32 ryb[TN] = get_global_range[TN](1); + int32 rka[TK] = 0 ... TK; + int32 rkb[TK] = 0 ... TK; + fp32 C[TM, TN] = 0; + fp32* pxa[TM, TK] = a + rxa[:, newaxis]; + fp32* pb[TN, TK] = b + rkb[newaxis, :]*N + ryb[:, newaxis]; + __constant__ int32* pd[TK] = delta + rka; + int32 pad_h = AR/2; + int32 pad_w = AS/2; + int32 rawhc[TM] = rxa / ABS; + int32 raw[TM] = rawhc % AW - pad_w; + int32 rahc[TM] = rawhc / AW; + int32 rah[TM] = rahc % AH - pad_h; + int32 maskh[TM] = pad_h + min(rah, 0) + max(rah + AR - AH, 0); + int32 maskw[TM] = pad_w + min(raw, 0) + max(raw + AS - AW, 0); + __constant__ int32* pxm[TM] = masks + maskh*K + maskw*K*(2*pad_h + 1); + __constant__ int32* pm[TM, TK] = pxm[:, newaxis] + rka[newaxis, :]; + for(int32 k = K; k > 0; k = k - TK){ + int32 delta[TK] = *pd; + fp32 *pa[TM, TK] = pxa + delta[newaxis, :]; + int1 m[TM, TK] = *pm > 0; + fp32 a[TM, TK] = m ? *pa : 0; + fp32 b[TN, TK] = *pb; + C = dot(a, trans(b), C); + pb = pb + TK*N; + pd = pd + TK; + pm = pm + TK; + } + int32 rxc[TM] = get_global_range[TM](0); + int32 ryc[TN] = get_global_range[TN](1); + fp32* pc[TM, TN] = c + ryc[newaxis, :]*M + rxc[:, newaxis]; + int1 checkc0[TM] = rxc < M; + int1 checkc1[TN] = ryc < N; + int1 checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :]; + @checkc *pc = C; +} +)"; +} + +} +}