diff --git a/include/triton/dnn/base.h b/include/triton/dnn/base.h deleted file mode 100644 index b991c3726..000000000 --- a/include/triton/dnn/base.h +++ /dev/null @@ -1,116 +0,0 @@ -/* Copyright 2015-2017 Philippe Tillet -* -* Permission is hereby granted, free of charge, to any person obtaining -* a copy of this software and associated documentation files -* (the "Software"), to deal in the Software without restriction, -* including without limitation the rights to use, copy, modify, merge, -* publish, distribute, sublicense, and/or sell copies of the Software, -* and to permit persons to whom the Software is furnished to do so, -* subject to the following conditions: -* -* The above copyright notice and this permission notice shall be -* included in all copies or substantial portions of the Software. -* -* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, -* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF -* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. -* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY -* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, -* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE -* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -*/ - -#ifndef TDL_INCLUDE_DNN_BASE_H -#define TDL_INCLUDE_DNN_BASE_H - -#include "triton/driver/stream.h" -#include "triton/driver/kernel.h" -#include "triton/runtime/launch_info.h" - -namespace triton{ - -namespace runtime{ - class jit; -} - -namespace dnn{ - - -enum autotuning_t{ - FULL_TUNING, - PARTIAL_TUNING, - NO_TUNING -}; - -class base; -struct launch_context_t{ - base *op; - driver::kernel* kernel; - triton::runtime::launch_information info; -}; - -typedef std::vector params_t; - -class base { - friend class recompile_hash; - friend class recompile_equal; - -protected: - // leading dimensions - static void set_ld(const std::vector& shapes, - std::vector& ld); - // list of retuning parameters - virtual std::vector retune_params() const = 0; - -private: - // initialize - virtual void init_impl(driver::stream *, driver::cu_module *, triton::runtime::launch_information) = 0; - // deinitialize - virtual void deinit_impl() = 0; - // enqueue - virtual void enqueue_impl(driver::stream *stream, driver::kernel *kernel, - std::vector args, - triton::runtime::launch_information info) = 0; - // number of flops - virtual size_t num_flops() const = 0; - // default parameters - virtual std::vector search_space() const; - virtual params_t heuristics() const; - // obtain execution jit - std::pair get_profile_impl(driver::stream *stream, std::vector args, autotuning_t autotune); - -public: - // constructor - base(const std::string& name); - // triton-c source - virtual void triton_c_src(std::ostream &os) const = 0; - // clone - virtual base* clone() const = 0; - // enqueue - base* enqueue(driver::stream* stream, std::vector args, autotuning_t autotune = PARTIAL_TUNING); - // get profile - launch_context_t get_launch_context(driver::stream *stream, std::vector args, autotuning_t autotune = PARTIAL_TUNING); - -private: - std::string name_; -}; - - -struct recompile_equal{ - bool operator()(base* x, base* y) const{ - return typeid(*x) == typeid(*y) && - x->retune_params() == y->retune_params(); - } -}; - -struct recompile_hash{ - unsigned operator()(base* x) const{ - return x->retune_params()[0]; - } -}; - - -} -} - -#endif diff --git a/include/triton/dnn/batchnorm.h b/include/triton/dnn/batchnorm.h deleted file mode 100644 index 204ab631b..000000000 --- a/include/triton/dnn/batchnorm.h +++ /dev/null @@ -1,110 +0,0 @@ -/* Copyright 2015-2019 Philippe Tillet -* -* Permission is hereby granted, free of charge, to any person obtaining -* a copy of this software and associated documentation files -* (the "Software"), to deal in the Software without restriction, -* including without limitation the rights to use, copy, modify, merge, -* publish, distribute, sublicense, and/or sell copies of the Software, -* and to permit persons to whom the Software is furnished to do so, -* subject to the following conditions: -* -* The above copyright notice and this permission notice shall be -* included in all copies or substantial portions of the Software. -* -* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, -* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF -* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. -* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY -* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, -* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE -* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -*/ - -#ifndef TDL_INCLUDE_DNN_BATCHNORM_H -#define TDL_INCLUDE_DNN_BATCHNORM_H - -#include -#include -#include -#include -#include -#include "triton/dnn/base.h" -#include "triton/driver/stream.h" -#include "triton/driver/kernel.h" - -namespace triton{ -namespace dnn{ - -class batchnorm_forward: public base { -private: - // init - void init_impl(driver::stream *, driver::cu_module *, triton::runtime::launch_information) { } - void deinit_impl() { } - - // enqueue - void enqueue_impl(driver::stream *stream, driver::kernel *kernel, - std::vector args, - triton::runtime::launch_information info); - // number of flops - size_t num_flops() const; - // retuning parameters - std::vector retune_params() const; - // clone - base* clone() const; - -public: - // constructor - batchnorm_forward(int C, int D, int H, int W, int B, - std::string ty = "float", float eps = 1e-5); - // triton-c source - void triton_c_src(std::ostream &os) const; - -private: - int32_t C_; - int32_t D_; - int32_t H_; - int32_t W_; - int32_t B_; - std::string ty_; - float eps_; - int32_t DHWB_; - float rcpDHWB_; -}; - -class batchnorm_backward: public base{ -private: - // init - void init_impl(driver::stream *, driver::cu_module *, triton::runtime::launch_information) { } - void deinit_impl() { } - // enqueue - void enqueue_impl(driver::stream *stream, driver::kernel *kernel, - std::vector args, - runtime::launch_information info); - // number of flops - size_t num_flops() const; - // retuning parameters - std::vector retune_params() const; - // clone - base* clone() const; - -public: - // constructor - batchnorm_backward(int C, int D, int H, int W, int B, - std::string ty = "float", float eps = 1e-5); - // triton-c source - void triton_c_src(std::ostream &os) const; - -private: - int32_t C_; - int32_t D_; - int32_t H_; - int32_t W_; - int32_t B_; - std::string ty_; - float eps_; -}; - -} -} - -#endif diff --git a/include/triton/dnn/blocksparse/dot.h b/include/triton/dnn/blocksparse/dot.h deleted file mode 100644 index f42d5b9d8..000000000 --- a/include/triton/dnn/blocksparse/dot.h +++ /dev/null @@ -1,61 +0,0 @@ -#include "triton/driver/stream.h" -#include "triton/driver/kernel.h" -#include "triton/dnn/base.h" -#include - -namespace triton{ -namespace dnn{ -namespace blocksparse{ - -enum op_t{ - FPROP, - BPROP, - WGRAD -}; - -class dot: public base { -private: - void enqueue_impl(driver::stream *stream, driver::kernel *kernel, - std::vector args, - triton::runtime::launch_information info); - // number of flops - size_t num_flops() const; - // retuning parameters - std::vector retune_params() const; - // default parameters - std::vector search_space() const; - params_t heuristics() const; - // init - void init_impl(driver::stream *stream, driver::cu_module *module, triton::runtime::launch_information info); - // deinit - void deinit_impl(); - // source - std::string triton_c_src_ydx() const; - std::string triton_c_src_dw() const; -public: - // constructor - dot(int32_t N, int32_t K, int32_t S, int32_t C, const std::string &ty, int32_t BS, int32_t nlocks, int32_t nblocks, op_t op = FPROP); - // triton-c source - void triton_c_src(std::ostream &os) const; - // locks - driver::buffer* get_locks() const; - // clone - base* clone() const; - -private: - std::string ab_ty_; - std::string c_ty_; - int32_t N_; - int32_t S_; - int32_t C_; - int32_t K_; - int32_t BS_; - int32_t nlocks_; - int32_t nblocks_; - std::shared_ptr locks_; - op_t op_; -}; - -} -} -} diff --git a/include/triton/dnn/conv.h b/include/triton/dnn/conv.h deleted file mode 100644 index 5a167531d..000000000 --- a/include/triton/dnn/conv.h +++ /dev/null @@ -1,155 +0,0 @@ -#include -#include -#include -#include -#include "triton/driver/stream.h" -#include "triton/driver/kernel.h" -#include "triton/dnn/base.h" - -namespace triton{ -namespace dnn{ - -class conv: public base{ -public: - enum type { - FPROP, - BPROP, - WGRAD - }; - -private: - // initialize - std::tuple - unpack(int32_t ltrs, bool flip, int32_t EBD, int32_t EBH, int32_t EBW); - void build_b_deltas(); - void build_a_deltas(); - void build_masks(); - void init_impl(driver::stream *stream, driver::cu_module *module, triton::runtime::launch_information info); - void deinit_impl() { } - - // enqueue - std::array get_grid(size_t TM, size_t TN); - void set_arg(driver::kernel *kernel, - driver::buffer *a, driver::buffer *b, driver::buffer *c, - driver::buffer *bias); - void enqueue_impl(driver::stream *stream, driver::kernel *kernel, - std::vector args, - triton::runtime::launch_information info); - // number of flops - size_t num_flops() const; - // retuning parameters - std::vector retune_params() const; - // clone - base* clone() const; - -public: - - conv(int B, int NC, - int D, int H, int W, - int T, int R, int S, int NF, - int stride_d, int stride_h, int stride_w, - int pad_d, int pad_h, int pad_w, - int upsample_d, int upsample_h, int upsample_w, - std::string a_ty = "float", std::string b_ty = "float", - type ty = FPROP, bool bias = false); - - // accessors - size_t a_size(); - size_t b_size(); - size_t c_size(); - std::vector c_shapes(); - // default params - std::vector default_params(); - - // triton-c source code - void triton_c_src(std::ostream &os) const; - - // cpu reference implementations - template - void cpu_xprop(OUT_DTYPE* C, IN_DTYPE* A, IN_DTYPE* B); - template - void cpu_wgrad(OUT_DTYPE* C, IN_DTYPE* A, IN_DTYPE* B); - template - void cpu_ref(OUT_DTYPE* C, IN_DTYPE* A, IN_DTYPE* B); - -private: - // image size - int32_t NB_; - int32_t NC_; - int32_t AD_; - int32_t AH_; - int32_t AW_; - // filter size - int32_t BD_; - int32_t BH_; - int32_t BW_; - int32_t NF_; - // activation size - int32_t CD_; - int32_t CH_; - int32_t CW_; - // striding - int32_t stride_d_; - int32_t stride_h_; - int32_t stride_w_; - // padding - int32_t pad_d_; - int32_t pad_h_; - int32_t pad_w_; - // upsampling - int32_t upsample_d_; - int32_t upsample_h_; - int32_t upsample_w_; - // equivalent matmul - int32_t M_; - int32_t N_; - int32_t K_; - // helpers - int32_t Fs_; - int32_t TK_; - int32_t Luts_; - // memory strides for A - std::vector shapes_a_; - std::vector ld_a_; - // memory strides for B - std::vector shapes_b_; - std::vector ld_b_; - // memory stride for C - std::vector shapes_c_; - std::vector ld_c_; - // constant memory - std::vector h_a_deltas_; - std::vector h_b_deltas_; - std::vector h_masks_; - driver::buffer* d_a_deltas_; - driver::buffer* d_b_deltas_; - driver::buffer* d_masks_; - driver::buffer* d_locks_; - bool is_a_deltas_cst; - bool is_b_deltas_cst_; - bool is_mask_cst_; - // data type - std::string a_ty_; - std::string b_ty_; - // conv type - type ty_; - bool bias_; - bool b_trans_; - bool b_lut_; - // axis index - int32_t a_inner_idx_; - int32_t a_outer_idx_; - int32_t a_pix_idx_; - int32_t b_inner_idx_; - int32_t b_outer_idx_; - int32_t b_pix_idx_; - int32_t c_outer_0_idx_; - int32_t c_outer_1_idx_; - int32_t c_pix_idx; - // maximum grid size for loc - int32_t max_grid_0_; - int32_t max_grid_1_; -}; - -} -} diff --git a/include/triton/dnn/dot.h b/include/triton/dnn/dot.h deleted file mode 100644 index f36d05db5..000000000 --- a/include/triton/dnn/dot.h +++ /dev/null @@ -1,79 +0,0 @@ -#include "triton/driver/stream.h" -#include "triton/driver/kernel.h" -#include "triton/dnn/base.h" -#include - -namespace triton{ -namespace dnn{ - -class dot: public base { -private: - // initialize - void init_impl(driver::stream *, driver::cu_module *, triton::runtime::launch_information); - void deinit_impl() { } - - // enqueue - void enqueue_impl(driver::stream *stream, driver::kernel *kernel, - std::vector args, - triton::runtime::launch_information info); - // retuning parameters - std::vector retune_params() const; - // default parameters - virtual std::vector search_space() const; - virtual params_t heuristics() const; - -public: - dot(int M, int N, int K, bool AT, bool BT, - std::string a_ty, std::string b_ty, std::string c_ty, - unsigned align_lda, unsigned align_ldb, unsigned align_ldc); - - // number of flops - size_t num_flops() const; - - // triton-c source - void triton_c_src(std::ostream &os) const; - - // clone - base* clone() const; - - // CPU reference implementation - template - static void cpu_ref(std::vector &c, const std::vector &a, const std::vector &b, - size_t M, size_t N, size_t K){ - for(size_t m = 0; m < M; m++) - for(size_t n = 0; n < N; n++){ - float acc = 0; - for(size_t k = 0; k < K; k++) - acc = acc + (AT ? a[k + m*K] : a[m + k*M]) * (BT ? b[n + k*N] : b[k + n*K]); - c[m + n*M] = static_cast(acc); - } - } - template - void cpu_ref(std::vector &c, const std::vector &a, const std::vector &b) { - if(AT_ && BT_) - dot::cpu_ref(c, a, b, M_, N_, K_); - else if(AT_ && !BT_) - dot::cpu_ref(c, a, b, M_, N_, K_); - else if(!AT_ && BT_) - dot::cpu_ref(c, a, b, M_, N_, K_); - else - dot::cpu_ref(c, a, b, M_, N_, K_); - } - -private: - int32_t M_; - int32_t N_; - int32_t K_; - bool AT_; - bool BT_; - std::string a_ty_; - std::string b_ty_; - std::string c_ty_; - unsigned align_lda_; - unsigned align_ldb_; - unsigned align_ldc_; - driver::buffer *locks_; -}; - -} -} diff --git a/include/triton/dnn/heuristics.h b/include/triton/dnn/heuristics.h deleted file mode 100644 index 56c23642b..000000000 --- a/include/triton/dnn/heuristics.h +++ /dev/null @@ -1,186 +0,0 @@ -#ifndef TRITON_DNN_HEURISTICS_H -#define TRITON_DNN_HEURISTICS_H - -#include -#include "triton/dnn/base.h" - -namespace triton{ -namespace dnn{ - -/* Dense matrix multiplication */ - -typedef std::vector params_t; -typedef std::tuple trans_key_t; -typedef std::tuple size_key_t; -static const std::map> dot_params = { - /* NN */ - {trans_key_t(false, false), std::map{ - {size_key_t(16, 16), {2, 8, 16, 4, 16, 2, 2, 1, 1, 16, 32, 8, 4, 1}}, - {size_key_t(16, 32), {4, 4, 16, 4, 32, 2, 2, 1, 1, 8, 32, 8, 4, 1}}, - {size_key_t(16, 64), {4, 4, 16, 4, 64, 2, 2, 1, 1, 8, 32, 8, 4, 1}}, - {size_key_t(16, 128), {2, 8, 16, 8, 128, 2, 2, 1, 1, 16, 32, 4, 8, 1}}, - {size_key_t(32, 16), {8, 4, 32, 8, 16, 2, 2, 1, 1, 4, 32, 4, 8, 1}}, - {size_key_t(32, 32), {4, 8, 32, 4, 32, 2, 2, 1, 1, 8, 32, 8, 4, 1}}, - {size_key_t(32, 64), {8, 4, 32, 4, 64, 2, 2, 1, 1, 4, 32, 8, 4, 1}}, - {size_key_t(32, 128), {8, 4, 32, 32, 128, 2, 2, 2, 2, 16, 32, 4, 4, 1}}, - {size_key_t(32, 256), {4, 8, 32, 32, 256, 2, 2, 1, 4, 32, 32, 4, 8, 1}}, - {size_key_t(64, 16), {8, 8, 64, 8, 16, 2, 2, 1, 1, 4, 32, 4, 8, 1}}, - {size_key_t(64, 32), {8, 8, 64, 8, 32, 2, 2, 1, 1, 4, 32, 4, 8, 1}}, - {size_key_t(64, 64), {8, 8, 64, 16, 64, 2, 2, 1, 2, 8, 32, 4, 8, 1}}, - {size_key_t(64, 128), {16, 4, 64, 32, 128, 2, 2, 1, 4, 8, 32, 4, 8, 1}}, - {size_key_t(128, 16), {8, 8, 128, 16, 16, 2, 2, 2, 1, 8, 32, 4, 8, 1}}, - {size_key_t(128, 32), {32, 4, 128, 16, 32, 2, 2, 2, 1, 2, 32, 4, 8, 1}}, - {size_key_t(128, 64), {16, 8, 128, 16, 64, 2, 2, 2, 2, 8, 32, 8, 4, 1}}, - {size_key_t(128, 128), {8, 8, 128, 32, 128, 2, 2, 2, 2, 16, 32, 4, 8, 1}}, - {size_key_t(256, 16), {32, 8, 256, 16, 16, 2, 2, 4, 1, 4, 32, 8, 4, 1}}, - {size_key_t(256, 32), {32, 8, 256, 16, 32, 2, 2, 4, 1, 4, 32, 8, 4, 1}}, - {size_key_t(256, 64), {16, 8, 256, 32, 64, 2, 2, 4, 1, 8, 32, 4, 8, 1}} - }}, - /* NT */ - {trans_key_t(false, true), std::map{ - {size_key_t(16, 16), {2, 4, 16, 2, 8, 16, 2, 2, 1, 1, 16, 32, 16, 1}}, - {size_key_t(16, 32), {4, 4, 16, 8, 4, 32, 2, 2, 1, 1, 8, 32, 4, 1}}, - {size_key_t(16, 64), {2, 4, 16, 2, 8, 64, 2, 2, 1, 1, 16, 32, 16, 1}}, - {size_key_t(16, 128), {2, 8, 16, 8, 8, 128, 2, 2, 1, 1, 16, 32, 4, 1}}, - {size_key_t(32, 16), {8, 4, 32, 2, 8, 16, 2, 2, 1, 1, 4, 32, 16, 1}}, - {size_key_t(32, 32), {4, 8, 32, 8, 4, 32, 2, 2, 1, 1, 8, 32, 4, 1}}, - {size_key_t(32, 64), {16, 4, 64, 16, 4, 64, 2, 2, 4, 1, 8, 32, 8, 1}}, - {size_key_t(32, 128), {4, 8, 32, 16, 4, 128, 2, 2, 1, 2, 16, 32, 4, 1}}, - {size_key_t(32, 256), {4, 8, 32, 64, 4, 256, 2, 2, 1, 4, 32, 32, 2, 1}}, - {size_key_t(64, 16), {8, 8, 64, 2, 8, 16, 2, 2, 1, 1, 4, 32, 16, 1}}, - {size_key_t(64, 32), {16, 4, 64, 4, 4, 32, 2, 2, 1, 1, 2, 32, 8, 1}}, - {size_key_t(64, 64), {8, 8, 64, 8, 8, 64, 2, 2, 2, 1, 8, 32, 8, 1}}, - {size_key_t(64, 128), {4, 4, 64, 8, 8, 128, 2, 2, 1, 4, 32, 32, 16, 1}}, - {size_key_t(64, 256), {8, 8, 64, 8, 8, 256, 2, 2, 1, 4, 16, 32, 16, 1}}, - {size_key_t(128, 16), {16, 8, 128, 2, 8, 16, 2, 2, 1, 1, 2, 32, 16, 1}}, - {size_key_t(128, 32), {32, 4, 128, 4, 8, 32, 2, 2, 2, 1, 2, 32, 16, 1}}, - {size_key_t(128, 64), {8, 8, 128, 8, 8, 64, 2, 2, 4, 1, 16, 32, 16, 1}}, - {size_key_t(128, 128), {8, 8, 128, 16, 8, 128, 2, 2, 2, 2, 16, 32, 8, 1}}, - {size_key_t(256, 16), {32, 4, 256, 4, 4, 16, 2, 2, 4, 1, 4, 32, 32, 1}}, - {size_key_t(256, 32), {16, 8, 256, 8, 4, 32, 2, 2, 4, 1, 8, 32, 16, 1}}, - {size_key_t(256, 64), {8, 8, 256, 8, 8, 64, 2, 2, 4, 1, 16, 32, 16, 1}} - }}, - /* TN */ - {trans_key_t(true, false), std::map{ - {size_key_t(16, 16), {4, 16, 4, 16, 2, 2, 1, 1, 8, 4, 32, 8, 4, 1}}, - {size_key_t(16, 32), {8, 16, 8, 32, 2, 2, 1, 1, 4, 4, 32, 4, 4, 1}}, - {size_key_t(16, 64), {4, 16, 8, 64, 2, 2, 1, 1, 8, 4, 32, 4, 8, 1}}, - {size_key_t(16, 128), {4, 16, 8, 128, 2, 2, 1, 1, 8, 4, 32, 4, 8, 1}}, - {size_key_t(32, 16), {4, 32, 8, 16, 2, 2, 1, 1, 8, 4, 32, 4, 8, 1}}, - {size_key_t(32, 32), {4, 32, 4, 32, 2, 2, 1, 1, 8, 4, 32, 8, 4, 1}}, - {size_key_t(32, 64), {4, 32, 4, 64, 2, 2, 1, 1, 8, 4, 32, 8, 4, 1}}, - {size_key_t(32, 128), {8, 32, 8, 128, 2, 2, 1, 1, 4, 8, 32, 4, 8, 1}}, - {size_key_t(32, 256), {32, 32, 32, 256, 2, 2, 1, 4, 4, 8, 32, 4, 8, 1}}, - {size_key_t(64, 16), {4, 64, 8, 16, 2, 2, 1, 1, 8, 4, 32, 4, 8, 1}}, - {size_key_t(64, 32), {4, 64, 4, 32, 2, 2, 1, 1, 8, 4, 32, 8, 4, 1}}, - {size_key_t(64, 64), {8, 64, 16, 64, 2, 2, 2, 1, 8, 4, 32, 4, 8, 1}}, - {size_key_t(64, 128), {16, 64, 32, 128, 2, 2, 1, 4, 8, 4, 32, 4, 8, 1}}, - {size_key_t(128, 16), {8, 128, 8, 16, 2, 2, 1, 1, 4, 8, 32, 4, 8, 1}}, - {size_key_t(128, 32), {16, 128, 16, 32, 2, 2, 4, 1, 8, 4, 32, 8, 4, 1}}, - {size_key_t(128, 64), {32, 128, 32, 64, 2, 2, 2, 2, 4, 8, 32, 4, 8, 1}}, - {size_key_t(128, 128), {32, 128, 32, 128, 2, 2, 1, 4, 4, 8, 32, 4, 8, 1}}, - {size_key_t(256, 16), {16, 256, 16, 16, 2, 2, 2, 1, 4, 8, 32, 4, 8, 1}}, - {size_key_t(256, 32), {16, 256, 32, 32, 2, 2, 4, 1, 8, 4, 32, 4, 8, 1}}, - }}, - /* TT */ - {trans_key_t(true, true), std::map{ - {size_key_t(16, 16), {8, 16, 4, 4, 16, 2, 2, 1, 1, 4, 8, 32, 8, 1}}, - {size_key_t(16, 32), {8, 16, 8, 4, 32, 2, 2, 1, 1, 4, 8, 32, 4, 1}}, - {size_key_t(16, 64), {16, 16, 4, 8, 64, 2, 2, 1, 4, 8, 4, 32, 32, 1}}, - {size_key_t(16, 128), {16, 16, 8, 8, 128, 2, 2, 1, 1, 2, 4, 32, 4, 1}}, - {size_key_t(32, 16), {4, 32, 4, 4, 16, 2, 2, 1, 1, 8, 4, 32, 8, 1}}, - {size_key_t(32, 32), {8, 32, 8, 4, 32, 2, 2, 1, 1, 4, 8, 32, 4, 1}}, - {size_key_t(32, 64), {64, 128, 8, 4, 64, 2, 2, 4, 1, 2, 8, 32, 16, 1}}, - {size_key_t(32, 128), {16, 32, 32, 4, 128, 2, 2, 1, 2, 4, 8, 32, 2, 1}}, - {size_key_t(32, 256), {32, 32, 32, 4, 256, 2, 2, 1, 4, 4, 8, 32, 4, 1}}, - {size_key_t(64, 16), {4, 64, 2, 8, 16, 2, 2, 1, 1, 8, 4, 32, 16, 1}}, - {size_key_t(64, 32), {4, 64, 8, 4, 32, 2, 2, 1, 1, 8, 4, 32, 4, 1}}, - {size_key_t(64, 64), {16, 64, 8, 8, 64, 2, 2, 2, 1, 4, 8, 32, 8, 1}}, - {size_key_t(64, 128), {32, 64, 8, 8, 128, 2, 2, 1, 4, 4, 4, 32, 16, 1}}, - {size_key_t(64, 256), {64, 64, 8, 8, 256, 2, 2, 1, 4, 2, 8, 32, 16}}, - {size_key_t(128, 16), {8, 128, 2, 8, 16, 2, 2, 1, 1, 4, 8, 32, 16, 1}}, - {size_key_t(128, 32), {16, 128, 8, 4, 32, 2, 2, 4, 1, 8, 4, 32, 16, 1}}, - {size_key_t(128, 64), {32, 128, 8, 8, 64, 2, 2, 4, 1, 4, 8, 32, 16, 1}}, - {size_key_t(128, 128), {32, 128, 16, 8, 128, 2, 2, 2, 2, 4, 8, 32, 8, 1}}, - {size_key_t(256, 16), {32, 256, 4, 4, 16, 2, 2, 4, 1, 4, 8, 32, 32, 1}}, - {size_key_t(256, 32), {32, 256, 8, 4, 32, 2, 2, 4, 1, 4, 8, 32, 16, 1}} - }} -}; - -// small search space for partial auto-tuning -inline std::vector dot_search_space(bool AT, bool BT) { - std::vector result; - for(auto x: dot_params.at(trans_key_t{AT, BT})) - result.push_back(x.second); - return result; -} - -// simple parameter heuristics -inline params_t dot_heuristics(bool AT, bool BT, size_t M, size_t N, size_t K) { - size_t TM = 128; - size_t TN = 128; -// return {4, 4, 128, 8, 4, 128, 2, 2, 2, 2, 32, 32, 16, 1}; - return dot_params.at(trans_key_t{AT, BT}).at(size_key_t{TM, TN}); -} - - -/* Block-sparse matrix multiplication */ - -static const std::map, std::map> bsdot_params = { - /* FPROP */ - {{true, 32}, std::map{ - {32, {2, 2, 32, 32, 2, 2, 4, 8, 32, 32, 8, 4, 16}}, - {64, {2, 2, 64, 32, 2, 1, 16, 4, 4, 32, 16, 2, 4}}, - {128, {2, 2, 128, 32, 4, 1, 32, 4, 4, 32, 8, 4, 16}} - }}, - - {{true, 16}, std::map{ - {32, {4, 1, 32, 16, 1, 1, 8, 4, 4, 16, 4, 4, 8}}, - {64, {4, 1, 64, 16, 2, 2, 8, 8, 16, 16, 8, 2, 16}}, - {128, {4, 1, 128, 16, 4, 1, 16, 8, 8, 16, 8, 2, 16}} - }}, - - {{true, 8}, std::map{ - {32, {4, 1, 32, 8, 1, 1, 4, 8, 8, 8, 4, 2, 8}}, - {64, {4, 1, 64, 8, 1, 1, 8, 8, 4, 8, 4, 2, 8}}, - {128, {4, 1, 128, 8, 1, 1, 4, 8, 8, 8, 4, 2, 8}} - }}, - - /* BPROP */ - {{false, 32}, std::map{ - {32, {2, 2, 32, 32, 1, 1, 8, 4, 4, 32, 8, 4, 8}}, - {64, {2, 2, 64, 32, 2, 1, 16, 4, 4, 32, 16, 4, 8}}, - {128, {2, 2, 128, 32, 4, 1, 32, 4, 4, 32, 32, 4, 8}} - }}, - - {{false, 16}, std::map{ - {32, {4, 1, 32, 16, 1, 2, 4, 8, 16, 16, 16, 4, 4}}, - {64, {4, 1, 64, 16, 2, 1, 8, 8, 8, 16, 16, 4, 4}}, - {128, {4, 1, 128, 16, 2, 2, 32, 4, 4, 16, 16, 8, 2}} - }}, - - {{false, 8}, std::map{ - {32, {4, 1, 32, 8, 1, 1, 4, 8, 8, 8, 8, 4, 2}}, - {64, {4, 1, 64, 8, 1, 1, 8, 8, 4, 8, 8, 4, 2}}, - {128, {4, 1, 128, 8, 1, 1, 8, 8, 4, 8, 8, 4, 2}} - }} -}; - -// small search space for partial auto-tuning -inline std::vector bsdot_search_space(bool is_fprop, size_t block_size) { - std::vector result; - for(auto x: bsdot_params.at({is_fprop, block_size})) - result.push_back(x.second); - return result; -} - -// simple parameter heuristics -inline params_t bsdot_heuristics(bool is_fprop, size_t block_size, size_t N, size_t S) { - return bsdot_params.at({is_fprop,block_size}).at(128); -} - - -} -} - -#endif diff --git a/include/triton/dnn/shift.h b/include/triton/dnn/shift.h deleted file mode 100644 index 4590c476e..000000000 --- a/include/triton/dnn/shift.h +++ /dev/null @@ -1,192 +0,0 @@ -/* Copyright 2015-2017 Philippe Tillet -* -* Permission is hereby granted, free of charge, to any person obtaining -* a copy of this software and associated documentation files -* (the "Software"), to deal in the Software without restriction, -* including without limitation the rights to use, copy, modify, merge, -* publish, distribute, sublicense, and/or sell copies of the Software, -* and to permit persons to whom the Software is furnished to do so, -* subject to the following conditions: -* -* The above copyright notice and this permission notice shall be -* included in all copies or substantial portions of the Software. -* -* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, -* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF -* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. -* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY -* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, -* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE -* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -*/ - -#ifndef TDL_INCLUDE_DNN_SHIFT_H -#define TDL_INCLUDE_DNN_SHIFT_H - -#include -#include -#include -#include -#include -#include "triton/dnn/base.h" -#include "triton/driver/stream.h" -#include "triton/driver/kernel.h" - -namespace triton{ -namespace dnn{ - -enum op_t { - FPROP, - BPROP, - WGRAD -}; - -enum layout_t { - NCHW, - CHWN -}; - -class shift: public base { -private: - // initialize and enqueue - void init_impl(driver::stream *stream, driver::cu_module *module, triton::runtime::launch_information info); - void deinit_impl(); - void enqueue_impl(driver::stream *stream, driver::kernel *kernel, - std::vector args, - triton::runtime::launch_information info); - std::vector search_space() const; - params_t heuristics() const; - -public: - - shift(int B, int NC, - int D, int H, int W, - int T, int R, int S, int NF, - int stride_h, int stride_w, - const int32_t* shift_h, const int32_t* shift_w, - std::string a_ty = "float", std::string b_ty = "float", - op_t ty = FPROP, bool bias = false, layout_t layout = CHWN); - - // look-up table - void build_delta_a(); - void build_masks(); - // accessors - size_t c_size(); - std::vector c_shapes(); - // equivalent GEMM - bool AT() const; - bool BT() const; - size_t M() const; - size_t N() const; - size_t K() const; - size_t lda() const; - size_t ldb() const; - size_t ldc() const; - // number of flops - size_t num_flops() const; - // source - void triton_c_src(std::ostream &os) const; - // retuning parameters - std::vector retune_params() const; - // clone - base* clone() const; - // cpu reference - template - void cpu_ref(OUT_DTYPE* O, - const IN_DTYPE* I, - const IN_DTYPE* F) - { - OUT_DTYPE acc; - for(int32_t p = 0; p < AH_; ++p) - for(int32_t q = 0; q < AW_; ++q) - for(int32_t bs = 0; bs < B_; ++bs) - for(int32_t k = 0; k < F_; ++k) - { - acc = 0; - for(int32_t c = 0; c < C_; ++c){ - int32_t h = p; - int32_t w = q; - if(h >= BH_/2 && h < AH_ - BH_/2 - && w >= BW_/2 && w < AW_ - BW_/2){ - h += shift_h_[c]; - w += shift_w_[c]; - } - IN_DTYPE a = I[bs + w*B_ + h*B_*AW_ + c*B_*AH_*AW_]; - IN_DTYPE b = F[k + c*F_]; - acc = std::fma(a, b, acc); - } - O[bs + q*B_ + p*B_*AW_ + k*B_*AH_*AW_] = acc; - } - } - -private: - int32_t MAX_C_; - int32_t TK_; - // image size - int32_t B_; - int32_t C_; - int32_t AD_; - int32_t AH_; - int32_t AW_; - // filter size - int32_t BD_; - int32_t BH_; - int32_t BW_; - int32_t F_; - // activation size - int32_t CD_; - int32_t CH_; - int32_t CW_; - // interior image size - int32_t IAD_; - int32_t IAH_; - int32_t IAW_; - // interior activation size - int32_t ICD_; - int32_t ICH_; - int32_t ICW_; - // equivalent matmul - int32_t M_; - int32_t N_; - int32_t K_; - // shapes - std::vector shapes_c_; - // strides - int32_t stride_d_; - int32_t stride_h_; - int32_t stride_w_; - // memory strides - int32_t lda_n_, lda_c_, lda_h_, lda_w_; - int32_t ldb_n_, ldb_c_, ldb_h_, ldb_w_; - int32_t ldc_n_, ldc_f_, ldc_h_, ldc_w_; - // shift values - const int32_t* shift_h_; - const int32_t* shift_w_; - bool shift_edge_h_; - bool shift_edge_w_; - // look-up tables - std::vector h_delta_a; - std::vector h_delta_b; - driver::buffer* d_delta_a; - driver::buffer* d_delta_b; - // data types - std::string a_ty_; - std::string b_ty_; - std::string c_ty_; - // convolution type - op_t op_; - bool bias_; - // transpose - bool AT_; - bool BT_; - // layout - layout_t layout_; - // locks - size_t max_locks_; - driver::buffer *locks_; -}; - -} -} - -#endif diff --git a/include/triton/runtime/jit.h b/include/triton/runtime/jit.h deleted file mode 100644 index a7fb5deeb..000000000 --- a/include/triton/runtime/jit.h +++ /dev/null @@ -1,136 +0,0 @@ -#ifndef TDL_INCLUDE_JIT_H -#define TDL_INCLUDE_JIT_H - -#include -#include -#include "llvm/IR/LLVMContext.h" -#include "triton/ir/context.h" -#include "triton/ir/print.h" -#include "triton/driver/module.h" -#include "triton/driver/kernel.h" -#include "triton/codegen/selection/selection.h" -#include "triton/codegen/selection/target.h" -#include "triton/codegen/analysis/tune.h" -#include "triton/codegen/analysis/shmem/allocation.h" -#include "triton/codegen/analysis/shmem/liveness.h" -#include "triton/codegen/analysis/shmem/info.h" -#include "triton/codegen/analysis/alignment.h" -#include "triton/codegen/transform/dce.h" -#include "triton/codegen/transform/peephole.h" -#include "triton/codegen/transform/shmem/barriers.h" -#include "triton/codegen/transform/reassociate.h" -#include "triton/codegen/transform/vectorize.h" -#include "triton/runtime/launch_info.h" -#include - -namespace llvm { - class Module; - -} - -namespace triton { - -namespace lang{ -class translation_unit; -} - -namespace codegen{ -namespace analysis{ -class tune; -} -} - -namespace ir { -class module; -class context; -class metaparameter; -} - -namespace runtime{ - -class jit { -public: - typedef std::function benchmark_t; - - struct tune_res_t{ - double perf; - std::vector params; - }; - - struct passes_wrapper { - passes_wrapper(codegen::target* target) - : tune(0), - shmem_liveness(&shmem_info), - shmem_allocation(&shmem_liveness, &shmem_info, &tune), - shmem_barriers(&shmem_allocation, &shmem_info), - vectorize(&tune), - selection(&shmem_allocation, &tune, &shmem_info, &alignment_info, target), - dce(), - peephole(), - alignment_info(), - reassociate(&tune), - target_(target) { } - - void target_independent(ir::module &module) { - peephole.run(module); - dce.run(module); - } - - void target_dependent(ir::module &module) { - reassociate.run(module); - peephole.run(module); - if(target_->is_gpu()){ - shmem_info.run(module); - shmem_liveness.run(module); - shmem_allocation.run(); - shmem_barriers.run(module); - } - alignment_info.run(module); - vectorize.run(module); - dce.run(module); - } - - codegen::selection selection; - codegen::analysis::tune tune; - codegen::analysis::shmem::info shmem_info; - codegen::analysis::shmem::liveness shmem_liveness; - codegen::analysis::shmem::allocation shmem_allocation; - codegen::analysis::alignment_info alignment_info; - codegen::transform::shmem_barriers shmem_barriers; - codegen::transform::vectorize vectorize; - codegen::transform::dce dce; - codegen::transform::peephole peephole; - codegen::transform::reassociate reassociate; - codegen::target* target_; - }; - -private: - std::string compute_data_layout(bool is_64bit = true, bool use_short_pointers = true); - std::unique_ptr make_llvm_module(triton::ir::module &module, passes_wrapper &passes, llvm::LLVMContext &context, launch_information &info); - std::unique_ptr make_triton_module(const char *name, triton::ir::context &context, triton::lang::translation_unit *program); - triton::lang::translation_unit *parse_program(const char *name, const char *src); - -public: - jit(driver::context* context, unsigned nthreads = 4); - ~jit(); - std::vector get_valid(const char *name, const char *src); - tune_res_t autotune(const char* name, const char* src, benchmark_t benchmark, const std::vector > &targets = {}); - void add_module(ir::module &module, const std::vector& params = {}); - void add_module(const char* name, const char* src, const std::vector& params = {}); - driver::kernel* get_function(const char* name); - launch_information get_launch_info(const char* name); - -private: - std::map modules_; - driver::context* driver_context_; - llvm::LLVMContext llvm_context_; - ir::context triton_context_; - std::map launch_info_map_; - std::shared_ptr target_; - unsigned nthreads_; -}; - -} -} - -#endif diff --git a/lib/dnn/base.cpp b/lib/dnn/base.cpp deleted file mode 100644 index 86d031564..000000000 --- a/lib/dnn/base.cpp +++ /dev/null @@ -1,94 +0,0 @@ -#include -#include -#include "triton/dnn/base.h" -#include "triton/runtime/jit.h" -#include "triton/tools/bench.hpp" - -namespace triton{ -namespace dnn{ - -namespace rt = triton::runtime; - - -void base::set_ld(const std::vector& shapes, - std::vector& ld) { - size_t size = shapes.size(); - ld.resize(size); - ld[size - 1] = 1; - for(int i = size - 1; i >= 1; i--) - ld[i - 1] = shapes[i] * ld[i]; -} - - -base::base(const std::string& name) - : name_(name) { } - -std::vector base::search_space() const { - return {}; -} - -params_t base::heuristics() const { - return *search_space().begin(); -} - -std::pair base::get_profile_impl(driver::stream *stream, std::vector args, autotuning_t autotune) { - static std::unordered_map, recompile_hash, recompile_equal> m_jit; - driver::context* ctx = stream->context(); - rt::jit* jit; - /* the current template has not already been compiled */ - if(m_jit.find(this) == m_jit.end()) { - base* clone = this->clone(); - jit = m_jit.emplace(clone, std::unique_ptr(new rt::jit(ctx, 8))).first->second.get(); - std::ostringstream oss; - clone->triton_c_src(oss); - std::string src = oss.str(); - auto benchmark = [&](triton::driver::kernel* kernel, - rt::launch_information info) { - // launch info - clone->init_impl(stream, (triton::driver::cu_module*)kernel->module(), info); - clone->enqueue_impl(stream, kernel, args, info); - stream->synchronize(); - double ts = triton::tools::bench([&](){ clone->enqueue_impl(stream, kernel, args, info); }, stream); - clone->deinit_impl(); -// std::cout << ts * 1e-6 << std::endl; - return num_flops() / ts * 1e-3; - }; - // auto-tune and save result - if(autotune == FULL_TUNING || autotune == PARTIAL_TUNING) { - std::vector space = {}; - if(autotune == PARTIAL_TUNING) - space = search_space(); - rt::jit::tune_res_t best = jit->autotune(name_.c_str(), src.c_str(), benchmark, space); - jit->add_module(name_.c_str(), src.c_str(), best.params); - } - else{ - params_t params = heuristics(); - jit->add_module(name_.c_str(), src.c_str(), params); - } - triton::driver::kernel* kernel = jit->get_function(name_.c_str()); - rt::launch_information info = jit->get_launch_info(name_.c_str()); - clone->init_impl(stream, (triton::driver::cu_module*)kernel->module(), info); - } - /* retrieved compiled template */ - else { - jit = m_jit.at(this).get(); - } - auto it = m_jit.find(this); - return {it->first, jit}; -} - -base* base::enqueue(driver::stream *stream, std::vector args, autotuning_t autotune) { - launch_context_t info = get_launch_context(stream, args, autotune); - info.op->enqueue_impl(stream, info.kernel, args, info.info); - return info.op; -} - -launch_context_t base::get_launch_context(driver::stream *stream, std::vector args, autotuning_t autotune) { - std::pair profile = get_profile_impl(stream, args, autotune); - driver::kernel* kernel = profile.second->get_function(name_.c_str()); - rt::launch_information info = profile.second->get_launch_info(name_.c_str()); - return {profile.first, kernel, info}; -} - -} -} diff --git a/lib/dnn/batchnorm.cpp b/lib/dnn/batchnorm.cpp deleted file mode 100644 index fe785afdd..000000000 --- a/lib/dnn/batchnorm.cpp +++ /dev/null @@ -1,227 +0,0 @@ -/* Copyright 2015-2019 Philippe Tillet -* -* Permission is hereby granted, free of charge, to any person obtaining -* a copy of this software and associated documentation files -* (the "Software"), to deal in the Software without restriction, -* including without limitation the rights to use, copy, modify, merge, -* publish, distribute, sublicense, and/or sell copies of the Software, -* and to permit persons to whom the Software is furnished to do so, -* subject to the following conditions: -* -* The above copyright notice and this permission notice shall be -* included in all copies or substantial portions of the Software. -* -* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, -* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF -* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. -* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY -* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, -* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE -* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -*/ - -#include "triton/dnn/batchnorm.h" - -namespace triton{ -namespace dnn{ - -/* --------------- - * Forward - * --------------- */ - -batchnorm_forward::batchnorm_forward(int C, int D, int H, int W, int B, std::string ty, float eps) - : base("batchnorm_forward"), - C_(C), D_(D), H_(H), W_(W), B_(B), ty_(ty), eps_(eps) { - DHWB_ = D_*H_*W_*B_; - rcpDHWB_ = (float)1 / DHWB_; -} - -size_t batchnorm_forward::num_flops() const { - return C_*DHWB_; -} - - -std::vector batchnorm_forward::retune_params() const { - return {C_, D_, H_, W_, B_}; -} - -base* batchnorm_forward::clone() const { - return new batchnorm_forward(*this); -} - -void batchnorm_forward::enqueue_impl(driver::stream *stream, driver::kernel *kernel, - std::vector args, - runtime::launch_information info) -{ - driver::buffer *y = args[0], *m = args[1], *v = args[2]; - driver::buffer *x = args[3], *g = args[4], *b = args[5]; - std::array grid = {1, (size_t)C_, 1}; - kernel->setArg(0, y); - kernel->setArg(1, m); - kernel->setArg(2, v); - kernel->setArg(3, x); - kernel->setArg(4, g); - kernel->setArg(5, b); - kernel->setArg(6, DHWB_); - kernel->setArg(7, rcpDHWB_); - kernel->setArg(8, eps_); - stream->enqueue(kernel, grid, {info.num_threads, 1, 1}); -} - -void batchnorm_forward::triton_c_src(std::ostream &os) const { - os << -R"( -const tunable int TM = {128}; - -void batchnorm_forward(float *Y, float *M, float *V, - restrict read_only float *X, - restrict read_only float *G, - restrict read_only float *B, - int DHWN, - float rcpDHWN, float eps) { - int rx[TM] = 0 ... TM; - float *px[TM]; - float x[TM] = 0; - int c = get_program_id(1); - float g = *(G + c); - float b = *(B + c); - - float mean[TM] = 0; - px = X + rx + c*DHWN; - for(int i = 0; i < DHWN; i = i + TM){ - x = *px; - mean = mean + x; - px = px + TM; - } - float *pm = M + c; - float m = __sum(mean, 0) * rcpDHWN; - *pm = m; - - float var[TM] = 0; - px = X + rx + c*DHWN; - for(int i = 0; i < DHWN; i = i + TM){ - x = *px; - x = x - m; - var = var + x*x; - px = px + TM; - } - float v = __sum(var, 0) * rcpDHWN; - float *pv = V + c; - *pv = v; - float rstdg = 1 / sqrt(v + eps) * g; - - px = X + rx + c*DHWN; - float* py[TM] = Y + rx + c*DHWN; - for(int i = 0; i < DHWN; i = i + TM){ - x = *px; - float y[TM] = (x - m)*rstdg + b; - *py = y; - px = px + TM; - py = py + TM; - } -})"; -} - -/* --------------- - * Backward - * --------------- */ - -batchnorm_backward::batchnorm_backward(int C, int D, int H, int W, int B, std::string ty, float eps) - : base("batchnorm_backward"), - C_(C), D_(D), H_(H), W_(W), B_(B), - ty_(ty), eps_(eps) -{ } - -size_t batchnorm_backward::num_flops() const { - return C_*D_*H_*W_*B_; -} - -std::vector batchnorm_backward::retune_params() const { - return {C_, D_, H_, W_, B_}; -} - -base* batchnorm_backward::clone() const { - return new batchnorm_backward(*this); -} - -void batchnorm_backward::enqueue_impl(driver::stream *stream, driver::kernel *kernel, - std::vector args, - runtime::launch_information info) { - driver::buffer *dx = args[0], *dg = args[1], *db = args[2], *dy = args[3]; - driver::buffer *x = args[4], *g = args[5], *m = args[6], *v = args[7]; - std::array grid = {1, (size_t)C_, 1}; - kernel->setArg(0, dx); - kernel->setArg(1, dg); - kernel->setArg(2, db); - kernel->setArg(3, dy); - kernel->setArg(4, x); - kernel->setArg(5, g); - kernel->setArg(6, m); - kernel->setArg(7, v); - kernel->setArg(8, (int32_t)(D_*H_*W_*B_)); - kernel->setArg(9, (float)1/(D_*H_*W_*B_)); - kernel->setArg(10, eps_); - stream->enqueue(kernel, grid, {info.num_threads, 1, 1}); -} - -void batchnorm_backward::triton_c_src(std::ostream &os) const { - os << -R"( -const tunable int TM = {128}; - -void batchnorm_backward(float *DX, float *DG, float *DB, - restrict read_only float *DY, - restrict read_only float *X, - restrict read_only float *G, - restrict read_only float *M, - restrict read_only float *V, - int DHWN, float rcpDHWN, float epsilon) { - int rx[TM] = 0 ... TM; - int c = get_program_id(1); - int offset = c*DHWN; - float g = *(G + c); - float mean = *(M + c); - float var = *(V + c); - float rstd = 1 / sqrt(var + epsilon); - float* px[TM]; - float* pdx[TM]; - float* pdy[TM]; - - px = X + rx + offset; - pdy = DY + rx + offset; - float dg[TM] = 0; - float db[TM] = 0; - for(int i = 0; i < DHWN; i = i + TM){ - float x[TM] = *px; - float dy[TM] = *pdy; - dg = dg + dy*(x - mean)*rstd; - db = db + dy; - px = px + TM; - pdy = pdy + TM; - } - float sdg = __sum(dg, 0); - float sdb = __sum(db, 0); - float *pdg = DG + c; - float *pdb = DB + c; - *pdg = sdg; - *pdb = sdb; - - px = X + rx + offset; - pdy = DY + rx + offset; - pdx = DX + rx + offset; - for(int i = 0; i < DHWN; i = i + TM){ - float x[TM] = *px; - float dy[TM] = *pdy; - float xhat[TM] = (x - mean) * rstd; - float xtmp[TM] = (xhat * dg + db) * rcpDHWN; - float dx[TM] = (dy - xtmp) * rstd * g; - *pdx = dx; - px = px + TM; - pdy = pdy + TM; - pdx = pdx + TM; - } -})"; -} - -} -} diff --git a/lib/dnn/blocksparse/dot.cpp b/lib/dnn/blocksparse/dot.cpp deleted file mode 100644 index b155f9c89..000000000 --- a/lib/dnn/blocksparse/dot.cpp +++ /dev/null @@ -1,238 +0,0 @@ -#include "triton/dnn/heuristics.h" -#include "triton/dnn/blocksparse/dot.h" - -namespace triton{ -namespace dnn{ -namespace blocksparse{ - - -size_t dot::num_flops() const { - return 2.*nblocks_*BS_*BS_*N_; -} - -std::vector dot::retune_params() const{ - return {N_, S_, C_, BS_, nlocks_, op_}; -} - -std::vector dot::search_space() const { - return bsdot_search_space(op_ == FPROP, BS_); -} - -params_t dot::heuristics() const { - return bsdot_heuristics(op_ == FPROP, BS_, N_, S_); -} - -base * dot::clone() const { - return new dot(*this); -} - -dot::dot(int32_t N, int32_t K, int32_t S, int32_t C, - const std::string& ty, int32_t BS, int32_t nlocks, int32_t nblocks, op_t op): - base("bsdot"), - N_(N), K_(K), S_(S), C_(C), - ab_ty_(ty), c_ty_(ty), - BS_(BS), nlocks_(nlocks), nblocks_(nblocks), op_(op){ -} - -void dot::init_impl(driver::stream *stream, driver::cu_module *module, triton::runtime::launch_information info) { - int32_t TM = info.globals["TM"]; - size_t grid_0 = (N_ + TM - 1) / TM; - if(nlocks_ && !locks_){ - locks_.reset(triton::driver::buffer::create(stream->context(), grid_0 * nlocks_ * 2 * 4)); - ((driver::cu_buffer*)locks_.get())->set_zero(stream, grid_0 * nlocks_ * 2 * 4); - } -} - -void dot::deinit_impl() { -} - -void dot::enqueue_impl(driver::stream *stream, driver::kernel *kernel, - std::vector args, runtime::launch_information info) { - driver::buffer *a = args[0]; - driver::buffer *b = args[1]; - driver::buffer *c = args[2]; - driver::buffer *lut = args[3]; - kernel->setArg(0, a); - kernel->setArg(1, b); - kernel->setArg(2, c); - if(op_ == FPROP || op_ == BPROP){ - kernel->setArg(3, N_); - kernel->setArg(4, BS_); - kernel->setArg(5, N_); - } - else{ - kernel->setArg(3, N_); - kernel->setArg(4, N_); - kernel->setArg(5, BS_); - } - kernel->setArg(6, N_); - kernel->setArg(7, lut); - kernel->setArg(8, locks_.get()); - kernel->setArg(9, nlocks_); - if(op_ == FPROP || op_ == BPROP){ - int32_t TM = info.globals["TM"]; - size_t grid_0 = (N_ + TM - 1) / TM; - size_t grid_1 = S_; - if(nlocks_) - ((driver::cu_buffer*)locks_.get())->set_zero(stream, grid_0 * nlocks_ * 2 * 4); - stream->enqueue(kernel, {grid_0, grid_1, 1}, {info.num_threads, 1, 1}); - } - else{ - size_t grid_0 = nblocks_; - stream->enqueue(kernel, {grid_0, 1, 1}, {info.num_threads, 1, 1}); - } -} - -driver::buffer* dot::get_locks() const { - return locks_.get(); -} - -std::string dot::triton_c_src_ydx() const { - bool AT = (op_ == WGRAD); - bool BT = (op_ == FPROP); - std::string usea = AT ? "trans(a)" : "a"; - std::string useb = BT ? "trans(b)" : "b"; - std::string sizea = "TM, TK"; - std::string sizeb = BT ? "TN, TK" : "TK, TN"; - std::string bca0 = ":, newaxis"; - std::string bca1 = "newaxis, :"; - std::string bcb0 = BT ? ":, newaxis" : "newaxis, :"; - std::string bcb1 = BT ? "newaxis, :" : ":, newaxis"; - std::string lda0 = AT ? "*lda" : ""; - std::string lda1 = AT ? "" : "*lda"; - std::string ldb0 = BT ? "" : "*ldb"; - std::string ldb1 = BT ? "*ldb" : "" ; - std::string result = - R"( - const tunable int TM = {16, 32, 64, 128}; - const tunable int TN = {)" + std::to_string(BS_) + R"(}; - const tunable int TK = {)" + std::to_string(BS_) + R"(}; - - void bsdot(restrict read_only align(16) )" + ab_ty_ + R"( *A, - restrict read_only align(16) )" + ab_ty_ + R"( *B, - )" + c_ty_ + R"(* C, - int lda, int ldb, int ldc, - int N, int* lut, - int* locks, int nlocks) { - int ridx = get_program_id(0); - float acc[TM, TN] = 0; - int rka[TK] = 0 ... TK; - int rkb[TK] = 0 ... TK; - int *header = lut + get_program_id(1) * 4; - int offset = *(header + 0); - int K = *(header + 1); - int column = *(header + 2); - int lockid = *(header + 3); - int rxa[TM] = ridx * TM + (0 ... TM); - int ryb[TN] = 0 ... TN; - int *plut = lut + offset * 2; - int offa[)" + sizea + "] = rxa[" + bca0 + "]" + lda0 + " + rka[" + bca1 + "]" + lda1 + R"(; - int offb[)" + sizeb + "] = ryb[" + bcb0 + "]" + ldb0 + " + rkb[" + bcb1 + "]" + ldb1 + R"(; - bool checka[TM, TK] = (rxa < N)[:, newaxis]; - for(int k = K; k > 0; k = k - 1) { - int ak = *(plut + 0); - int bk = *(plut + 1); - )" + ab_ty_ + "* pa[" + sizea + R"(] = A + offa + ak * TK * lda; - )" + ab_ty_ + "* pb[" + sizeb + R"(] = B + offb + bk * TK * TN; - )" + ab_ty_ + " a[" + sizea + R"(] = checka ? *pa : 0; - )" + ab_ty_ + " b[" + sizeb + R"(] = *pb; - acc = dot()" + usea + ", " + useb + R"(, acc); - plut = plut + 2; - } - int rxc[TM] = ridx * TM + (0 ... TM); - int ryc[TN] = column * TN + (0 ... TN); - )" + c_ty_ + R"(" c[TM, TN] = acc; - )" + c_ty_ + R"(* pc[TM, TN] = C + rxc[:, newaxis] + ryc[newaxis, :]*ldc; - bool checkc[TM, TN] = (rxc < N)[:, newaxis]; - if(lockid == 0) { - @checkc *pc = c; - } - else { - int *plock = locks + ridx*nlocks + lockid - 1; - int *pcount = plock + get_num_program(0)*nlocks; - while(__atomic_cas(plock, 0, 1)); - int count = *pcount; - if(count == 0) - @checkc *pc = c; - else - @checkc *pc = c + *pc; - __atomic_exch(pcount, 1); - __atomic_exch(plock, 0); - } - })"; - - return result; -} - -std::string dot::triton_c_src_dw() const { - bool AT = (op_ == WGRAD); - bool BT = (op_ == FPROP); - std::string usea = AT ? "trans(a)" : "a"; - std::string useb = BT ? "trans(b)" : "b"; - std::string sizea = AT ? "TK, TM" : "TM, TK"; - std::string sizeb = BT ? "TN, TK" : "TK, TN"; - std::string bca0 = AT ? "newaxis, :" : ":, newaxis"; - std::string bca1 = AT ? ":, newaxis" : "newaxis, :"; - std::string bcb0 = BT ? ":, newaxis" : "newaxis, :"; - std::string bcb1 = BT ? "newaxis, :" : ":, newaxis"; - std::string lda0 = AT ? "*lda" : ""; - std::string lda1 = AT ? "" : "*lda"; - std::string ldb0 = BT ? "" : "*ldb"; - std::string ldb1 = BT ? "*ldb" : "" ; - std::string result = - R"( - const tunable int TM = {)" + std::to_string(BS_) + R"(}; - const tunable int TN = {)" + std::to_string(BS_) + R"(}; - const tunable int TK = {32}; - - void bsdot(restrict read_only align(16) )" + ab_ty_ + R"( *A, - restrict read_only align(16) )" + ab_ty_ + R"( *B, - )" + c_ty_ + R"(* C, - int lda, int ldb, int ldc, - int N, int* lut, - int* locks, int nlocks) { - int ridx = get_program_id(0); - float acc[TM, TN] = 0; - int rka[TK] = 0 ... TK; - int rkb[TK] = 0 ... TK; - int *header = lut + ridx * 2; - int offx = *(header + 0); - int offy = *(header + 1); - int rxa[TM] = offx*TM + (0 ... TM); - int ryb[TN] = offy*TN + (0 ... TN); - bool checka[TK, TM] = (rka < N)[:, newaxis]; - bool checkb[TK, TN] = (rkb < N)[:, newaxis]; - int offa[)" + sizea + "] = rxa[" + bca0 + "]" + lda0 + " + rka[" + bca1 + "]" + lda1 + R"(; - int offb[)" + sizeb + "] = ryb[" + bcb0 + "]" + ldb0 + " + rkb[" + bcb1 + "]" + ldb1 + R"(; - )" + ab_ty_ + " * pa[" + sizea + R"(] = A + offa; - )" + ab_ty_ + " * pb[" + sizeb + R"(] = B + offb; - )" + ab_ty_ + " a[" + sizea + R"(] = checka ? *pa : 0; - )" + ab_ty_ + " b[" + sizeb + R"(] = checkb ? *pb : 0; - for(int k = N; k > 0; k = k - TK) { - acc = dot()" + usea + ", " + useb + R"(, acc); - pa = pa + TK)" + lda1 + R"(; - pb = pb + TK)" + ldb1 + R"(; - a = checka ? *pa : 0; - b = checkb ? *pb : 0; - } - int rxc[TM] = (0 ... TM); - int ryc[TN] = (0 ... TN); - )" + c_ty_ + R"( c[TM, TN] = acc; - )" + c_ty_ + R"(* pc[TM, TN] = C + rxc[:, newaxis]*TM + ryc[newaxis, :] + ridx*TM*TN; - *pc = c; - })"; - - return result; -} -void dot::triton_c_src(std::ostream &os) const { - if(op_ == FPROP || op_ == BPROP) - os << triton_c_src_ydx(); - else - os << triton_c_src_dw(); -} - - - -} -} -} diff --git a/lib/dnn/conv.cpp b/lib/dnn/conv.cpp deleted file mode 100644 index 381691ff0..000000000 --- a/lib/dnn/conv.cpp +++ /dev/null @@ -1,720 +0,0 @@ -#include -#include "triton/dnn/conv.h" - -namespace triton{ -namespace dnn{ - -conv::conv(int B, int NC, - int D, int H, int W, - int T, int R, int S, int NF, - int stride_d, int stride_h, int stride_w, - int pad_d, int pad_h, int pad_w, - int upsample_d, int upsample_h, int upsample_w, - std::string a_ty, std::string b_ty, - type ty, bool bias) - : base("conv"), - NB_(B), NC_(NC), AD_(D), AH_(H), AW_(W), BD_(T), BH_(R), BW_(S), NF_(NF), - stride_d_(stride_d), stride_h_(stride_h), stride_w_(stride_w), - pad_d_(pad_d), pad_h_(pad_h), pad_w_(pad_w), - upsample_d_(upsample_d), upsample_h_(upsample_h), upsample_w_(upsample_w), - a_ty_(a_ty), b_ty_(b_ty), - ty_(ty), bias_(bias) -{ - CD_ = (AD_*upsample_d_ - BD_ + 1 + 2*pad_d_ + stride_d_ - 1)/stride_d_; - CH_ = (AH_*upsample_h_ - BH_ + 1 + 2*pad_h_ + stride_h_ - 1)/stride_h_; - CW_ = (AW_*upsample_w_ - BW_ + 1 + 2*pad_w_ + stride_w_ - 1)/stride_w_; - // shapes - shapes_a_ = {NB_, NC_, AD_, AH_, AW_}; - shapes_b_ = {NC_, BD_, BH_, BW_, NF_}; - shapes_c_ = {NB_, NF_, CD_, CH_, CW_}; - // a layout - NCHW - a_outer_idx_ = 0; - a_inner_idx_ = 1; - a_pix_idx_ = 2; - // b layout - CRSK - b_inner_idx_ = 0; - b_pix_idx_ = 1; - b_outer_idx_ = 4; - // c layout - NKPQ - c_outer_0_idx_ = 0; - c_outer_1_idx_ = 1; - c_pix_idx = 2; - // swap a and c for bprop - if(ty_ == BPROP){ - std::swap(AD_, CD_); - std::swap(AH_, CH_); - std::swap(AW_, CW_); - shapes_a_.swap(shapes_c_); - std::swap(stride_d_, upsample_d_); - std::swap(stride_h_, upsample_h_); - std::swap(stride_w_, upsample_w_); - pad_d_ = (CD_*stride_d_ - AD_*upsample_d_ + BD_ - 1 - stride_d_ + 1)/2; - pad_h_ = (CH_*stride_h_ - AH_*upsample_h_ + BH_ - 1 - stride_h_ + 1)/2; - pad_w_ = (CW_*stride_w_ - AW_*upsample_w_ + BW_ - 1 - stride_w_ + 1)/2; - std::swap(b_inner_idx_, b_outer_idx_); - std::swap(NC_, NF_); - } - // swap b and c for wgrad - if(ty_ == WGRAD){ - shapes_b_.swap(shapes_c_); - std::swap(BD_, CD_); - std::swap(BH_, CH_); - std::swap(BW_, CW_); - std::swap(a_outer_idx_, a_inner_idx_); - std::swap(b_inner_idx_, c_outer_0_idx_); - std::swap(b_outer_idx_, c_outer_1_idx_); - std::swap(b_pix_idx_, c_pix_idx); - } - // leading dimensions - set_ld(shapes_a_, ld_a_); - set_ld(shapes_b_, ld_b_); - set_ld(shapes_c_, ld_c_); - // equivalent matmul - bool upsampled_b = (ty_ == BPROP) && (upsample_d_ > 1 || upsample_h_ > 1 || upsample_w_ > 1); - b_trans_ = ty_ != BPROP; - b_lut_ = ty_ == WGRAD || upsampled_b; - M_ = shapes_c_[c_outer_0_idx_]*shapes_c_[c_pix_idx]*shapes_c_[c_pix_idx+1]*shapes_c_[c_pix_idx+2]; - N_ = shapes_c_[c_outer_1_idx_]; - K_ = shapes_b_[b_inner_idx_]*BD_*BH_*BW_; - // look-up table info - if(ty_ == FPROP) - Fs_ = shapes_b_[1]*shapes_b_[2]*shapes_b_[3]; - else - Fs_ = K_; - TK_ = 8; - Luts_ = (TK_ + Fs_ - 1) / Fs_ * Fs_; - build_a_deltas(); - if(b_lut_) - build_b_deltas(); - build_masks(); - size_t cst_size = h_b_deltas_.size()*4; - is_b_deltas_cst_ = cst_size < 65536; - cst_size += h_a_deltas_.size()*4; - is_a_deltas_cst = cst_size < 65536; - cst_size += h_masks_.size()*4; - is_mask_cst_ = cst_size < 65536; - max_grid_0_ = 256; - max_grid_1_ = 256; -} - -// comparison for maps -std::vector conv::retune_params() const { - return {NB_, NC_, AD_, AH_, AW_, - NF_, BD_, BH_, BW_, - pad_d_, pad_h_, pad_w_, - stride_d_, stride_h_, stride_w_, - ty_, bias_}; -} - -// clone -base* conv::clone() const { - return new conv(*this); -} - -size_t conv::a_size() -{ return std::accumulate(shapes_a_.begin(), shapes_a_.end(), - 1, std::multiplies()); } - -size_t conv::b_size() -{ return std::accumulate(shapes_b_.begin(), shapes_b_.end(), - 1, std::multiplies()); } - -size_t conv::c_size() -{ return std::accumulate(shapes_c_.begin(), shapes_c_.end(), - 1, std::multiplies()); } - -std::vector conv::c_shapes() -{ return shapes_c_; } - - -std::tuple conv::unpack(int32_t ltrs, bool flip, int32_t EBD, int32_t EBH, int32_t EBW) { - int32_t l, t, r, s; - if(b_trans_){ - l = ltrs / (EBD*EBH*EBW); - int32_t trs = ltrs % (EBD*EBH*EBW); - int32_t tr = trs / EBW; - s = trs % EBW; - t = tr / EBH; - r = tr % EBH; - } - else{ - int32_t rs = ltrs / NC_; - l = ltrs % NC_; - r = rs / EBW; - s = rs % EBW; - } - if(flip){ - r = EBH - 1 - r; - s = EBW - 1 - s; - } - return std::make_tuple(l, t, r, s); -} - -void conv::build_b_deltas(){ - h_b_deltas_.resize(Luts_*upsample_d_*upsample_h_*upsample_w_); - - size_t Ds0 = Luts_; - size_t Ds1 = upsample_w_; - size_t Ds2 = upsample_h_; - size_t Ds3 = upsample_d_; - for(size_t ud = 0; ud < Ds3; ++ud) - for(size_t uh = 0; uh < Ds2; ++uh) - for(size_t uw = 0; uw < Ds1; ++uw) { - int32_t* deltas_ptr = &h_b_deltas_[uw*Ds0 + uh*Ds0*Ds1 + ud*Ds0*Ds1*Ds2]; - for(size_t i = 0; i < Luts_; ++i) { - int32_t EBD = 1; - int32_t EBH = ((upsample_h_ - uh - 1) + BH_) / upsample_h_; - int32_t EBW = ((upsample_w_ - uw - 1) + BW_) / upsample_w_; - if(EBD == 0 || EBH == 0 || EBW == 0) - continue; - int32_t c, t, r, s; - int32_t nextc, nextt, nextr, nexts; - std::tie(c, t, r, s) = unpack(i, false, EBD, EBH, EBW); - std::tie(nextc, nextt, nextr, nexts) = unpack(i + TK_, false, EBD, EBH, EBW); - int32_t cdiff = nextc - c; - int32_t tdiff = (nextt - t)*upsample_d_; - int32_t rdiff = (nextr - r)*upsample_h_; - int32_t sdiff = (nexts - s)*upsample_w_; - deltas_ptr[i] = cdiff*ld_b_[b_inner_idx_] + tdiff*ld_b_[b_pix_idx_] + rdiff*ld_b_[b_pix_idx_ + 1] + sdiff*ld_b_[b_pix_idx_ + 2]; - } - } -} - -void conv::build_a_deltas(){ - h_a_deltas_.resize(Luts_ + upsample_d_*upsample_h_*upsample_w_*Luts_); - for(size_t i = 0; i < Luts_; ++i) - h_a_deltas_[i] = (((i + TK_) % Luts_) - i); - size_t Ds0 = Luts_; - size_t Ds1 = upsample_w_; - size_t Ds2 = upsample_h_; - size_t Ds3 = upsample_d_; - for(size_t ud = 0; ud < Ds3; ++ud) - for(size_t uh = 0; uh < Ds2; ++uh) - for(size_t uw = 0; uw < Ds1; ++uw) { - int32_t* deltas_ptr = &h_a_deltas_[Luts_ + uw*Ds0 + uh*Ds0*Ds1 + ud*Ds0*Ds1*Ds2]; - // cumulative increments - for(size_t i = 0; i < Ds0; ++i) { - int32_t EBD = 1; - int32_t EBH = ((upsample_h_ - uh - 1) + BH_) / upsample_h_; - int32_t EBW = ((upsample_w_ - uw - 1) + BW_) / upsample_w_; - if(EBD == 0 || EBH == 0 || EBW == 0) - continue; - // unpack - int32_t ctrs = i; - int32_t c, t, r, s; - std::tie(c, t, r, s) = unpack(ctrs, !b_trans_, EBD, EBH, EBW); - // next indices - int32_t nextctrs = ctrs + TK_; - int32_t nextc, nextt, nextr, nexts; - std::tie(nextc, nextt, nextr, nexts) = unpack(nextctrs, !b_trans_, EBD, EBH, EBW); - // diffs - int32_t cdiff = nextc - c; - int32_t tdiff = nextt - t; - int32_t rdiff = nextr - r; - int32_t sdiff = nexts - s; - if(ty_ == WGRAD){ - tdiff = tdiff * stride_d_; - rdiff = rdiff * stride_h_; - sdiff = sdiff * stride_w_; - } - // delta pointers - deltas_ptr[i] = cdiff*ld_a_[a_inner_idx_] + tdiff*ld_a_[a_pix_idx_] + rdiff*ld_a_[a_pix_idx_ + 1] + sdiff*ld_a_[a_pix_idx_ + 2]; - } - } -} - -void conv::build_masks(){ - h_masks_.resize(Luts_ + upsample_d_*upsample_h_*upsample_w_*(2*pad_h_+1)*(2*pad_w_+1)*(2*pad_d_+1)*Luts_); - - size_t Ms0 = Luts_; - size_t Ms1 = 2*pad_w_ + 1; - size_t Ms2 = 2*pad_h_ + 1; - size_t Ms3 = 2*pad_d_ + 1; - size_t Ms4 = upsample_w_; - size_t Ms5 = upsample_h_; - size_t Ms6 = upsample_d_; - for(size_t ud = 0; ud < Ms6; ++ud) - for(size_t uh = 0; uh < Ms5; ++uh) - for(size_t uw = 0; uw < Ms4; ++uw) - for(size_t pd = 0; pd < Ms3; ++pd) - for(size_t ph = 0; ph < Ms2; ++ph) - for(size_t pw = 0; pw < Ms1; ++pw){ - int32_t* masks_ptr = &h_masks_[Luts_ + pw*Ms0 + ph*Ms0*Ms1 + pd*Ms0*Ms1*Ms2 + uw*Ms0*Ms1*Ms2*Ms3 + uh*Ms0*Ms1*Ms2*Ms3*Ms4 + ud*Ms0*Ms1*Ms2*Ms3*Ms4*Ms5]; - for(size_t i = 0; i < Ms0; ++i){ - int32_t l, t, r, s; - int32_t mask = 0x0; - for(size_t j = 0; j < TK_; ++j){ - int32_t EBD = 1; - int32_t EBH = ((upsample_h_ - uh - 1) + BH_) / upsample_h_; - int32_t EBW = ((upsample_w_ - uw - 1) + BW_) / upsample_w_; - if(EBD == 0 || EBH == 0 || EBW == 0) - continue; - std::tie(l, t, r, s) = unpack(i + j, !b_trans_, EBD, EBH, EBW); - bool in_bounds_d = (t + pd) >= pad_d_ && (t + pd) < (EBD + pad_d_); - bool in_bounds_h = (r + ph) >= pad_h_ && (r + ph) < (EBH + pad_h_); - bool in_bounds_w = (s + pw) >= pad_w_ && (s + pw) < (EBW + pad_w_); - mask |= (in_bounds_d && in_bounds_h && in_bounds_w) << j; - } - masks_ptr[i] = mask; - } - } - for(size_t i = 0; i < Luts_; ++i) - h_masks_[i] = 0x0; -} - -std::array conv::get_grid(size_t TM, size_t TN){ - return {(M_ + TM - 1)/TM, (N_ + TN - 1)/TN, 1}; -} - -size_t conv::num_flops() const{ - return 2.*M_*N_*K_; -} - -void conv::init_impl(driver::stream *stream, triton::driver::cu_module* module, triton::runtime::launch_information info) { - auto init_lut = [&](bool is_cst, const char *name, std::vector host) -> triton::driver::buffer*{ - if(host.empty()) - return nullptr; - size_t nbytes = host.size()*4; - // get buffer - triton::driver::buffer* buffer; - if(is_cst) - buffer = module->symbol(name); - else - buffer = triton::driver::buffer::create(stream->context(), nbytes); - // copy - stream->write(buffer, false, 0, nbytes, host.data()); - return buffer; - }; - if(d_a_deltas_ == nullptr) - d_a_deltas_ = init_lut(is_a_deltas_cst, "delta", h_a_deltas_); - if(d_b_deltas_ == nullptr) - d_b_deltas_ = init_lut(is_b_deltas_cst_, "b_delta", h_b_deltas_); - if(d_masks_ == nullptr) - d_masks_ = init_lut(is_mask_cst_, "masks", h_masks_); - if(d_locks_ == nullptr){ - d_locks_ = triton::driver::buffer::create(stream->context(), max_grid_0_*max_grid_1_*4*2); - ((triton::driver::cu_buffer*)d_locks_)->set_zero(stream, max_grid_0_*max_grid_1_*4*2); - } -} - -void conv::set_arg(driver::kernel *kernel, - driver::buffer *a, driver::buffer *b, driver::buffer *c, driver::buffer *bias) -{ - kernel->setArg(0, a); - kernel->setArg(1, b); - kernel->setArg(2, c); - kernel->setArg(3, bias); - kernel->setArg(4, M_); - kernel->setArg(5, N_); - kernel->setArg(6, K_); - kernel->setArg(7, AH_); - kernel->setArg(8, AW_); - kernel->setArg(9, BH_); - kernel->setArg(10, BW_); - kernel->setArg(11, CH_); - kernel->setArg(12, CW_); - kernel->setArg(13, NC_); - // A arguments - kernel->setArg(14, ld_a_[a_outer_idx_]); - kernel->setArg(15, ld_a_[a_inner_idx_]); - kernel->setArg(16, ld_a_[2]); - kernel->setArg(17, ld_a_[3]); - kernel->setArg(18, ld_a_[4]); - // B arguments - kernel->setArg(19, ld_b_[b_inner_idx_]); - kernel->setArg(20, ld_b_[b_pix_idx_]); - kernel->setArg(21, ld_b_[b_pix_idx_+1]); - kernel->setArg(22, ld_b_[b_pix_idx_+2]); - kernel->setArg(23, ld_b_[b_outer_idx_]); - // C arguments - kernel->setArg(24, ld_c_[c_outer_0_idx_]); - kernel->setArg(25, ld_c_[c_outer_1_idx_]); - kernel->setArg(26, ld_c_[c_pix_idx]); - kernel->setArg(27, ld_c_[c_pix_idx+1]); - kernel->setArg(28, ld_c_[c_pix_idx+2]); - // pad - kernel->setArg(29, pad_h_); - kernel->setArg(30, pad_w_); - // stride - kernel->setArg(31, stride_h_); - kernel->setArg(32, stride_w_); - // dilate - kernel->setArg(33, upsample_h_); - kernel->setArg(34, upsample_w_); - kernel->setArg(35, (int32_t)0); - kernel->setArg(36, (int32_t)0); - kernel->setArg(37, pad_h_); - kernel->setArg(38, pad_w_); - kernel->setArg(39, (int32_t)0); - kernel->setArg(40, (int32_t)0); - kernel->setArg(41, d_locks_); - kernel->setArg(42, max_grid_0_); - kernel->setArg(43, max_grid_1_); - size_t idx = 44; - if(!is_a_deltas_cst) - kernel->setArg(idx++, d_a_deltas_); - if(!is_b_deltas_cst_) - kernel->setArg(idx++, d_b_deltas_); - if(!is_mask_cst_) - kernel->setArg(idx++, d_masks_); -} - -void conv::enqueue_impl(driver::stream *stream, driver::kernel *kernel, - std::vector args, - runtime::launch_information info) { - driver::buffer *a = args[0], *b = args[1], *c = args[2], *bias = args[3]; - unsigned TM = info.globals["TM"], TN = info.globals["TN"]; - unsigned GZ = 1; - set_arg(kernel, a, b, c, bias); - std::array grid = {1}; - grid[0] = (M_ + TM - 1)/TM; - grid[1] = (N_ + TN - 1)/TN; - grid[2] = GZ; - grid[0] /= upsample_h_*upsample_w_; - kernel->setArg(11, CH_/upsample_h_); - kernel->setArg(12, CW_/upsample_w_); - - // initialize to zero if necessary - bool init_zero = false; - for(int32_t off_uh = 0; off_uh < upsample_h_; off_uh++) - for(int32_t off_uw = 0; off_uw < upsample_w_; off_uw++) { - int32_t EBD = 1; - int32_t EBH = ((upsample_h_ - off_uh - 1) + BH_) / upsample_h_; - int32_t EBW = ((upsample_w_ - off_uw - 1) + BW_) / upsample_w_; - if(EBD == 0 || EBH == 0 || EBW == 0) - init_zero = true; - } - if(init_zero) - ((driver::cu_buffer*)c)->set_zero(stream, c_size()*4); - - for(int32_t off_uh = 0; off_uh < upsample_h_; off_uh++) - for(int32_t off_uw = 0; off_uw < upsample_w_; off_uw++) { - int32_t EBD = 1; - int32_t EBH = ((upsample_h_ - off_uh - 1) + BH_) / upsample_h_; - int32_t EBW = ((upsample_w_ - off_uw - 1) + BW_) / upsample_w_; - if(EBD == 0 || EBH == 0 || EBW == 0) - continue; - int32_t K = shapes_b_[b_inner_idx_]*EBD*EBH*EBW; - kernel->setArg(6, K); - kernel->setArg(9, EBH); - kernel->setArg(10, EBW); - kernel->setArg(29, pad_h_); - kernel->setArg(30, pad_w_); - kernel->setArg(35, off_uh); - kernel->setArg(36, off_uw); - kernel->setArg(37, (pad_h_ + (1 - upsample_h_)*off_uh)/upsample_h_); - kernel->setArg(38, (pad_w_ + (1 - upsample_w_)*off_uw)/upsample_w_); - kernel->setArg(39, (off_uh + pad_h_) % upsample_h_); - kernel->setArg(40, (off_uw + pad_w_) % upsample_w_); - stream->enqueue(kernel, grid, {info.num_threads, 1, 1}); - } -} - -std::vector conv::default_params() { - if(b_lut_){ - if(!b_trans_) - return {16, 2, 32, 16, 16, 8, 8, 2, 2, 4, 2, 8, 4, 2, 1}; - else - return {32, 2, 64, 32, 2, 64, 16, 8, 2, 2, 4, 2, 8, 1}; - } - else if(ty_ == FPROP) - return {16, 2, 64, 32, 2, 64, 16, 8, 2, 2, 8, 1, 8, 4, 1}; - else - return {16, 2, 64, 16, 16, 16, 4, 2, 2, 4, 2, 8, 4, 2, 1}; -} - - -/* CPU reference implementation */ - -template -void conv::cpu_xprop(OUT_DTYPE* C, IN_DTYPE* A, IN_DTYPE* B) -{ - IN_DTYPE acc; - for(int32_t n = 0; n < shapes_c_[0]; ++n) - for(int32_t cf = 0; cf < shapes_c_[1] ; ++cf) - for(int32_t cd = 0 ; cd < shapes_c_[2]; ++cd) - for(int32_t ch = 0 ; ch < shapes_c_[3]; ++ch) - for(int32_t cw = 0; cw < shapes_c_[4]; ++cw) - { - acc = 0; - int32_t d = cd*stride_d_ - pad_d_; - int32_t h = ch*stride_h_ - pad_h_; - int32_t w = cw*stride_w_ - pad_w_; - for(int32_t ac = 0; ac < shapes_a_[1]; ++ac) - for(int32_t bd = 0; bd < shapes_b_[1]; ++bd) - for(int32_t bh = 0; bh < shapes_b_[2]; ++bh) - for(int32_t bw = 0; bw < shapes_b_[3]; ++bw){ - int32_t ad = d + bd; - int32_t ah = h + bh; - int32_t aw = w + bw; - bool in_bounds = (ad >= 0 && ad < shapes_a_[2] && - ah >= 0 && ah < shapes_a_[3] && - aw >= 0 && aw < shapes_a_[4]); - IN_DTYPE a = 0; - if(in_bounds) - a = A[n*ld_a_[0] + ac*ld_a_[1] + ad*ld_a_[2] + ah*ld_a_[3] + aw*ld_a_[4]]; - IN_DTYPE b; - if(b_trans_) - b = B[ac*ld_b_[0] + bd*ld_b_[1] + bh*ld_b_[2] + bw*ld_b_[3] + cf*ld_b_[4]]; - else{ - int32_t bdd = shapes_b_[1] - 1 - bd; - int32_t bhh = shapes_b_[2] - 1 - bh; - int32_t bww = shapes_b_[3] - 1 - bw; - b = B[cf*ld_b_[0] + bdd*ld_b_[1] + bhh*ld_b_[2] + bww*ld_b_[3] + ac*ld_b_[4]]; - } - acc = std::fma(a, b, acc); - } - C[n*ld_c_[0] + cf*ld_c_[1] + cd*ld_c_[2] + ch*ld_c_[3] + cw*ld_c_[4]] = acc; - } -} - -template -void conv::cpu_wgrad(OUT_DTYPE* C, IN_DTYPE* A, IN_DTYPE* B) -{ - IN_DTYPE acc; - for(int32_t c = 0 ; c < shapes_c_[0]; ++c) - for(int32_t cd = 0; cd < shapes_c_[1]; ++cd) - for(int32_t ch = 0; ch < shapes_c_[2]; ++ch) - for(int32_t cw = 0; cw < shapes_c_[3]; ++cw) - for(int32_t k = 0 ; k < shapes_c_[4]; ++k) - { - acc = 0; - int32_t d = cd*stride_d_ - pad_d_; - int32_t h = ch*stride_h_ - pad_h_; - int32_t w = cw*stride_w_ - pad_w_; - for(int32_t n = 0; n < shapes_b_[0]; ++n) - for(int32_t bd = 0; bd < shapes_b_[2]; ++bd) - for(int32_t bh = 0; bh < shapes_b_[3]; ++bh) - for(int32_t bw = 0; bw < shapes_b_[4]; ++bw){ - int32_t ad = d + bd; - int32_t ah = h + bh; - int32_t aw = w + bw; - bool in_bounds = (ad >= 0 && ad < shapes_a_[2] && - ah >= 0 && ah < shapes_a_[3] && - aw >= 0 && aw < shapes_a_[4]); - IN_DTYPE a = 0; - if(in_bounds) - a = A[n*ld_a_[0] + c*ld_a_[1] + ad*ld_a_[2] + ah*ld_a_[3] + aw*ld_a_[4]]; - IN_DTYPE b = B[n*ld_b_[0] + k*ld_b_[1] + bd*ld_b_[2] + bh*ld_b_[3] + bw*ld_b_[4]]; - acc = std::fma(a, b, acc); - } - C[c*ld_c_[0] + cd*ld_c_[1] + ch*ld_c_[2] + cw*ld_c_[3] + k*ld_c_[4]] = acc; - } -} - -template -void conv::cpu_ref(OUT_DTYPE* C, IN_DTYPE* A, IN_DTYPE* B) -{ - if(ty_ == FPROP || ty_ == BPROP) - cpu_xprop(C, A, B); - else - cpu_wgrad(C, A, B); -} - -/* Triton-C source code */ - -void conv::triton_c_src(std::ostream &os) const { - std::string BS = b_trans_ ? "[TN,TK]" : "[TK, TN]"; - std::string bcb0 = b_trans_ ? "[:, newaxis]" : "[newaxis, :]"; - std::string bcb1 = b_trans_ ? "[newaxis, :]" : "[:, newaxis]"; - std::string ldb0 = b_trans_ ? "*ldb_s" : ""; - std::string useb = b_trans_ ? "trans(b)" : "b"; - std::string flipr = b_trans_ ? "" : "BH - 1 -"; - std::string flips = b_trans_ ? "" : "BW - 1 -"; - std::string upar = ty_ == WGRAD ? "stride_h * ": ""; - std::string upas = ty_ == WGRAD ? "stride_w * ": ""; - std::string upah = ty_ == WGRAD ? "": "*stride_h"; - std::string upaw = ty_ == WGRAD ? "": "*stride_w"; - std::vector crs = {"c", "r", "s"}; - std::vector rsc = {"r", "s", "c"}; - std::vector ax = b_trans_ ? crs : rsc; - std::vector redax; - if(b_trans_) - redax = {"NC", "BH", "BW"}; - else - redax = {"BH", "BW", "NC"}; - std::string inc_pb = b_lut_ ? "db" + bcb1 : "TK" + ldb0; - std::string inc_pdb = b_trans_ ? "incd" : "TK"; - std::string a_delta_mem = is_a_deltas_cst ? "__constant__" : ""; - std::string b_delta_mem = is_b_deltas_cst_? "__constant__" : ""; - std::string masks_mem = is_mask_cst_? "__constant__" : ""; - - os << - R"( -const tunable int TM = {16, 32, 64}; -const tunable int TN = {16, 32, 64}; -const tunable int TK = {)" << TK_ << R"(}; -const tunable int GZ = {1}; -)"; -if(is_a_deltas_cst) - os << "__constant__ int* delta = alloc_const int[" + std::to_string(h_a_deltas_.size()) + "];\n"; -if(b_lut_ && is_b_deltas_cst_) - os << "__constant__ int* b_delta = alloc_const int[" + std::to_string(h_b_deltas_.size()) + "];\n"; -if(is_mask_cst_) - os << "__constant__ int* masks = alloc_const int[" + std::to_string(h_masks_.size()) + "];\n"; -os << R"( - - void conv(read_only restrict )" << a_ty_ << R"( *a, - read_only restrict )" << b_ty_ << R"( *b, - float *c, - float *bias, - int M, int N, int K, - int AH, int AW, - int BH, int BW, - int CH, int CW, - int NC, - int lda_n, int lda_c, int lda_d, int lda_h, int lda_w, - int ldb_c, int ldb_t, int ldb_r, int ldb_s, int ldb_k, - int ldc_n, int ldc_k, int ldc_m, int ldc_p, int ldc_q, - int pad_h, int pad_w, - int stride_h, int stride_w, - int upsample_h, int upsample_w, - int off_uh, int off_uw, - int off_uah, int off_uaw, - int off_uch, int off_ucw, - int *locks, int grid0, int grid1)"; -if(!is_a_deltas_cst) - os << ", int* delta"; -if(b_lut_ && !is_b_deltas_cst_) - os << ", int* b_delta"; -if(!is_mask_cst_) - os << ", int* masks"; - os << R"(){ - int rxa[TM] = get_global_range[TM](0); - int rb0[TN] = get_global_range[TN](1); - int rz = get_global_range[1](2); - int rka[TK] = 0 ... TK; - int rkb[TK] = 0 ... TK; - float C[TM, TN] = 0; - int ldlut = )" + std::to_string(Luts_) + R"(; - int div = K / GZ; - int rem = K % GZ; - K = select(rz < rem, div, div + rem); - int offk = rz*div; - rka = rka + offk; - rkb = rkb + offk; - int rabh[TM] = rxa / CW; - int raw[TM] = rxa % CW; - int rab[TM] = rabh / CH; - int rah[TM] = rabh % CH; - rah = rah)" + upaw + R"( - off_uah; - raw = raw)" + upah + R"( - off_uaw; - int ra0[TM] = rab*lda_n + rah*lda_h + raw*lda_w; - int ra)" + ax[0] + ax[1] + "[TK] = rka / " + redax[2] + R"(; - int ra)" + ax[2] + "[TK] = rka % " + redax[2] + R"(; - int ra)" + ax[0] + "[TK] = ra" + ax[0] + ax[1] + " / " + redax[1] + R"(; - int ra)" + ax[1] + "[TK] = ra" + ax[0] + ax[1] + " % " + redax[1] + R"(; - rar = )" + flipr + R"( rar; - ras = )" + flips + R"( ras; - rar = )" + upar + R"( rar; - ras = )" + upas + R"( ras; - int ra1[TK] = rac*lda_c + rar*lda_h + ras*lda_w; - )" << a_ty_ << R"(* pa[TM, TK] = a + ra1[newaxis, :] + ra0[:, newaxis];)"; -if(b_lut_){ - os << R"( - int rb)" + ax[0] + ax[1] + "[TK] = rkb / " + redax[2] + R"(; - int rb)" + ax[2] + "[TK] = rkb % " + redax[2] + R"(; - int rb)" + ax[0] + "[TK] = rb" + ax[0] + ax[1] + " / " + redax[1] + R"(; - int rb)" + ax[1] + "[TK] = rb" + ax[0] + ax[1] + " % " + redax[1] + R"(; - rbr = rbr*upsample_h + off_uh; - rbs = rbs*upsample_w + off_uw; - int offdb[TK] = rkb % ldlut; - int rb1[TK] = rbc*ldb_c + rbr*ldb_r + rbs*ldb_s; - )" + b_delta_mem + R"( int* pdb[TK] = b_delta + offdb + off_uw*ldlut + off_uh*ldlut*upsample_w; - int db[TK] = *pdb;)"; -} -else{ -os << R"( - int rb1[TK] = rkb)" + ldb0 + ";"; -} -os << R"( - )" << b_ty_ << R"(* pb)" + BS + " = b + rb1" + bcb1 + " + rb0" + bcb0 + R"(*ldb_k; - int offda[TK] = rka % ldlut; - )" + a_delta_mem + R"( int* pincd[TK] = delta + offda; - )" + a_delta_mem + R"( int* pda[TK] = delta + ldlut + offda + off_uw*ldlut + off_uh*ldlut*upsample_w; - int da[TK] = *pda; - int incd[TK] = *pincd; - int maskh[TM] = pad_h + min(rah, 0) + max(rah + BH - AH, 0); - int maskw[TM] = pad_w + min(raw, 0) + max(raw + BW - AW, 0); - int offma = offk % ldlut; - )" + masks_mem + R"( int* pm[TM] = masks + ldlut + offma + maskw*ldlut + maskh*ldlut*(2*pad_w + 1) + off_uw*ldlut*(2*pad_w+1)*(2*pad_h+1) + off_uh*ldlut*(2*pad_w+1)*(2*pad_h+1)*upsample_w; - )" + a_delta_mem + R"( int* pincm[TM] = delta + offma; - int incm[TM] = *pincm; - int maska0[TM] = *pm; - int maska1[TK] = 1 << (0 ... TK); - bool checka[TM, TK] = (maska0[:, newaxis] & maska1[newaxis, :]) > 0; - bool checkb0[TN] = rb0 < N; - bool checkb)" + BS + " = checkb0" + bcb0 + R"(; - )" << a_ty_ << R"( a[TM, TK] = checka ? *pa : 0; - )" << b_ty_ << R"( b)" + BS + R"( = checkb ? *pb : 0; - int rkamin[TK] = rka - offk + TK; - for(int k = K; k > 0; k = k - TK){ - C = dot(a, )" + useb + R"(, C); - pa = pa + da[newaxis, :]; - pb = pb + )" + inc_pb + R"(; - pda = pda + incd;)"; -if(b_lut_){ - os << R"( - pdb = pdb + )" + inc_pdb + R"(; - db = *pdb;)"; -} - os << R"( - pincd = pincd + incd; - da = *pda; - incd = *pincd; - pm = pm + incm; - pincm = pincm + incm; - incm = *pincm; - bool checka1[TK] = (rkamin < k); - maska0 = *pm; - checka = (maska0[:, newaxis] & maska1[newaxis, :]) > 0; - checka = checka && checka1[newaxis,:]; - a = checka ? *pa : 0; - checkb = checkb && (k > TK); - @checkb b = *pb; - } - int rxc[TM] = get_global_range[TM](0); - int rc1[TN] = get_global_range[TN](1); - int rcn[TM] = rxc / (CH*CW); - int rcpq[TM] = rxc % (CH*CW); - int rcp[TM] = rcpq / CW; - int rcq[TM] = rcpq % CW; - rcp = rcp * upsample_h + off_uch; - rcq = rcq * upsample_w + off_ucw; - bool checkc1[TN] = rc1 < N; - int rc0[TM] = rcn * ldc_n + rcp * ldc_p + rcq * ldc_q; - float* pc[TM, TN] = c + rc1[newaxis, :]*ldc_k + rc0[:, newaxis]; - bool checkc0[TM] = rxc < M; - bool checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :]; - int ridx = get_program_id(0); - int ridy = get_program_id(1); - int *plock = locks + ridx + ridy*grid0; - while(__atomic_cas(plock, 0, 1) == 1); - int *pcount = plock + grid0*grid1; - int count = *pcount; - int countp1 = select(count == GZ - 1, 0, count + 1); - if(count == 0) {)"; - if(bias_ && ty_==FPROP){ - os << R"( - float* pbias[TN] = bias + rc1; - float bias[TN] = checkc1 ? *pbias : 0; - C = C + bias[newaxis, :];)"; - } - os << R"( - @checkc *pc = C; - *pcount = countp1; - } - else { - @checkc *pc = C + *pc; - *pcount = countp1; - } - *plock = 0; -})"; -} - -template void conv::cpu_ref(float*, float*, float*); -template void conv::cpu_xprop(float*, float*, float*); -template void conv::cpu_wgrad(float*, float*, float*); - -} -} diff --git a/lib/dnn/dot.cpp b/lib/dnn/dot.cpp deleted file mode 100644 index f3d35a2f0..000000000 --- a/lib/dnn/dot.cpp +++ /dev/null @@ -1,162 +0,0 @@ -#include "triton/driver/stream.h" -#include "triton/driver/kernel.h" -#include "triton/dnn/dot.h" -#include "triton/dnn/heuristics.h" -#include - -namespace triton{ -namespace dnn{ - -dot::dot(int M, int N, int K, - bool AT, bool BT, - std::string a_ty, std::string b_ty, std::string c_ty, - unsigned align_lda, unsigned align_ldb, unsigned align_ldc) - : base("matmul"), - M_(M), N_(N), K_(K), AT_(AT), BT_(BT), - a_ty_(a_ty), b_ty_(b_ty), c_ty_(c_ty), - align_lda_(align_lda), align_ldb_(align_ldb), align_ldc_(align_ldc), - locks_(nullptr) { - -} - -size_t dot::num_flops() const { - return 2.*M_*N_*K_; -} - -// retune parameters -std::vector dot::retune_params() const { - return {M_, N_, K_, AT_, BT_, - (int)align_lda_, (int)align_ldb_}; -} - -// clone -base* dot::clone() const { - return new dot(*this); -} - -void dot::init_impl(driver::stream* stream, driver::cu_module *, runtime::launch_information) { - std::vector hlocks(2048, 0); - if(locks_ == nullptr) - locks_ = triton::driver::buffer::create(stream->context(), hlocks.size()*4); - stream->write(locks_, false, 0, hlocks); -} - -void dot::enqueue_impl(driver::stream *stream, driver::kernel *kernel, - std::vector args, - runtime::launch_information info) { - driver::buffer *a = args[0], *b = args[1], *c = args[2]; - unsigned TM = info.globals.at("TM"); - unsigned TN = info.globals.at("TN"); - unsigned TK = info.globals.at("TK"); - unsigned grid_0 = (M_ + TM - 1)/TM; - unsigned grid_1 = (N_ + TN - 1)/TN; - unsigned grid_2 = 1; - int32_t lda = AT_ ? K_ : M_; - int32_t ldb = BT_ ? N_ : K_; - int32_t ldc = M_; - std::array grid = {grid_0, grid_1, grid_2}; - kernel->setArg(0, a); - kernel->setArg(1, b); - kernel->setArg(2, c); - kernel->setArg(3, M_); - kernel->setArg(4, N_); - kernel->setArg(5, K_); - kernel->setArg(6, lda); - kernel->setArg(7, ldb); - kernel->setArg(8, ldc); - kernel->setArg(9, TK); - kernel->setArg(10, locks_); - kernel->setArg(11, grid_0); - kernel->setArg(12, grid_1); - stream->enqueue(kernel, grid, {info.num_threads, 1, 1}); -} - -void dot::triton_c_src(std::ostream &os) const { - std::string ZS = "1"; - std::string AS0 = "TM", AS1 = "TK"; - std::string BS0 = "TK", BS1 = "TN"; - std::string XAS0 = "TM", XAS1 = "TK / " + ZS, XAS2 = ZS; - std::string XBS0 = "TK / " + ZS, XBS1 = ZS, XBS2 = "TN"; - std::string bca0 = "[newaxis, :]", bca1 = "[:, newaxis]"; - std::string bcb0 = "[:, newaxis]", bcb1 = "[newaxis, :]"; - std::string lda0 = "*lda", lda1 = ""; - std::string ldb0 = "", ldb1 = "*ldb"; - std::string usea = AT_ ? "trans(a)" : "a"; - std::string useb = BT_ ? "trans(b)" : "b"; - if(AT_){ - std::swap(AS0, AS1); - std::swap(XAS0, XAS1); - std::swap(XAS1, XAS2); - std::swap(bca0, bca1); - std::swap(lda0, lda1); - } - if(BT_){ - std::swap(BS0, BS1); - std::swap(XBS1, XBS2); - std::swap(XBS0, XBS1); - std::swap(bcb0, bcb1); - std::swap(ldb0, ldb1); - } - std::string AS = AS0 + ", " + AS1; - std::string BS = BS0 + ", " + BS1; -// std::string XAS = XAS0 + ", " + XAS1 + ", " + XAS2; -// std::string XBS = XBS0 + ", " + XBS1 + ", " + XBS2; - std::string XCS = "TM, TN"; - std::string align_lda_str = "multiple_of(" + std::to_string(align_lda_) + ")"; - std::string align_ldb_str = "multiple_of(" + std::to_string(align_ldb_) + ")"; - std::string res = -R"( -const tunable int TM = {128}; -const tunable int TN = {128}; -const tunable int TK = {32}; - -void matmul(restrict read_only align(16) )" + a_ty_ + R"( *A, - restrict read_only align(16) )" + b_ty_ + R"( *B, - restrict read_only align(16) )" + c_ty_ + R"( *C, - int M, int N, int K, - )" + align_lda_str + R"( int lda, )" + align_ldb_str + R"(" int ldb, int ldc, - int bound, int *locks, int grid0, int grid1) { - int ridx = get_program_id(0); - int ridy = get_program_id(1); - int rxa[TM] = ridx * TM + (0 ... TM); - int ryb[TN] = ridy * TN + (0 ... TN); - int rka[TK] = 0 ... TK; - int rkb[TK] = 0 ... TK; - float xc[)" + XCS + R"(] = 0; - )" + a_ty_ + R"(* pa[)" + AS + "] = A + rka" + bca0 + lda0 + " + rxa" + bca1 + lda1 + R"(; - )" + b_ty_ + R"(* pb[)" + BS + "] = B + rkb" + bcb0 + ldb0 + " + ryb" + bcb1 + ldb1 + R"(; - )" + a_ty_ + R"( a[)" + AS + R"(] = *pa; - )" + b_ty_ + R"( b[)" + BS + R"(] = *pb; - for(int k = K; k > 0; k = k - TK){ - xc = dot()" + usea + ", " + useb + R"(, xc); - pa = pa + TK)" + lda0 + R"(; - pb = pb + TK)" + ldb0 + R"(; - a = *pa; - b = *pb; - } - int rxc[TM] = ridx * TM + (0 ... TM); - int ryc[TN] = ridy * TN + (0 ... TN); - )" + c_ty_ + R"(* pc[TM, TN] = C + ryc[newaxis, :]*ldc + rxc[:, newaxis]; - )" + c_ty_ + R"( c[TM, TN] = xc; - bool checkc0[TM] = rxc < M; - bool checkc1[TN] = ryc < N; - bool checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :]; - @checkc *pc = c; -} -)"; - - os << res; -} - -// small search space for partial auto-tuning -std::vector dot::search_space() const { - return dot_search_space(AT_, BT_); -} - -// simple parameter heuristics -params_t dot::heuristics() const { - return dot_heuristics(AT_, BT_, M_, N_, K_); -} - -} -} diff --git a/lib/dnn/shift.cpp b/lib/dnn/shift.cpp deleted file mode 100644 index 93ae57cd4..000000000 --- a/lib/dnn/shift.cpp +++ /dev/null @@ -1,538 +0,0 @@ -#include -#include "triton/dnn/shift.h" -#include "triton/dnn/heuristics.h" -#include "triton/tools/bench.hpp" - -namespace triton{ -namespace dnn{ - - -shift::shift(int B, int C, - int D, int H, int W, - int T, int R, int S, - int F, - int stride_h, int stride_w, - const int32_t *shift_h, const int32_t *shift_w, - std::string a_ty, std::string b_ty, - op_t ty, bool bias, - layout_t layout) - : base("shift"), - B_(B), C_(C), - AD_(D), AH_(H), AW_(W), - BD_(T), BH_(R), BW_(S), - F_(F), - stride_d_(1), stride_h_(stride_h), stride_w_(stride_w), - shift_h_(shift_h), shift_w_(shift_w), - a_ty_(a_ty), b_ty_(b_ty), c_ty_(b_ty), - op_(ty), bias_(bias), - layout_(layout){ -// std::cout << B_ << " " << C_ << " " << F_ << " " << stride_h_ << " " << stride_w_ << " " << a_ty_ << " " << b_ty_ << " " << ty_ << " " << layout_ << std::endl; - // max number of channels - TK_ = (ty == FPROP && a_ty_ == "float") ? 8 : 32; - MAX_C_ = 8192 + TK_; - // activation sizes - CD_ = AD_ / stride_d_; - CH_ = AH_ / stride_h_; - CW_ = AW_ / stride_w_; - // A memory strides: [C, H, W, B] - switch(layout_){ - case CHWN: { - lda_n_ = 1; - lda_w_ = B_; - lda_h_ = B_*AW_; - lda_c_ = B_*AW_*AH_; - break; - } - case NCHW: { - lda_w_ = 1; - lda_h_ = AW_; - lda_c_ = AW_*AH_; - lda_n_ = AW_*AH_*C_; - break; - } - default: - throw std::runtime_error("unsupported input layout"); - } - // Shift edge - shift_edge_h_ = (AH_ == stride_h_ && stride_h_ > 1); - shift_edge_w_ = (AW_ == stride_w_ && stride_w_ > 1); - // B memory strides: [C, F] - ldb_n_ = 1; - ldb_h_ = 1; - ldb_w_ = 1; - ldb_c_ = F_; - // C memory strides: [F, H, W, B] - switch(layout_){ - case CHWN: { - ldc_n_ = 1; - ldc_w_ = B_; - ldc_h_ = B_*CW_; - ldc_f_ = B_*CW_*CH_; - break; - } - case NCHW: { - ldc_w_ = 1; - ldc_h_ = CW_; - ldc_f_ = CW_*CH_; - ldc_n_ = CW_*CH_*F_; - break; - } - default: - throw std::runtime_error("unsupported input layout"); - } - IAD_ = AD_ - 2*(BD_/2); - IAH_ = AH_ - 2*(BH_/2); - IAW_ = AW_ - 2*(BW_/2); - ICD_ = IAD_ / stride_d_; - ICH_ = IAH_ / stride_h_; - ICW_ = IAW_ / stride_w_; - - // Equivalent matmul - M_ = B_*ICH_*ICW_; - N_ = F_; - K_ = C_; - // transpose - AT_ = false; - BT_ = true; - // C shapes - if(layout_ == CHWN) - shapes_c_ = {F, CH_, CW_, B}; - if(layout_ == NCHW) - shapes_c_ = {B, F, CH_, CW_}; - // Weight gradient - if(op_ == WGRAD){ - // b <-> c - // b <-> a - std::swap(ldb_n_, ldc_n_); - std::swap(ldb_w_, ldc_w_); - std::swap(ldb_h_, ldc_h_); - std::swap(ldb_c_, ldc_f_); - std::swap(lda_n_, ldb_n_); - std::swap(lda_w_, ldb_w_); - std::swap(lda_h_, ldb_h_); - std::swap(lda_c_, ldb_c_); - std::swap(M_, K_); - std::swap(M_, N_); - AT_ = true; - BT_ = false; - shapes_c_ = {C, F}; - } - // Input gradient - if(op_ == BPROP){ - // a <-> c - std::swap(lda_n_, ldc_n_); - std::swap(lda_w_, ldc_w_); - std::swap(lda_h_, ldc_h_); - std::swap(lda_c_, ldc_f_); - std::swap(K_, N_); - AT_ = false; - BT_ = false; - if(layout_ == CHWN) - shapes_c_ = {C, AH_, AW_, B}; - if(layout_ == NCHW) - shapes_c_ = {B, C, AH_, AW_}; - } - // locks - max_locks_ = (op_ == WGRAD) ? 8192 : 0; - locks_ = nullptr; -} - -base* shift::clone() const { - return new shift(*this); -} - -void shift::build_delta_a() { - h_delta_a.resize(MAX_C_); - auto shift_h = [&](int c) { return shift_edge_h_ ? (c / AH_) % AH_ : shift_h_[c]; }; - auto shift_w = [&](int c) { return shift_edge_w_ ? c % AW_ : shift_w_[c]; }; - if(op_ == FPROP){ - // compute offset - auto offset = [&](unsigned c) { - return c*lda_c_ + shift_h(c)*lda_h_ + shift_w(c)*lda_w_; - }; - // populate look-up table - for(unsigned c = 0; c < TK_; c++) - h_delta_a[c] = offset(c); - for(unsigned c = 0; c < C_; c++) - h_delta_a[TK_ + c] = offset(c + TK_) - offset(c); - } - if(op_ == BPROP){ - for(unsigned c = 0; c < C_; c++){ - h_delta_a[c] = shift_h(c)*ldc_h_ + shift_w(c)*ldc_w_; - } - } - if(op_ == WGRAD){ - for(unsigned c = 0; c < C_; c++) - h_delta_a[c] = shift_h(c)*ldb_h_ + shift_w(c)*ldb_w_; - } -} - -size_t shift::c_size() { - return std::accumulate(shapes_c_.begin(), shapes_c_.end(), - 1, std::multiplies()); -} - -std::vector shift::c_shapes(){ - return shapes_c_; -} - -size_t shift::num_flops() const { - return 2.*M_*N_*K_; -} - -bool shift::AT() const -{ return AT_; } - -bool shift::BT() const -{ return BT_; } - -size_t shift::M() const -{ return M_; } - -size_t shift::N() const -{ return N_; } - -size_t shift::K() const -{ return K_; } - -size_t shift::lda() const -{ return AT_ ? K_ : M_; } - -size_t shift::ldb() const -{ return BT_ ? N_ : K_; } - -size_t shift::ldc() const -{ return M_; } - -std::vector shift::retune_params() const { - return {B_, C_, F_, - AD_, AH_, AW_, - BD_, BH_, BW_, - CD_, CH_, CW_, - (int64_t)shift_h_, (int64_t)shift_w_, - stride_h_, stride_w_, - layout_, op_, - bias_}; -} - -void shift::init_impl(driver::stream *stream, driver::cu_module *module, triton::runtime::launch_information info) { - build_delta_a(); - triton::driver::buffer* delta_a = ((triton::driver::cu_module*)module)->symbol("delta_a"); - stream->write(delta_a, false, 0, h_delta_a.size()*4, h_delta_a.data()); - // locks - if(locks_ == nullptr && max_locks_ > 0){ - std::vector hlocks(2*max_locks_, 0); - locks_ = triton::driver::buffer::create(stream->context(), 2*max_locks_*4); - stream->write(locks_, false, 0, hlocks); - } -} - -void shift::deinit_impl() { - if(locks_ != nullptr){ - delete locks_; - locks_ = nullptr; - } -} - -void shift::enqueue_impl(driver::stream *stream, driver::kernel *kernel, - std::vector args, - runtime::launch_information info) { - unsigned TM = info.globals.at("TM"), TN = info.globals.at("TN"); - unsigned grid_0 = (M_ + TM - 1)/TM; - unsigned grid_1 = (N_ + TN - 1)/TN; - unsigned num_locks = grid_0 * grid_1; - unsigned grid_2 = num_locks < max_locks_ ? info.globals.at("GZ") : 1; - std::array grid = {grid_0, grid_1, grid_2}; - driver::buffer *a = args[0], *b = args[1], *c = args[2]; -// std::cout << op_ << " " << M_ << " " << N_ << " " << K_ << std::endl; - kernel->setArg(0, a); - kernel->setArg(1, b); - kernel->setArg(2, c); - kernel->setArg(3, M_); - kernel->setArg(4, N_); - kernel->setArg(5, K_); - kernel->setArg(6, stride_h_); - kernel->setArg(7, stride_w_); - kernel->setArg(8, lda_n_); - kernel->setArg(9, lda_w_); - kernel->setArg(10, lda_h_); - kernel->setArg(11, lda_c_); - kernel->setArg(12, ldb_n_); - kernel->setArg(13, ldb_w_); - kernel->setArg(14, ldb_h_); - kernel->setArg(15, ldb_c_); - kernel->setArg(16, ldc_n_); - kernel->setArg(17, ldc_w_); - kernel->setArg(18, ldc_h_); - kernel->setArg(19, ldc_f_); - kernel->setArg(20, B_); - kernel->setArg(21, IAH_); - kernel->setArg(22, IAW_); - kernel->setArg(23, BH_); - kernel->setArg(24, BW_); - kernel->setArg(25, ICH_); - kernel->setArg(26, ICW_); - kernel->setArg(27, (num_locks > max_locks_) ? nullptr : locks_); - kernel->setArg(28, (int32_t)grid[0]); - kernel->setArg(29, (int32_t)grid[1]); - kernel->setArg(30, (int32_t)grid[2]); - if(locks_) - ((driver::cu_buffer*)locks_)->set_zero(stream, 2*max_locks_*4); - stream->enqueue(kernel, grid, {info.num_threads, 1, 1}); -} - -void shift::triton_c_src(std::ostream &os) const { - std::string AS0 = "TM", AS1 = "TK"; - std::string BS0 = "TK", BS1 = "TN"; - std::string bcb0 = "[:, newaxis]", bcb1 = "[newaxis, :]"; - std::string usea = AT_ ? "trans(a)" : "a"; - std::string useb = BT_ ? "trans(b)" : "b"; - std::string bca0 = "[newaxis, :]", bca1 = "[:, newaxis]"; - std::string stride_h = std::to_string(stride_h_); - std::string stride_w = std::to_string(stride_w_); - if(AT_){ - std::swap(AS0, AS1); - std::swap(bca0, bca1); - } - if(BT_){ - std::swap(BS0, BS1); - std::swap(bcb0, bcb1); - } - std::string AS = AS0 + ", " + AS1; - std::string BS = BS0 + ", " + BS1; - bool is_chwn = layout_ == CHWN; - - std::string lda_b = is_chwn ? "1" : "lda_b"; - std::string ldb_b = is_chwn ? "1" : "ldb_b"; - std::string ldc_b = is_chwn ? "1" : "ldc_b"; - - - auto compute_bhw = [&](std::string rx, std::string sz, std::string rkx){ - std::string B = std::to_string(B_); - std::string CW = std::to_string(ICW_); - std::string CH = std::to_string(ICH_); - - if(is_chwn) { - return R"( - int )" + rx + "wh[" + sz + "] = " + rkx + " / " + B + R"(; - int )" + rx + "b[" + sz + "] = " + rkx + " % " + B + R"(; - int )" + rx + "w[" + sz + "] = (" + rx + "wh % " + CW + R"() + pad_w; - int )" + rx + "h[" + sz + "] = (" + rx + "wh / " + CW + R"() + pad_h;)"; - } - else { - return R"( - int )" + rx + "bh[" + sz + "] = " + rkx + " / " + CW + R"(; - int )" + rx + "w[" + sz + "] = (" + rkx + " % " + CW + R"() + pad_w; - int )" + rx + "h[" + sz + "] = (" + rx + "bh % " + CH + R"() + pad_h; - int )" + rx + "b[" + sz + "] = " + rx + "bh / " + CH + ";"; - } - }; - - std::string result = -R"( -const tunable int TM = {16, 32, 64, 128}; -const tunable int TN = {16, 32, 64, 128}; -const tunable int TK = {)" + std::to_string(TK_) + "};"; -if(op_ == WGRAD) - result += "const tunable int GZ = {1};"; -else - result += "const tunable int GZ = {1};"; - -result += R"( -__constant__ int* delta_a = alloc_const int[)" + std::to_string(MAX_C_) + R"(]; - -void shift(restrict read_only align(16) )" + a_ty_ + R"( *A, - restrict read_only align(16) )" + b_ty_ + R"( *B, - )" + c_ty_ + R"( *C, - int M, int N, int K, - int stride_h, int stride_w, - multiple_of(8) int lda_b, multiple_of(8) int lda_w, multiple_of(8) int lda_h, multiple_of(8) int lda_c, - multiple_of(8) int ldb_b, multiple_of(8) int ldb_w, multiple_of(8) int ldb_h, multiple_of(8) int ldb_c, - multiple_of(8) int ldc_b, multiple_of(8) int ldc_w, multiple_of(8) int ldc_h, multiple_of(8) int ldc_c, - int NB, - int AH, int AW, - int BH, int BW, - int CH, int CW, - int* locks, int grid0, int grid1, int grid2) { - int ridx = get_program_id(0); - int ridy = get_program_id(1); - int rz = get_program_id(2); - int rxa[TM] = ridx*TM + (0 ... TM); - int ryb[TN] = ridy*TN + (0 ... TN); - int rka[TK] = 0 ... TK; - int rkb[TK] = 0 ... TK; - float acc[TM, TN] = 0; - int pad_h = BH / 2; - int pad_w = BW / 2;)"; - -/* A offsets */ -if(op_ == FPROP){ - result += - compute_bhw("ra", "TM", "rxa") + R"( - raw = raw * )" + stride_w + R"(; - rah = rah * )" + stride_h + R"(; - int offxa[TM] = rab*)" + lda_b + R"( + raw*lda_w + rah*lda_h; - int offa0[TM, TK] = offxa[:, newaxis]; - __constant__ int* pd[TK] = delta_a + rka; - multiple_of(8) int d[TK] = *pd; - int offa1[TM, TK] = d[newaxis, :];)"; -} -if(op_ == BPROP){ - result += - compute_bhw("ra", "TM", "rxa") + R"( - int offxa[TM] = rab*)" + lda_b + R"( + raw*lda_w + rah*lda_h; - int offa0[TM, TK] = offxa[:, newaxis]; - int offa1[TM, TK] = rka[newaxis, :] * lda_c;)"; -} -if(op_ == WGRAD){ - result += - compute_bhw("ra", "TK", "rka") + R"( - int offa0[TK, TM] = rxa[newaxis, :] * lda_c; - int offxa[TK] = rab*)" + lda_b + R"( + raw*lda_w + rah*lda_h; - int offa1[TK, TM] = offxa[:, newaxis];)"; -} - -/* B offsets */ -if(op_ == FPROP){ - result += R"( - int offb0[TN, TK] = ryb[:, newaxis]; - int offb1[TN, TK] = rkb[newaxis, :] * ldb_c;)"; -} -if(op_ == BPROP){ - result += R"( - int offb0[TK, TN] = ryb[newaxis, :] * ldb_c; - int offb1[TK, TN] = rkb[:, newaxis];)"; -} -if(op_ == WGRAD){ - result += - compute_bhw("rb", "TK", "rkb") + R"( - __constant__ int* pd[TN] = delta_a + ryb; - multiple_of(8) int d[TN] = *pd; - multiple_of(8) int shift[TK, TN] = d[newaxis, :]; - rbw = rbw * )" + stride_w + R"(; - rbh = rbh * )" + stride_h + R"(; - int offkb[TK] = rbb*)" + ldb_b + R"( + rbw*ldb_w + rbh*ldb_h; - int offb0[TK, TN] = ryb[newaxis, :] * ldb_c; - int offb1[TK, TN] = offkb[:, newaxis]; - )" + a_ty_ + "* pa_base[" + AS + R"(] = A + offa0; - )" + b_ty_ + "* pb_base[" + BS + R"(] = B + offb0 + shift; - )" + a_ty_ + "* pa[" + AS + R"(] = pa_base + offa1; - )" + b_ty_ + "* pb[" + BS + R"(] = pb_base + offb1;)"; -} -else{ - result += R"( - )" + a_ty_ + "* pa[" + AS + R"(] = A + offa0 + offa1; - )" + b_ty_ + "* pb[" + BS + R"(] = B + offb0 + offb1;)"; -} - -/* Main loop */ -/* Increment A pointers */ - result += R"( - bool checka[)" + AS + "] = (rka < K)" + bca0 + R"(; - bool checkb[)" + BS + "] = (rkb < K)" + bcb0 + R"(; - )" + a_ty_ + " a[" + AS + R"(] = checka ? *pa : 0; - )" + b_ty_ + " b[" + BS + R"(] = checkb ? *pb : 0; - for(int k = K; k > 0; k = k - TK){ - acc = dot()" + usea + "," + useb + R"(, acc); - bool checka[)" + AS + R"(] = k > TK; - bool checkb[)" + BS + R"(] = k > TK;)"; - -/* Increment A pointers */ -if(op_ == FPROP){ - result += R"( - pd = pd + TK; - d = *pd; - pa = pa + d[newaxis, :];)"; -} -if(op_ == BPROP){ - result += R"( - pa = pa + TK * lda_c;)"; -} -if(op_ == WGRAD){ - result += R"( - rka = rka + TK;)" - + compute_bhw("ra", "TK", "rka") + R"( - offxa = rab*)" + lda_b + R"( + raw*lda_w + rah*lda_h; - pa = pa_base + offxa[:, newaxis];)"; -} - result += R"( - a = checka ? *pa : 0;)"; - -/* Increment B pointers */ -if(op_ == WGRAD){ - result += R"( - rkb = rkb + TK;)" - + compute_bhw("rb", "TK", "rkb") + R"( - rbw = rbw * )" + stride_w + R"(; - rbh = rbh * )" + stride_h + R"(; - offkb = rbb*)" + ldb_b + R"( + rbw*ldb_w + rbh*ldb_h; - pb = pb_base + offkb[:, newaxis];)"; -} -if(op_ == FPROP){ - result += R"( - pb = pb + TK * ldb_c;)"; -} -if(op_ == BPROP){ - result += R"( - pb = pb + TK;)"; -} - result += R"( - b = checkb ? *pb : 0; - } - int rxc[TM] = ridx*TM + (0 ... TM); - int ryc[TN] = ridy*TN + (0 ... TN);)"; - -/* C offsets */ -if(op_ == BPROP){ - result += - compute_bhw("rc", "TM", "rxc") + R"( - rcw = rcw * )" + stride_w + R"(; - rch = rch * )" + stride_h + R"(; - int offxc[TM] = rcb*)" + ldc_b + R"( + rcw*ldc_w + rch*ldc_h;)"; - } -if(op_ == FPROP){ - result += - compute_bhw("rc", "TM", "rxc") + R"( - int offxc[TM] = rcb*)" + ldc_b + R"( + rcw*ldc_w + rch*ldc_h;)"; -} -if(op_ == WGRAD){ - result += R"( - int offxc[TM] = rxc;)"; -} - result += R"(" - )" + c_ty_ + R"( c[TM, TN] = acc; - )" + c_ty_ + R"(* pc[TM, TN] = C + offxc[:, newaxis] + ryc[newaxis, :]*ldc_c; - bool checkc0[TM] = rxc < M; - bool checkc1[TN] = ryc < N; - bool checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];)"; -if(op_ == BPROP){ - result += R"( - __constant__ int* pd[TN] = delta_a + ryc; - )" + c_ty_ + R"(* shift_pc[TM, TN] = pc + (*pd)[newaxis, :]; - @checkc *shift_pc = c; - )"; -} -else{ - result += R"( - @checkc *pc = c;)"; -} - result += R"( -})"; - - os << result; -} - - -// small search space for partial auto-tuning -std::vector shift::search_space() const { - return dot_search_space(AT_, BT_); -} - -// simple parameter heuristics -params_t shift::heuristics() const { - return dot_heuristics(AT_, BT_, M_, N_, K_); -} - - -} -} diff --git a/lib/runtime/jit.cpp b/lib/runtime/jit.cpp deleted file mode 100644 index ae9e1c783..000000000 --- a/lib/runtime/jit.cpp +++ /dev/null @@ -1,284 +0,0 @@ -#include -#include "triton/lang/lang.h" -#include "triton/codegen/selection/target.h" -#include "triton/ir/context.h" -#include "triton/ir/context_impl.h" -#include "triton/driver/device.h" -#include "triton/driver/error.h" -#include "triton/runtime/jit.h" -#include "llvm/IR/IRPrintingPasses.h" -#include "llvm/IR/Module.h" -#include "llvm/IR/LLVMContext.h" -#include "llvm/IR/PassManager.h" -#include "llvm/Support/raw_ostream.h" -#include "llvm/Support/TargetRegistry.h" -#include "llvm/Support/TargetSelect.h" -#include "llvm/Target/TargetMachine.h" -#include "llvm/Target/TargetOptions.h" -#include "llvm/CodeGen/TargetPassConfig.h" -#include "llvm/IR/LegacyPassManager.h" -#include "llvm/Transforms/Scalar/EarlyCSE.h" -#include "llvm/Analysis/LoopPass.h" -#include "triton/tools/thread_pool.h" -#include - -typedef struct yy_buffer_state * YY_BUFFER_STATE; -extern int yyparse(); -extern YY_BUFFER_STATE yy_scan_string(const char * str); -extern void yy_delete_buffer(YY_BUFFER_STATE buffer); -extern triton::lang::translation_unit *ast_root; - -namespace triton { -namespace runtime{ - -void parallel_loop_nest(std::vector const & ranges, - std::function const &)> const & f, - size_t nthreads){ - size_t D = ranges.size(); - std::vector values(D, 0); - // thread pools -// ThreadPool pool(nthreads); - // Start with innermost loop - size_t i = D - 1; - while(true){ - // Execute function -// pool.enqueue(f,values); - f(values); - while(values[i]++ == ranges[i] - 1){ - if(i == 0) - return; - values[i--] = 0; - } - i = D - 1; - // Short sleep so that the thread pool doesn't grow too big - std::this_thread::sleep_for(std::chrono::microseconds(1)); - } -} - -template -void parallel_loop_nest(std::vector> const & iterates, std::function)> const & f, size_t nthreads){ - //Ranges to iterate over - std::vector ranges; - for(auto const & x: iterates) - ranges.push_back(x.size()); - //Proxy function - auto proxy = [&](std::vector const & idx){ - std::vector x(iterates.size()); - for(size_t i = 0; i < x.size(); ++i) - x[i] = iterates[i][idx[i]]; - f(x); - }; - //Iterate - parallel_loop_nest(ranges, proxy, nthreads); -} - -void parallel_for_each(std::vector> const & iterates, std::function)> const & f, size_t nthreads) { - ThreadPool pool(nthreads); - for(const std::vector& values: iterates) - pool.enqueue(f, values); -} - - -std::unique_ptr jit::make_llvm_module(ir::module &module, passes_wrapper &passes, llvm::LLVMContext& llvm_context, launch_information& info) { - llvm::Module* result = new llvm::Module(module.get_name(), llvm_context); - passes.selection.run(module, *result); - // add globals - for(auto x: module.globals()) - info.globals[x.first] = ((ir::metaparameter*)x.second)->get_value(); - // number of threads - info.num_threads = passes.tune.get_num_threads(); - return std::unique_ptr(result); -} - -triton::lang::translation_unit *jit::parse_program(const char *name, const char *src) { - // create AST from Triton-C source - YY_BUFFER_STATE buffer = yy_scan_string(src); - yyparse(); - yy_delete_buffer(buffer); - triton::lang::translation_unit *program = ast_root; - return program; -} - -std::unique_ptr jit::make_triton_module(const char * name, triton::ir::context &context, triton::lang::translation_unit *program) { - // create Triton-IR from AST - ir::module* module = new ir::module(name, context); - program->codegen(module); - return std::unique_ptr(module); -} - - -jit::jit(driver::context *context, unsigned nthreads): driver_context_(context), - target_(context->device()->make_target()), - nthreads_(nthreads) { } - -jit::~jit(){ } - -std::vector jit::get_valid(const char *name, const char *src) { - // find metaparameters - triton::lang::translation_unit* program = parse_program(name, src); - auto ptt_module = make_triton_module(name, triton_context_, program); - ir::module &tt_module = *ptt_module; - // set parameters - passes_wrapper passes(target_.get()); - passes.target_independent(tt_module); - passes.tune.run(tt_module); - auto mps = passes.tune.get_params(tt_module); - // create parameter ranges - std::vector> ranges; - for(ir::metaparameter *mp: mps) - ranges.push_back(mp->get_space()); - // iterate over parameters - std::vector result; - parallel_loop_nest(ranges, [&](const std::vector params){ - if(!result.empty()) - return; - std::map> errors; - unsigned i = 0; - for(ir::metaparameter *mp: mps) - mp->set_value(params[i++]); - passes.tune.init(tt_module); - passes.tune.check_constraints(errors); - if(!errors.empty()) - return; - result = params; - }, 1); - if(result.empty()) - throw std::runtime_error("couldn't find valid parameters"); - return result; -} - - - -jit::tune_res_t jit::autotune(const char *name, const char *src, benchmark_t benchmark, const std::vector> & targets) { - // find metaparameters - triton::lang::translation_unit* program = parse_program(name, src); - auto ptt_module_0 = make_triton_module(name, triton_context_, program); - ir::module &tt_module_0 = *ptt_module_0; - // set parameters - passes_wrapper passes_0(target_.get()); - passes_0.target_independent(tt_module_0); - passes_0.tune.run(tt_module_0); - auto mps = passes_0.tune.get_params(tt_module_0); - // iterate over parameters - tune_res_t best; - // update_best - std::mutex mutex; - auto update_best = [&](const std::vector params){ - std::map> errors; - unsigned i = 0; - { - std::lock_guard lock(mutex); - for(ir::metaparameter *mp: mps) - mp->set_value(params[i++]); -// for(size_t i = 0; i < params.size(); i++) -// std::cout << ((i==0)?"":", ") << params[i] << std::flush; -// std::cout << std::endl; - passes_0.tune.init(tt_module_0); - passes_0.tune.check_constraints(errors); -// for(auto x: errors) -// for(auto e: x.second){ -// std::cout << x.first->get_name() << ": " << e << std::endl; -// } - } - if(!errors.empty()) - return; - // Deep copy of the module and tuner - triton::ir::context triton_context; - auto ptt_module_1 = make_triton_module(name, triton_context, program); - ir::module &tt_module_1 = *ptt_module_1; - // run passes - passes_wrapper passes_1(target_.get()); - passes_1.target_independent(tt_module_1); - passes_1.tune.run(tt_module_1); - i = 0; - for(ir::metaparameter* mp: passes_1.tune.get_params(tt_module_1)){ - mp->set_value(params[i++]); - } - passes_1.tune.init(tt_module_1); - passes_1.target_dependent(tt_module_1); - driver::device* device = driver_context_->device(); - if(passes_1.shmem_allocation.get_allocated_size() > device->max_shared_memory()) - return; - if(passes_1.tune.get_num_threads() > device->max_threads_per_block()) - return; - // Compile - launch_information info; - llvm::LLVMContext llvm_context; - auto ll_module = make_llvm_module(tt_module_1, passes_1, llvm_context, info); - std::unique_ptr module(driver::module::create(driver_context_, &*ll_module)); - double perf; - { - std::lock_guard lock(mutex); - std::unique_ptr kernel(driver::kernel::create(module.get(), name)); - perf = benchmark(kernel.get(), info); - if(perf > best.perf){ - best.perf = perf; - best.params = params; - } - for(size_t i = 0; i < params.size(); i++) - std::cout << ((i==0)?"":", ") << params[i] << std::flush; - std::cout << ", " << perf << " [ " << best.perf << " ] " << std::endl; - } - }; - - - if(targets.empty()) { - // create parameter ranges - std::vector> ranges; - for(ir::metaparameter *mp: mps) - ranges.push_back(mp->get_space()); - parallel_loop_nest(ranges, update_best, nthreads_); - } - else { - parallel_for_each(targets, update_best, nthreads_); - } - - if(best.params.empty()) - throw std::runtime_error("auto-tuning didn't find valid parameters"); -// std::cout << "Autotuning done - Best performance: " << best.perf << std::endl; - return best; -} - -void jit::add_module(ir::module &tt_module, const std::vector ¶ms) { - // set parameters - passes_wrapper passes(target_.get()); - passes.target_independent(tt_module); - passes.tune.run(tt_module); - unsigned i = 0; - for(ir::metaparameter* mp: passes.tune.get_params(tt_module)) - mp->set_value(params[i++]); - passes.tune.init(tt_module); - passes.target_dependent(tt_module); - // check constraints - std::map> errors; - passes.tune.check_constraints(errors); - for(auto x: errors){ - for(auto str: x.second) - std::cout << x.first->get_name() << ": " << str << std::endl; - } - if(errors.size()) - throw std::runtime_error("invalid parameters"); - // triton module -> llvm module - std::string name = tt_module.get_name(); - auto ll_module = make_llvm_module(tt_module, passes, llvm_context_, launch_info_map_[name]); - // llvm module -> machine code - modules_.insert({name, driver::module::create(driver_context_, &*ll_module)}); -} - -void jit::add_module(const char *name, const char *src, const std::vector ¶ms) { - triton::lang::translation_unit* program = parse_program(name, src); - auto ptt_module = make_triton_module(name, triton_context_, program); - add_module(*ptt_module, params); -} - -driver::kernel *jit::get_function(const char *name) { - return driver::kernel::create(modules_.at(name), name); -} - -launch_information jit::get_launch_info(const char *name) { - return launch_info_map_.at(name); -} - - -} -}