[general] removed dnn/ module and runtime/jit.cpp
This commit is contained in:
@@ -1,116 +0,0 @@
|
||||
/* Copyright 2015-2017 Philippe Tillet
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining
|
||||
* a copy of this software and associated documentation files
|
||||
* (the "Software"), to deal in the Software without restriction,
|
||||
* including without limitation the rights to use, copy, modify, merge,
|
||||
* publish, distribute, sublicense, and/or sell copies of the Software,
|
||||
* and to permit persons to whom the Software is furnished to do so,
|
||||
* subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be
|
||||
* included in all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*/
|
||||
|
||||
#ifndef TDL_INCLUDE_DNN_BASE_H
|
||||
#define TDL_INCLUDE_DNN_BASE_H
|
||||
|
||||
#include "triton/driver/stream.h"
|
||||
#include "triton/driver/kernel.h"
|
||||
#include "triton/runtime/launch_info.h"
|
||||
|
||||
namespace triton{
|
||||
|
||||
namespace runtime{
|
||||
class jit;
|
||||
}
|
||||
|
||||
namespace dnn{
|
||||
|
||||
|
||||
enum autotuning_t{
|
||||
FULL_TUNING,
|
||||
PARTIAL_TUNING,
|
||||
NO_TUNING
|
||||
};
|
||||
|
||||
class base;
|
||||
struct launch_context_t{
|
||||
base *op;
|
||||
driver::kernel* kernel;
|
||||
triton::runtime::launch_information info;
|
||||
};
|
||||
|
||||
typedef std::vector<unsigned> params_t;
|
||||
|
||||
class base {
|
||||
friend class recompile_hash;
|
||||
friend class recompile_equal;
|
||||
|
||||
protected:
|
||||
// leading dimensions
|
||||
static void set_ld(const std::vector<int32_t>& shapes,
|
||||
std::vector<int32_t>& ld);
|
||||
// list of retuning parameters
|
||||
virtual std::vector<int64_t> retune_params() const = 0;
|
||||
|
||||
private:
|
||||
// initialize
|
||||
virtual void init_impl(driver::stream *, driver::cu_module *, triton::runtime::launch_information) = 0;
|
||||
// deinitialize
|
||||
virtual void deinit_impl() = 0;
|
||||
// enqueue
|
||||
virtual void enqueue_impl(driver::stream *stream, driver::kernel *kernel,
|
||||
std::vector<driver::buffer*> args,
|
||||
triton::runtime::launch_information info) = 0;
|
||||
// number of flops
|
||||
virtual size_t num_flops() const = 0;
|
||||
// default parameters
|
||||
virtual std::vector<params_t> search_space() const;
|
||||
virtual params_t heuristics() const;
|
||||
// obtain execution jit
|
||||
std::pair<base*, triton::runtime::jit*> get_profile_impl(driver::stream *stream, std::vector<driver::buffer *> args, autotuning_t autotune);
|
||||
|
||||
public:
|
||||
// constructor
|
||||
base(const std::string& name);
|
||||
// triton-c source
|
||||
virtual void triton_c_src(std::ostream &os) const = 0;
|
||||
// clone
|
||||
virtual base* clone() const = 0;
|
||||
// enqueue
|
||||
base* enqueue(driver::stream* stream, std::vector<driver::buffer*> args, autotuning_t autotune = PARTIAL_TUNING);
|
||||
// get profile
|
||||
launch_context_t get_launch_context(driver::stream *stream, std::vector<driver::buffer *> args, autotuning_t autotune = PARTIAL_TUNING);
|
||||
|
||||
private:
|
||||
std::string name_;
|
||||
};
|
||||
|
||||
|
||||
struct recompile_equal{
|
||||
bool operator()(base* x, base* y) const{
|
||||
return typeid(*x) == typeid(*y) &&
|
||||
x->retune_params() == y->retune_params();
|
||||
}
|
||||
};
|
||||
|
||||
struct recompile_hash{
|
||||
unsigned operator()(base* x) const{
|
||||
return x->retune_params()[0];
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,110 +0,0 @@
|
||||
/* Copyright 2015-2019 Philippe Tillet
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining
|
||||
* a copy of this software and associated documentation files
|
||||
* (the "Software"), to deal in the Software without restriction,
|
||||
* including without limitation the rights to use, copy, modify, merge,
|
||||
* publish, distribute, sublicense, and/or sell copies of the Software,
|
||||
* and to permit persons to whom the Software is furnished to do so,
|
||||
* subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be
|
||||
* included in all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*/
|
||||
|
||||
#ifndef TDL_INCLUDE_DNN_BATCHNORM_H
|
||||
#define TDL_INCLUDE_DNN_BATCHNORM_H
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include <numeric>
|
||||
#include <cmath>
|
||||
#include "triton/dnn/base.h"
|
||||
#include "triton/driver/stream.h"
|
||||
#include "triton/driver/kernel.h"
|
||||
|
||||
namespace triton{
|
||||
namespace dnn{
|
||||
|
||||
class batchnorm_forward: public base {
|
||||
private:
|
||||
// init
|
||||
void init_impl(driver::stream *, driver::cu_module *, triton::runtime::launch_information) { }
|
||||
void deinit_impl() { }
|
||||
|
||||
// enqueue
|
||||
void enqueue_impl(driver::stream *stream, driver::kernel *kernel,
|
||||
std::vector<driver::buffer*> args,
|
||||
triton::runtime::launch_information info);
|
||||
// number of flops
|
||||
size_t num_flops() const;
|
||||
// retuning parameters
|
||||
std::vector<int64_t> retune_params() const;
|
||||
// clone
|
||||
base* clone() const;
|
||||
|
||||
public:
|
||||
// constructor
|
||||
batchnorm_forward(int C, int D, int H, int W, int B,
|
||||
std::string ty = "float", float eps = 1e-5);
|
||||
// triton-c source
|
||||
void triton_c_src(std::ostream &os) const;
|
||||
|
||||
private:
|
||||
int32_t C_;
|
||||
int32_t D_;
|
||||
int32_t H_;
|
||||
int32_t W_;
|
||||
int32_t B_;
|
||||
std::string ty_;
|
||||
float eps_;
|
||||
int32_t DHWB_;
|
||||
float rcpDHWB_;
|
||||
};
|
||||
|
||||
class batchnorm_backward: public base{
|
||||
private:
|
||||
// init
|
||||
void init_impl(driver::stream *, driver::cu_module *, triton::runtime::launch_information) { }
|
||||
void deinit_impl() { }
|
||||
// enqueue
|
||||
void enqueue_impl(driver::stream *stream, driver::kernel *kernel,
|
||||
std::vector<driver::buffer*> args,
|
||||
runtime::launch_information info);
|
||||
// number of flops
|
||||
size_t num_flops() const;
|
||||
// retuning parameters
|
||||
std::vector<int64_t> retune_params() const;
|
||||
// clone
|
||||
base* clone() const;
|
||||
|
||||
public:
|
||||
// constructor
|
||||
batchnorm_backward(int C, int D, int H, int W, int B,
|
||||
std::string ty = "float", float eps = 1e-5);
|
||||
// triton-c source
|
||||
void triton_c_src(std::ostream &os) const;
|
||||
|
||||
private:
|
||||
int32_t C_;
|
||||
int32_t D_;
|
||||
int32_t H_;
|
||||
int32_t W_;
|
||||
int32_t B_;
|
||||
std::string ty_;
|
||||
float eps_;
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,61 +0,0 @@
|
||||
#include "triton/driver/stream.h"
|
||||
#include "triton/driver/kernel.h"
|
||||
#include "triton/dnn/base.h"
|
||||
#include <string>
|
||||
|
||||
namespace triton{
|
||||
namespace dnn{
|
||||
namespace blocksparse{
|
||||
|
||||
enum op_t{
|
||||
FPROP,
|
||||
BPROP,
|
||||
WGRAD
|
||||
};
|
||||
|
||||
class dot: public base {
|
||||
private:
|
||||
void enqueue_impl(driver::stream *stream, driver::kernel *kernel,
|
||||
std::vector<driver::buffer*> args,
|
||||
triton::runtime::launch_information info);
|
||||
// number of flops
|
||||
size_t num_flops() const;
|
||||
// retuning parameters
|
||||
std::vector<int64_t> retune_params() const;
|
||||
// default parameters
|
||||
std::vector<params_t> search_space() const;
|
||||
params_t heuristics() const;
|
||||
// init
|
||||
void init_impl(driver::stream *stream, driver::cu_module *module, triton::runtime::launch_information info);
|
||||
// deinit
|
||||
void deinit_impl();
|
||||
// source
|
||||
std::string triton_c_src_ydx() const;
|
||||
std::string triton_c_src_dw() const;
|
||||
public:
|
||||
// constructor
|
||||
dot(int32_t N, int32_t K, int32_t S, int32_t C, const std::string &ty, int32_t BS, int32_t nlocks, int32_t nblocks, op_t op = FPROP);
|
||||
// triton-c source
|
||||
void triton_c_src(std::ostream &os) const;
|
||||
// locks
|
||||
driver::buffer* get_locks() const;
|
||||
// clone
|
||||
base* clone() const;
|
||||
|
||||
private:
|
||||
std::string ab_ty_;
|
||||
std::string c_ty_;
|
||||
int32_t N_;
|
||||
int32_t S_;
|
||||
int32_t C_;
|
||||
int32_t K_;
|
||||
int32_t BS_;
|
||||
int32_t nlocks_;
|
||||
int32_t nblocks_;
|
||||
std::shared_ptr<driver::buffer> locks_;
|
||||
op_t op_;
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
}
|
@@ -1,155 +0,0 @@
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include <numeric>
|
||||
#include "triton/driver/stream.h"
|
||||
#include "triton/driver/kernel.h"
|
||||
#include "triton/dnn/base.h"
|
||||
|
||||
namespace triton{
|
||||
namespace dnn{
|
||||
|
||||
class conv: public base{
|
||||
public:
|
||||
enum type {
|
||||
FPROP,
|
||||
BPROP,
|
||||
WGRAD
|
||||
};
|
||||
|
||||
private:
|
||||
// initialize
|
||||
std::tuple<int32_t, int32_t, int32_t, int32_t>
|
||||
unpack(int32_t ltrs, bool flip, int32_t EBD, int32_t EBH, int32_t EBW);
|
||||
void build_b_deltas();
|
||||
void build_a_deltas();
|
||||
void build_masks();
|
||||
void init_impl(driver::stream *stream, driver::cu_module *module, triton::runtime::launch_information info);
|
||||
void deinit_impl() { }
|
||||
|
||||
// enqueue
|
||||
std::array<size_t, 3> get_grid(size_t TM, size_t TN);
|
||||
void set_arg(driver::kernel *kernel,
|
||||
driver::buffer *a, driver::buffer *b, driver::buffer *c,
|
||||
driver::buffer *bias);
|
||||
void enqueue_impl(driver::stream *stream, driver::kernel *kernel,
|
||||
std::vector<driver::buffer*> args,
|
||||
triton::runtime::launch_information info);
|
||||
// number of flops
|
||||
size_t num_flops() const;
|
||||
// retuning parameters
|
||||
std::vector<int64_t> retune_params() const;
|
||||
// clone
|
||||
base* clone() const;
|
||||
|
||||
public:
|
||||
|
||||
conv(int B, int NC,
|
||||
int D, int H, int W,
|
||||
int T, int R, int S, int NF,
|
||||
int stride_d, int stride_h, int stride_w,
|
||||
int pad_d, int pad_h, int pad_w,
|
||||
int upsample_d, int upsample_h, int upsample_w,
|
||||
std::string a_ty = "float", std::string b_ty = "float",
|
||||
type ty = FPROP, bool bias = false);
|
||||
|
||||
// accessors
|
||||
size_t a_size();
|
||||
size_t b_size();
|
||||
size_t c_size();
|
||||
std::vector<int32_t> c_shapes();
|
||||
// default params
|
||||
std::vector<unsigned> default_params();
|
||||
|
||||
// triton-c source code
|
||||
void triton_c_src(std::ostream &os) const;
|
||||
|
||||
// cpu reference implementations
|
||||
template<class IN_DTYPE, class OUT_DTYPE>
|
||||
void cpu_xprop(OUT_DTYPE* C, IN_DTYPE* A, IN_DTYPE* B);
|
||||
template<class IN_DTYPE, class OUT_DTYPE>
|
||||
void cpu_wgrad(OUT_DTYPE* C, IN_DTYPE* A, IN_DTYPE* B);
|
||||
template<class IN_DTYPE, class OUT_DTYPE>
|
||||
void cpu_ref(OUT_DTYPE* C, IN_DTYPE* A, IN_DTYPE* B);
|
||||
|
||||
private:
|
||||
// image size
|
||||
int32_t NB_;
|
||||
int32_t NC_;
|
||||
int32_t AD_;
|
||||
int32_t AH_;
|
||||
int32_t AW_;
|
||||
// filter size
|
||||
int32_t BD_;
|
||||
int32_t BH_;
|
||||
int32_t BW_;
|
||||
int32_t NF_;
|
||||
// activation size
|
||||
int32_t CD_;
|
||||
int32_t CH_;
|
||||
int32_t CW_;
|
||||
// striding
|
||||
int32_t stride_d_;
|
||||
int32_t stride_h_;
|
||||
int32_t stride_w_;
|
||||
// padding
|
||||
int32_t pad_d_;
|
||||
int32_t pad_h_;
|
||||
int32_t pad_w_;
|
||||
// upsampling
|
||||
int32_t upsample_d_;
|
||||
int32_t upsample_h_;
|
||||
int32_t upsample_w_;
|
||||
// equivalent matmul
|
||||
int32_t M_;
|
||||
int32_t N_;
|
||||
int32_t K_;
|
||||
// helpers
|
||||
int32_t Fs_;
|
||||
int32_t TK_;
|
||||
int32_t Luts_;
|
||||
// memory strides for A
|
||||
std::vector<int32_t> shapes_a_;
|
||||
std::vector<int32_t> ld_a_;
|
||||
// memory strides for B
|
||||
std::vector<int32_t> shapes_b_;
|
||||
std::vector<int32_t> ld_b_;
|
||||
// memory stride for C
|
||||
std::vector<int32_t> shapes_c_;
|
||||
std::vector<int32_t> ld_c_;
|
||||
// constant memory
|
||||
std::vector<int32_t> h_a_deltas_;
|
||||
std::vector<int32_t> h_b_deltas_;
|
||||
std::vector<int32_t> h_masks_;
|
||||
driver::buffer* d_a_deltas_;
|
||||
driver::buffer* d_b_deltas_;
|
||||
driver::buffer* d_masks_;
|
||||
driver::buffer* d_locks_;
|
||||
bool is_a_deltas_cst;
|
||||
bool is_b_deltas_cst_;
|
||||
bool is_mask_cst_;
|
||||
// data type
|
||||
std::string a_ty_;
|
||||
std::string b_ty_;
|
||||
// conv type
|
||||
type ty_;
|
||||
bool bias_;
|
||||
bool b_trans_;
|
||||
bool b_lut_;
|
||||
// axis index
|
||||
int32_t a_inner_idx_;
|
||||
int32_t a_outer_idx_;
|
||||
int32_t a_pix_idx_;
|
||||
int32_t b_inner_idx_;
|
||||
int32_t b_outer_idx_;
|
||||
int32_t b_pix_idx_;
|
||||
int32_t c_outer_0_idx_;
|
||||
int32_t c_outer_1_idx_;
|
||||
int32_t c_pix_idx;
|
||||
// maximum grid size for loc
|
||||
int32_t max_grid_0_;
|
||||
int32_t max_grid_1_;
|
||||
};
|
||||
|
||||
}
|
||||
}
|
@@ -1,79 +0,0 @@
|
||||
#include "triton/driver/stream.h"
|
||||
#include "triton/driver/kernel.h"
|
||||
#include "triton/dnn/base.h"
|
||||
#include <string>
|
||||
|
||||
namespace triton{
|
||||
namespace dnn{
|
||||
|
||||
class dot: public base {
|
||||
private:
|
||||
// initialize
|
||||
void init_impl(driver::stream *, driver::cu_module *, triton::runtime::launch_information);
|
||||
void deinit_impl() { }
|
||||
|
||||
// enqueue
|
||||
void enqueue_impl(driver::stream *stream, driver::kernel *kernel,
|
||||
std::vector<driver::buffer*> args,
|
||||
triton::runtime::launch_information info);
|
||||
// retuning parameters
|
||||
std::vector<int64_t> retune_params() const;
|
||||
// default parameters
|
||||
virtual std::vector<params_t> search_space() const;
|
||||
virtual params_t heuristics() const;
|
||||
|
||||
public:
|
||||
dot(int M, int N, int K, bool AT, bool BT,
|
||||
std::string a_ty, std::string b_ty, std::string c_ty,
|
||||
unsigned align_lda, unsigned align_ldb, unsigned align_ldc);
|
||||
|
||||
// number of flops
|
||||
size_t num_flops() const;
|
||||
|
||||
// triton-c source
|
||||
void triton_c_src(std::ostream &os) const;
|
||||
|
||||
// clone
|
||||
base* clone() const;
|
||||
|
||||
// CPU reference implementation
|
||||
template<class T, bool AT, bool BT>
|
||||
static void cpu_ref(std::vector<T> &c, const std::vector<T> &a, const std::vector<T> &b,
|
||||
size_t M, size_t N, size_t K){
|
||||
for(size_t m = 0; m < M; m++)
|
||||
for(size_t n = 0; n < N; n++){
|
||||
float acc = 0;
|
||||
for(size_t k = 0; k < K; k++)
|
||||
acc = acc + (AT ? a[k + m*K] : a[m + k*M]) * (BT ? b[n + k*N] : b[k + n*K]);
|
||||
c[m + n*M] = static_cast<T>(acc);
|
||||
}
|
||||
}
|
||||
template<class T>
|
||||
void cpu_ref(std::vector<T> &c, const std::vector<T> &a, const std::vector<T> &b) {
|
||||
if(AT_ && BT_)
|
||||
dot::cpu_ref<T, true, true>(c, a, b, M_, N_, K_);
|
||||
else if(AT_ && !BT_)
|
||||
dot::cpu_ref<T, true, false>(c, a, b, M_, N_, K_);
|
||||
else if(!AT_ && BT_)
|
||||
dot::cpu_ref<T, false, true>(c, a, b, M_, N_, K_);
|
||||
else
|
||||
dot::cpu_ref<T, false, false>(c, a, b, M_, N_, K_);
|
||||
}
|
||||
|
||||
private:
|
||||
int32_t M_;
|
||||
int32_t N_;
|
||||
int32_t K_;
|
||||
bool AT_;
|
||||
bool BT_;
|
||||
std::string a_ty_;
|
||||
std::string b_ty_;
|
||||
std::string c_ty_;
|
||||
unsigned align_lda_;
|
||||
unsigned align_ldb_;
|
||||
unsigned align_ldc_;
|
||||
driver::buffer *locks_;
|
||||
};
|
||||
|
||||
}
|
||||
}
|
@@ -1,186 +0,0 @@
|
||||
#ifndef TRITON_DNN_HEURISTICS_H
|
||||
#define TRITON_DNN_HEURISTICS_H
|
||||
|
||||
#include <vector>
|
||||
#include "triton/dnn/base.h"
|
||||
|
||||
namespace triton{
|
||||
namespace dnn{
|
||||
|
||||
/* Dense matrix multiplication */
|
||||
|
||||
typedef std::vector<unsigned> params_t;
|
||||
typedef std::tuple<bool, bool> trans_key_t;
|
||||
typedef std::tuple<size_t, size_t> size_key_t;
|
||||
static const std::map<trans_key_t, std::map<size_key_t, params_t>> dot_params = {
|
||||
/* NN */
|
||||
{trans_key_t(false, false), std::map<size_key_t, params_t>{
|
||||
{size_key_t(16, 16), {2, 8, 16, 4, 16, 2, 2, 1, 1, 16, 32, 8, 4, 1}},
|
||||
{size_key_t(16, 32), {4, 4, 16, 4, 32, 2, 2, 1, 1, 8, 32, 8, 4, 1}},
|
||||
{size_key_t(16, 64), {4, 4, 16, 4, 64, 2, 2, 1, 1, 8, 32, 8, 4, 1}},
|
||||
{size_key_t(16, 128), {2, 8, 16, 8, 128, 2, 2, 1, 1, 16, 32, 4, 8, 1}},
|
||||
{size_key_t(32, 16), {8, 4, 32, 8, 16, 2, 2, 1, 1, 4, 32, 4, 8, 1}},
|
||||
{size_key_t(32, 32), {4, 8, 32, 4, 32, 2, 2, 1, 1, 8, 32, 8, 4, 1}},
|
||||
{size_key_t(32, 64), {8, 4, 32, 4, 64, 2, 2, 1, 1, 4, 32, 8, 4, 1}},
|
||||
{size_key_t(32, 128), {8, 4, 32, 32, 128, 2, 2, 2, 2, 16, 32, 4, 4, 1}},
|
||||
{size_key_t(32, 256), {4, 8, 32, 32, 256, 2, 2, 1, 4, 32, 32, 4, 8, 1}},
|
||||
{size_key_t(64, 16), {8, 8, 64, 8, 16, 2, 2, 1, 1, 4, 32, 4, 8, 1}},
|
||||
{size_key_t(64, 32), {8, 8, 64, 8, 32, 2, 2, 1, 1, 4, 32, 4, 8, 1}},
|
||||
{size_key_t(64, 64), {8, 8, 64, 16, 64, 2, 2, 1, 2, 8, 32, 4, 8, 1}},
|
||||
{size_key_t(64, 128), {16, 4, 64, 32, 128, 2, 2, 1, 4, 8, 32, 4, 8, 1}},
|
||||
{size_key_t(128, 16), {8, 8, 128, 16, 16, 2, 2, 2, 1, 8, 32, 4, 8, 1}},
|
||||
{size_key_t(128, 32), {32, 4, 128, 16, 32, 2, 2, 2, 1, 2, 32, 4, 8, 1}},
|
||||
{size_key_t(128, 64), {16, 8, 128, 16, 64, 2, 2, 2, 2, 8, 32, 8, 4, 1}},
|
||||
{size_key_t(128, 128), {8, 8, 128, 32, 128, 2, 2, 2, 2, 16, 32, 4, 8, 1}},
|
||||
{size_key_t(256, 16), {32, 8, 256, 16, 16, 2, 2, 4, 1, 4, 32, 8, 4, 1}},
|
||||
{size_key_t(256, 32), {32, 8, 256, 16, 32, 2, 2, 4, 1, 4, 32, 8, 4, 1}},
|
||||
{size_key_t(256, 64), {16, 8, 256, 32, 64, 2, 2, 4, 1, 8, 32, 4, 8, 1}}
|
||||
}},
|
||||
/* NT */
|
||||
{trans_key_t(false, true), std::map<size_key_t, params_t>{
|
||||
{size_key_t(16, 16), {2, 4, 16, 2, 8, 16, 2, 2, 1, 1, 16, 32, 16, 1}},
|
||||
{size_key_t(16, 32), {4, 4, 16, 8, 4, 32, 2, 2, 1, 1, 8, 32, 4, 1}},
|
||||
{size_key_t(16, 64), {2, 4, 16, 2, 8, 64, 2, 2, 1, 1, 16, 32, 16, 1}},
|
||||
{size_key_t(16, 128), {2, 8, 16, 8, 8, 128, 2, 2, 1, 1, 16, 32, 4, 1}},
|
||||
{size_key_t(32, 16), {8, 4, 32, 2, 8, 16, 2, 2, 1, 1, 4, 32, 16, 1}},
|
||||
{size_key_t(32, 32), {4, 8, 32, 8, 4, 32, 2, 2, 1, 1, 8, 32, 4, 1}},
|
||||
{size_key_t(32, 64), {16, 4, 64, 16, 4, 64, 2, 2, 4, 1, 8, 32, 8, 1}},
|
||||
{size_key_t(32, 128), {4, 8, 32, 16, 4, 128, 2, 2, 1, 2, 16, 32, 4, 1}},
|
||||
{size_key_t(32, 256), {4, 8, 32, 64, 4, 256, 2, 2, 1, 4, 32, 32, 2, 1}},
|
||||
{size_key_t(64, 16), {8, 8, 64, 2, 8, 16, 2, 2, 1, 1, 4, 32, 16, 1}},
|
||||
{size_key_t(64, 32), {16, 4, 64, 4, 4, 32, 2, 2, 1, 1, 2, 32, 8, 1}},
|
||||
{size_key_t(64, 64), {8, 8, 64, 8, 8, 64, 2, 2, 2, 1, 8, 32, 8, 1}},
|
||||
{size_key_t(64, 128), {4, 4, 64, 8, 8, 128, 2, 2, 1, 4, 32, 32, 16, 1}},
|
||||
{size_key_t(64, 256), {8, 8, 64, 8, 8, 256, 2, 2, 1, 4, 16, 32, 16, 1}},
|
||||
{size_key_t(128, 16), {16, 8, 128, 2, 8, 16, 2, 2, 1, 1, 2, 32, 16, 1}},
|
||||
{size_key_t(128, 32), {32, 4, 128, 4, 8, 32, 2, 2, 2, 1, 2, 32, 16, 1}},
|
||||
{size_key_t(128, 64), {8, 8, 128, 8, 8, 64, 2, 2, 4, 1, 16, 32, 16, 1}},
|
||||
{size_key_t(128, 128), {8, 8, 128, 16, 8, 128, 2, 2, 2, 2, 16, 32, 8, 1}},
|
||||
{size_key_t(256, 16), {32, 4, 256, 4, 4, 16, 2, 2, 4, 1, 4, 32, 32, 1}},
|
||||
{size_key_t(256, 32), {16, 8, 256, 8, 4, 32, 2, 2, 4, 1, 8, 32, 16, 1}},
|
||||
{size_key_t(256, 64), {8, 8, 256, 8, 8, 64, 2, 2, 4, 1, 16, 32, 16, 1}}
|
||||
}},
|
||||
/* TN */
|
||||
{trans_key_t(true, false), std::map<size_key_t, params_t>{
|
||||
{size_key_t(16, 16), {4, 16, 4, 16, 2, 2, 1, 1, 8, 4, 32, 8, 4, 1}},
|
||||
{size_key_t(16, 32), {8, 16, 8, 32, 2, 2, 1, 1, 4, 4, 32, 4, 4, 1}},
|
||||
{size_key_t(16, 64), {4, 16, 8, 64, 2, 2, 1, 1, 8, 4, 32, 4, 8, 1}},
|
||||
{size_key_t(16, 128), {4, 16, 8, 128, 2, 2, 1, 1, 8, 4, 32, 4, 8, 1}},
|
||||
{size_key_t(32, 16), {4, 32, 8, 16, 2, 2, 1, 1, 8, 4, 32, 4, 8, 1}},
|
||||
{size_key_t(32, 32), {4, 32, 4, 32, 2, 2, 1, 1, 8, 4, 32, 8, 4, 1}},
|
||||
{size_key_t(32, 64), {4, 32, 4, 64, 2, 2, 1, 1, 8, 4, 32, 8, 4, 1}},
|
||||
{size_key_t(32, 128), {8, 32, 8, 128, 2, 2, 1, 1, 4, 8, 32, 4, 8, 1}},
|
||||
{size_key_t(32, 256), {32, 32, 32, 256, 2, 2, 1, 4, 4, 8, 32, 4, 8, 1}},
|
||||
{size_key_t(64, 16), {4, 64, 8, 16, 2, 2, 1, 1, 8, 4, 32, 4, 8, 1}},
|
||||
{size_key_t(64, 32), {4, 64, 4, 32, 2, 2, 1, 1, 8, 4, 32, 8, 4, 1}},
|
||||
{size_key_t(64, 64), {8, 64, 16, 64, 2, 2, 2, 1, 8, 4, 32, 4, 8, 1}},
|
||||
{size_key_t(64, 128), {16, 64, 32, 128, 2, 2, 1, 4, 8, 4, 32, 4, 8, 1}},
|
||||
{size_key_t(128, 16), {8, 128, 8, 16, 2, 2, 1, 1, 4, 8, 32, 4, 8, 1}},
|
||||
{size_key_t(128, 32), {16, 128, 16, 32, 2, 2, 4, 1, 8, 4, 32, 8, 4, 1}},
|
||||
{size_key_t(128, 64), {32, 128, 32, 64, 2, 2, 2, 2, 4, 8, 32, 4, 8, 1}},
|
||||
{size_key_t(128, 128), {32, 128, 32, 128, 2, 2, 1, 4, 4, 8, 32, 4, 8, 1}},
|
||||
{size_key_t(256, 16), {16, 256, 16, 16, 2, 2, 2, 1, 4, 8, 32, 4, 8, 1}},
|
||||
{size_key_t(256, 32), {16, 256, 32, 32, 2, 2, 4, 1, 8, 4, 32, 4, 8, 1}},
|
||||
}},
|
||||
/* TT */
|
||||
{trans_key_t(true, true), std::map<size_key_t, params_t>{
|
||||
{size_key_t(16, 16), {8, 16, 4, 4, 16, 2, 2, 1, 1, 4, 8, 32, 8, 1}},
|
||||
{size_key_t(16, 32), {8, 16, 8, 4, 32, 2, 2, 1, 1, 4, 8, 32, 4, 1}},
|
||||
{size_key_t(16, 64), {16, 16, 4, 8, 64, 2, 2, 1, 4, 8, 4, 32, 32, 1}},
|
||||
{size_key_t(16, 128), {16, 16, 8, 8, 128, 2, 2, 1, 1, 2, 4, 32, 4, 1}},
|
||||
{size_key_t(32, 16), {4, 32, 4, 4, 16, 2, 2, 1, 1, 8, 4, 32, 8, 1}},
|
||||
{size_key_t(32, 32), {8, 32, 8, 4, 32, 2, 2, 1, 1, 4, 8, 32, 4, 1}},
|
||||
{size_key_t(32, 64), {64, 128, 8, 4, 64, 2, 2, 4, 1, 2, 8, 32, 16, 1}},
|
||||
{size_key_t(32, 128), {16, 32, 32, 4, 128, 2, 2, 1, 2, 4, 8, 32, 2, 1}},
|
||||
{size_key_t(32, 256), {32, 32, 32, 4, 256, 2, 2, 1, 4, 4, 8, 32, 4, 1}},
|
||||
{size_key_t(64, 16), {4, 64, 2, 8, 16, 2, 2, 1, 1, 8, 4, 32, 16, 1}},
|
||||
{size_key_t(64, 32), {4, 64, 8, 4, 32, 2, 2, 1, 1, 8, 4, 32, 4, 1}},
|
||||
{size_key_t(64, 64), {16, 64, 8, 8, 64, 2, 2, 2, 1, 4, 8, 32, 8, 1}},
|
||||
{size_key_t(64, 128), {32, 64, 8, 8, 128, 2, 2, 1, 4, 4, 4, 32, 16, 1}},
|
||||
{size_key_t(64, 256), {64, 64, 8, 8, 256, 2, 2, 1, 4, 2, 8, 32, 16}},
|
||||
{size_key_t(128, 16), {8, 128, 2, 8, 16, 2, 2, 1, 1, 4, 8, 32, 16, 1}},
|
||||
{size_key_t(128, 32), {16, 128, 8, 4, 32, 2, 2, 4, 1, 8, 4, 32, 16, 1}},
|
||||
{size_key_t(128, 64), {32, 128, 8, 8, 64, 2, 2, 4, 1, 4, 8, 32, 16, 1}},
|
||||
{size_key_t(128, 128), {32, 128, 16, 8, 128, 2, 2, 2, 2, 4, 8, 32, 8, 1}},
|
||||
{size_key_t(256, 16), {32, 256, 4, 4, 16, 2, 2, 4, 1, 4, 8, 32, 32, 1}},
|
||||
{size_key_t(256, 32), {32, 256, 8, 4, 32, 2, 2, 4, 1, 4, 8, 32, 16, 1}}
|
||||
}}
|
||||
};
|
||||
|
||||
// small search space for partial auto-tuning
|
||||
inline std::vector<params_t> dot_search_space(bool AT, bool BT) {
|
||||
std::vector<params_t> result;
|
||||
for(auto x: dot_params.at(trans_key_t{AT, BT}))
|
||||
result.push_back(x.second);
|
||||
return result;
|
||||
}
|
||||
|
||||
// simple parameter heuristics
|
||||
inline params_t dot_heuristics(bool AT, bool BT, size_t M, size_t N, size_t K) {
|
||||
size_t TM = 128;
|
||||
size_t TN = 128;
|
||||
// return {4, 4, 128, 8, 4, 128, 2, 2, 2, 2, 32, 32, 16, 1};
|
||||
return dot_params.at(trans_key_t{AT, BT}).at(size_key_t{TM, TN});
|
||||
}
|
||||
|
||||
|
||||
/* Block-sparse matrix multiplication */
|
||||
|
||||
static const std::map<std::pair<bool, size_t>, std::map<size_t, params_t>> bsdot_params = {
|
||||
/* FPROP */
|
||||
{{true, 32}, std::map<size_t, params_t>{
|
||||
{32, {2, 2, 32, 32, 2, 2, 4, 8, 32, 32, 8, 4, 16}},
|
||||
{64, {2, 2, 64, 32, 2, 1, 16, 4, 4, 32, 16, 2, 4}},
|
||||
{128, {2, 2, 128, 32, 4, 1, 32, 4, 4, 32, 8, 4, 16}}
|
||||
}},
|
||||
|
||||
{{true, 16}, std::map<size_t, params_t>{
|
||||
{32, {4, 1, 32, 16, 1, 1, 8, 4, 4, 16, 4, 4, 8}},
|
||||
{64, {4, 1, 64, 16, 2, 2, 8, 8, 16, 16, 8, 2, 16}},
|
||||
{128, {4, 1, 128, 16, 4, 1, 16, 8, 8, 16, 8, 2, 16}}
|
||||
}},
|
||||
|
||||
{{true, 8}, std::map<size_t, params_t>{
|
||||
{32, {4, 1, 32, 8, 1, 1, 4, 8, 8, 8, 4, 2, 8}},
|
||||
{64, {4, 1, 64, 8, 1, 1, 8, 8, 4, 8, 4, 2, 8}},
|
||||
{128, {4, 1, 128, 8, 1, 1, 4, 8, 8, 8, 4, 2, 8}}
|
||||
}},
|
||||
|
||||
/* BPROP */
|
||||
{{false, 32}, std::map<size_t, params_t>{
|
||||
{32, {2, 2, 32, 32, 1, 1, 8, 4, 4, 32, 8, 4, 8}},
|
||||
{64, {2, 2, 64, 32, 2, 1, 16, 4, 4, 32, 16, 4, 8}},
|
||||
{128, {2, 2, 128, 32, 4, 1, 32, 4, 4, 32, 32, 4, 8}}
|
||||
}},
|
||||
|
||||
{{false, 16}, std::map<size_t, params_t>{
|
||||
{32, {4, 1, 32, 16, 1, 2, 4, 8, 16, 16, 16, 4, 4}},
|
||||
{64, {4, 1, 64, 16, 2, 1, 8, 8, 8, 16, 16, 4, 4}},
|
||||
{128, {4, 1, 128, 16, 2, 2, 32, 4, 4, 16, 16, 8, 2}}
|
||||
}},
|
||||
|
||||
{{false, 8}, std::map<size_t, params_t>{
|
||||
{32, {4, 1, 32, 8, 1, 1, 4, 8, 8, 8, 8, 4, 2}},
|
||||
{64, {4, 1, 64, 8, 1, 1, 8, 8, 4, 8, 8, 4, 2}},
|
||||
{128, {4, 1, 128, 8, 1, 1, 8, 8, 4, 8, 8, 4, 2}}
|
||||
}}
|
||||
};
|
||||
|
||||
// small search space for partial auto-tuning
|
||||
inline std::vector<params_t> bsdot_search_space(bool is_fprop, size_t block_size) {
|
||||
std::vector<params_t> result;
|
||||
for(auto x: bsdot_params.at({is_fprop, block_size}))
|
||||
result.push_back(x.second);
|
||||
return result;
|
||||
}
|
||||
|
||||
// simple parameter heuristics
|
||||
inline params_t bsdot_heuristics(bool is_fprop, size_t block_size, size_t N, size_t S) {
|
||||
return bsdot_params.at({is_fprop,block_size}).at(128);
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,192 +0,0 @@
|
||||
/* Copyright 2015-2017 Philippe Tillet
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining
|
||||
* a copy of this software and associated documentation files
|
||||
* (the "Software"), to deal in the Software without restriction,
|
||||
* including without limitation the rights to use, copy, modify, merge,
|
||||
* publish, distribute, sublicense, and/or sell copies of the Software,
|
||||
* and to permit persons to whom the Software is furnished to do so,
|
||||
* subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be
|
||||
* included in all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*/
|
||||
|
||||
#ifndef TDL_INCLUDE_DNN_SHIFT_H
|
||||
#define TDL_INCLUDE_DNN_SHIFT_H
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include <numeric>
|
||||
#include <cmath>
|
||||
#include "triton/dnn/base.h"
|
||||
#include "triton/driver/stream.h"
|
||||
#include "triton/driver/kernel.h"
|
||||
|
||||
namespace triton{
|
||||
namespace dnn{
|
||||
|
||||
enum op_t {
|
||||
FPROP,
|
||||
BPROP,
|
||||
WGRAD
|
||||
};
|
||||
|
||||
enum layout_t {
|
||||
NCHW,
|
||||
CHWN
|
||||
};
|
||||
|
||||
class shift: public base {
|
||||
private:
|
||||
// initialize and enqueue
|
||||
void init_impl(driver::stream *stream, driver::cu_module *module, triton::runtime::launch_information info);
|
||||
void deinit_impl();
|
||||
void enqueue_impl(driver::stream *stream, driver::kernel *kernel,
|
||||
std::vector<driver::buffer*> args,
|
||||
triton::runtime::launch_information info);
|
||||
std::vector<params_t> search_space() const;
|
||||
params_t heuristics() const;
|
||||
|
||||
public:
|
||||
|
||||
shift(int B, int NC,
|
||||
int D, int H, int W,
|
||||
int T, int R, int S, int NF,
|
||||
int stride_h, int stride_w,
|
||||
const int32_t* shift_h, const int32_t* shift_w,
|
||||
std::string a_ty = "float", std::string b_ty = "float",
|
||||
op_t ty = FPROP, bool bias = false, layout_t layout = CHWN);
|
||||
|
||||
// look-up table
|
||||
void build_delta_a();
|
||||
void build_masks();
|
||||
// accessors
|
||||
size_t c_size();
|
||||
std::vector<int32_t> c_shapes();
|
||||
// equivalent GEMM
|
||||
bool AT() const;
|
||||
bool BT() const;
|
||||
size_t M() const;
|
||||
size_t N() const;
|
||||
size_t K() const;
|
||||
size_t lda() const;
|
||||
size_t ldb() const;
|
||||
size_t ldc() const;
|
||||
// number of flops
|
||||
size_t num_flops() const;
|
||||
// source
|
||||
void triton_c_src(std::ostream &os) const;
|
||||
// retuning parameters
|
||||
std::vector<int64_t> retune_params() const;
|
||||
// clone
|
||||
base* clone() const;
|
||||
// cpu reference
|
||||
template<class IN_DTYPE, class OUT_DTYPE>
|
||||
void cpu_ref(OUT_DTYPE* O,
|
||||
const IN_DTYPE* I,
|
||||
const IN_DTYPE* F)
|
||||
{
|
||||
OUT_DTYPE acc;
|
||||
for(int32_t p = 0; p < AH_; ++p)
|
||||
for(int32_t q = 0; q < AW_; ++q)
|
||||
for(int32_t bs = 0; bs < B_; ++bs)
|
||||
for(int32_t k = 0; k < F_; ++k)
|
||||
{
|
||||
acc = 0;
|
||||
for(int32_t c = 0; c < C_; ++c){
|
||||
int32_t h = p;
|
||||
int32_t w = q;
|
||||
if(h >= BH_/2 && h < AH_ - BH_/2
|
||||
&& w >= BW_/2 && w < AW_ - BW_/2){
|
||||
h += shift_h_[c];
|
||||
w += shift_w_[c];
|
||||
}
|
||||
IN_DTYPE a = I[bs + w*B_ + h*B_*AW_ + c*B_*AH_*AW_];
|
||||
IN_DTYPE b = F[k + c*F_];
|
||||
acc = std::fma(a, b, acc);
|
||||
}
|
||||
O[bs + q*B_ + p*B_*AW_ + k*B_*AH_*AW_] = acc;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
int32_t MAX_C_;
|
||||
int32_t TK_;
|
||||
// image size
|
||||
int32_t B_;
|
||||
int32_t C_;
|
||||
int32_t AD_;
|
||||
int32_t AH_;
|
||||
int32_t AW_;
|
||||
// filter size
|
||||
int32_t BD_;
|
||||
int32_t BH_;
|
||||
int32_t BW_;
|
||||
int32_t F_;
|
||||
// activation size
|
||||
int32_t CD_;
|
||||
int32_t CH_;
|
||||
int32_t CW_;
|
||||
// interior image size
|
||||
int32_t IAD_;
|
||||
int32_t IAH_;
|
||||
int32_t IAW_;
|
||||
// interior activation size
|
||||
int32_t ICD_;
|
||||
int32_t ICH_;
|
||||
int32_t ICW_;
|
||||
// equivalent matmul
|
||||
int32_t M_;
|
||||
int32_t N_;
|
||||
int32_t K_;
|
||||
// shapes
|
||||
std::vector<int32_t> shapes_c_;
|
||||
// strides
|
||||
int32_t stride_d_;
|
||||
int32_t stride_h_;
|
||||
int32_t stride_w_;
|
||||
// memory strides
|
||||
int32_t lda_n_, lda_c_, lda_h_, lda_w_;
|
||||
int32_t ldb_n_, ldb_c_, ldb_h_, ldb_w_;
|
||||
int32_t ldc_n_, ldc_f_, ldc_h_, ldc_w_;
|
||||
// shift values
|
||||
const int32_t* shift_h_;
|
||||
const int32_t* shift_w_;
|
||||
bool shift_edge_h_;
|
||||
bool shift_edge_w_;
|
||||
// look-up tables
|
||||
std::vector<int32_t> h_delta_a;
|
||||
std::vector<int32_t> h_delta_b;
|
||||
driver::buffer* d_delta_a;
|
||||
driver::buffer* d_delta_b;
|
||||
// data types
|
||||
std::string a_ty_;
|
||||
std::string b_ty_;
|
||||
std::string c_ty_;
|
||||
// convolution type
|
||||
op_t op_;
|
||||
bool bias_;
|
||||
// transpose
|
||||
bool AT_;
|
||||
bool BT_;
|
||||
// layout
|
||||
layout_t layout_;
|
||||
// locks
|
||||
size_t max_locks_;
|
||||
driver::buffer *locks_;
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,136 +0,0 @@
|
||||
#ifndef TDL_INCLUDE_JIT_H
|
||||
#define TDL_INCLUDE_JIT_H
|
||||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "llvm/IR/LLVMContext.h"
|
||||
#include "triton/ir/context.h"
|
||||
#include "triton/ir/print.h"
|
||||
#include "triton/driver/module.h"
|
||||
#include "triton/driver/kernel.h"
|
||||
#include "triton/codegen/selection/selection.h"
|
||||
#include "triton/codegen/selection/target.h"
|
||||
#include "triton/codegen/analysis/tune.h"
|
||||
#include "triton/codegen/analysis/shmem/allocation.h"
|
||||
#include "triton/codegen/analysis/shmem/liveness.h"
|
||||
#include "triton/codegen/analysis/shmem/info.h"
|
||||
#include "triton/codegen/analysis/alignment.h"
|
||||
#include "triton/codegen/transform/dce.h"
|
||||
#include "triton/codegen/transform/peephole.h"
|
||||
#include "triton/codegen/transform/shmem/barriers.h"
|
||||
#include "triton/codegen/transform/reassociate.h"
|
||||
#include "triton/codegen/transform/vectorize.h"
|
||||
#include "triton/runtime/launch_info.h"
|
||||
#include <functional>
|
||||
|
||||
namespace llvm {
|
||||
class Module;
|
||||
|
||||
}
|
||||
|
||||
namespace triton {
|
||||
|
||||
namespace lang{
|
||||
class translation_unit;
|
||||
}
|
||||
|
||||
namespace codegen{
|
||||
namespace analysis{
|
||||
class tune;
|
||||
}
|
||||
}
|
||||
|
||||
namespace ir {
|
||||
class module;
|
||||
class context;
|
||||
class metaparameter;
|
||||
}
|
||||
|
||||
namespace runtime{
|
||||
|
||||
class jit {
|
||||
public:
|
||||
typedef std::function<double(driver::kernel*, launch_information)> benchmark_t;
|
||||
|
||||
struct tune_res_t{
|
||||
double perf;
|
||||
std::vector<unsigned> params;
|
||||
};
|
||||
|
||||
struct passes_wrapper {
|
||||
passes_wrapper(codegen::target* target)
|
||||
: tune(0),
|
||||
shmem_liveness(&shmem_info),
|
||||
shmem_allocation(&shmem_liveness, &shmem_info, &tune),
|
||||
shmem_barriers(&shmem_allocation, &shmem_info),
|
||||
vectorize(&tune),
|
||||
selection(&shmem_allocation, &tune, &shmem_info, &alignment_info, target),
|
||||
dce(),
|
||||
peephole(),
|
||||
alignment_info(),
|
||||
reassociate(&tune),
|
||||
target_(target) { }
|
||||
|
||||
void target_independent(ir::module &module) {
|
||||
peephole.run(module);
|
||||
dce.run(module);
|
||||
}
|
||||
|
||||
void target_dependent(ir::module &module) {
|
||||
reassociate.run(module);
|
||||
peephole.run(module);
|
||||
if(target_->is_gpu()){
|
||||
shmem_info.run(module);
|
||||
shmem_liveness.run(module);
|
||||
shmem_allocation.run();
|
||||
shmem_barriers.run(module);
|
||||
}
|
||||
alignment_info.run(module);
|
||||
vectorize.run(module);
|
||||
dce.run(module);
|
||||
}
|
||||
|
||||
codegen::selection selection;
|
||||
codegen::analysis::tune tune;
|
||||
codegen::analysis::shmem::info shmem_info;
|
||||
codegen::analysis::shmem::liveness shmem_liveness;
|
||||
codegen::analysis::shmem::allocation shmem_allocation;
|
||||
codegen::analysis::alignment_info alignment_info;
|
||||
codegen::transform::shmem_barriers shmem_barriers;
|
||||
codegen::transform::vectorize vectorize;
|
||||
codegen::transform::dce dce;
|
||||
codegen::transform::peephole peephole;
|
||||
codegen::transform::reassociate reassociate;
|
||||
codegen::target* target_;
|
||||
};
|
||||
|
||||
private:
|
||||
std::string compute_data_layout(bool is_64bit = true, bool use_short_pointers = true);
|
||||
std::unique_ptr<llvm::Module> make_llvm_module(triton::ir::module &module, passes_wrapper &passes, llvm::LLVMContext &context, launch_information &info);
|
||||
std::unique_ptr<ir::module> make_triton_module(const char *name, triton::ir::context &context, triton::lang::translation_unit *program);
|
||||
triton::lang::translation_unit *parse_program(const char *name, const char *src);
|
||||
|
||||
public:
|
||||
jit(driver::context* context, unsigned nthreads = 4);
|
||||
~jit();
|
||||
std::vector<unsigned> get_valid(const char *name, const char *src);
|
||||
tune_res_t autotune(const char* name, const char* src, benchmark_t benchmark, const std::vector<std::vector<unsigned> > &targets = {});
|
||||
void add_module(ir::module &module, const std::vector<unsigned>& params = {});
|
||||
void add_module(const char* name, const char* src, const std::vector<unsigned>& params = {});
|
||||
driver::kernel* get_function(const char* name);
|
||||
launch_information get_launch_info(const char* name);
|
||||
|
||||
private:
|
||||
std::map<std::string, driver::module*> modules_;
|
||||
driver::context* driver_context_;
|
||||
llvm::LLVMContext llvm_context_;
|
||||
ir::context triton_context_;
|
||||
std::map<std::string, launch_information> launch_info_map_;
|
||||
std::shared_ptr<triton::codegen::target> target_;
|
||||
unsigned nthreads_;
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,94 +0,0 @@
|
||||
#include <sstream>
|
||||
#include <unordered_map>
|
||||
#include "triton/dnn/base.h"
|
||||
#include "triton/runtime/jit.h"
|
||||
#include "triton/tools/bench.hpp"
|
||||
|
||||
namespace triton{
|
||||
namespace dnn{
|
||||
|
||||
namespace rt = triton::runtime;
|
||||
|
||||
|
||||
void base::set_ld(const std::vector<int32_t>& shapes,
|
||||
std::vector<int32_t>& ld) {
|
||||
size_t size = shapes.size();
|
||||
ld.resize(size);
|
||||
ld[size - 1] = 1;
|
||||
for(int i = size - 1; i >= 1; i--)
|
||||
ld[i - 1] = shapes[i] * ld[i];
|
||||
}
|
||||
|
||||
|
||||
base::base(const std::string& name)
|
||||
: name_(name) { }
|
||||
|
||||
std::vector<params_t> base::search_space() const {
|
||||
return {};
|
||||
}
|
||||
|
||||
params_t base::heuristics() const {
|
||||
return *search_space().begin();
|
||||
}
|
||||
|
||||
std::pair<base*, rt::jit*> base::get_profile_impl(driver::stream *stream, std::vector<driver::buffer *> args, autotuning_t autotune) {
|
||||
static std::unordered_map<base*, std::unique_ptr<rt::jit>, recompile_hash, recompile_equal> m_jit;
|
||||
driver::context* ctx = stream->context();
|
||||
rt::jit* jit;
|
||||
/* the current template has not already been compiled */
|
||||
if(m_jit.find(this) == m_jit.end()) {
|
||||
base* clone = this->clone();
|
||||
jit = m_jit.emplace(clone, std::unique_ptr<rt::jit>(new rt::jit(ctx, 8))).first->second.get();
|
||||
std::ostringstream oss;
|
||||
clone->triton_c_src(oss);
|
||||
std::string src = oss.str();
|
||||
auto benchmark = [&](triton::driver::kernel* kernel,
|
||||
rt::launch_information info) {
|
||||
// launch info
|
||||
clone->init_impl(stream, (triton::driver::cu_module*)kernel->module(), info);
|
||||
clone->enqueue_impl(stream, kernel, args, info);
|
||||
stream->synchronize();
|
||||
double ts = triton::tools::bench([&](){ clone->enqueue_impl(stream, kernel, args, info); }, stream);
|
||||
clone->deinit_impl();
|
||||
// std::cout << ts * 1e-6 << std::endl;
|
||||
return num_flops() / ts * 1e-3;
|
||||
};
|
||||
// auto-tune and save result
|
||||
if(autotune == FULL_TUNING || autotune == PARTIAL_TUNING) {
|
||||
std::vector<params_t> space = {};
|
||||
if(autotune == PARTIAL_TUNING)
|
||||
space = search_space();
|
||||
rt::jit::tune_res_t best = jit->autotune(name_.c_str(), src.c_str(), benchmark, space);
|
||||
jit->add_module(name_.c_str(), src.c_str(), best.params);
|
||||
}
|
||||
else{
|
||||
params_t params = heuristics();
|
||||
jit->add_module(name_.c_str(), src.c_str(), params);
|
||||
}
|
||||
triton::driver::kernel* kernel = jit->get_function(name_.c_str());
|
||||
rt::launch_information info = jit->get_launch_info(name_.c_str());
|
||||
clone->init_impl(stream, (triton::driver::cu_module*)kernel->module(), info);
|
||||
}
|
||||
/* retrieved compiled template */
|
||||
else {
|
||||
jit = m_jit.at(this).get();
|
||||
}
|
||||
auto it = m_jit.find(this);
|
||||
return {it->first, jit};
|
||||
}
|
||||
|
||||
base* base::enqueue(driver::stream *stream, std::vector<driver::buffer *> args, autotuning_t autotune) {
|
||||
launch_context_t info = get_launch_context(stream, args, autotune);
|
||||
info.op->enqueue_impl(stream, info.kernel, args, info.info);
|
||||
return info.op;
|
||||
}
|
||||
|
||||
launch_context_t base::get_launch_context(driver::stream *stream, std::vector<driver::buffer *> args, autotuning_t autotune) {
|
||||
std::pair<base*, rt::jit*> profile = get_profile_impl(stream, args, autotune);
|
||||
driver::kernel* kernel = profile.second->get_function(name_.c_str());
|
||||
rt::launch_information info = profile.second->get_launch_info(name_.c_str());
|
||||
return {profile.first, kernel, info};
|
||||
}
|
||||
|
||||
}
|
||||
}
|
@@ -1,227 +0,0 @@
|
||||
/* Copyright 2015-2019 Philippe Tillet
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining
|
||||
* a copy of this software and associated documentation files
|
||||
* (the "Software"), to deal in the Software without restriction,
|
||||
* including without limitation the rights to use, copy, modify, merge,
|
||||
* publish, distribute, sublicense, and/or sell copies of the Software,
|
||||
* and to permit persons to whom the Software is furnished to do so,
|
||||
* subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be
|
||||
* included in all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*/
|
||||
|
||||
#include "triton/dnn/batchnorm.h"
|
||||
|
||||
namespace triton{
|
||||
namespace dnn{
|
||||
|
||||
/* ---------------
|
||||
* Forward
|
||||
* --------------- */
|
||||
|
||||
batchnorm_forward::batchnorm_forward(int C, int D, int H, int W, int B, std::string ty, float eps)
|
||||
: base("batchnorm_forward"),
|
||||
C_(C), D_(D), H_(H), W_(W), B_(B), ty_(ty), eps_(eps) {
|
||||
DHWB_ = D_*H_*W_*B_;
|
||||
rcpDHWB_ = (float)1 / DHWB_;
|
||||
}
|
||||
|
||||
size_t batchnorm_forward::num_flops() const {
|
||||
return C_*DHWB_;
|
||||
}
|
||||
|
||||
|
||||
std::vector<int64_t> batchnorm_forward::retune_params() const {
|
||||
return {C_, D_, H_, W_, B_};
|
||||
}
|
||||
|
||||
base* batchnorm_forward::clone() const {
|
||||
return new batchnorm_forward(*this);
|
||||
}
|
||||
|
||||
void batchnorm_forward::enqueue_impl(driver::stream *stream, driver::kernel *kernel,
|
||||
std::vector<driver::buffer*> args,
|
||||
runtime::launch_information info)
|
||||
{
|
||||
driver::buffer *y = args[0], *m = args[1], *v = args[2];
|
||||
driver::buffer *x = args[3], *g = args[4], *b = args[5];
|
||||
std::array<size_t, 3> grid = {1, (size_t)C_, 1};
|
||||
kernel->setArg(0, y);
|
||||
kernel->setArg(1, m);
|
||||
kernel->setArg(2, v);
|
||||
kernel->setArg(3, x);
|
||||
kernel->setArg(4, g);
|
||||
kernel->setArg(5, b);
|
||||
kernel->setArg(6, DHWB_);
|
||||
kernel->setArg(7, rcpDHWB_);
|
||||
kernel->setArg(8, eps_);
|
||||
stream->enqueue(kernel, grid, {info.num_threads, 1, 1});
|
||||
}
|
||||
|
||||
void batchnorm_forward::triton_c_src(std::ostream &os) const {
|
||||
os <<
|
||||
R"(
|
||||
const tunable int TM = {128};
|
||||
|
||||
void batchnorm_forward(float *Y, float *M, float *V,
|
||||
restrict read_only float *X,
|
||||
restrict read_only float *G,
|
||||
restrict read_only float *B,
|
||||
int DHWN,
|
||||
float rcpDHWN, float eps) {
|
||||
int rx[TM] = 0 ... TM;
|
||||
float *px[TM];
|
||||
float x[TM] = 0;
|
||||
int c = get_program_id(1);
|
||||
float g = *(G + c);
|
||||
float b = *(B + c);
|
||||
|
||||
float mean[TM] = 0;
|
||||
px = X + rx + c*DHWN;
|
||||
for(int i = 0; i < DHWN; i = i + TM){
|
||||
x = *px;
|
||||
mean = mean + x;
|
||||
px = px + TM;
|
||||
}
|
||||
float *pm = M + c;
|
||||
float m = __sum(mean, 0) * rcpDHWN;
|
||||
*pm = m;
|
||||
|
||||
float var[TM] = 0;
|
||||
px = X + rx + c*DHWN;
|
||||
for(int i = 0; i < DHWN; i = i + TM){
|
||||
x = *px;
|
||||
x = x - m;
|
||||
var = var + x*x;
|
||||
px = px + TM;
|
||||
}
|
||||
float v = __sum(var, 0) * rcpDHWN;
|
||||
float *pv = V + c;
|
||||
*pv = v;
|
||||
float rstdg = 1 / sqrt(v + eps) * g;
|
||||
|
||||
px = X + rx + c*DHWN;
|
||||
float* py[TM] = Y + rx + c*DHWN;
|
||||
for(int i = 0; i < DHWN; i = i + TM){
|
||||
x = *px;
|
||||
float y[TM] = (x - m)*rstdg + b;
|
||||
*py = y;
|
||||
px = px + TM;
|
||||
py = py + TM;
|
||||
}
|
||||
})";
|
||||
}
|
||||
|
||||
/* ---------------
|
||||
* Backward
|
||||
* --------------- */
|
||||
|
||||
batchnorm_backward::batchnorm_backward(int C, int D, int H, int W, int B, std::string ty, float eps)
|
||||
: base("batchnorm_backward"),
|
||||
C_(C), D_(D), H_(H), W_(W), B_(B),
|
||||
ty_(ty), eps_(eps)
|
||||
{ }
|
||||
|
||||
size_t batchnorm_backward::num_flops() const {
|
||||
return C_*D_*H_*W_*B_;
|
||||
}
|
||||
|
||||
std::vector<int64_t> batchnorm_backward::retune_params() const {
|
||||
return {C_, D_, H_, W_, B_};
|
||||
}
|
||||
|
||||
base* batchnorm_backward::clone() const {
|
||||
return new batchnorm_backward(*this);
|
||||
}
|
||||
|
||||
void batchnorm_backward::enqueue_impl(driver::stream *stream, driver::kernel *kernel,
|
||||
std::vector<driver::buffer *> args,
|
||||
runtime::launch_information info) {
|
||||
driver::buffer *dx = args[0], *dg = args[1], *db = args[2], *dy = args[3];
|
||||
driver::buffer *x = args[4], *g = args[5], *m = args[6], *v = args[7];
|
||||
std::array<size_t, 3> grid = {1, (size_t)C_, 1};
|
||||
kernel->setArg(0, dx);
|
||||
kernel->setArg(1, dg);
|
||||
kernel->setArg(2, db);
|
||||
kernel->setArg(3, dy);
|
||||
kernel->setArg(4, x);
|
||||
kernel->setArg(5, g);
|
||||
kernel->setArg(6, m);
|
||||
kernel->setArg(7, v);
|
||||
kernel->setArg(8, (int32_t)(D_*H_*W_*B_));
|
||||
kernel->setArg(9, (float)1/(D_*H_*W_*B_));
|
||||
kernel->setArg(10, eps_);
|
||||
stream->enqueue(kernel, grid, {info.num_threads, 1, 1});
|
||||
}
|
||||
|
||||
void batchnorm_backward::triton_c_src(std::ostream &os) const {
|
||||
os <<
|
||||
R"(
|
||||
const tunable int TM = {128};
|
||||
|
||||
void batchnorm_backward(float *DX, float *DG, float *DB,
|
||||
restrict read_only float *DY,
|
||||
restrict read_only float *X,
|
||||
restrict read_only float *G,
|
||||
restrict read_only float *M,
|
||||
restrict read_only float *V,
|
||||
int DHWN, float rcpDHWN, float epsilon) {
|
||||
int rx[TM] = 0 ... TM;
|
||||
int c = get_program_id(1);
|
||||
int offset = c*DHWN;
|
||||
float g = *(G + c);
|
||||
float mean = *(M + c);
|
||||
float var = *(V + c);
|
||||
float rstd = 1 / sqrt(var + epsilon);
|
||||
float* px[TM];
|
||||
float* pdx[TM];
|
||||
float* pdy[TM];
|
||||
|
||||
px = X + rx + offset;
|
||||
pdy = DY + rx + offset;
|
||||
float dg[TM] = 0;
|
||||
float db[TM] = 0;
|
||||
for(int i = 0; i < DHWN; i = i + TM){
|
||||
float x[TM] = *px;
|
||||
float dy[TM] = *pdy;
|
||||
dg = dg + dy*(x - mean)*rstd;
|
||||
db = db + dy;
|
||||
px = px + TM;
|
||||
pdy = pdy + TM;
|
||||
}
|
||||
float sdg = __sum(dg, 0);
|
||||
float sdb = __sum(db, 0);
|
||||
float *pdg = DG + c;
|
||||
float *pdb = DB + c;
|
||||
*pdg = sdg;
|
||||
*pdb = sdb;
|
||||
|
||||
px = X + rx + offset;
|
||||
pdy = DY + rx + offset;
|
||||
pdx = DX + rx + offset;
|
||||
for(int i = 0; i < DHWN; i = i + TM){
|
||||
float x[TM] = *px;
|
||||
float dy[TM] = *pdy;
|
||||
float xhat[TM] = (x - mean) * rstd;
|
||||
float xtmp[TM] = (xhat * dg + db) * rcpDHWN;
|
||||
float dx[TM] = (dy - xtmp) * rstd * g;
|
||||
*pdx = dx;
|
||||
px = px + TM;
|
||||
pdy = pdy + TM;
|
||||
pdx = pdx + TM;
|
||||
}
|
||||
})";
|
||||
}
|
||||
|
||||
}
|
||||
}
|
@@ -1,238 +0,0 @@
|
||||
#include "triton/dnn/heuristics.h"
|
||||
#include "triton/dnn/blocksparse/dot.h"
|
||||
|
||||
namespace triton{
|
||||
namespace dnn{
|
||||
namespace blocksparse{
|
||||
|
||||
|
||||
size_t dot::num_flops() const {
|
||||
return 2.*nblocks_*BS_*BS_*N_;
|
||||
}
|
||||
|
||||
std::vector<int64_t> dot::retune_params() const{
|
||||
return {N_, S_, C_, BS_, nlocks_, op_};
|
||||
}
|
||||
|
||||
std::vector<params_t> dot::search_space() const {
|
||||
return bsdot_search_space(op_ == FPROP, BS_);
|
||||
}
|
||||
|
||||
params_t dot::heuristics() const {
|
||||
return bsdot_heuristics(op_ == FPROP, BS_, N_, S_);
|
||||
}
|
||||
|
||||
base * dot::clone() const {
|
||||
return new dot(*this);
|
||||
}
|
||||
|
||||
dot::dot(int32_t N, int32_t K, int32_t S, int32_t C,
|
||||
const std::string& ty, int32_t BS, int32_t nlocks, int32_t nblocks, op_t op):
|
||||
base("bsdot"),
|
||||
N_(N), K_(K), S_(S), C_(C),
|
||||
ab_ty_(ty), c_ty_(ty),
|
||||
BS_(BS), nlocks_(nlocks), nblocks_(nblocks), op_(op){
|
||||
}
|
||||
|
||||
void dot::init_impl(driver::stream *stream, driver::cu_module *module, triton::runtime::launch_information info) {
|
||||
int32_t TM = info.globals["TM"];
|
||||
size_t grid_0 = (N_ + TM - 1) / TM;
|
||||
if(nlocks_ && !locks_){
|
||||
locks_.reset(triton::driver::buffer::create(stream->context(), grid_0 * nlocks_ * 2 * 4));
|
||||
((driver::cu_buffer*)locks_.get())->set_zero(stream, grid_0 * nlocks_ * 2 * 4);
|
||||
}
|
||||
}
|
||||
|
||||
void dot::deinit_impl() {
|
||||
}
|
||||
|
||||
void dot::enqueue_impl(driver::stream *stream, driver::kernel *kernel,
|
||||
std::vector<driver::buffer *> args, runtime::launch_information info) {
|
||||
driver::buffer *a = args[0];
|
||||
driver::buffer *b = args[1];
|
||||
driver::buffer *c = args[2];
|
||||
driver::buffer *lut = args[3];
|
||||
kernel->setArg(0, a);
|
||||
kernel->setArg(1, b);
|
||||
kernel->setArg(2, c);
|
||||
if(op_ == FPROP || op_ == BPROP){
|
||||
kernel->setArg(3, N_);
|
||||
kernel->setArg(4, BS_);
|
||||
kernel->setArg(5, N_);
|
||||
}
|
||||
else{
|
||||
kernel->setArg(3, N_);
|
||||
kernel->setArg(4, N_);
|
||||
kernel->setArg(5, BS_);
|
||||
}
|
||||
kernel->setArg(6, N_);
|
||||
kernel->setArg(7, lut);
|
||||
kernel->setArg(8, locks_.get());
|
||||
kernel->setArg(9, nlocks_);
|
||||
if(op_ == FPROP || op_ == BPROP){
|
||||
int32_t TM = info.globals["TM"];
|
||||
size_t grid_0 = (N_ + TM - 1) / TM;
|
||||
size_t grid_1 = S_;
|
||||
if(nlocks_)
|
||||
((driver::cu_buffer*)locks_.get())->set_zero(stream, grid_0 * nlocks_ * 2 * 4);
|
||||
stream->enqueue(kernel, {grid_0, grid_1, 1}, {info.num_threads, 1, 1});
|
||||
}
|
||||
else{
|
||||
size_t grid_0 = nblocks_;
|
||||
stream->enqueue(kernel, {grid_0, 1, 1}, {info.num_threads, 1, 1});
|
||||
}
|
||||
}
|
||||
|
||||
driver::buffer* dot::get_locks() const {
|
||||
return locks_.get();
|
||||
}
|
||||
|
||||
std::string dot::triton_c_src_ydx() const {
|
||||
bool AT = (op_ == WGRAD);
|
||||
bool BT = (op_ == FPROP);
|
||||
std::string usea = AT ? "trans(a)" : "a";
|
||||
std::string useb = BT ? "trans(b)" : "b";
|
||||
std::string sizea = "TM, TK";
|
||||
std::string sizeb = BT ? "TN, TK" : "TK, TN";
|
||||
std::string bca0 = ":, newaxis";
|
||||
std::string bca1 = "newaxis, :";
|
||||
std::string bcb0 = BT ? ":, newaxis" : "newaxis, :";
|
||||
std::string bcb1 = BT ? "newaxis, :" : ":, newaxis";
|
||||
std::string lda0 = AT ? "*lda" : "";
|
||||
std::string lda1 = AT ? "" : "*lda";
|
||||
std::string ldb0 = BT ? "" : "*ldb";
|
||||
std::string ldb1 = BT ? "*ldb" : "" ;
|
||||
std::string result =
|
||||
R"(
|
||||
const tunable int TM = {16, 32, 64, 128};
|
||||
const tunable int TN = {)" + std::to_string(BS_) + R"(};
|
||||
const tunable int TK = {)" + std::to_string(BS_) + R"(};
|
||||
|
||||
void bsdot(restrict read_only align(16) )" + ab_ty_ + R"( *A,
|
||||
restrict read_only align(16) )" + ab_ty_ + R"( *B,
|
||||
)" + c_ty_ + R"(* C,
|
||||
int lda, int ldb, int ldc,
|
||||
int N, int* lut,
|
||||
int* locks, int nlocks) {
|
||||
int ridx = get_program_id(0);
|
||||
float acc[TM, TN] = 0;
|
||||
int rka[TK] = 0 ... TK;
|
||||
int rkb[TK] = 0 ... TK;
|
||||
int *header = lut + get_program_id(1) * 4;
|
||||
int offset = *(header + 0);
|
||||
int K = *(header + 1);
|
||||
int column = *(header + 2);
|
||||
int lockid = *(header + 3);
|
||||
int rxa[TM] = ridx * TM + (0 ... TM);
|
||||
int ryb[TN] = 0 ... TN;
|
||||
int *plut = lut + offset * 2;
|
||||
int offa[)" + sizea + "] = rxa[" + bca0 + "]" + lda0 + " + rka[" + bca1 + "]" + lda1 + R"(;
|
||||
int offb[)" + sizeb + "] = ryb[" + bcb0 + "]" + ldb0 + " + rkb[" + bcb1 + "]" + ldb1 + R"(;
|
||||
bool checka[TM, TK] = (rxa < N)[:, newaxis];
|
||||
for(int k = K; k > 0; k = k - 1) {
|
||||
int ak = *(plut + 0);
|
||||
int bk = *(plut + 1);
|
||||
)" + ab_ty_ + "* pa[" + sizea + R"(] = A + offa + ak * TK * lda;
|
||||
)" + ab_ty_ + "* pb[" + sizeb + R"(] = B + offb + bk * TK * TN;
|
||||
)" + ab_ty_ + " a[" + sizea + R"(] = checka ? *pa : 0;
|
||||
)" + ab_ty_ + " b[" + sizeb + R"(] = *pb;
|
||||
acc = dot()" + usea + ", " + useb + R"(, acc);
|
||||
plut = plut + 2;
|
||||
}
|
||||
int rxc[TM] = ridx * TM + (0 ... TM);
|
||||
int ryc[TN] = column * TN + (0 ... TN);
|
||||
)" + c_ty_ + R"(" c[TM, TN] = acc;
|
||||
)" + c_ty_ + R"(* pc[TM, TN] = C + rxc[:, newaxis] + ryc[newaxis, :]*ldc;
|
||||
bool checkc[TM, TN] = (rxc < N)[:, newaxis];
|
||||
if(lockid == 0) {
|
||||
@checkc *pc = c;
|
||||
}
|
||||
else {
|
||||
int *plock = locks + ridx*nlocks + lockid - 1;
|
||||
int *pcount = plock + get_num_program(0)*nlocks;
|
||||
while(__atomic_cas(plock, 0, 1));
|
||||
int count = *pcount;
|
||||
if(count == 0)
|
||||
@checkc *pc = c;
|
||||
else
|
||||
@checkc *pc = c + *pc;
|
||||
__atomic_exch(pcount, 1);
|
||||
__atomic_exch(plock, 0);
|
||||
}
|
||||
})";
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
std::string dot::triton_c_src_dw() const {
|
||||
bool AT = (op_ == WGRAD);
|
||||
bool BT = (op_ == FPROP);
|
||||
std::string usea = AT ? "trans(a)" : "a";
|
||||
std::string useb = BT ? "trans(b)" : "b";
|
||||
std::string sizea = AT ? "TK, TM" : "TM, TK";
|
||||
std::string sizeb = BT ? "TN, TK" : "TK, TN";
|
||||
std::string bca0 = AT ? "newaxis, :" : ":, newaxis";
|
||||
std::string bca1 = AT ? ":, newaxis" : "newaxis, :";
|
||||
std::string bcb0 = BT ? ":, newaxis" : "newaxis, :";
|
||||
std::string bcb1 = BT ? "newaxis, :" : ":, newaxis";
|
||||
std::string lda0 = AT ? "*lda" : "";
|
||||
std::string lda1 = AT ? "" : "*lda";
|
||||
std::string ldb0 = BT ? "" : "*ldb";
|
||||
std::string ldb1 = BT ? "*ldb" : "" ;
|
||||
std::string result =
|
||||
R"(
|
||||
const tunable int TM = {)" + std::to_string(BS_) + R"(};
|
||||
const tunable int TN = {)" + std::to_string(BS_) + R"(};
|
||||
const tunable int TK = {32};
|
||||
|
||||
void bsdot(restrict read_only align(16) )" + ab_ty_ + R"( *A,
|
||||
restrict read_only align(16) )" + ab_ty_ + R"( *B,
|
||||
)" + c_ty_ + R"(* C,
|
||||
int lda, int ldb, int ldc,
|
||||
int N, int* lut,
|
||||
int* locks, int nlocks) {
|
||||
int ridx = get_program_id(0);
|
||||
float acc[TM, TN] = 0;
|
||||
int rka[TK] = 0 ... TK;
|
||||
int rkb[TK] = 0 ... TK;
|
||||
int *header = lut + ridx * 2;
|
||||
int offx = *(header + 0);
|
||||
int offy = *(header + 1);
|
||||
int rxa[TM] = offx*TM + (0 ... TM);
|
||||
int ryb[TN] = offy*TN + (0 ... TN);
|
||||
bool checka[TK, TM] = (rka < N)[:, newaxis];
|
||||
bool checkb[TK, TN] = (rkb < N)[:, newaxis];
|
||||
int offa[)" + sizea + "] = rxa[" + bca0 + "]" + lda0 + " + rka[" + bca1 + "]" + lda1 + R"(;
|
||||
int offb[)" + sizeb + "] = ryb[" + bcb0 + "]" + ldb0 + " + rkb[" + bcb1 + "]" + ldb1 + R"(;
|
||||
)" + ab_ty_ + " * pa[" + sizea + R"(] = A + offa;
|
||||
)" + ab_ty_ + " * pb[" + sizeb + R"(] = B + offb;
|
||||
)" + ab_ty_ + " a[" + sizea + R"(] = checka ? *pa : 0;
|
||||
)" + ab_ty_ + " b[" + sizeb + R"(] = checkb ? *pb : 0;
|
||||
for(int k = N; k > 0; k = k - TK) {
|
||||
acc = dot()" + usea + ", " + useb + R"(, acc);
|
||||
pa = pa + TK)" + lda1 + R"(;
|
||||
pb = pb + TK)" + ldb1 + R"(;
|
||||
a = checka ? *pa : 0;
|
||||
b = checkb ? *pb : 0;
|
||||
}
|
||||
int rxc[TM] = (0 ... TM);
|
||||
int ryc[TN] = (0 ... TN);
|
||||
)" + c_ty_ + R"( c[TM, TN] = acc;
|
||||
)" + c_ty_ + R"(* pc[TM, TN] = C + rxc[:, newaxis]*TM + ryc[newaxis, :] + ridx*TM*TN;
|
||||
*pc = c;
|
||||
})";
|
||||
|
||||
return result;
|
||||
}
|
||||
void dot::triton_c_src(std::ostream &os) const {
|
||||
if(op_ == FPROP || op_ == BPROP)
|
||||
os << triton_c_src_ydx();
|
||||
else
|
||||
os << triton_c_src_dw();
|
||||
}
|
||||
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
}
|
720
lib/dnn/conv.cpp
720
lib/dnn/conv.cpp
@@ -1,720 +0,0 @@
|
||||
#include <cmath>
|
||||
#include "triton/dnn/conv.h"
|
||||
|
||||
namespace triton{
|
||||
namespace dnn{
|
||||
|
||||
conv::conv(int B, int NC,
|
||||
int D, int H, int W,
|
||||
int T, int R, int S, int NF,
|
||||
int stride_d, int stride_h, int stride_w,
|
||||
int pad_d, int pad_h, int pad_w,
|
||||
int upsample_d, int upsample_h, int upsample_w,
|
||||
std::string a_ty, std::string b_ty,
|
||||
type ty, bool bias)
|
||||
: base("conv"),
|
||||
NB_(B), NC_(NC), AD_(D), AH_(H), AW_(W), BD_(T), BH_(R), BW_(S), NF_(NF),
|
||||
stride_d_(stride_d), stride_h_(stride_h), stride_w_(stride_w),
|
||||
pad_d_(pad_d), pad_h_(pad_h), pad_w_(pad_w),
|
||||
upsample_d_(upsample_d), upsample_h_(upsample_h), upsample_w_(upsample_w),
|
||||
a_ty_(a_ty), b_ty_(b_ty),
|
||||
ty_(ty), bias_(bias)
|
||||
{
|
||||
CD_ = (AD_*upsample_d_ - BD_ + 1 + 2*pad_d_ + stride_d_ - 1)/stride_d_;
|
||||
CH_ = (AH_*upsample_h_ - BH_ + 1 + 2*pad_h_ + stride_h_ - 1)/stride_h_;
|
||||
CW_ = (AW_*upsample_w_ - BW_ + 1 + 2*pad_w_ + stride_w_ - 1)/stride_w_;
|
||||
// shapes
|
||||
shapes_a_ = {NB_, NC_, AD_, AH_, AW_};
|
||||
shapes_b_ = {NC_, BD_, BH_, BW_, NF_};
|
||||
shapes_c_ = {NB_, NF_, CD_, CH_, CW_};
|
||||
// a layout - NCHW
|
||||
a_outer_idx_ = 0;
|
||||
a_inner_idx_ = 1;
|
||||
a_pix_idx_ = 2;
|
||||
// b layout - CRSK
|
||||
b_inner_idx_ = 0;
|
||||
b_pix_idx_ = 1;
|
||||
b_outer_idx_ = 4;
|
||||
// c layout - NKPQ
|
||||
c_outer_0_idx_ = 0;
|
||||
c_outer_1_idx_ = 1;
|
||||
c_pix_idx = 2;
|
||||
// swap a and c for bprop
|
||||
if(ty_ == BPROP){
|
||||
std::swap(AD_, CD_);
|
||||
std::swap(AH_, CH_);
|
||||
std::swap(AW_, CW_);
|
||||
shapes_a_.swap(shapes_c_);
|
||||
std::swap(stride_d_, upsample_d_);
|
||||
std::swap(stride_h_, upsample_h_);
|
||||
std::swap(stride_w_, upsample_w_);
|
||||
pad_d_ = (CD_*stride_d_ - AD_*upsample_d_ + BD_ - 1 - stride_d_ + 1)/2;
|
||||
pad_h_ = (CH_*stride_h_ - AH_*upsample_h_ + BH_ - 1 - stride_h_ + 1)/2;
|
||||
pad_w_ = (CW_*stride_w_ - AW_*upsample_w_ + BW_ - 1 - stride_w_ + 1)/2;
|
||||
std::swap(b_inner_idx_, b_outer_idx_);
|
||||
std::swap(NC_, NF_);
|
||||
}
|
||||
// swap b and c for wgrad
|
||||
if(ty_ == WGRAD){
|
||||
shapes_b_.swap(shapes_c_);
|
||||
std::swap(BD_, CD_);
|
||||
std::swap(BH_, CH_);
|
||||
std::swap(BW_, CW_);
|
||||
std::swap(a_outer_idx_, a_inner_idx_);
|
||||
std::swap(b_inner_idx_, c_outer_0_idx_);
|
||||
std::swap(b_outer_idx_, c_outer_1_idx_);
|
||||
std::swap(b_pix_idx_, c_pix_idx);
|
||||
}
|
||||
// leading dimensions
|
||||
set_ld(shapes_a_, ld_a_);
|
||||
set_ld(shapes_b_, ld_b_);
|
||||
set_ld(shapes_c_, ld_c_);
|
||||
// equivalent matmul
|
||||
bool upsampled_b = (ty_ == BPROP) && (upsample_d_ > 1 || upsample_h_ > 1 || upsample_w_ > 1);
|
||||
b_trans_ = ty_ != BPROP;
|
||||
b_lut_ = ty_ == WGRAD || upsampled_b;
|
||||
M_ = shapes_c_[c_outer_0_idx_]*shapes_c_[c_pix_idx]*shapes_c_[c_pix_idx+1]*shapes_c_[c_pix_idx+2];
|
||||
N_ = shapes_c_[c_outer_1_idx_];
|
||||
K_ = shapes_b_[b_inner_idx_]*BD_*BH_*BW_;
|
||||
// look-up table info
|
||||
if(ty_ == FPROP)
|
||||
Fs_ = shapes_b_[1]*shapes_b_[2]*shapes_b_[3];
|
||||
else
|
||||
Fs_ = K_;
|
||||
TK_ = 8;
|
||||
Luts_ = (TK_ + Fs_ - 1) / Fs_ * Fs_;
|
||||
build_a_deltas();
|
||||
if(b_lut_)
|
||||
build_b_deltas();
|
||||
build_masks();
|
||||
size_t cst_size = h_b_deltas_.size()*4;
|
||||
is_b_deltas_cst_ = cst_size < 65536;
|
||||
cst_size += h_a_deltas_.size()*4;
|
||||
is_a_deltas_cst = cst_size < 65536;
|
||||
cst_size += h_masks_.size()*4;
|
||||
is_mask_cst_ = cst_size < 65536;
|
||||
max_grid_0_ = 256;
|
||||
max_grid_1_ = 256;
|
||||
}
|
||||
|
||||
// comparison for maps
|
||||
std::vector<int64_t> conv::retune_params() const {
|
||||
return {NB_, NC_, AD_, AH_, AW_,
|
||||
NF_, BD_, BH_, BW_,
|
||||
pad_d_, pad_h_, pad_w_,
|
||||
stride_d_, stride_h_, stride_w_,
|
||||
ty_, bias_};
|
||||
}
|
||||
|
||||
// clone
|
||||
base* conv::clone() const {
|
||||
return new conv(*this);
|
||||
}
|
||||
|
||||
size_t conv::a_size()
|
||||
{ return std::accumulate(shapes_a_.begin(), shapes_a_.end(),
|
||||
1, std::multiplies<int>()); }
|
||||
|
||||
size_t conv::b_size()
|
||||
{ return std::accumulate(shapes_b_.begin(), shapes_b_.end(),
|
||||
1, std::multiplies<int>()); }
|
||||
|
||||
size_t conv::c_size()
|
||||
{ return std::accumulate(shapes_c_.begin(), shapes_c_.end(),
|
||||
1, std::multiplies<int>()); }
|
||||
|
||||
std::vector<int32_t> conv::c_shapes()
|
||||
{ return shapes_c_; }
|
||||
|
||||
|
||||
std::tuple<int32_t, int32_t, int32_t, int32_t> conv::unpack(int32_t ltrs, bool flip, int32_t EBD, int32_t EBH, int32_t EBW) {
|
||||
int32_t l, t, r, s;
|
||||
if(b_trans_){
|
||||
l = ltrs / (EBD*EBH*EBW);
|
||||
int32_t trs = ltrs % (EBD*EBH*EBW);
|
||||
int32_t tr = trs / EBW;
|
||||
s = trs % EBW;
|
||||
t = tr / EBH;
|
||||
r = tr % EBH;
|
||||
}
|
||||
else{
|
||||
int32_t rs = ltrs / NC_;
|
||||
l = ltrs % NC_;
|
||||
r = rs / EBW;
|
||||
s = rs % EBW;
|
||||
}
|
||||
if(flip){
|
||||
r = EBH - 1 - r;
|
||||
s = EBW - 1 - s;
|
||||
}
|
||||
return std::make_tuple(l, t, r, s);
|
||||
}
|
||||
|
||||
void conv::build_b_deltas(){
|
||||
h_b_deltas_.resize(Luts_*upsample_d_*upsample_h_*upsample_w_);
|
||||
|
||||
size_t Ds0 = Luts_;
|
||||
size_t Ds1 = upsample_w_;
|
||||
size_t Ds2 = upsample_h_;
|
||||
size_t Ds3 = upsample_d_;
|
||||
for(size_t ud = 0; ud < Ds3; ++ud)
|
||||
for(size_t uh = 0; uh < Ds2; ++uh)
|
||||
for(size_t uw = 0; uw < Ds1; ++uw) {
|
||||
int32_t* deltas_ptr = &h_b_deltas_[uw*Ds0 + uh*Ds0*Ds1 + ud*Ds0*Ds1*Ds2];
|
||||
for(size_t i = 0; i < Luts_; ++i) {
|
||||
int32_t EBD = 1;
|
||||
int32_t EBH = ((upsample_h_ - uh - 1) + BH_) / upsample_h_;
|
||||
int32_t EBW = ((upsample_w_ - uw - 1) + BW_) / upsample_w_;
|
||||
if(EBD == 0 || EBH == 0 || EBW == 0)
|
||||
continue;
|
||||
int32_t c, t, r, s;
|
||||
int32_t nextc, nextt, nextr, nexts;
|
||||
std::tie(c, t, r, s) = unpack(i, false, EBD, EBH, EBW);
|
||||
std::tie(nextc, nextt, nextr, nexts) = unpack(i + TK_, false, EBD, EBH, EBW);
|
||||
int32_t cdiff = nextc - c;
|
||||
int32_t tdiff = (nextt - t)*upsample_d_;
|
||||
int32_t rdiff = (nextr - r)*upsample_h_;
|
||||
int32_t sdiff = (nexts - s)*upsample_w_;
|
||||
deltas_ptr[i] = cdiff*ld_b_[b_inner_idx_] + tdiff*ld_b_[b_pix_idx_] + rdiff*ld_b_[b_pix_idx_ + 1] + sdiff*ld_b_[b_pix_idx_ + 2];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void conv::build_a_deltas(){
|
||||
h_a_deltas_.resize(Luts_ + upsample_d_*upsample_h_*upsample_w_*Luts_);
|
||||
for(size_t i = 0; i < Luts_; ++i)
|
||||
h_a_deltas_[i] = (((i + TK_) % Luts_) - i);
|
||||
size_t Ds0 = Luts_;
|
||||
size_t Ds1 = upsample_w_;
|
||||
size_t Ds2 = upsample_h_;
|
||||
size_t Ds3 = upsample_d_;
|
||||
for(size_t ud = 0; ud < Ds3; ++ud)
|
||||
for(size_t uh = 0; uh < Ds2; ++uh)
|
||||
for(size_t uw = 0; uw < Ds1; ++uw) {
|
||||
int32_t* deltas_ptr = &h_a_deltas_[Luts_ + uw*Ds0 + uh*Ds0*Ds1 + ud*Ds0*Ds1*Ds2];
|
||||
// cumulative increments
|
||||
for(size_t i = 0; i < Ds0; ++i) {
|
||||
int32_t EBD = 1;
|
||||
int32_t EBH = ((upsample_h_ - uh - 1) + BH_) / upsample_h_;
|
||||
int32_t EBW = ((upsample_w_ - uw - 1) + BW_) / upsample_w_;
|
||||
if(EBD == 0 || EBH == 0 || EBW == 0)
|
||||
continue;
|
||||
// unpack
|
||||
int32_t ctrs = i;
|
||||
int32_t c, t, r, s;
|
||||
std::tie(c, t, r, s) = unpack(ctrs, !b_trans_, EBD, EBH, EBW);
|
||||
// next indices
|
||||
int32_t nextctrs = ctrs + TK_;
|
||||
int32_t nextc, nextt, nextr, nexts;
|
||||
std::tie(nextc, nextt, nextr, nexts) = unpack(nextctrs, !b_trans_, EBD, EBH, EBW);
|
||||
// diffs
|
||||
int32_t cdiff = nextc - c;
|
||||
int32_t tdiff = nextt - t;
|
||||
int32_t rdiff = nextr - r;
|
||||
int32_t sdiff = nexts - s;
|
||||
if(ty_ == WGRAD){
|
||||
tdiff = tdiff * stride_d_;
|
||||
rdiff = rdiff * stride_h_;
|
||||
sdiff = sdiff * stride_w_;
|
||||
}
|
||||
// delta pointers
|
||||
deltas_ptr[i] = cdiff*ld_a_[a_inner_idx_] + tdiff*ld_a_[a_pix_idx_] + rdiff*ld_a_[a_pix_idx_ + 1] + sdiff*ld_a_[a_pix_idx_ + 2];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void conv::build_masks(){
|
||||
h_masks_.resize(Luts_ + upsample_d_*upsample_h_*upsample_w_*(2*pad_h_+1)*(2*pad_w_+1)*(2*pad_d_+1)*Luts_);
|
||||
|
||||
size_t Ms0 = Luts_;
|
||||
size_t Ms1 = 2*pad_w_ + 1;
|
||||
size_t Ms2 = 2*pad_h_ + 1;
|
||||
size_t Ms3 = 2*pad_d_ + 1;
|
||||
size_t Ms4 = upsample_w_;
|
||||
size_t Ms5 = upsample_h_;
|
||||
size_t Ms6 = upsample_d_;
|
||||
for(size_t ud = 0; ud < Ms6; ++ud)
|
||||
for(size_t uh = 0; uh < Ms5; ++uh)
|
||||
for(size_t uw = 0; uw < Ms4; ++uw)
|
||||
for(size_t pd = 0; pd < Ms3; ++pd)
|
||||
for(size_t ph = 0; ph < Ms2; ++ph)
|
||||
for(size_t pw = 0; pw < Ms1; ++pw){
|
||||
int32_t* masks_ptr = &h_masks_[Luts_ + pw*Ms0 + ph*Ms0*Ms1 + pd*Ms0*Ms1*Ms2 + uw*Ms0*Ms1*Ms2*Ms3 + uh*Ms0*Ms1*Ms2*Ms3*Ms4 + ud*Ms0*Ms1*Ms2*Ms3*Ms4*Ms5];
|
||||
for(size_t i = 0; i < Ms0; ++i){
|
||||
int32_t l, t, r, s;
|
||||
int32_t mask = 0x0;
|
||||
for(size_t j = 0; j < TK_; ++j){
|
||||
int32_t EBD = 1;
|
||||
int32_t EBH = ((upsample_h_ - uh - 1) + BH_) / upsample_h_;
|
||||
int32_t EBW = ((upsample_w_ - uw - 1) + BW_) / upsample_w_;
|
||||
if(EBD == 0 || EBH == 0 || EBW == 0)
|
||||
continue;
|
||||
std::tie(l, t, r, s) = unpack(i + j, !b_trans_, EBD, EBH, EBW);
|
||||
bool in_bounds_d = (t + pd) >= pad_d_ && (t + pd) < (EBD + pad_d_);
|
||||
bool in_bounds_h = (r + ph) >= pad_h_ && (r + ph) < (EBH + pad_h_);
|
||||
bool in_bounds_w = (s + pw) >= pad_w_ && (s + pw) < (EBW + pad_w_);
|
||||
mask |= (in_bounds_d && in_bounds_h && in_bounds_w) << j;
|
||||
}
|
||||
masks_ptr[i] = mask;
|
||||
}
|
||||
}
|
||||
for(size_t i = 0; i < Luts_; ++i)
|
||||
h_masks_[i] = 0x0;
|
||||
}
|
||||
|
||||
std::array<size_t, 3> conv::get_grid(size_t TM, size_t TN){
|
||||
return {(M_ + TM - 1)/TM, (N_ + TN - 1)/TN, 1};
|
||||
}
|
||||
|
||||
size_t conv::num_flops() const{
|
||||
return 2.*M_*N_*K_;
|
||||
}
|
||||
|
||||
void conv::init_impl(driver::stream *stream, triton::driver::cu_module* module, triton::runtime::launch_information info) {
|
||||
auto init_lut = [&](bool is_cst, const char *name, std::vector<int32_t> host) -> triton::driver::buffer*{
|
||||
if(host.empty())
|
||||
return nullptr;
|
||||
size_t nbytes = host.size()*4;
|
||||
// get buffer
|
||||
triton::driver::buffer* buffer;
|
||||
if(is_cst)
|
||||
buffer = module->symbol(name);
|
||||
else
|
||||
buffer = triton::driver::buffer::create(stream->context(), nbytes);
|
||||
// copy
|
||||
stream->write(buffer, false, 0, nbytes, host.data());
|
||||
return buffer;
|
||||
};
|
||||
if(d_a_deltas_ == nullptr)
|
||||
d_a_deltas_ = init_lut(is_a_deltas_cst, "delta", h_a_deltas_);
|
||||
if(d_b_deltas_ == nullptr)
|
||||
d_b_deltas_ = init_lut(is_b_deltas_cst_, "b_delta", h_b_deltas_);
|
||||
if(d_masks_ == nullptr)
|
||||
d_masks_ = init_lut(is_mask_cst_, "masks", h_masks_);
|
||||
if(d_locks_ == nullptr){
|
||||
d_locks_ = triton::driver::buffer::create(stream->context(), max_grid_0_*max_grid_1_*4*2);
|
||||
((triton::driver::cu_buffer*)d_locks_)->set_zero(stream, max_grid_0_*max_grid_1_*4*2);
|
||||
}
|
||||
}
|
||||
|
||||
void conv::set_arg(driver::kernel *kernel,
|
||||
driver::buffer *a, driver::buffer *b, driver::buffer *c, driver::buffer *bias)
|
||||
{
|
||||
kernel->setArg(0, a);
|
||||
kernel->setArg(1, b);
|
||||
kernel->setArg(2, c);
|
||||
kernel->setArg(3, bias);
|
||||
kernel->setArg(4, M_);
|
||||
kernel->setArg(5, N_);
|
||||
kernel->setArg(6, K_);
|
||||
kernel->setArg(7, AH_);
|
||||
kernel->setArg(8, AW_);
|
||||
kernel->setArg(9, BH_);
|
||||
kernel->setArg(10, BW_);
|
||||
kernel->setArg(11, CH_);
|
||||
kernel->setArg(12, CW_);
|
||||
kernel->setArg(13, NC_);
|
||||
// A arguments
|
||||
kernel->setArg(14, ld_a_[a_outer_idx_]);
|
||||
kernel->setArg(15, ld_a_[a_inner_idx_]);
|
||||
kernel->setArg(16, ld_a_[2]);
|
||||
kernel->setArg(17, ld_a_[3]);
|
||||
kernel->setArg(18, ld_a_[4]);
|
||||
// B arguments
|
||||
kernel->setArg(19, ld_b_[b_inner_idx_]);
|
||||
kernel->setArg(20, ld_b_[b_pix_idx_]);
|
||||
kernel->setArg(21, ld_b_[b_pix_idx_+1]);
|
||||
kernel->setArg(22, ld_b_[b_pix_idx_+2]);
|
||||
kernel->setArg(23, ld_b_[b_outer_idx_]);
|
||||
// C arguments
|
||||
kernel->setArg(24, ld_c_[c_outer_0_idx_]);
|
||||
kernel->setArg(25, ld_c_[c_outer_1_idx_]);
|
||||
kernel->setArg(26, ld_c_[c_pix_idx]);
|
||||
kernel->setArg(27, ld_c_[c_pix_idx+1]);
|
||||
kernel->setArg(28, ld_c_[c_pix_idx+2]);
|
||||
// pad
|
||||
kernel->setArg(29, pad_h_);
|
||||
kernel->setArg(30, pad_w_);
|
||||
// stride
|
||||
kernel->setArg(31, stride_h_);
|
||||
kernel->setArg(32, stride_w_);
|
||||
// dilate
|
||||
kernel->setArg(33, upsample_h_);
|
||||
kernel->setArg(34, upsample_w_);
|
||||
kernel->setArg(35, (int32_t)0);
|
||||
kernel->setArg(36, (int32_t)0);
|
||||
kernel->setArg(37, pad_h_);
|
||||
kernel->setArg(38, pad_w_);
|
||||
kernel->setArg(39, (int32_t)0);
|
||||
kernel->setArg(40, (int32_t)0);
|
||||
kernel->setArg(41, d_locks_);
|
||||
kernel->setArg(42, max_grid_0_);
|
||||
kernel->setArg(43, max_grid_1_);
|
||||
size_t idx = 44;
|
||||
if(!is_a_deltas_cst)
|
||||
kernel->setArg(idx++, d_a_deltas_);
|
||||
if(!is_b_deltas_cst_)
|
||||
kernel->setArg(idx++, d_b_deltas_);
|
||||
if(!is_mask_cst_)
|
||||
kernel->setArg(idx++, d_masks_);
|
||||
}
|
||||
|
||||
void conv::enqueue_impl(driver::stream *stream, driver::kernel *kernel,
|
||||
std::vector<driver::buffer*> args,
|
||||
runtime::launch_information info) {
|
||||
driver::buffer *a = args[0], *b = args[1], *c = args[2], *bias = args[3];
|
||||
unsigned TM = info.globals["TM"], TN = info.globals["TN"];
|
||||
unsigned GZ = 1;
|
||||
set_arg(kernel, a, b, c, bias);
|
||||
std::array<size_t, 3> grid = {1};
|
||||
grid[0] = (M_ + TM - 1)/TM;
|
||||
grid[1] = (N_ + TN - 1)/TN;
|
||||
grid[2] = GZ;
|
||||
grid[0] /= upsample_h_*upsample_w_;
|
||||
kernel->setArg(11, CH_/upsample_h_);
|
||||
kernel->setArg(12, CW_/upsample_w_);
|
||||
|
||||
// initialize to zero if necessary
|
||||
bool init_zero = false;
|
||||
for(int32_t off_uh = 0; off_uh < upsample_h_; off_uh++)
|
||||
for(int32_t off_uw = 0; off_uw < upsample_w_; off_uw++) {
|
||||
int32_t EBD = 1;
|
||||
int32_t EBH = ((upsample_h_ - off_uh - 1) + BH_) / upsample_h_;
|
||||
int32_t EBW = ((upsample_w_ - off_uw - 1) + BW_) / upsample_w_;
|
||||
if(EBD == 0 || EBH == 0 || EBW == 0)
|
||||
init_zero = true;
|
||||
}
|
||||
if(init_zero)
|
||||
((driver::cu_buffer*)c)->set_zero(stream, c_size()*4);
|
||||
|
||||
for(int32_t off_uh = 0; off_uh < upsample_h_; off_uh++)
|
||||
for(int32_t off_uw = 0; off_uw < upsample_w_; off_uw++) {
|
||||
int32_t EBD = 1;
|
||||
int32_t EBH = ((upsample_h_ - off_uh - 1) + BH_) / upsample_h_;
|
||||
int32_t EBW = ((upsample_w_ - off_uw - 1) + BW_) / upsample_w_;
|
||||
if(EBD == 0 || EBH == 0 || EBW == 0)
|
||||
continue;
|
||||
int32_t K = shapes_b_[b_inner_idx_]*EBD*EBH*EBW;
|
||||
kernel->setArg(6, K);
|
||||
kernel->setArg(9, EBH);
|
||||
kernel->setArg(10, EBW);
|
||||
kernel->setArg(29, pad_h_);
|
||||
kernel->setArg(30, pad_w_);
|
||||
kernel->setArg(35, off_uh);
|
||||
kernel->setArg(36, off_uw);
|
||||
kernel->setArg(37, (pad_h_ + (1 - upsample_h_)*off_uh)/upsample_h_);
|
||||
kernel->setArg(38, (pad_w_ + (1 - upsample_w_)*off_uw)/upsample_w_);
|
||||
kernel->setArg(39, (off_uh + pad_h_) % upsample_h_);
|
||||
kernel->setArg(40, (off_uw + pad_w_) % upsample_w_);
|
||||
stream->enqueue(kernel, grid, {info.num_threads, 1, 1});
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<unsigned> conv::default_params() {
|
||||
if(b_lut_){
|
||||
if(!b_trans_)
|
||||
return {16, 2, 32, 16, 16, 8, 8, 2, 2, 4, 2, 8, 4, 2, 1};
|
||||
else
|
||||
return {32, 2, 64, 32, 2, 64, 16, 8, 2, 2, 4, 2, 8, 1};
|
||||
}
|
||||
else if(ty_ == FPROP)
|
||||
return {16, 2, 64, 32, 2, 64, 16, 8, 2, 2, 8, 1, 8, 4, 1};
|
||||
else
|
||||
return {16, 2, 64, 16, 16, 16, 4, 2, 2, 4, 2, 8, 4, 2, 1};
|
||||
}
|
||||
|
||||
|
||||
/* CPU reference implementation */
|
||||
|
||||
template<class IN_DTYPE, class OUT_DTYPE>
|
||||
void conv::cpu_xprop(OUT_DTYPE* C, IN_DTYPE* A, IN_DTYPE* B)
|
||||
{
|
||||
IN_DTYPE acc;
|
||||
for(int32_t n = 0; n < shapes_c_[0]; ++n)
|
||||
for(int32_t cf = 0; cf < shapes_c_[1] ; ++cf)
|
||||
for(int32_t cd = 0 ; cd < shapes_c_[2]; ++cd)
|
||||
for(int32_t ch = 0 ; ch < shapes_c_[3]; ++ch)
|
||||
for(int32_t cw = 0; cw < shapes_c_[4]; ++cw)
|
||||
{
|
||||
acc = 0;
|
||||
int32_t d = cd*stride_d_ - pad_d_;
|
||||
int32_t h = ch*stride_h_ - pad_h_;
|
||||
int32_t w = cw*stride_w_ - pad_w_;
|
||||
for(int32_t ac = 0; ac < shapes_a_[1]; ++ac)
|
||||
for(int32_t bd = 0; bd < shapes_b_[1]; ++bd)
|
||||
for(int32_t bh = 0; bh < shapes_b_[2]; ++bh)
|
||||
for(int32_t bw = 0; bw < shapes_b_[3]; ++bw){
|
||||
int32_t ad = d + bd;
|
||||
int32_t ah = h + bh;
|
||||
int32_t aw = w + bw;
|
||||
bool in_bounds = (ad >= 0 && ad < shapes_a_[2] &&
|
||||
ah >= 0 && ah < shapes_a_[3] &&
|
||||
aw >= 0 && aw < shapes_a_[4]);
|
||||
IN_DTYPE a = 0;
|
||||
if(in_bounds)
|
||||
a = A[n*ld_a_[0] + ac*ld_a_[1] + ad*ld_a_[2] + ah*ld_a_[3] + aw*ld_a_[4]];
|
||||
IN_DTYPE b;
|
||||
if(b_trans_)
|
||||
b = B[ac*ld_b_[0] + bd*ld_b_[1] + bh*ld_b_[2] + bw*ld_b_[3] + cf*ld_b_[4]];
|
||||
else{
|
||||
int32_t bdd = shapes_b_[1] - 1 - bd;
|
||||
int32_t bhh = shapes_b_[2] - 1 - bh;
|
||||
int32_t bww = shapes_b_[3] - 1 - bw;
|
||||
b = B[cf*ld_b_[0] + bdd*ld_b_[1] + bhh*ld_b_[2] + bww*ld_b_[3] + ac*ld_b_[4]];
|
||||
}
|
||||
acc = std::fma(a, b, acc);
|
||||
}
|
||||
C[n*ld_c_[0] + cf*ld_c_[1] + cd*ld_c_[2] + ch*ld_c_[3] + cw*ld_c_[4]] = acc;
|
||||
}
|
||||
}
|
||||
|
||||
template<class IN_DTYPE, class OUT_DTYPE>
|
||||
void conv::cpu_wgrad(OUT_DTYPE* C, IN_DTYPE* A, IN_DTYPE* B)
|
||||
{
|
||||
IN_DTYPE acc;
|
||||
for(int32_t c = 0 ; c < shapes_c_[0]; ++c)
|
||||
for(int32_t cd = 0; cd < shapes_c_[1]; ++cd)
|
||||
for(int32_t ch = 0; ch < shapes_c_[2]; ++ch)
|
||||
for(int32_t cw = 0; cw < shapes_c_[3]; ++cw)
|
||||
for(int32_t k = 0 ; k < shapes_c_[4]; ++k)
|
||||
{
|
||||
acc = 0;
|
||||
int32_t d = cd*stride_d_ - pad_d_;
|
||||
int32_t h = ch*stride_h_ - pad_h_;
|
||||
int32_t w = cw*stride_w_ - pad_w_;
|
||||
for(int32_t n = 0; n < shapes_b_[0]; ++n)
|
||||
for(int32_t bd = 0; bd < shapes_b_[2]; ++bd)
|
||||
for(int32_t bh = 0; bh < shapes_b_[3]; ++bh)
|
||||
for(int32_t bw = 0; bw < shapes_b_[4]; ++bw){
|
||||
int32_t ad = d + bd;
|
||||
int32_t ah = h + bh;
|
||||
int32_t aw = w + bw;
|
||||
bool in_bounds = (ad >= 0 && ad < shapes_a_[2] &&
|
||||
ah >= 0 && ah < shapes_a_[3] &&
|
||||
aw >= 0 && aw < shapes_a_[4]);
|
||||
IN_DTYPE a = 0;
|
||||
if(in_bounds)
|
||||
a = A[n*ld_a_[0] + c*ld_a_[1] + ad*ld_a_[2] + ah*ld_a_[3] + aw*ld_a_[4]];
|
||||
IN_DTYPE b = B[n*ld_b_[0] + k*ld_b_[1] + bd*ld_b_[2] + bh*ld_b_[3] + bw*ld_b_[4]];
|
||||
acc = std::fma(a, b, acc);
|
||||
}
|
||||
C[c*ld_c_[0] + cd*ld_c_[1] + ch*ld_c_[2] + cw*ld_c_[3] + k*ld_c_[4]] = acc;
|
||||
}
|
||||
}
|
||||
|
||||
template<class IN_DTYPE, class OUT_DTYPE>
|
||||
void conv::cpu_ref(OUT_DTYPE* C, IN_DTYPE* A, IN_DTYPE* B)
|
||||
{
|
||||
if(ty_ == FPROP || ty_ == BPROP)
|
||||
cpu_xprop(C, A, B);
|
||||
else
|
||||
cpu_wgrad(C, A, B);
|
||||
}
|
||||
|
||||
/* Triton-C source code */
|
||||
|
||||
void conv::triton_c_src(std::ostream &os) const {
|
||||
std::string BS = b_trans_ ? "[TN,TK]" : "[TK, TN]";
|
||||
std::string bcb0 = b_trans_ ? "[:, newaxis]" : "[newaxis, :]";
|
||||
std::string bcb1 = b_trans_ ? "[newaxis, :]" : "[:, newaxis]";
|
||||
std::string ldb0 = b_trans_ ? "*ldb_s" : "";
|
||||
std::string useb = b_trans_ ? "trans(b)" : "b";
|
||||
std::string flipr = b_trans_ ? "" : "BH - 1 -";
|
||||
std::string flips = b_trans_ ? "" : "BW - 1 -";
|
||||
std::string upar = ty_ == WGRAD ? "stride_h * ": "";
|
||||
std::string upas = ty_ == WGRAD ? "stride_w * ": "";
|
||||
std::string upah = ty_ == WGRAD ? "": "*stride_h";
|
||||
std::string upaw = ty_ == WGRAD ? "": "*stride_w";
|
||||
std::vector<std::string> crs = {"c", "r", "s"};
|
||||
std::vector<std::string> rsc = {"r", "s", "c"};
|
||||
std::vector<std::string> ax = b_trans_ ? crs : rsc;
|
||||
std::vector<std::string> redax;
|
||||
if(b_trans_)
|
||||
redax = {"NC", "BH", "BW"};
|
||||
else
|
||||
redax = {"BH", "BW", "NC"};
|
||||
std::string inc_pb = b_lut_ ? "db" + bcb1 : "TK" + ldb0;
|
||||
std::string inc_pdb = b_trans_ ? "incd" : "TK";
|
||||
std::string a_delta_mem = is_a_deltas_cst ? "__constant__" : "";
|
||||
std::string b_delta_mem = is_b_deltas_cst_? "__constant__" : "";
|
||||
std::string masks_mem = is_mask_cst_? "__constant__" : "";
|
||||
|
||||
os <<
|
||||
R"(
|
||||
const tunable int TM = {16, 32, 64};
|
||||
const tunable int TN = {16, 32, 64};
|
||||
const tunable int TK = {)" << TK_ << R"(};
|
||||
const tunable int GZ = {1};
|
||||
)";
|
||||
if(is_a_deltas_cst)
|
||||
os << "__constant__ int* delta = alloc_const int[" + std::to_string(h_a_deltas_.size()) + "];\n";
|
||||
if(b_lut_ && is_b_deltas_cst_)
|
||||
os << "__constant__ int* b_delta = alloc_const int[" + std::to_string(h_b_deltas_.size()) + "];\n";
|
||||
if(is_mask_cst_)
|
||||
os << "__constant__ int* masks = alloc_const int[" + std::to_string(h_masks_.size()) + "];\n";
|
||||
os << R"(
|
||||
|
||||
void conv(read_only restrict )" << a_ty_ << R"( *a,
|
||||
read_only restrict )" << b_ty_ << R"( *b,
|
||||
float *c,
|
||||
float *bias,
|
||||
int M, int N, int K,
|
||||
int AH, int AW,
|
||||
int BH, int BW,
|
||||
int CH, int CW,
|
||||
int NC,
|
||||
int lda_n, int lda_c, int lda_d, int lda_h, int lda_w,
|
||||
int ldb_c, int ldb_t, int ldb_r, int ldb_s, int ldb_k,
|
||||
int ldc_n, int ldc_k, int ldc_m, int ldc_p, int ldc_q,
|
||||
int pad_h, int pad_w,
|
||||
int stride_h, int stride_w,
|
||||
int upsample_h, int upsample_w,
|
||||
int off_uh, int off_uw,
|
||||
int off_uah, int off_uaw,
|
||||
int off_uch, int off_ucw,
|
||||
int *locks, int grid0, int grid1)";
|
||||
if(!is_a_deltas_cst)
|
||||
os << ", int* delta";
|
||||
if(b_lut_ && !is_b_deltas_cst_)
|
||||
os << ", int* b_delta";
|
||||
if(!is_mask_cst_)
|
||||
os << ", int* masks";
|
||||
os << R"(){
|
||||
int rxa[TM] = get_global_range[TM](0);
|
||||
int rb0[TN] = get_global_range[TN](1);
|
||||
int rz = get_global_range[1](2);
|
||||
int rka[TK] = 0 ... TK;
|
||||
int rkb[TK] = 0 ... TK;
|
||||
float C[TM, TN] = 0;
|
||||
int ldlut = )" + std::to_string(Luts_) + R"(;
|
||||
int div = K / GZ;
|
||||
int rem = K % GZ;
|
||||
K = select(rz < rem, div, div + rem);
|
||||
int offk = rz*div;
|
||||
rka = rka + offk;
|
||||
rkb = rkb + offk;
|
||||
int rabh[TM] = rxa / CW;
|
||||
int raw[TM] = rxa % CW;
|
||||
int rab[TM] = rabh / CH;
|
||||
int rah[TM] = rabh % CH;
|
||||
rah = rah)" + upaw + R"( - off_uah;
|
||||
raw = raw)" + upah + R"( - off_uaw;
|
||||
int ra0[TM] = rab*lda_n + rah*lda_h + raw*lda_w;
|
||||
int ra)" + ax[0] + ax[1] + "[TK] = rka / " + redax[2] + R"(;
|
||||
int ra)" + ax[2] + "[TK] = rka % " + redax[2] + R"(;
|
||||
int ra)" + ax[0] + "[TK] = ra" + ax[0] + ax[1] + " / " + redax[1] + R"(;
|
||||
int ra)" + ax[1] + "[TK] = ra" + ax[0] + ax[1] + " % " + redax[1] + R"(;
|
||||
rar = )" + flipr + R"( rar;
|
||||
ras = )" + flips + R"( ras;
|
||||
rar = )" + upar + R"( rar;
|
||||
ras = )" + upas + R"( ras;
|
||||
int ra1[TK] = rac*lda_c + rar*lda_h + ras*lda_w;
|
||||
)" << a_ty_ << R"(* pa[TM, TK] = a + ra1[newaxis, :] + ra0[:, newaxis];)";
|
||||
if(b_lut_){
|
||||
os << R"(
|
||||
int rb)" + ax[0] + ax[1] + "[TK] = rkb / " + redax[2] + R"(;
|
||||
int rb)" + ax[2] + "[TK] = rkb % " + redax[2] + R"(;
|
||||
int rb)" + ax[0] + "[TK] = rb" + ax[0] + ax[1] + " / " + redax[1] + R"(;
|
||||
int rb)" + ax[1] + "[TK] = rb" + ax[0] + ax[1] + " % " + redax[1] + R"(;
|
||||
rbr = rbr*upsample_h + off_uh;
|
||||
rbs = rbs*upsample_w + off_uw;
|
||||
int offdb[TK] = rkb % ldlut;
|
||||
int rb1[TK] = rbc*ldb_c + rbr*ldb_r + rbs*ldb_s;
|
||||
)" + b_delta_mem + R"( int* pdb[TK] = b_delta + offdb + off_uw*ldlut + off_uh*ldlut*upsample_w;
|
||||
int db[TK] = *pdb;)";
|
||||
}
|
||||
else{
|
||||
os << R"(
|
||||
int rb1[TK] = rkb)" + ldb0 + ";";
|
||||
}
|
||||
os << R"(
|
||||
)" << b_ty_ << R"(* pb)" + BS + " = b + rb1" + bcb1 + " + rb0" + bcb0 + R"(*ldb_k;
|
||||
int offda[TK] = rka % ldlut;
|
||||
)" + a_delta_mem + R"( int* pincd[TK] = delta + offda;
|
||||
)" + a_delta_mem + R"( int* pda[TK] = delta + ldlut + offda + off_uw*ldlut + off_uh*ldlut*upsample_w;
|
||||
int da[TK] = *pda;
|
||||
int incd[TK] = *pincd;
|
||||
int maskh[TM] = pad_h + min(rah, 0) + max(rah + BH - AH, 0);
|
||||
int maskw[TM] = pad_w + min(raw, 0) + max(raw + BW - AW, 0);
|
||||
int offma = offk % ldlut;
|
||||
)" + masks_mem + R"( int* pm[TM] = masks + ldlut + offma + maskw*ldlut + maskh*ldlut*(2*pad_w + 1) + off_uw*ldlut*(2*pad_w+1)*(2*pad_h+1) + off_uh*ldlut*(2*pad_w+1)*(2*pad_h+1)*upsample_w;
|
||||
)" + a_delta_mem + R"( int* pincm[TM] = delta + offma;
|
||||
int incm[TM] = *pincm;
|
||||
int maska0[TM] = *pm;
|
||||
int maska1[TK] = 1 << (0 ... TK);
|
||||
bool checka[TM, TK] = (maska0[:, newaxis] & maska1[newaxis, :]) > 0;
|
||||
bool checkb0[TN] = rb0 < N;
|
||||
bool checkb)" + BS + " = checkb0" + bcb0 + R"(;
|
||||
)" << a_ty_ << R"( a[TM, TK] = checka ? *pa : 0;
|
||||
)" << b_ty_ << R"( b)" + BS + R"( = checkb ? *pb : 0;
|
||||
int rkamin[TK] = rka - offk + TK;
|
||||
for(int k = K; k > 0; k = k - TK){
|
||||
C = dot(a, )" + useb + R"(, C);
|
||||
pa = pa + da[newaxis, :];
|
||||
pb = pb + )" + inc_pb + R"(;
|
||||
pda = pda + incd;)";
|
||||
if(b_lut_){
|
||||
os << R"(
|
||||
pdb = pdb + )" + inc_pdb + R"(;
|
||||
db = *pdb;)";
|
||||
}
|
||||
os << R"(
|
||||
pincd = pincd + incd;
|
||||
da = *pda;
|
||||
incd = *pincd;
|
||||
pm = pm + incm;
|
||||
pincm = pincm + incm;
|
||||
incm = *pincm;
|
||||
bool checka1[TK] = (rkamin < k);
|
||||
maska0 = *pm;
|
||||
checka = (maska0[:, newaxis] & maska1[newaxis, :]) > 0;
|
||||
checka = checka && checka1[newaxis,:];
|
||||
a = checka ? *pa : 0;
|
||||
checkb = checkb && (k > TK);
|
||||
@checkb b = *pb;
|
||||
}
|
||||
int rxc[TM] = get_global_range[TM](0);
|
||||
int rc1[TN] = get_global_range[TN](1);
|
||||
int rcn[TM] = rxc / (CH*CW);
|
||||
int rcpq[TM] = rxc % (CH*CW);
|
||||
int rcp[TM] = rcpq / CW;
|
||||
int rcq[TM] = rcpq % CW;
|
||||
rcp = rcp * upsample_h + off_uch;
|
||||
rcq = rcq * upsample_w + off_ucw;
|
||||
bool checkc1[TN] = rc1 < N;
|
||||
int rc0[TM] = rcn * ldc_n + rcp * ldc_p + rcq * ldc_q;
|
||||
float* pc[TM, TN] = c + rc1[newaxis, :]*ldc_k + rc0[:, newaxis];
|
||||
bool checkc0[TM] = rxc < M;
|
||||
bool checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];
|
||||
int ridx = get_program_id(0);
|
||||
int ridy = get_program_id(1);
|
||||
int *plock = locks + ridx + ridy*grid0;
|
||||
while(__atomic_cas(plock, 0, 1) == 1);
|
||||
int *pcount = plock + grid0*grid1;
|
||||
int count = *pcount;
|
||||
int countp1 = select(count == GZ - 1, 0, count + 1);
|
||||
if(count == 0) {)";
|
||||
if(bias_ && ty_==FPROP){
|
||||
os << R"(
|
||||
float* pbias[TN] = bias + rc1;
|
||||
float bias[TN] = checkc1 ? *pbias : 0;
|
||||
C = C + bias[newaxis, :];)";
|
||||
}
|
||||
os << R"(
|
||||
@checkc *pc = C;
|
||||
*pcount = countp1;
|
||||
}
|
||||
else {
|
||||
@checkc *pc = C + *pc;
|
||||
*pcount = countp1;
|
||||
}
|
||||
*plock = 0;
|
||||
})";
|
||||
}
|
||||
|
||||
template void conv::cpu_ref<float,float>(float*, float*, float*);
|
||||
template void conv::cpu_xprop<float,float>(float*, float*, float*);
|
||||
template void conv::cpu_wgrad<float,float>(float*, float*, float*);
|
||||
|
||||
}
|
||||
}
|
162
lib/dnn/dot.cpp
162
lib/dnn/dot.cpp
@@ -1,162 +0,0 @@
|
||||
#include "triton/driver/stream.h"
|
||||
#include "triton/driver/kernel.h"
|
||||
#include "triton/dnn/dot.h"
|
||||
#include "triton/dnn/heuristics.h"
|
||||
#include <string>
|
||||
|
||||
namespace triton{
|
||||
namespace dnn{
|
||||
|
||||
dot::dot(int M, int N, int K,
|
||||
bool AT, bool BT,
|
||||
std::string a_ty, std::string b_ty, std::string c_ty,
|
||||
unsigned align_lda, unsigned align_ldb, unsigned align_ldc)
|
||||
: base("matmul"),
|
||||
M_(M), N_(N), K_(K), AT_(AT), BT_(BT),
|
||||
a_ty_(a_ty), b_ty_(b_ty), c_ty_(c_ty),
|
||||
align_lda_(align_lda), align_ldb_(align_ldb), align_ldc_(align_ldc),
|
||||
locks_(nullptr) {
|
||||
|
||||
}
|
||||
|
||||
size_t dot::num_flops() const {
|
||||
return 2.*M_*N_*K_;
|
||||
}
|
||||
|
||||
// retune parameters
|
||||
std::vector<int64_t> dot::retune_params() const {
|
||||
return {M_, N_, K_, AT_, BT_,
|
||||
(int)align_lda_, (int)align_ldb_};
|
||||
}
|
||||
|
||||
// clone
|
||||
base* dot::clone() const {
|
||||
return new dot(*this);
|
||||
}
|
||||
|
||||
void dot::init_impl(driver::stream* stream, driver::cu_module *, runtime::launch_information) {
|
||||
std::vector<int32_t> hlocks(2048, 0);
|
||||
if(locks_ == nullptr)
|
||||
locks_ = triton::driver::buffer::create(stream->context(), hlocks.size()*4);
|
||||
stream->write(locks_, false, 0, hlocks);
|
||||
}
|
||||
|
||||
void dot::enqueue_impl(driver::stream *stream, driver::kernel *kernel,
|
||||
std::vector<driver::buffer*> args,
|
||||
runtime::launch_information info) {
|
||||
driver::buffer *a = args[0], *b = args[1], *c = args[2];
|
||||
unsigned TM = info.globals.at("TM");
|
||||
unsigned TN = info.globals.at("TN");
|
||||
unsigned TK = info.globals.at("TK");
|
||||
unsigned grid_0 = (M_ + TM - 1)/TM;
|
||||
unsigned grid_1 = (N_ + TN - 1)/TN;
|
||||
unsigned grid_2 = 1;
|
||||
int32_t lda = AT_ ? K_ : M_;
|
||||
int32_t ldb = BT_ ? N_ : K_;
|
||||
int32_t ldc = M_;
|
||||
std::array<size_t, 3> grid = {grid_0, grid_1, grid_2};
|
||||
kernel->setArg(0, a);
|
||||
kernel->setArg(1, b);
|
||||
kernel->setArg(2, c);
|
||||
kernel->setArg(3, M_);
|
||||
kernel->setArg(4, N_);
|
||||
kernel->setArg(5, K_);
|
||||
kernel->setArg(6, lda);
|
||||
kernel->setArg(7, ldb);
|
||||
kernel->setArg(8, ldc);
|
||||
kernel->setArg(9, TK);
|
||||
kernel->setArg(10, locks_);
|
||||
kernel->setArg(11, grid_0);
|
||||
kernel->setArg(12, grid_1);
|
||||
stream->enqueue(kernel, grid, {info.num_threads, 1, 1});
|
||||
}
|
||||
|
||||
void dot::triton_c_src(std::ostream &os) const {
|
||||
std::string ZS = "1";
|
||||
std::string AS0 = "TM", AS1 = "TK";
|
||||
std::string BS0 = "TK", BS1 = "TN";
|
||||
std::string XAS0 = "TM", XAS1 = "TK / " + ZS, XAS2 = ZS;
|
||||
std::string XBS0 = "TK / " + ZS, XBS1 = ZS, XBS2 = "TN";
|
||||
std::string bca0 = "[newaxis, :]", bca1 = "[:, newaxis]";
|
||||
std::string bcb0 = "[:, newaxis]", bcb1 = "[newaxis, :]";
|
||||
std::string lda0 = "*lda", lda1 = "";
|
||||
std::string ldb0 = "", ldb1 = "*ldb";
|
||||
std::string usea = AT_ ? "trans(a)" : "a";
|
||||
std::string useb = BT_ ? "trans(b)" : "b";
|
||||
if(AT_){
|
||||
std::swap(AS0, AS1);
|
||||
std::swap(XAS0, XAS1);
|
||||
std::swap(XAS1, XAS2);
|
||||
std::swap(bca0, bca1);
|
||||
std::swap(lda0, lda1);
|
||||
}
|
||||
if(BT_){
|
||||
std::swap(BS0, BS1);
|
||||
std::swap(XBS1, XBS2);
|
||||
std::swap(XBS0, XBS1);
|
||||
std::swap(bcb0, bcb1);
|
||||
std::swap(ldb0, ldb1);
|
||||
}
|
||||
std::string AS = AS0 + ", " + AS1;
|
||||
std::string BS = BS0 + ", " + BS1;
|
||||
// std::string XAS = XAS0 + ", " + XAS1 + ", " + XAS2;
|
||||
// std::string XBS = XBS0 + ", " + XBS1 + ", " + XBS2;
|
||||
std::string XCS = "TM, TN";
|
||||
std::string align_lda_str = "multiple_of(" + std::to_string(align_lda_) + ")";
|
||||
std::string align_ldb_str = "multiple_of(" + std::to_string(align_ldb_) + ")";
|
||||
std::string res =
|
||||
R"(
|
||||
const tunable int TM = {128};
|
||||
const tunable int TN = {128};
|
||||
const tunable int TK = {32};
|
||||
|
||||
void matmul(restrict read_only align(16) )" + a_ty_ + R"( *A,
|
||||
restrict read_only align(16) )" + b_ty_ + R"( *B,
|
||||
restrict read_only align(16) )" + c_ty_ + R"( *C,
|
||||
int M, int N, int K,
|
||||
)" + align_lda_str + R"( int lda, )" + align_ldb_str + R"(" int ldb, int ldc,
|
||||
int bound, int *locks, int grid0, int grid1) {
|
||||
int ridx = get_program_id(0);
|
||||
int ridy = get_program_id(1);
|
||||
int rxa[TM] = ridx * TM + (0 ... TM);
|
||||
int ryb[TN] = ridy * TN + (0 ... TN);
|
||||
int rka[TK] = 0 ... TK;
|
||||
int rkb[TK] = 0 ... TK;
|
||||
float xc[)" + XCS + R"(] = 0;
|
||||
)" + a_ty_ + R"(* pa[)" + AS + "] = A + rka" + bca0 + lda0 + " + rxa" + bca1 + lda1 + R"(;
|
||||
)" + b_ty_ + R"(* pb[)" + BS + "] = B + rkb" + bcb0 + ldb0 + " + ryb" + bcb1 + ldb1 + R"(;
|
||||
)" + a_ty_ + R"( a[)" + AS + R"(] = *pa;
|
||||
)" + b_ty_ + R"( b[)" + BS + R"(] = *pb;
|
||||
for(int k = K; k > 0; k = k - TK){
|
||||
xc = dot()" + usea + ", " + useb + R"(, xc);
|
||||
pa = pa + TK)" + lda0 + R"(;
|
||||
pb = pb + TK)" + ldb0 + R"(;
|
||||
a = *pa;
|
||||
b = *pb;
|
||||
}
|
||||
int rxc[TM] = ridx * TM + (0 ... TM);
|
||||
int ryc[TN] = ridy * TN + (0 ... TN);
|
||||
)" + c_ty_ + R"(* pc[TM, TN] = C + ryc[newaxis, :]*ldc + rxc[:, newaxis];
|
||||
)" + c_ty_ + R"( c[TM, TN] = xc;
|
||||
bool checkc0[TM] = rxc < M;
|
||||
bool checkc1[TN] = ryc < N;
|
||||
bool checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];
|
||||
@checkc *pc = c;
|
||||
}
|
||||
)";
|
||||
|
||||
os << res;
|
||||
}
|
||||
|
||||
// small search space for partial auto-tuning
|
||||
std::vector<params_t> dot::search_space() const {
|
||||
return dot_search_space(AT_, BT_);
|
||||
}
|
||||
|
||||
// simple parameter heuristics
|
||||
params_t dot::heuristics() const {
|
||||
return dot_heuristics(AT_, BT_, M_, N_, K_);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
@@ -1,538 +0,0 @@
|
||||
#include <sstream>
|
||||
#include "triton/dnn/shift.h"
|
||||
#include "triton/dnn/heuristics.h"
|
||||
#include "triton/tools/bench.hpp"
|
||||
|
||||
namespace triton{
|
||||
namespace dnn{
|
||||
|
||||
|
||||
shift::shift(int B, int C,
|
||||
int D, int H, int W,
|
||||
int T, int R, int S,
|
||||
int F,
|
||||
int stride_h, int stride_w,
|
||||
const int32_t *shift_h, const int32_t *shift_w,
|
||||
std::string a_ty, std::string b_ty,
|
||||
op_t ty, bool bias,
|
||||
layout_t layout)
|
||||
: base("shift"),
|
||||
B_(B), C_(C),
|
||||
AD_(D), AH_(H), AW_(W),
|
||||
BD_(T), BH_(R), BW_(S),
|
||||
F_(F),
|
||||
stride_d_(1), stride_h_(stride_h), stride_w_(stride_w),
|
||||
shift_h_(shift_h), shift_w_(shift_w),
|
||||
a_ty_(a_ty), b_ty_(b_ty), c_ty_(b_ty),
|
||||
op_(ty), bias_(bias),
|
||||
layout_(layout){
|
||||
// std::cout << B_ << " " << C_ << " " << F_ << " " << stride_h_ << " " << stride_w_ << " " << a_ty_ << " " << b_ty_ << " " << ty_ << " " << layout_ << std::endl;
|
||||
// max number of channels
|
||||
TK_ = (ty == FPROP && a_ty_ == "float") ? 8 : 32;
|
||||
MAX_C_ = 8192 + TK_;
|
||||
// activation sizes
|
||||
CD_ = AD_ / stride_d_;
|
||||
CH_ = AH_ / stride_h_;
|
||||
CW_ = AW_ / stride_w_;
|
||||
// A memory strides: [C, H, W, B]
|
||||
switch(layout_){
|
||||
case CHWN: {
|
||||
lda_n_ = 1;
|
||||
lda_w_ = B_;
|
||||
lda_h_ = B_*AW_;
|
||||
lda_c_ = B_*AW_*AH_;
|
||||
break;
|
||||
}
|
||||
case NCHW: {
|
||||
lda_w_ = 1;
|
||||
lda_h_ = AW_;
|
||||
lda_c_ = AW_*AH_;
|
||||
lda_n_ = AW_*AH_*C_;
|
||||
break;
|
||||
}
|
||||
default:
|
||||
throw std::runtime_error("unsupported input layout");
|
||||
}
|
||||
// Shift edge
|
||||
shift_edge_h_ = (AH_ == stride_h_ && stride_h_ > 1);
|
||||
shift_edge_w_ = (AW_ == stride_w_ && stride_w_ > 1);
|
||||
// B memory strides: [C, F]
|
||||
ldb_n_ = 1;
|
||||
ldb_h_ = 1;
|
||||
ldb_w_ = 1;
|
||||
ldb_c_ = F_;
|
||||
// C memory strides: [F, H, W, B]
|
||||
switch(layout_){
|
||||
case CHWN: {
|
||||
ldc_n_ = 1;
|
||||
ldc_w_ = B_;
|
||||
ldc_h_ = B_*CW_;
|
||||
ldc_f_ = B_*CW_*CH_;
|
||||
break;
|
||||
}
|
||||
case NCHW: {
|
||||
ldc_w_ = 1;
|
||||
ldc_h_ = CW_;
|
||||
ldc_f_ = CW_*CH_;
|
||||
ldc_n_ = CW_*CH_*F_;
|
||||
break;
|
||||
}
|
||||
default:
|
||||
throw std::runtime_error("unsupported input layout");
|
||||
}
|
||||
IAD_ = AD_ - 2*(BD_/2);
|
||||
IAH_ = AH_ - 2*(BH_/2);
|
||||
IAW_ = AW_ - 2*(BW_/2);
|
||||
ICD_ = IAD_ / stride_d_;
|
||||
ICH_ = IAH_ / stride_h_;
|
||||
ICW_ = IAW_ / stride_w_;
|
||||
|
||||
// Equivalent matmul
|
||||
M_ = B_*ICH_*ICW_;
|
||||
N_ = F_;
|
||||
K_ = C_;
|
||||
// transpose
|
||||
AT_ = false;
|
||||
BT_ = true;
|
||||
// C shapes
|
||||
if(layout_ == CHWN)
|
||||
shapes_c_ = {F, CH_, CW_, B};
|
||||
if(layout_ == NCHW)
|
||||
shapes_c_ = {B, F, CH_, CW_};
|
||||
// Weight gradient
|
||||
if(op_ == WGRAD){
|
||||
// b <-> c
|
||||
// b <-> a
|
||||
std::swap(ldb_n_, ldc_n_);
|
||||
std::swap(ldb_w_, ldc_w_);
|
||||
std::swap(ldb_h_, ldc_h_);
|
||||
std::swap(ldb_c_, ldc_f_);
|
||||
std::swap(lda_n_, ldb_n_);
|
||||
std::swap(lda_w_, ldb_w_);
|
||||
std::swap(lda_h_, ldb_h_);
|
||||
std::swap(lda_c_, ldb_c_);
|
||||
std::swap(M_, K_);
|
||||
std::swap(M_, N_);
|
||||
AT_ = true;
|
||||
BT_ = false;
|
||||
shapes_c_ = {C, F};
|
||||
}
|
||||
// Input gradient
|
||||
if(op_ == BPROP){
|
||||
// a <-> c
|
||||
std::swap(lda_n_, ldc_n_);
|
||||
std::swap(lda_w_, ldc_w_);
|
||||
std::swap(lda_h_, ldc_h_);
|
||||
std::swap(lda_c_, ldc_f_);
|
||||
std::swap(K_, N_);
|
||||
AT_ = false;
|
||||
BT_ = false;
|
||||
if(layout_ == CHWN)
|
||||
shapes_c_ = {C, AH_, AW_, B};
|
||||
if(layout_ == NCHW)
|
||||
shapes_c_ = {B, C, AH_, AW_};
|
||||
}
|
||||
// locks
|
||||
max_locks_ = (op_ == WGRAD) ? 8192 : 0;
|
||||
locks_ = nullptr;
|
||||
}
|
||||
|
||||
base* shift::clone() const {
|
||||
return new shift(*this);
|
||||
}
|
||||
|
||||
void shift::build_delta_a() {
|
||||
h_delta_a.resize(MAX_C_);
|
||||
auto shift_h = [&](int c) { return shift_edge_h_ ? (c / AH_) % AH_ : shift_h_[c]; };
|
||||
auto shift_w = [&](int c) { return shift_edge_w_ ? c % AW_ : shift_w_[c]; };
|
||||
if(op_ == FPROP){
|
||||
// compute offset
|
||||
auto offset = [&](unsigned c) {
|
||||
return c*lda_c_ + shift_h(c)*lda_h_ + shift_w(c)*lda_w_;
|
||||
};
|
||||
// populate look-up table
|
||||
for(unsigned c = 0; c < TK_; c++)
|
||||
h_delta_a[c] = offset(c);
|
||||
for(unsigned c = 0; c < C_; c++)
|
||||
h_delta_a[TK_ + c] = offset(c + TK_) - offset(c);
|
||||
}
|
||||
if(op_ == BPROP){
|
||||
for(unsigned c = 0; c < C_; c++){
|
||||
h_delta_a[c] = shift_h(c)*ldc_h_ + shift_w(c)*ldc_w_;
|
||||
}
|
||||
}
|
||||
if(op_ == WGRAD){
|
||||
for(unsigned c = 0; c < C_; c++)
|
||||
h_delta_a[c] = shift_h(c)*ldb_h_ + shift_w(c)*ldb_w_;
|
||||
}
|
||||
}
|
||||
|
||||
size_t shift::c_size() {
|
||||
return std::accumulate(shapes_c_.begin(), shapes_c_.end(),
|
||||
1, std::multiplies<int>());
|
||||
}
|
||||
|
||||
std::vector<int32_t> shift::c_shapes(){
|
||||
return shapes_c_;
|
||||
}
|
||||
|
||||
size_t shift::num_flops() const {
|
||||
return 2.*M_*N_*K_;
|
||||
}
|
||||
|
||||
bool shift::AT() const
|
||||
{ return AT_; }
|
||||
|
||||
bool shift::BT() const
|
||||
{ return BT_; }
|
||||
|
||||
size_t shift::M() const
|
||||
{ return M_; }
|
||||
|
||||
size_t shift::N() const
|
||||
{ return N_; }
|
||||
|
||||
size_t shift::K() const
|
||||
{ return K_; }
|
||||
|
||||
size_t shift::lda() const
|
||||
{ return AT_ ? K_ : M_; }
|
||||
|
||||
size_t shift::ldb() const
|
||||
{ return BT_ ? N_ : K_; }
|
||||
|
||||
size_t shift::ldc() const
|
||||
{ return M_; }
|
||||
|
||||
std::vector<int64_t> shift::retune_params() const {
|
||||
return {B_, C_, F_,
|
||||
AD_, AH_, AW_,
|
||||
BD_, BH_, BW_,
|
||||
CD_, CH_, CW_,
|
||||
(int64_t)shift_h_, (int64_t)shift_w_,
|
||||
stride_h_, stride_w_,
|
||||
layout_, op_,
|
||||
bias_};
|
||||
}
|
||||
|
||||
void shift::init_impl(driver::stream *stream, driver::cu_module *module, triton::runtime::launch_information info) {
|
||||
build_delta_a();
|
||||
triton::driver::buffer* delta_a = ((triton::driver::cu_module*)module)->symbol("delta_a");
|
||||
stream->write(delta_a, false, 0, h_delta_a.size()*4, h_delta_a.data());
|
||||
// locks
|
||||
if(locks_ == nullptr && max_locks_ > 0){
|
||||
std::vector<int32_t> hlocks(2*max_locks_, 0);
|
||||
locks_ = triton::driver::buffer::create(stream->context(), 2*max_locks_*4);
|
||||
stream->write(locks_, false, 0, hlocks);
|
||||
}
|
||||
}
|
||||
|
||||
void shift::deinit_impl() {
|
||||
if(locks_ != nullptr){
|
||||
delete locks_;
|
||||
locks_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
void shift::enqueue_impl(driver::stream *stream, driver::kernel *kernel,
|
||||
std::vector<driver::buffer *> args,
|
||||
runtime::launch_information info) {
|
||||
unsigned TM = info.globals.at("TM"), TN = info.globals.at("TN");
|
||||
unsigned grid_0 = (M_ + TM - 1)/TM;
|
||||
unsigned grid_1 = (N_ + TN - 1)/TN;
|
||||
unsigned num_locks = grid_0 * grid_1;
|
||||
unsigned grid_2 = num_locks < max_locks_ ? info.globals.at("GZ") : 1;
|
||||
std::array<size_t, 3> grid = {grid_0, grid_1, grid_2};
|
||||
driver::buffer *a = args[0], *b = args[1], *c = args[2];
|
||||
// std::cout << op_ << " " << M_ << " " << N_ << " " << K_ << std::endl;
|
||||
kernel->setArg(0, a);
|
||||
kernel->setArg(1, b);
|
||||
kernel->setArg(2, c);
|
||||
kernel->setArg(3, M_);
|
||||
kernel->setArg(4, N_);
|
||||
kernel->setArg(5, K_);
|
||||
kernel->setArg(6, stride_h_);
|
||||
kernel->setArg(7, stride_w_);
|
||||
kernel->setArg(8, lda_n_);
|
||||
kernel->setArg(9, lda_w_);
|
||||
kernel->setArg(10, lda_h_);
|
||||
kernel->setArg(11, lda_c_);
|
||||
kernel->setArg(12, ldb_n_);
|
||||
kernel->setArg(13, ldb_w_);
|
||||
kernel->setArg(14, ldb_h_);
|
||||
kernel->setArg(15, ldb_c_);
|
||||
kernel->setArg(16, ldc_n_);
|
||||
kernel->setArg(17, ldc_w_);
|
||||
kernel->setArg(18, ldc_h_);
|
||||
kernel->setArg(19, ldc_f_);
|
||||
kernel->setArg(20, B_);
|
||||
kernel->setArg(21, IAH_);
|
||||
kernel->setArg(22, IAW_);
|
||||
kernel->setArg(23, BH_);
|
||||
kernel->setArg(24, BW_);
|
||||
kernel->setArg(25, ICH_);
|
||||
kernel->setArg(26, ICW_);
|
||||
kernel->setArg(27, (num_locks > max_locks_) ? nullptr : locks_);
|
||||
kernel->setArg(28, (int32_t)grid[0]);
|
||||
kernel->setArg(29, (int32_t)grid[1]);
|
||||
kernel->setArg(30, (int32_t)grid[2]);
|
||||
if(locks_)
|
||||
((driver::cu_buffer*)locks_)->set_zero(stream, 2*max_locks_*4);
|
||||
stream->enqueue(kernel, grid, {info.num_threads, 1, 1});
|
||||
}
|
||||
|
||||
void shift::triton_c_src(std::ostream &os) const {
|
||||
std::string AS0 = "TM", AS1 = "TK";
|
||||
std::string BS0 = "TK", BS1 = "TN";
|
||||
std::string bcb0 = "[:, newaxis]", bcb1 = "[newaxis, :]";
|
||||
std::string usea = AT_ ? "trans(a)" : "a";
|
||||
std::string useb = BT_ ? "trans(b)" : "b";
|
||||
std::string bca0 = "[newaxis, :]", bca1 = "[:, newaxis]";
|
||||
std::string stride_h = std::to_string(stride_h_);
|
||||
std::string stride_w = std::to_string(stride_w_);
|
||||
if(AT_){
|
||||
std::swap(AS0, AS1);
|
||||
std::swap(bca0, bca1);
|
||||
}
|
||||
if(BT_){
|
||||
std::swap(BS0, BS1);
|
||||
std::swap(bcb0, bcb1);
|
||||
}
|
||||
std::string AS = AS0 + ", " + AS1;
|
||||
std::string BS = BS0 + ", " + BS1;
|
||||
bool is_chwn = layout_ == CHWN;
|
||||
|
||||
std::string lda_b = is_chwn ? "1" : "lda_b";
|
||||
std::string ldb_b = is_chwn ? "1" : "ldb_b";
|
||||
std::string ldc_b = is_chwn ? "1" : "ldc_b";
|
||||
|
||||
|
||||
auto compute_bhw = [&](std::string rx, std::string sz, std::string rkx){
|
||||
std::string B = std::to_string(B_);
|
||||
std::string CW = std::to_string(ICW_);
|
||||
std::string CH = std::to_string(ICH_);
|
||||
|
||||
if(is_chwn) {
|
||||
return R"(
|
||||
int )" + rx + "wh[" + sz + "] = " + rkx + " / " + B + R"(;
|
||||
int )" + rx + "b[" + sz + "] = " + rkx + " % " + B + R"(;
|
||||
int )" + rx + "w[" + sz + "] = (" + rx + "wh % " + CW + R"() + pad_w;
|
||||
int )" + rx + "h[" + sz + "] = (" + rx + "wh / " + CW + R"() + pad_h;)";
|
||||
}
|
||||
else {
|
||||
return R"(
|
||||
int )" + rx + "bh[" + sz + "] = " + rkx + " / " + CW + R"(;
|
||||
int )" + rx + "w[" + sz + "] = (" + rkx + " % " + CW + R"() + pad_w;
|
||||
int )" + rx + "h[" + sz + "] = (" + rx + "bh % " + CH + R"() + pad_h;
|
||||
int )" + rx + "b[" + sz + "] = " + rx + "bh / " + CH + ";";
|
||||
}
|
||||
};
|
||||
|
||||
std::string result =
|
||||
R"(
|
||||
const tunable int TM = {16, 32, 64, 128};
|
||||
const tunable int TN = {16, 32, 64, 128};
|
||||
const tunable int TK = {)" + std::to_string(TK_) + "};";
|
||||
if(op_ == WGRAD)
|
||||
result += "const tunable int GZ = {1};";
|
||||
else
|
||||
result += "const tunable int GZ = {1};";
|
||||
|
||||
result += R"(
|
||||
__constant__ int* delta_a = alloc_const int[)" + std::to_string(MAX_C_) + R"(];
|
||||
|
||||
void shift(restrict read_only align(16) )" + a_ty_ + R"( *A,
|
||||
restrict read_only align(16) )" + b_ty_ + R"( *B,
|
||||
)" + c_ty_ + R"( *C,
|
||||
int M, int N, int K,
|
||||
int stride_h, int stride_w,
|
||||
multiple_of(8) int lda_b, multiple_of(8) int lda_w, multiple_of(8) int lda_h, multiple_of(8) int lda_c,
|
||||
multiple_of(8) int ldb_b, multiple_of(8) int ldb_w, multiple_of(8) int ldb_h, multiple_of(8) int ldb_c,
|
||||
multiple_of(8) int ldc_b, multiple_of(8) int ldc_w, multiple_of(8) int ldc_h, multiple_of(8) int ldc_c,
|
||||
int NB,
|
||||
int AH, int AW,
|
||||
int BH, int BW,
|
||||
int CH, int CW,
|
||||
int* locks, int grid0, int grid1, int grid2) {
|
||||
int ridx = get_program_id(0);
|
||||
int ridy = get_program_id(1);
|
||||
int rz = get_program_id(2);
|
||||
int rxa[TM] = ridx*TM + (0 ... TM);
|
||||
int ryb[TN] = ridy*TN + (0 ... TN);
|
||||
int rka[TK] = 0 ... TK;
|
||||
int rkb[TK] = 0 ... TK;
|
||||
float acc[TM, TN] = 0;
|
||||
int pad_h = BH / 2;
|
||||
int pad_w = BW / 2;)";
|
||||
|
||||
/* A offsets */
|
||||
if(op_ == FPROP){
|
||||
result +=
|
||||
compute_bhw("ra", "TM", "rxa") + R"(
|
||||
raw = raw * )" + stride_w + R"(;
|
||||
rah = rah * )" + stride_h + R"(;
|
||||
int offxa[TM] = rab*)" + lda_b + R"( + raw*lda_w + rah*lda_h;
|
||||
int offa0[TM, TK] = offxa[:, newaxis];
|
||||
__constant__ int* pd[TK] = delta_a + rka;
|
||||
multiple_of(8) int d[TK] = *pd;
|
||||
int offa1[TM, TK] = d[newaxis, :];)";
|
||||
}
|
||||
if(op_ == BPROP){
|
||||
result +=
|
||||
compute_bhw("ra", "TM", "rxa") + R"(
|
||||
int offxa[TM] = rab*)" + lda_b + R"( + raw*lda_w + rah*lda_h;
|
||||
int offa0[TM, TK] = offxa[:, newaxis];
|
||||
int offa1[TM, TK] = rka[newaxis, :] * lda_c;)";
|
||||
}
|
||||
if(op_ == WGRAD){
|
||||
result +=
|
||||
compute_bhw("ra", "TK", "rka") + R"(
|
||||
int offa0[TK, TM] = rxa[newaxis, :] * lda_c;
|
||||
int offxa[TK] = rab*)" + lda_b + R"( + raw*lda_w + rah*lda_h;
|
||||
int offa1[TK, TM] = offxa[:, newaxis];)";
|
||||
}
|
||||
|
||||
/* B offsets */
|
||||
if(op_ == FPROP){
|
||||
result += R"(
|
||||
int offb0[TN, TK] = ryb[:, newaxis];
|
||||
int offb1[TN, TK] = rkb[newaxis, :] * ldb_c;)";
|
||||
}
|
||||
if(op_ == BPROP){
|
||||
result += R"(
|
||||
int offb0[TK, TN] = ryb[newaxis, :] * ldb_c;
|
||||
int offb1[TK, TN] = rkb[:, newaxis];)";
|
||||
}
|
||||
if(op_ == WGRAD){
|
||||
result +=
|
||||
compute_bhw("rb", "TK", "rkb") + R"(
|
||||
__constant__ int* pd[TN] = delta_a + ryb;
|
||||
multiple_of(8) int d[TN] = *pd;
|
||||
multiple_of(8) int shift[TK, TN] = d[newaxis, :];
|
||||
rbw = rbw * )" + stride_w + R"(;
|
||||
rbh = rbh * )" + stride_h + R"(;
|
||||
int offkb[TK] = rbb*)" + ldb_b + R"( + rbw*ldb_w + rbh*ldb_h;
|
||||
int offb0[TK, TN] = ryb[newaxis, :] * ldb_c;
|
||||
int offb1[TK, TN] = offkb[:, newaxis];
|
||||
)" + a_ty_ + "* pa_base[" + AS + R"(] = A + offa0;
|
||||
)" + b_ty_ + "* pb_base[" + BS + R"(] = B + offb0 + shift;
|
||||
)" + a_ty_ + "* pa[" + AS + R"(] = pa_base + offa1;
|
||||
)" + b_ty_ + "* pb[" + BS + R"(] = pb_base + offb1;)";
|
||||
}
|
||||
else{
|
||||
result += R"(
|
||||
)" + a_ty_ + "* pa[" + AS + R"(] = A + offa0 + offa1;
|
||||
)" + b_ty_ + "* pb[" + BS + R"(] = B + offb0 + offb1;)";
|
||||
}
|
||||
|
||||
/* Main loop */
|
||||
/* Increment A pointers */
|
||||
result += R"(
|
||||
bool checka[)" + AS + "] = (rka < K)" + bca0 + R"(;
|
||||
bool checkb[)" + BS + "] = (rkb < K)" + bcb0 + R"(;
|
||||
)" + a_ty_ + " a[" + AS + R"(] = checka ? *pa : 0;
|
||||
)" + b_ty_ + " b[" + BS + R"(] = checkb ? *pb : 0;
|
||||
for(int k = K; k > 0; k = k - TK){
|
||||
acc = dot()" + usea + "," + useb + R"(, acc);
|
||||
bool checka[)" + AS + R"(] = k > TK;
|
||||
bool checkb[)" + BS + R"(] = k > TK;)";
|
||||
|
||||
/* Increment A pointers */
|
||||
if(op_ == FPROP){
|
||||
result += R"(
|
||||
pd = pd + TK;
|
||||
d = *pd;
|
||||
pa = pa + d[newaxis, :];)";
|
||||
}
|
||||
if(op_ == BPROP){
|
||||
result += R"(
|
||||
pa = pa + TK * lda_c;)";
|
||||
}
|
||||
if(op_ == WGRAD){
|
||||
result += R"(
|
||||
rka = rka + TK;)"
|
||||
+ compute_bhw("ra", "TK", "rka") + R"(
|
||||
offxa = rab*)" + lda_b + R"( + raw*lda_w + rah*lda_h;
|
||||
pa = pa_base + offxa[:, newaxis];)";
|
||||
}
|
||||
result += R"(
|
||||
a = checka ? *pa : 0;)";
|
||||
|
||||
/* Increment B pointers */
|
||||
if(op_ == WGRAD){
|
||||
result += R"(
|
||||
rkb = rkb + TK;)"
|
||||
+ compute_bhw("rb", "TK", "rkb") + R"(
|
||||
rbw = rbw * )" + stride_w + R"(;
|
||||
rbh = rbh * )" + stride_h + R"(;
|
||||
offkb = rbb*)" + ldb_b + R"( + rbw*ldb_w + rbh*ldb_h;
|
||||
pb = pb_base + offkb[:, newaxis];)";
|
||||
}
|
||||
if(op_ == FPROP){
|
||||
result += R"(
|
||||
pb = pb + TK * ldb_c;)";
|
||||
}
|
||||
if(op_ == BPROP){
|
||||
result += R"(
|
||||
pb = pb + TK;)";
|
||||
}
|
||||
result += R"(
|
||||
b = checkb ? *pb : 0;
|
||||
}
|
||||
int rxc[TM] = ridx*TM + (0 ... TM);
|
||||
int ryc[TN] = ridy*TN + (0 ... TN);)";
|
||||
|
||||
/* C offsets */
|
||||
if(op_ == BPROP){
|
||||
result +=
|
||||
compute_bhw("rc", "TM", "rxc") + R"(
|
||||
rcw = rcw * )" + stride_w + R"(;
|
||||
rch = rch * )" + stride_h + R"(;
|
||||
int offxc[TM] = rcb*)" + ldc_b + R"( + rcw*ldc_w + rch*ldc_h;)";
|
||||
}
|
||||
if(op_ == FPROP){
|
||||
result +=
|
||||
compute_bhw("rc", "TM", "rxc") + R"(
|
||||
int offxc[TM] = rcb*)" + ldc_b + R"( + rcw*ldc_w + rch*ldc_h;)";
|
||||
}
|
||||
if(op_ == WGRAD){
|
||||
result += R"(
|
||||
int offxc[TM] = rxc;)";
|
||||
}
|
||||
result += R"("
|
||||
)" + c_ty_ + R"( c[TM, TN] = acc;
|
||||
)" + c_ty_ + R"(* pc[TM, TN] = C + offxc[:, newaxis] + ryc[newaxis, :]*ldc_c;
|
||||
bool checkc0[TM] = rxc < M;
|
||||
bool checkc1[TN] = ryc < N;
|
||||
bool checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];)";
|
||||
if(op_ == BPROP){
|
||||
result += R"(
|
||||
__constant__ int* pd[TN] = delta_a + ryc;
|
||||
)" + c_ty_ + R"(* shift_pc[TM, TN] = pc + (*pd)[newaxis, :];
|
||||
@checkc *shift_pc = c;
|
||||
)";
|
||||
}
|
||||
else{
|
||||
result += R"(
|
||||
@checkc *pc = c;)";
|
||||
}
|
||||
result += R"(
|
||||
})";
|
||||
|
||||
os << result;
|
||||
}
|
||||
|
||||
|
||||
// small search space for partial auto-tuning
|
||||
std::vector<params_t> shift::search_space() const {
|
||||
return dot_search_space(AT_, BT_);
|
||||
}
|
||||
|
||||
// simple parameter heuristics
|
||||
params_t shift::heuristics() const {
|
||||
return dot_heuristics(AT_, BT_, M_, N_, K_);
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
}
|
@@ -1,284 +0,0 @@
|
||||
#include <string>
|
||||
#include "triton/lang/lang.h"
|
||||
#include "triton/codegen/selection/target.h"
|
||||
#include "triton/ir/context.h"
|
||||
#include "triton/ir/context_impl.h"
|
||||
#include "triton/driver/device.h"
|
||||
#include "triton/driver/error.h"
|
||||
#include "triton/runtime/jit.h"
|
||||
#include "llvm/IR/IRPrintingPasses.h"
|
||||
#include "llvm/IR/Module.h"
|
||||
#include "llvm/IR/LLVMContext.h"
|
||||
#include "llvm/IR/PassManager.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "llvm/Support/TargetRegistry.h"
|
||||
#include "llvm/Support/TargetSelect.h"
|
||||
#include "llvm/Target/TargetMachine.h"
|
||||
#include "llvm/Target/TargetOptions.h"
|
||||
#include "llvm/CodeGen/TargetPassConfig.h"
|
||||
#include "llvm/IR/LegacyPassManager.h"
|
||||
#include "llvm/Transforms/Scalar/EarlyCSE.h"
|
||||
#include "llvm/Analysis/LoopPass.h"
|
||||
#include "triton/tools/thread_pool.h"
|
||||
#include <mutex>
|
||||
|
||||
typedef struct yy_buffer_state * YY_BUFFER_STATE;
|
||||
extern int yyparse();
|
||||
extern YY_BUFFER_STATE yy_scan_string(const char * str);
|
||||
extern void yy_delete_buffer(YY_BUFFER_STATE buffer);
|
||||
extern triton::lang::translation_unit *ast_root;
|
||||
|
||||
namespace triton {
|
||||
namespace runtime{
|
||||
|
||||
void parallel_loop_nest(std::vector<size_t> const & ranges,
|
||||
std::function<void(std::vector<size_t> const &)> const & f,
|
||||
size_t nthreads){
|
||||
size_t D = ranges.size();
|
||||
std::vector<size_t> values(D, 0);
|
||||
// thread pools
|
||||
// ThreadPool pool(nthreads);
|
||||
// Start with innermost loop
|
||||
size_t i = D - 1;
|
||||
while(true){
|
||||
// Execute function
|
||||
// pool.enqueue(f,values);
|
||||
f(values);
|
||||
while(values[i]++ == ranges[i] - 1){
|
||||
if(i == 0)
|
||||
return;
|
||||
values[i--] = 0;
|
||||
}
|
||||
i = D - 1;
|
||||
// Short sleep so that the thread pool doesn't grow too big
|
||||
std::this_thread::sleep_for(std::chrono::microseconds(1));
|
||||
}
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void parallel_loop_nest(std::vector<std::vector<T>> const & iterates, std::function<void(std::vector<T>)> const & f, size_t nthreads){
|
||||
//Ranges to iterate over
|
||||
std::vector<size_t> ranges;
|
||||
for(auto const & x: iterates)
|
||||
ranges.push_back(x.size());
|
||||
//Proxy function
|
||||
auto proxy = [&](std::vector<size_t> const & idx){
|
||||
std::vector<T> x(iterates.size());
|
||||
for(size_t i = 0; i < x.size(); ++i)
|
||||
x[i] = iterates[i][idx[i]];
|
||||
f(x);
|
||||
};
|
||||
//Iterate
|
||||
parallel_loop_nest(ranges, proxy, nthreads);
|
||||
}
|
||||
|
||||
void parallel_for_each(std::vector<std::vector<unsigned>> const & iterates, std::function<void(std::vector<unsigned>)> const & f, size_t nthreads) {
|
||||
ThreadPool pool(nthreads);
|
||||
for(const std::vector<unsigned>& values: iterates)
|
||||
pool.enqueue(f, values);
|
||||
}
|
||||
|
||||
|
||||
std::unique_ptr<llvm::Module> jit::make_llvm_module(ir::module &module, passes_wrapper &passes, llvm::LLVMContext& llvm_context, launch_information& info) {
|
||||
llvm::Module* result = new llvm::Module(module.get_name(), llvm_context);
|
||||
passes.selection.run(module, *result);
|
||||
// add globals
|
||||
for(auto x: module.globals())
|
||||
info.globals[x.first] = ((ir::metaparameter*)x.second)->get_value();
|
||||
// number of threads
|
||||
info.num_threads = passes.tune.get_num_threads();
|
||||
return std::unique_ptr<llvm::Module>(result);
|
||||
}
|
||||
|
||||
triton::lang::translation_unit *jit::parse_program(const char *name, const char *src) {
|
||||
// create AST from Triton-C source
|
||||
YY_BUFFER_STATE buffer = yy_scan_string(src);
|
||||
yyparse();
|
||||
yy_delete_buffer(buffer);
|
||||
triton::lang::translation_unit *program = ast_root;
|
||||
return program;
|
||||
}
|
||||
|
||||
std::unique_ptr<ir::module> jit::make_triton_module(const char * name, triton::ir::context &context, triton::lang::translation_unit *program) {
|
||||
// create Triton-IR from AST
|
||||
ir::module* module = new ir::module(name, context);
|
||||
program->codegen(module);
|
||||
return std::unique_ptr<ir::module>(module);
|
||||
}
|
||||
|
||||
|
||||
jit::jit(driver::context *context, unsigned nthreads): driver_context_(context),
|
||||
target_(context->device()->make_target()),
|
||||
nthreads_(nthreads) { }
|
||||
|
||||
jit::~jit(){ }
|
||||
|
||||
std::vector<unsigned> jit::get_valid(const char *name, const char *src) {
|
||||
// find metaparameters
|
||||
triton::lang::translation_unit* program = parse_program(name, src);
|
||||
auto ptt_module = make_triton_module(name, triton_context_, program);
|
||||
ir::module &tt_module = *ptt_module;
|
||||
// set parameters
|
||||
passes_wrapper passes(target_.get());
|
||||
passes.target_independent(tt_module);
|
||||
passes.tune.run(tt_module);
|
||||
auto mps = passes.tune.get_params(tt_module);
|
||||
// create parameter ranges
|
||||
std::vector<std::vector<unsigned>> ranges;
|
||||
for(ir::metaparameter *mp: mps)
|
||||
ranges.push_back(mp->get_space());
|
||||
// iterate over parameters
|
||||
std::vector<unsigned> result;
|
||||
parallel_loop_nest<unsigned>(ranges, [&](const std::vector<unsigned> params){
|
||||
if(!result.empty())
|
||||
return;
|
||||
std::map<ir::value*, std::vector<std::string>> errors;
|
||||
unsigned i = 0;
|
||||
for(ir::metaparameter *mp: mps)
|
||||
mp->set_value(params[i++]);
|
||||
passes.tune.init(tt_module);
|
||||
passes.tune.check_constraints(errors);
|
||||
if(!errors.empty())
|
||||
return;
|
||||
result = params;
|
||||
}, 1);
|
||||
if(result.empty())
|
||||
throw std::runtime_error("couldn't find valid parameters");
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
|
||||
jit::tune_res_t jit::autotune(const char *name, const char *src, benchmark_t benchmark, const std::vector<std::vector<unsigned>> & targets) {
|
||||
// find metaparameters
|
||||
triton::lang::translation_unit* program = parse_program(name, src);
|
||||
auto ptt_module_0 = make_triton_module(name, triton_context_, program);
|
||||
ir::module &tt_module_0 = *ptt_module_0;
|
||||
// set parameters
|
||||
passes_wrapper passes_0(target_.get());
|
||||
passes_0.target_independent(tt_module_0);
|
||||
passes_0.tune.run(tt_module_0);
|
||||
auto mps = passes_0.tune.get_params(tt_module_0);
|
||||
// iterate over parameters
|
||||
tune_res_t best;
|
||||
// update_best
|
||||
std::mutex mutex;
|
||||
auto update_best = [&](const std::vector<unsigned> params){
|
||||
std::map<ir::value*, std::vector<std::string>> errors;
|
||||
unsigned i = 0;
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mutex);
|
||||
for(ir::metaparameter *mp: mps)
|
||||
mp->set_value(params[i++]);
|
||||
// for(size_t i = 0; i < params.size(); i++)
|
||||
// std::cout << ((i==0)?"":", ") << params[i] << std::flush;
|
||||
// std::cout << std::endl;
|
||||
passes_0.tune.init(tt_module_0);
|
||||
passes_0.tune.check_constraints(errors);
|
||||
// for(auto x: errors)
|
||||
// for(auto e: x.second){
|
||||
// std::cout << x.first->get_name() << ": " << e << std::endl;
|
||||
// }
|
||||
}
|
||||
if(!errors.empty())
|
||||
return;
|
||||
// Deep copy of the module and tuner
|
||||
triton::ir::context triton_context;
|
||||
auto ptt_module_1 = make_triton_module(name, triton_context, program);
|
||||
ir::module &tt_module_1 = *ptt_module_1;
|
||||
// run passes
|
||||
passes_wrapper passes_1(target_.get());
|
||||
passes_1.target_independent(tt_module_1);
|
||||
passes_1.tune.run(tt_module_1);
|
||||
i = 0;
|
||||
for(ir::metaparameter* mp: passes_1.tune.get_params(tt_module_1)){
|
||||
mp->set_value(params[i++]);
|
||||
}
|
||||
passes_1.tune.init(tt_module_1);
|
||||
passes_1.target_dependent(tt_module_1);
|
||||
driver::device* device = driver_context_->device();
|
||||
if(passes_1.shmem_allocation.get_allocated_size() > device->max_shared_memory())
|
||||
return;
|
||||
if(passes_1.tune.get_num_threads() > device->max_threads_per_block())
|
||||
return;
|
||||
// Compile
|
||||
launch_information info;
|
||||
llvm::LLVMContext llvm_context;
|
||||
auto ll_module = make_llvm_module(tt_module_1, passes_1, llvm_context, info);
|
||||
std::unique_ptr<driver::module> module(driver::module::create(driver_context_, &*ll_module));
|
||||
double perf;
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mutex);
|
||||
std::unique_ptr<driver::kernel> kernel(driver::kernel::create(module.get(), name));
|
||||
perf = benchmark(kernel.get(), info);
|
||||
if(perf > best.perf){
|
||||
best.perf = perf;
|
||||
best.params = params;
|
||||
}
|
||||
for(size_t i = 0; i < params.size(); i++)
|
||||
std::cout << ((i==0)?"":", ") << params[i] << std::flush;
|
||||
std::cout << ", " << perf << " [ " << best.perf << " ] " << std::endl;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
if(targets.empty()) {
|
||||
// create parameter ranges
|
||||
std::vector<std::vector<unsigned>> ranges;
|
||||
for(ir::metaparameter *mp: mps)
|
||||
ranges.push_back(mp->get_space());
|
||||
parallel_loop_nest<unsigned>(ranges, update_best, nthreads_);
|
||||
}
|
||||
else {
|
||||
parallel_for_each(targets, update_best, nthreads_);
|
||||
}
|
||||
|
||||
if(best.params.empty())
|
||||
throw std::runtime_error("auto-tuning didn't find valid parameters");
|
||||
// std::cout << "Autotuning done - Best performance: " << best.perf << std::endl;
|
||||
return best;
|
||||
}
|
||||
|
||||
void jit::add_module(ir::module &tt_module, const std::vector<unsigned> ¶ms) {
|
||||
// set parameters
|
||||
passes_wrapper passes(target_.get());
|
||||
passes.target_independent(tt_module);
|
||||
passes.tune.run(tt_module);
|
||||
unsigned i = 0;
|
||||
for(ir::metaparameter* mp: passes.tune.get_params(tt_module))
|
||||
mp->set_value(params[i++]);
|
||||
passes.tune.init(tt_module);
|
||||
passes.target_dependent(tt_module);
|
||||
// check constraints
|
||||
std::map<ir::value*, std::vector<std::string>> errors;
|
||||
passes.tune.check_constraints(errors);
|
||||
for(auto x: errors){
|
||||
for(auto str: x.second)
|
||||
std::cout << x.first->get_name() << ": " << str << std::endl;
|
||||
}
|
||||
if(errors.size())
|
||||
throw std::runtime_error("invalid parameters");
|
||||
// triton module -> llvm module
|
||||
std::string name = tt_module.get_name();
|
||||
auto ll_module = make_llvm_module(tt_module, passes, llvm_context_, launch_info_map_[name]);
|
||||
// llvm module -> machine code
|
||||
modules_.insert({name, driver::module::create(driver_context_, &*ll_module)});
|
||||
}
|
||||
|
||||
void jit::add_module(const char *name, const char *src, const std::vector<unsigned> ¶ms) {
|
||||
triton::lang::translation_unit* program = parse_program(name, src);
|
||||
auto ptt_module = make_triton_module(name, triton_context_, program);
|
||||
add_module(*ptt_module, params);
|
||||
}
|
||||
|
||||
driver::kernel *jit::get_function(const char *name) {
|
||||
return driver::kernel::create(modules_.at(name), name);
|
||||
}
|
||||
|
||||
launch_information jit::get_launch_info(const char *name) {
|
||||
return launch_info_map_.at(name);
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user