[RUNTIME] Added auto-alignment mechanism (#71)

This PR adds an automatic memory alignment mechanism in the Triton runtime. Specifically, the JIT compiler detects the alignment (in bytes) of each pointer argument as well as the largest power of two divisor (between 1 and 16) of each integer argument. Proper .aligned and .multipleof attributes are then added to the Triton-IR on-the-fly for all auto-tunable kernels. There is a cache that remembers all the kernels compiled for each possible configuration.

This PR also includes substantial cleaning of the Python API. This adds 2-3us overhead, mostly due to accessing integer #defines from the auto-tuned compilation options. The previous solution was slightly faster but hacky and potentially unsafe, so this is preferred for now.
This commit is contained in:
Philippe Tillet
2021-03-04 01:51:11 -05:00
committed by Philippe Tillet
parent ff62f7fffc
commit 62835a0979
19 changed files with 668 additions and 707 deletions

View File

@@ -10,6 +10,7 @@
#include <functional>
#include "triton/ir/builder.h"
#include "triton/ir/metadata.h"
#include "triton/ir/context.h"
namespace triton{
@@ -60,7 +61,7 @@ private:
void push_function(function *fn) { functions_.push_back(fn); }
public:
module(const std::string &name, context &ctx);
module(const std::string &name);
context& get_context();
builder& get_builder();
// Setters
@@ -94,7 +95,7 @@ public:
private:
std::string name_;
context &context_;
context context_;
builder builder_;
std::map<val_key_t, value*> values_;
std::map<val_key_t, type*> types_;

View File

@@ -5,6 +5,7 @@
#include <string>
#include <stdexcept>
#include <sstream>
namespace triton{
namespace ir{
@@ -17,73 +18,8 @@ namespace driver{
namespace runtime {
enum arg_type {
INT1_T,
INT8_T,
INT16_T,
INT32_T,
INT64_T,
HALF_T,
FLOAT_T,
DOUBLE_T,
BUFFER_T
};
arg_type convert(ir::type *ty);
inline size_t size_of(arg_type ty){
switch(ty){
case INT1_T: return 1;
case INT8_T: return 1;
case INT16_T: return 2;
case INT32_T: return 4;
case INT64_T: return 8;
case HALF_T: return 2;
case FLOAT_T: return 4;
case DOUBLE_T: return 8;
case BUFFER_T: return 8;
default: throw std::runtime_error("unknown type");
}
}
inline bool is_int_type(arg_type ty){
return ty == INT1_T || ty == INT8_T || ty == INT16_T ||
ty == INT32_T || ty == INT64_T;
}
class arg {
public:
union value_t {
bool int1;
int8_t int8;
int16_t int16;
int32_t int32;
int64_t int64;
uint16_t fp16;
float fp32;
double fp64;
driver::buffer* buf;
};
public:
// construct from primitive types
arg(arg_type ty, value_t val): ty_(ty) { val_ = val; }
arg(int32_t x): ty_(INT32_T) { val_.int32 = x; }
arg(int64_t x): ty_(INT64_T) { val_.int64 = x; }
arg(float x): ty_(FLOAT_T) { val_.fp32 = x; }
arg(double x): ty_(DOUBLE_T) { val_.fp64 = x; }
arg(driver::buffer* x): ty_(BUFFER_T) { val_.buf = x; }
// accessors
arg_type type() const { return ty_; }
void* data() const { return (void*)&val_; }
driver::buffer* buffer() const { return val_.buf; }
private:
arg_type ty_;
value_t val_;
};
}
}

View File

@@ -11,6 +11,7 @@
#include <memory>
#include <functional>
// codegen
#include "triton/ir/function.h"
#include "triton/ir/context.h"
#include "triton/runtime/arg.h"
#include "triton/runtime/error.h"
@@ -37,63 +38,86 @@ class context;
namespace triton{
namespace runtime{
typedef std::vector<size_t> grid_t;
typedef std::map<std::string, size_t> params_t;
template<typename T> inline T convert(const std::string& name);
template<> inline long convert<long>(const std::string& name) { return std::stol(name); }
template<> inline int convert<int>(const std::string& name) { return std::stoi(name); }
/* ------------------------- */
/* Compilation options */
/* ------------------------- */
struct options_t {
template<class T>
T D(const std::string& name) const {
return std::stoi(defines.at(name));
}
std::unordered_map<std::string, std::string> defines;
int num_warps;
};
/* ------------------------- */
/* Runtime arguments */
/* ------------------------- */
enum arg_type {
INT1_T,
INT8_T,
INT16_T,
INT32_T,
INT64_T,
HALF_T,
FLOAT_T,
DOUBLE_T,
BUFFER_T
};
inline size_t size_of(arg_type ty){
switch(ty){
case INT1_T : return 1;
case INT8_T : return 1;
case INT16_T : return 2;
case INT32_T : return 4;
case INT64_T : return 8;
case HALF_T : return 2;
case FLOAT_T : return 4;
case DOUBLE_T: return 8;
case BUFFER_T: return 8;
default: throw std::runtime_error("unknown type");
}
}
template<class T>
void add_arg(std::stringstream& ss, T arg) {
ss.write((char*)&arg, sizeof(T));
}
/* ------------------------- */
/* ------------------------- */
enum asm_mode_t {
ASM_LLIR,
ASM_NV_PTX,
ASM_NV_SASS
};
struct options_t {
template<class T>
T D(const std::string& name) const {
return convert<T>(defines.at(name));
}
std::unordered_map<std::string, std::string> defines;
int num_warps;
};
/* ------------------------- */
class kernel{
private:
static std::string preheader();
static arg_type convert(ir::type *ty);
public:
typedef std::vector<size_t> grid_t;
public:
kernel(const std::string& src, const options_t& opt, driver::device *device);
void operator()(void* args, size_t args_size, driver::stream *stream, const std::vector<size_t>& grid) const;
// getters
const std::vector<arg_type>& get_sig() const { return sig_; }
const std::vector<std::string>& get_arg_names() const { return arg_names_; }
std::string get_asm(asm_mode_t mode);
static std::shared_ptr<ir::module> src_to_ir(const std::string& src, const options_t& opt);
static std::tuple<std::shared_ptr<driver::module>,
std::shared_ptr<driver::kernel>,
size_t> ir_to_bin(ir::module& ir, driver::device *dev, const options_t &opt);
private:
void init_ir (const std::string &src);
void init_ker();
void init_sig();
public:
kernel(const std::string& src, const options_t& opt, driver::device *device, const std::map<int, triton::ir::attribute> &attrs = {});
void operator()(const std::string& args, driver::stream *stream, const grid_t& grid) const;
std::string get_asm(asm_mode_t mode);
public:
const options_t opt;
private:
driver::device* dev_;
// signature
std::vector<arg_type> sig_;
std::vector<std::string> arg_names_;
// triton context for parsing
ir::context ctx_;
// handles
std::shared_ptr<ir::module> ir_;
std::shared_ptr<driver::module> mod_;
@@ -102,36 +126,37 @@ private:
size_t shared_mem_;
};
struct config {
std::map<std::string, std::string> defines;
int num_warps;
};
class function {
public:
typedef std::function<grid_t(const options_t&)> grid_fn_ty;
typedef std::function<kernel::grid_t(const options_t&)> grid_fn_ty;
typedef std::pair<options_t, std::shared_ptr<kernel>> kernel_pair_t;
typedef std::map<std::vector<uint64_t>, kernel*> cache_t;
typedef std::vector<std::pair<std::map<std::string, std::string>, int>> autotune_vals_t;
typedef std::vector<config> autotune_confs_t;
private:
static void do_loop_nest(std::vector<size_t> const & ranges,
std::function<void(std::vector<size_t> const &)> const & f);
public:
function(const std::string& src, const options_t& opt, driver::device *device,
const autotune_vals_t& autotune_vals = {}, const std::vector<std::string> &autotune_key = {});
void operator()(void* args, size_t args_size, const grid_fn_ty& grid, driver::stream *stream);
void operator()(void* args, size_t args_size, const grid_t& grid, driver::stream *stream);
// auto-tuning
cache_t::iterator find_in_cache(void* args, size_t args_size);
kernel* autotune(void* args, size_t args_size, const grid_fn_ty& grid, driver::stream *stream);
// getters
const std::vector<kernel_pair_t> get_kernels() { return kernels_; }
const std::vector<config>& tune_confs = {}, const std::vector<std::string> &tune_key = {});
kernel* autotune(const std::string& args, const grid_fn_ty& grid, driver::stream *stream);
void operator()(const std::string& args, const grid_fn_ty& grid, driver::stream *stream);
const std::vector<arg_type> get_signature() { return sig_; }
private:
void init_kernels(const std::string& src, const options_t& opt, const autotune_vals_t& autotune_vals, driver::device *device);
private:
std::vector<kernel_pair_t> kernels_;
std::map<std::vector<uint64_t>, std::vector<std::shared_ptr<kernel>>> kernels_;
std::map<std::vector<uint64_t>, kernel*> cache_;
std::vector<arg_type> sig_;
std::vector<int> align_idxs_;
std::vector<int> int_idxs_;
std::vector<int> key_idxs_;
std::vector<int> arg_size_;
std::vector<int> arg_off_;
std::vector<options_t> opts_;
std::string src_;
driver::device* device_;
};
}