[RUNTIME] Added auto-alignment mechanism (#71)
This PR adds an automatic memory alignment mechanism in the Triton runtime. Specifically, the JIT compiler detects the alignment (in bytes) of each pointer argument as well as the largest power of two divisor (between 1 and 16) of each integer argument. Proper .aligned and .multipleof attributes are then added to the Triton-IR on-the-fly for all auto-tunable kernels. There is a cache that remembers all the kernels compiled for each possible configuration. This PR also includes substantial cleaning of the Python API. This adds 2-3us overhead, mostly due to accessing integer #defines from the auto-tuned compilation options. The previous solution was slightly faster but hacky and potentially unsafe, so this is preferred for now.
This commit is contained in:
committed by
Philippe Tillet
parent
ff62f7fffc
commit
62835a0979
@@ -10,8 +10,8 @@ namespace triton{
|
||||
namespace ir{
|
||||
|
||||
/* Module */
|
||||
module::module(const std::string &name, context &ctx)
|
||||
: name_(name), context_(ctx), builder_(ctx) {
|
||||
module::module(const std::string &name)
|
||||
: name_(name), builder_(context_) {
|
||||
sealed_blocks_.insert(nullptr);
|
||||
}
|
||||
|
||||
|
@@ -40,7 +40,6 @@
|
||||
#include <mutex>
|
||||
#include <fstream>
|
||||
|
||||
std::mutex mut;
|
||||
|
||||
namespace triton{
|
||||
namespace runtime {
|
||||
@@ -49,22 +48,9 @@ namespace runtime {
|
||||
/* --------------------------------- */
|
||||
/* --------------------------------- */
|
||||
|
||||
arg_type kernel::convert(ir::type *ty) {
|
||||
if(ty->is_integer_ty(1)) return INT1_T;
|
||||
if(ty->is_integer_ty(8)) return INT8_T;
|
||||
if(ty->is_integer_ty(16)) return INT16_T;
|
||||
if(ty->is_integer_ty(32)) return INT32_T;
|
||||
if(ty->is_integer_ty(64)) return INT64_T;
|
||||
if(ty->is_half_ty()) return HALF_T;
|
||||
if(ty->is_float_ty()) return FLOAT_T;
|
||||
if(ty->is_double_ty()) return DOUBLE_T;
|
||||
if(ty->is_pointer_ty()) return BUFFER_T;
|
||||
throw std::runtime_error("unknown type");
|
||||
}
|
||||
|
||||
|
||||
std::string kernel::preheader() {
|
||||
return R"(
|
||||
std::shared_ptr<ir::module> kernel::src_to_ir(const std::string& _src, const options_t& opt) {
|
||||
std::string src =
|
||||
R"(
|
||||
#define bool _Bool
|
||||
#define true 1
|
||||
#define false 0
|
||||
@@ -116,9 +102,7 @@ typedef short int16;
|
||||
typedef int int32;
|
||||
typedef long int64;
|
||||
)";
|
||||
}
|
||||
|
||||
void kernel::init_ir(const std::string& src) {
|
||||
src += _src;
|
||||
// pre-process
|
||||
TokenSequence tokens;
|
||||
Preprocessor cpp(&src, true);
|
||||
@@ -129,21 +113,21 @@ void kernel::init_ir(const std::string& src) {
|
||||
Parser parser(tokens);
|
||||
parser.Parse();
|
||||
// ast -> triton-ir
|
||||
ir::module* module = new ir::module("", ctx_);
|
||||
auto ret = std::make_shared<ir::module>("");
|
||||
Generator gen(&parser);
|
||||
gen.Gen(module);
|
||||
ir_.reset(module);
|
||||
gen.Gen(&*ret);
|
||||
return ret;
|
||||
}
|
||||
|
||||
void kernel::init_ker(){
|
||||
// triton-ir -> binary
|
||||
std::unique_ptr<driver::module> bin;
|
||||
std::unique_ptr<codegen::target> target = dev_->make_target();
|
||||
std::tuple<std::shared_ptr<driver::module>,
|
||||
std::shared_ptr<driver::kernel>,
|
||||
size_t> kernel::ir_to_bin(ir::module &ir, driver::device* dev, const options_t& opt) {
|
||||
// generate llvm code
|
||||
llvm::LLVMContext ctx;
|
||||
std::string name = ir_->get_function_list()[0]->get_name();
|
||||
std::string name = ir.get_function_list()[0]->get_name();
|
||||
std::unique_ptr<llvm::Module> llvm(new llvm::Module(name, ctx));
|
||||
// optimizations
|
||||
std::unique_ptr<codegen::target> target = dev->make_target();
|
||||
bool cts_use_async = target->as_nvidia()->sm() >= 80;
|
||||
// create passes
|
||||
codegen::analysis::align align;
|
||||
@@ -162,73 +146,61 @@ void kernel::init_ker(){
|
||||
codegen::transform::coalesce coalesce(&align, &layouts);
|
||||
codegen::generator isel(&axes, &layouts, &align, &allocation, &swizzle, target.get(), opt.num_warps);
|
||||
// run passes
|
||||
dce.run(*ir_);
|
||||
pipeline.run(*ir_);
|
||||
dce.run(*ir_);
|
||||
disassociate.run(*ir_);
|
||||
dce.run(*ir_);
|
||||
align.run(*ir_);
|
||||
axes.run(*ir_);
|
||||
layouts.run(*ir_);
|
||||
peephole.run(*ir_);
|
||||
dce.run(*ir_);
|
||||
dce.run(ir);
|
||||
pipeline.run(ir);
|
||||
dce.run(ir);
|
||||
disassociate.run(ir);
|
||||
dce.run(ir);
|
||||
align.run(ir);
|
||||
axes.run(ir);
|
||||
layouts.run(ir);
|
||||
peephole.run(ir);
|
||||
dce.run(ir);
|
||||
if(target->is_gpu())
|
||||
cts.run(*ir_);
|
||||
align.run(*ir_);
|
||||
axes.run(*ir_);
|
||||
layouts.run(*ir_);
|
||||
coalesce.run(*ir_);
|
||||
dce.run(*ir_);
|
||||
align.run(*ir_);
|
||||
dce.run(*ir_);
|
||||
cts.run(ir);
|
||||
align.run(ir);
|
||||
axes.run(ir);
|
||||
layouts.run(ir);
|
||||
coalesce.run(ir);
|
||||
dce.run(ir);
|
||||
align.run(ir);
|
||||
dce.run(ir);
|
||||
if(target->is_gpu()){
|
||||
reassociate.run(*ir_);
|
||||
cts.run(*ir_);
|
||||
reassociate.run(ir);
|
||||
cts.run(ir);
|
||||
}
|
||||
dce.run(*ir_);
|
||||
// ir::print(*ir_, std::cout);
|
||||
align.run(*ir_);
|
||||
axes.run(*ir_);
|
||||
layouts.run(*ir_);
|
||||
peephole.run(*ir_);
|
||||
dce.run(*ir_);
|
||||
align.run(*ir_);
|
||||
axes.run(*ir_);
|
||||
layouts.run(*ir_);
|
||||
swizzle.run(*ir_);
|
||||
liveness.run(*ir_);
|
||||
allocation.run(*ir_);
|
||||
shared_mem_ = allocation.allocated_size();
|
||||
// if(allocation.allocated_size() > dev_->max_shared_memory())
|
||||
// throw exception::out_of_shared_memory();
|
||||
barriers.run(*ir_);
|
||||
isel.visit(*ir_, *llvm);
|
||||
//if(res->spilled() > 256)
|
||||
// throw exception::out_of_registers();
|
||||
mod_.reset(driver::module::create(dev_, std::move(llvm)));
|
||||
ker_.reset(driver::kernel::create(&*mod_, name.c_str()));
|
||||
dce.run(ir);
|
||||
align.run(ir);
|
||||
axes.run(ir);
|
||||
layouts.run(ir);
|
||||
peephole.run(ir);
|
||||
dce.run(ir);
|
||||
align.run(ir);
|
||||
axes.run(ir);
|
||||
layouts.run(ir);
|
||||
swizzle.run(ir);
|
||||
liveness.run(ir);
|
||||
allocation.run(ir);
|
||||
barriers.run(ir);
|
||||
isel.visit(ir, *llvm);
|
||||
std::shared_ptr<driver::module> mod(driver::module::create(dev, std::move(llvm)));
|
||||
std::shared_ptr<driver::kernel> ker(driver::kernel::create(&*mod, name.c_str()));
|
||||
size_t shared_mem = allocation.allocated_size();
|
||||
return std::make_tuple(mod, ker, shared_mem);
|
||||
}
|
||||
|
||||
void kernel::init_sig() {
|
||||
ir::function* fn = ir_->get_function_list()[0];
|
||||
ir::function_type* ty = fn->get_fn_type();
|
||||
for(size_t i = 0; i < ty->get_num_params(); i++){
|
||||
sig_.push_back(convert(ty->get_param_ty(i)));
|
||||
if(!fn->has_attr(i+1))
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
kernel::kernel(const std::string& src, const options_t& opt, driver::device *dev):
|
||||
kernel::kernel(const std::string& src, const options_t& opt, driver::device *dev, const std::map<int, ir::attribute> &attrs):
|
||||
opt(opt), dev_(dev) {
|
||||
init_ir(preheader() + src);
|
||||
init_ker();
|
||||
init_sig();
|
||||
for(auto arg: ir_->get_function_list()[0]->args())
|
||||
arg_names_.push_back(arg->get_name());
|
||||
// compile to Triton IR
|
||||
ir_ = src_to_ir(src, opt);
|
||||
// add attributes
|
||||
for(const auto&x: attrs)
|
||||
ir_->get_function_list()[0]->add_attr(x.first, x.second);
|
||||
// compile to binary
|
||||
std::tie(mod_, ker_, shared_mem_) = ir_to_bin(*ir_, dev, opt);
|
||||
}
|
||||
|
||||
void kernel::operator()(void *args, size_t args_size, driver::stream *stream, const std::vector<size_t>& _grid) const{
|
||||
void kernel::operator()(const std::string& args, driver::stream *stream, const std::vector<size_t>& _grid) const{
|
||||
// set grid
|
||||
if(_grid.size() > 3)
|
||||
throw std::runtime_error("grid size must be no greater than 3");
|
||||
@@ -236,7 +208,7 @@ void kernel::operator()(void *args, size_t args_size, driver::stream *stream, co
|
||||
for(size_t i = 0; i < 3; i++)
|
||||
grid[i] = (i < _grid.size()) ? _grid[i] : 1;
|
||||
// enqueue
|
||||
stream->enqueue(&*ker_, grid, {(size_t)opt.num_warps * 32, 1, 1}, args, args_size, shared_mem_);
|
||||
stream->enqueue(&*ker_, grid, {(size_t)opt.num_warps * 32, 1, 1}, (void*)args.data(), args.size(), shared_mem_);
|
||||
}
|
||||
|
||||
std::string kernel::get_asm(asm_mode_t mode) {
|
||||
@@ -282,124 +254,124 @@ std::string kernel::get_asm(asm_mode_t mode) {
|
||||
/* --------------------------------- */
|
||||
/* --------------------------------- */
|
||||
|
||||
void function::do_loop_nest(std::vector<size_t> const & ranges,
|
||||
std::function<void(std::vector<size_t> const &)> const & f){
|
||||
size_t D = ranges.size();
|
||||
std::vector<size_t> values(D, 0);
|
||||
size_t i = D - 1;
|
||||
while(true){
|
||||
f(values);
|
||||
while(values[i]++ == ranges[i] - 1){
|
||||
if(i == 0)
|
||||
return;
|
||||
values[i--] = 0;
|
||||
}
|
||||
i = D - 1; options_t opt;
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void function::init_kernels(const std::string& src, const options_t& opt,
|
||||
const autotune_vals_t& confs, driver::device *device) {
|
||||
// list of all possible configs
|
||||
// just augment `opt` with each define of `confs`
|
||||
// and override warp count
|
||||
size_t num_opts = std::max(confs.size(), (size_t)1);
|
||||
std::vector<options_t> opts(num_opts, opt);
|
||||
for(size_t i = 0; i < confs.size(); i++){
|
||||
opts[i].defines.insert(confs[i].first.begin(), confs[i].first.end());
|
||||
opts[i].num_warps = confs[i].second;
|
||||
}
|
||||
// compile all possible configs
|
||||
// compilation errors (e.g., too much shared mem)
|
||||
// will populate `err`
|
||||
std::vector<std::pair<options_t, std::string>> err;
|
||||
for(const options_t& opt: opts) {
|
||||
try{
|
||||
kernels_.push_back({opt, std::make_shared<kernel>(src, opt, device)});
|
||||
}catch(const exception::base& e){
|
||||
err.push_back({opt, e.what()});
|
||||
}
|
||||
}
|
||||
// throw an exception if `err` is not empty
|
||||
if(kernels_.empty()){
|
||||
std::ostringstream dbg;
|
||||
dbg << "Auto-Tuner could not find any valid configuration:" << std::endl;
|
||||
for(auto x: err){
|
||||
dbg << "[ ";
|
||||
dbg << x.first.num_warps << ", ";
|
||||
dbg << "{ ";
|
||||
for(const auto& y: x.first.defines)
|
||||
dbg << '"' << y.first << "\"= \"" << y.second << "\", ";
|
||||
dbg << " } ] -> " << x.second << std::endl;
|
||||
}
|
||||
throw exception::no_valid_configuration(dbg.str());
|
||||
}
|
||||
}
|
||||
|
||||
kernel* function::autotune(void* args, size_t args_size, const grid_fn_ty& grid_fn, driver::stream* stream) {
|
||||
// fast path -- no autotuning necessary
|
||||
if(kernels_.size() == 1)
|
||||
return &*kernels_.begin()->second;
|
||||
// auto-tuning key
|
||||
std::vector<uint64_t> key(key_idxs_.size());
|
||||
for(size_t i = 0; i < key.size(); i++){
|
||||
int idx = key_idxs_[i];
|
||||
std::memcpy((void*)&key[i], (void*)((char*)args + arg_off_[idx]), arg_size_[idx]);
|
||||
}
|
||||
auto it = cache_.find(key);
|
||||
if(it != cache_.end())
|
||||
return it->second;
|
||||
// run auto-tuner
|
||||
double best_ts = INFINITY;
|
||||
kernel* ret = nullptr;
|
||||
for(auto &x : kernels_){
|
||||
kernel* current = &*x.second;
|
||||
auto grid = grid_fn(x.first);
|
||||
while(grid.size() < 3)
|
||||
grid.push_back(1);
|
||||
double ts = tools::bench([&]() { (*current)(args, args_size, stream, grid); },
|
||||
stream, 5, 20);
|
||||
ret = (ts < best_ts) ? current : ret;
|
||||
best_ts = std::min(ts, best_ts);
|
||||
}
|
||||
stream->synchronize();
|
||||
it = cache_.insert({key, ret}).first;
|
||||
return it->second;
|
||||
}
|
||||
|
||||
function::function(const std::string& src, const options_t &opt, driver::device *device,
|
||||
const autotune_vals_t& autotune_vals, const std::vector<std::string>& autotune_key) {
|
||||
// pre-compile all kernels
|
||||
init_kernels(src, opt, autotune_vals, device);
|
||||
// find indices of autotune keys
|
||||
auto arg_names = kernels_.at(0).second->get_arg_names();
|
||||
for(const std::string& name: autotune_key){
|
||||
auto it = std::find(arg_names.begin(), arg_names.end(), name);
|
||||
if(it == arg_names.end())
|
||||
throw std::runtime_error(name + " is not a valid argument name");
|
||||
key_idxs_.push_back(std::distance(arg_names.begin(), it));
|
||||
const std::vector<config> &tune_confs, const std::vector<std::string>& tune_key)
|
||||
: src_(src), device_(device) {
|
||||
// kernel options
|
||||
size_t num_opts = std::max(tune_confs.size(), (size_t)1);
|
||||
opts_ = std::vector<options_t>(num_opts, opt);
|
||||
for(size_t i = 0; i < tune_confs.size(); i++){
|
||||
opts_[i].defines.insert(tune_confs[i].defines.begin(), tune_confs[i].defines.end());
|
||||
opts_[i].num_warps = tune_confs[i].num_warps;
|
||||
}
|
||||
std::shared_ptr<ir::module> ir = kernel::src_to_ir(src, opts_[0]);
|
||||
std::vector<ir::argument*> args = ir->get_function_list()[0]->args();
|
||||
// signature
|
||||
auto convert = [](ir::type *ty) {
|
||||
if(ty->is_integer_ty(1)) return INT1_T;
|
||||
if(ty->is_integer_ty(8)) return INT8_T;
|
||||
if(ty->is_integer_ty(16)) return INT16_T;
|
||||
if(ty->is_integer_ty(32)) return INT32_T;
|
||||
if(ty->is_integer_ty(64)) return INT64_T;
|
||||
if(ty->is_half_ty()) return HALF_T;
|
||||
if(ty->is_float_ty()) return FLOAT_T;
|
||||
if(ty->is_double_ty()) return DOUBLE_T;
|
||||
if(ty->is_pointer_ty()) return BUFFER_T;
|
||||
throw std::runtime_error("unknown type");
|
||||
};
|
||||
for(ir::argument* arg: args)
|
||||
sig_.push_back(convert(arg->get_type()));
|
||||
// find indices of autotune keys
|
||||
for(const std::string& name: tune_key){
|
||||
auto pred = [&](ir::argument* arg) { return arg->get_name() == name; };
|
||||
auto it = std::find_if(args.begin(), args.end(), pred);
|
||||
if(it == args.end())
|
||||
throw std::runtime_error(name + " is not a valid argument name");
|
||||
key_idxs_.push_back(std::distance(args.begin(), it));
|
||||
}
|
||||
// find indices of pointer
|
||||
for(size_t i = 0; i < args.size(); i++)
|
||||
if(args[i]->get_type()->is_pointer_ty() ||
|
||||
args[i]->get_type()->is_integer_ty())
|
||||
align_idxs_.push_back(i);
|
||||
// argument size and offset
|
||||
auto tys = kernels_.at(0).second->get_sig();
|
||||
size_t curr = 0;
|
||||
for(arg_type ty: tys){
|
||||
for(arg_type ty: sig_){
|
||||
arg_size_.push_back(size_of(ty));
|
||||
arg_off_.push_back(curr);
|
||||
curr += arg_size_.back();
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
void function::operator()(void* args, size_t args_size, const grid_fn_ty& grid_fn, driver::stream *stream) {
|
||||
runtime::kernel* fn = autotune(args, args_size, grid_fn, stream);
|
||||
(*fn)(args, args_size, stream, grid_fn(fn->opt));
|
||||
uint64_t pow2_divisor(uint64_t N){
|
||||
if(N % 16 == 0) return 16;
|
||||
if(N % 8 == 0) return 8;
|
||||
if(N % 4 == 0) return 4;
|
||||
if(N % 2 == 0) return 2;
|
||||
return 1;
|
||||
}
|
||||
|
||||
void function::operator()(void* args, size_t args_size, const grid_t& grid, driver::stream* stream) {
|
||||
return this->operator()(args, args_size, [&grid](const options_t&){ return grid; }, stream);
|
||||
kernel* function::autotune(const std::string &args, const grid_fn_ty& grid_fn, driver::stream* stream) {
|
||||
// align key
|
||||
std::vector<uint64_t> rt_key(align_idxs_.size(), 0);
|
||||
for(size_t i = 0; i < align_idxs_.size(); i++){
|
||||
int idx = align_idxs_[i];
|
||||
uint64_t tmp = 0;
|
||||
std::memcpy((void*)&tmp, (void*)((char*)args.data() + arg_off_[idx]), arg_size_[idx]);
|
||||
rt_key[i] = pow2_divisor(tmp);
|
||||
}
|
||||
// auto-tuning key
|
||||
std::vector<uint64_t> at_key(key_idxs_.size(), 0);
|
||||
for(size_t i = 0; i < at_key.size(); i++){
|
||||
int idx = key_idxs_[i];
|
||||
std::memcpy((void*)&at_key[i], (void*)((char*)args.data() + arg_off_[idx]), arg_size_[idx]);
|
||||
}
|
||||
// cache key
|
||||
std::vector<uint64_t> cache_key;
|
||||
cache_key.reserve(rt_key.size() + at_key.size());
|
||||
cache_key.insert(cache_key.end(), rt_key.begin(), rt_key.end());
|
||||
cache_key.insert(cache_key.end(), at_key.begin(), at_key.end());
|
||||
auto it = cache_.find(cache_key);
|
||||
if(it != cache_.end())
|
||||
return it->second;
|
||||
// compile kernels
|
||||
if(kernels_.find(rt_key) == kernels_.end()){
|
||||
std::map<int, ir::attribute> attrs;
|
||||
for(size_t i = 0; i < align_idxs_.size(); i++){
|
||||
bool is_ptr = sig_[align_idxs_[i]] == BUFFER_T;
|
||||
attrs.insert({align_idxs_[i] + 1, ir::attribute(is_ptr ? ir::aligned : ir::multiple_of, rt_key[i])});
|
||||
}
|
||||
for(const options_t& opt: opts_)
|
||||
kernels_[rt_key].emplace_back(new kernel(src_, opt, device_, attrs));
|
||||
}
|
||||
// run auto-tuner
|
||||
double best_ts = INFINITY;
|
||||
auto& kernels = kernels_.at(rt_key);
|
||||
kernel* ret = nullptr;
|
||||
if(kernels.size() == 1)
|
||||
ret = &*kernels.back();
|
||||
else{
|
||||
for(auto ¤t : kernels_.at(rt_key)){
|
||||
auto grid = grid_fn(current->opt);
|
||||
while(grid.size() < 3)
|
||||
grid.push_back(1);
|
||||
double ts = tools::bench([&]() { (*current)(args, stream, grid); },
|
||||
stream, 5, 20);
|
||||
ret = (ts < best_ts) ? &*current : ret;
|
||||
best_ts = std::min(ts, best_ts);
|
||||
}
|
||||
stream->synchronize();
|
||||
}
|
||||
it = cache_.insert({cache_key, ret}).first;
|
||||
return it->second;
|
||||
}
|
||||
|
||||
void function::operator()(const std::string& args, const grid_fn_ty& grid_fn, driver::stream *stream) {
|
||||
runtime::kernel* fn = autotune(args, grid_fn, stream);
|
||||
(*fn)(args, stream, grid_fn(fn->opt));
|
||||
}
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user