[RUNTIME] Added auto-alignment mechanism (#71)

This PR adds an automatic memory alignment mechanism in the Triton runtime. Specifically, the JIT compiler detects the alignment (in bytes) of each pointer argument as well as the largest power of two divisor (between 1 and 16) of each integer argument. Proper .aligned and .multipleof attributes are then added to the Triton-IR on-the-fly for all auto-tunable kernels. There is a cache that remembers all the kernels compiled for each possible configuration.

This PR also includes substantial cleaning of the Python API. This adds 2-3us overhead, mostly due to accessing integer #defines from the auto-tuned compilation options. The previous solution was slightly faster but hacky and potentially unsafe, so this is preferred for now.
This commit is contained in:
Philippe Tillet
2021-03-04 01:51:11 -05:00
committed by Philippe Tillet
parent ff62f7fffc
commit 62835a0979
19 changed files with 668 additions and 707 deletions

View File

@@ -10,8 +10,8 @@ namespace triton{
namespace ir{
/* Module */
module::module(const std::string &name, context &ctx)
: name_(name), context_(ctx), builder_(ctx) {
module::module(const std::string &name)
: name_(name), builder_(context_) {
sealed_blocks_.insert(nullptr);
}

View File

@@ -40,7 +40,6 @@
#include <mutex>
#include <fstream>
std::mutex mut;
namespace triton{
namespace runtime {
@@ -49,22 +48,9 @@ namespace runtime {
/* --------------------------------- */
/* --------------------------------- */
arg_type kernel::convert(ir::type *ty) {
if(ty->is_integer_ty(1)) return INT1_T;
if(ty->is_integer_ty(8)) return INT8_T;
if(ty->is_integer_ty(16)) return INT16_T;
if(ty->is_integer_ty(32)) return INT32_T;
if(ty->is_integer_ty(64)) return INT64_T;
if(ty->is_half_ty()) return HALF_T;
if(ty->is_float_ty()) return FLOAT_T;
if(ty->is_double_ty()) return DOUBLE_T;
if(ty->is_pointer_ty()) return BUFFER_T;
throw std::runtime_error("unknown type");
}
std::string kernel::preheader() {
return R"(
std::shared_ptr<ir::module> kernel::src_to_ir(const std::string& _src, const options_t& opt) {
std::string src =
R"(
#define bool _Bool
#define true 1
#define false 0
@@ -116,9 +102,7 @@ typedef short int16;
typedef int int32;
typedef long int64;
)";
}
void kernel::init_ir(const std::string& src) {
src += _src;
// pre-process
TokenSequence tokens;
Preprocessor cpp(&src, true);
@@ -129,21 +113,21 @@ void kernel::init_ir(const std::string& src) {
Parser parser(tokens);
parser.Parse();
// ast -> triton-ir
ir::module* module = new ir::module("", ctx_);
auto ret = std::make_shared<ir::module>("");
Generator gen(&parser);
gen.Gen(module);
ir_.reset(module);
gen.Gen(&*ret);
return ret;
}
void kernel::init_ker(){
// triton-ir -> binary
std::unique_ptr<driver::module> bin;
std::unique_ptr<codegen::target> target = dev_->make_target();
std::tuple<std::shared_ptr<driver::module>,
std::shared_ptr<driver::kernel>,
size_t> kernel::ir_to_bin(ir::module &ir, driver::device* dev, const options_t& opt) {
// generate llvm code
llvm::LLVMContext ctx;
std::string name = ir_->get_function_list()[0]->get_name();
std::string name = ir.get_function_list()[0]->get_name();
std::unique_ptr<llvm::Module> llvm(new llvm::Module(name, ctx));
// optimizations
std::unique_ptr<codegen::target> target = dev->make_target();
bool cts_use_async = target->as_nvidia()->sm() >= 80;
// create passes
codegen::analysis::align align;
@@ -162,73 +146,61 @@ void kernel::init_ker(){
codegen::transform::coalesce coalesce(&align, &layouts);
codegen::generator isel(&axes, &layouts, &align, &allocation, &swizzle, target.get(), opt.num_warps);
// run passes
dce.run(*ir_);
pipeline.run(*ir_);
dce.run(*ir_);
disassociate.run(*ir_);
dce.run(*ir_);
align.run(*ir_);
axes.run(*ir_);
layouts.run(*ir_);
peephole.run(*ir_);
dce.run(*ir_);
dce.run(ir);
pipeline.run(ir);
dce.run(ir);
disassociate.run(ir);
dce.run(ir);
align.run(ir);
axes.run(ir);
layouts.run(ir);
peephole.run(ir);
dce.run(ir);
if(target->is_gpu())
cts.run(*ir_);
align.run(*ir_);
axes.run(*ir_);
layouts.run(*ir_);
coalesce.run(*ir_);
dce.run(*ir_);
align.run(*ir_);
dce.run(*ir_);
cts.run(ir);
align.run(ir);
axes.run(ir);
layouts.run(ir);
coalesce.run(ir);
dce.run(ir);
align.run(ir);
dce.run(ir);
if(target->is_gpu()){
reassociate.run(*ir_);
cts.run(*ir_);
reassociate.run(ir);
cts.run(ir);
}
dce.run(*ir_);
// ir::print(*ir_, std::cout);
align.run(*ir_);
axes.run(*ir_);
layouts.run(*ir_);
peephole.run(*ir_);
dce.run(*ir_);
align.run(*ir_);
axes.run(*ir_);
layouts.run(*ir_);
swizzle.run(*ir_);
liveness.run(*ir_);
allocation.run(*ir_);
shared_mem_ = allocation.allocated_size();
// if(allocation.allocated_size() > dev_->max_shared_memory())
// throw exception::out_of_shared_memory();
barriers.run(*ir_);
isel.visit(*ir_, *llvm);
//if(res->spilled() > 256)
// throw exception::out_of_registers();
mod_.reset(driver::module::create(dev_, std::move(llvm)));
ker_.reset(driver::kernel::create(&*mod_, name.c_str()));
dce.run(ir);
align.run(ir);
axes.run(ir);
layouts.run(ir);
peephole.run(ir);
dce.run(ir);
align.run(ir);
axes.run(ir);
layouts.run(ir);
swizzle.run(ir);
liveness.run(ir);
allocation.run(ir);
barriers.run(ir);
isel.visit(ir, *llvm);
std::shared_ptr<driver::module> mod(driver::module::create(dev, std::move(llvm)));
std::shared_ptr<driver::kernel> ker(driver::kernel::create(&*mod, name.c_str()));
size_t shared_mem = allocation.allocated_size();
return std::make_tuple(mod, ker, shared_mem);
}
void kernel::init_sig() {
ir::function* fn = ir_->get_function_list()[0];
ir::function_type* ty = fn->get_fn_type();
for(size_t i = 0; i < ty->get_num_params(); i++){
sig_.push_back(convert(ty->get_param_ty(i)));
if(!fn->has_attr(i+1))
continue;
}
}
kernel::kernel(const std::string& src, const options_t& opt, driver::device *dev):
kernel::kernel(const std::string& src, const options_t& opt, driver::device *dev, const std::map<int, ir::attribute> &attrs):
opt(opt), dev_(dev) {
init_ir(preheader() + src);
init_ker();
init_sig();
for(auto arg: ir_->get_function_list()[0]->args())
arg_names_.push_back(arg->get_name());
// compile to Triton IR
ir_ = src_to_ir(src, opt);
// add attributes
for(const auto&x: attrs)
ir_->get_function_list()[0]->add_attr(x.first, x.second);
// compile to binary
std::tie(mod_, ker_, shared_mem_) = ir_to_bin(*ir_, dev, opt);
}
void kernel::operator()(void *args, size_t args_size, driver::stream *stream, const std::vector<size_t>& _grid) const{
void kernel::operator()(const std::string& args, driver::stream *stream, const std::vector<size_t>& _grid) const{
// set grid
if(_grid.size() > 3)
throw std::runtime_error("grid size must be no greater than 3");
@@ -236,7 +208,7 @@ void kernel::operator()(void *args, size_t args_size, driver::stream *stream, co
for(size_t i = 0; i < 3; i++)
grid[i] = (i < _grid.size()) ? _grid[i] : 1;
// enqueue
stream->enqueue(&*ker_, grid, {(size_t)opt.num_warps * 32, 1, 1}, args, args_size, shared_mem_);
stream->enqueue(&*ker_, grid, {(size_t)opt.num_warps * 32, 1, 1}, (void*)args.data(), args.size(), shared_mem_);
}
std::string kernel::get_asm(asm_mode_t mode) {
@@ -282,124 +254,124 @@ std::string kernel::get_asm(asm_mode_t mode) {
/* --------------------------------- */
/* --------------------------------- */
void function::do_loop_nest(std::vector<size_t> const & ranges,
std::function<void(std::vector<size_t> const &)> const & f){
size_t D = ranges.size();
std::vector<size_t> values(D, 0);
size_t i = D - 1;
while(true){
f(values);
while(values[i]++ == ranges[i] - 1){
if(i == 0)
return;
values[i--] = 0;
}
i = D - 1; options_t opt;
}
}
void function::init_kernels(const std::string& src, const options_t& opt,
const autotune_vals_t& confs, driver::device *device) {
// list of all possible configs
// just augment `opt` with each define of `confs`
// and override warp count
size_t num_opts = std::max(confs.size(), (size_t)1);
std::vector<options_t> opts(num_opts, opt);
for(size_t i = 0; i < confs.size(); i++){
opts[i].defines.insert(confs[i].first.begin(), confs[i].first.end());
opts[i].num_warps = confs[i].second;
}
// compile all possible configs
// compilation errors (e.g., too much shared mem)
// will populate `err`
std::vector<std::pair<options_t, std::string>> err;
for(const options_t& opt: opts) {
try{
kernels_.push_back({opt, std::make_shared<kernel>(src, opt, device)});
}catch(const exception::base& e){
err.push_back({opt, e.what()});
}
}
// throw an exception if `err` is not empty
if(kernels_.empty()){
std::ostringstream dbg;
dbg << "Auto-Tuner could not find any valid configuration:" << std::endl;
for(auto x: err){
dbg << "[ ";
dbg << x.first.num_warps << ", ";
dbg << "{ ";
for(const auto& y: x.first.defines)
dbg << '"' << y.first << "\"= \"" << y.second << "\", ";
dbg << " } ] -> " << x.second << std::endl;
}
throw exception::no_valid_configuration(dbg.str());
}
}
kernel* function::autotune(void* args, size_t args_size, const grid_fn_ty& grid_fn, driver::stream* stream) {
// fast path -- no autotuning necessary
if(kernels_.size() == 1)
return &*kernels_.begin()->second;
// auto-tuning key
std::vector<uint64_t> key(key_idxs_.size());
for(size_t i = 0; i < key.size(); i++){
int idx = key_idxs_[i];
std::memcpy((void*)&key[i], (void*)((char*)args + arg_off_[idx]), arg_size_[idx]);
}
auto it = cache_.find(key);
if(it != cache_.end())
return it->second;
// run auto-tuner
double best_ts = INFINITY;
kernel* ret = nullptr;
for(auto &x : kernels_){
kernel* current = &*x.second;
auto grid = grid_fn(x.first);
while(grid.size() < 3)
grid.push_back(1);
double ts = tools::bench([&]() { (*current)(args, args_size, stream, grid); },
stream, 5, 20);
ret = (ts < best_ts) ? current : ret;
best_ts = std::min(ts, best_ts);
}
stream->synchronize();
it = cache_.insert({key, ret}).first;
return it->second;
}
function::function(const std::string& src, const options_t &opt, driver::device *device,
const autotune_vals_t& autotune_vals, const std::vector<std::string>& autotune_key) {
// pre-compile all kernels
init_kernels(src, opt, autotune_vals, device);
// find indices of autotune keys
auto arg_names = kernels_.at(0).second->get_arg_names();
for(const std::string& name: autotune_key){
auto it = std::find(arg_names.begin(), arg_names.end(), name);
if(it == arg_names.end())
throw std::runtime_error(name + " is not a valid argument name");
key_idxs_.push_back(std::distance(arg_names.begin(), it));
const std::vector<config> &tune_confs, const std::vector<std::string>& tune_key)
: src_(src), device_(device) {
// kernel options
size_t num_opts = std::max(tune_confs.size(), (size_t)1);
opts_ = std::vector<options_t>(num_opts, opt);
for(size_t i = 0; i < tune_confs.size(); i++){
opts_[i].defines.insert(tune_confs[i].defines.begin(), tune_confs[i].defines.end());
opts_[i].num_warps = tune_confs[i].num_warps;
}
std::shared_ptr<ir::module> ir = kernel::src_to_ir(src, opts_[0]);
std::vector<ir::argument*> args = ir->get_function_list()[0]->args();
// signature
auto convert = [](ir::type *ty) {
if(ty->is_integer_ty(1)) return INT1_T;
if(ty->is_integer_ty(8)) return INT8_T;
if(ty->is_integer_ty(16)) return INT16_T;
if(ty->is_integer_ty(32)) return INT32_T;
if(ty->is_integer_ty(64)) return INT64_T;
if(ty->is_half_ty()) return HALF_T;
if(ty->is_float_ty()) return FLOAT_T;
if(ty->is_double_ty()) return DOUBLE_T;
if(ty->is_pointer_ty()) return BUFFER_T;
throw std::runtime_error("unknown type");
};
for(ir::argument* arg: args)
sig_.push_back(convert(arg->get_type()));
// find indices of autotune keys
for(const std::string& name: tune_key){
auto pred = [&](ir::argument* arg) { return arg->get_name() == name; };
auto it = std::find_if(args.begin(), args.end(), pred);
if(it == args.end())
throw std::runtime_error(name + " is not a valid argument name");
key_idxs_.push_back(std::distance(args.begin(), it));
}
// find indices of pointer
for(size_t i = 0; i < args.size(); i++)
if(args[i]->get_type()->is_pointer_ty() ||
args[i]->get_type()->is_integer_ty())
align_idxs_.push_back(i);
// argument size and offset
auto tys = kernels_.at(0).second->get_sig();
size_t curr = 0;
for(arg_type ty: tys){
for(arg_type ty: sig_){
arg_size_.push_back(size_of(ty));
arg_off_.push_back(curr);
curr += arg_size_.back();
}
}
void function::operator()(void* args, size_t args_size, const grid_fn_ty& grid_fn, driver::stream *stream) {
runtime::kernel* fn = autotune(args, args_size, grid_fn, stream);
(*fn)(args, args_size, stream, grid_fn(fn->opt));
uint64_t pow2_divisor(uint64_t N){
if(N % 16 == 0) return 16;
if(N % 8 == 0) return 8;
if(N % 4 == 0) return 4;
if(N % 2 == 0) return 2;
return 1;
}
void function::operator()(void* args, size_t args_size, const grid_t& grid, driver::stream* stream) {
return this->operator()(args, args_size, [&grid](const options_t&){ return grid; }, stream);
kernel* function::autotune(const std::string &args, const grid_fn_ty& grid_fn, driver::stream* stream) {
// align key
std::vector<uint64_t> rt_key(align_idxs_.size(), 0);
for(size_t i = 0; i < align_idxs_.size(); i++){
int idx = align_idxs_[i];
uint64_t tmp = 0;
std::memcpy((void*)&tmp, (void*)((char*)args.data() + arg_off_[idx]), arg_size_[idx]);
rt_key[i] = pow2_divisor(tmp);
}
// auto-tuning key
std::vector<uint64_t> at_key(key_idxs_.size(), 0);
for(size_t i = 0; i < at_key.size(); i++){
int idx = key_idxs_[i];
std::memcpy((void*)&at_key[i], (void*)((char*)args.data() + arg_off_[idx]), arg_size_[idx]);
}
// cache key
std::vector<uint64_t> cache_key;
cache_key.reserve(rt_key.size() + at_key.size());
cache_key.insert(cache_key.end(), rt_key.begin(), rt_key.end());
cache_key.insert(cache_key.end(), at_key.begin(), at_key.end());
auto it = cache_.find(cache_key);
if(it != cache_.end())
return it->second;
// compile kernels
if(kernels_.find(rt_key) == kernels_.end()){
std::map<int, ir::attribute> attrs;
for(size_t i = 0; i < align_idxs_.size(); i++){
bool is_ptr = sig_[align_idxs_[i]] == BUFFER_T;
attrs.insert({align_idxs_[i] + 1, ir::attribute(is_ptr ? ir::aligned : ir::multiple_of, rt_key[i])});
}
for(const options_t& opt: opts_)
kernels_[rt_key].emplace_back(new kernel(src_, opt, device_, attrs));
}
// run auto-tuner
double best_ts = INFINITY;
auto& kernels = kernels_.at(rt_key);
kernel* ret = nullptr;
if(kernels.size() == 1)
ret = &*kernels.back();
else{
for(auto &current : kernels_.at(rt_key)){
auto grid = grid_fn(current->opt);
while(grid.size() < 3)
grid.push_back(1);
double ts = tools::bench([&]() { (*current)(args, stream, grid); },
stream, 5, 20);
ret = (ts < best_ts) ? &*current : ret;
best_ts = std::min(ts, best_ts);
}
stream->synchronize();
}
it = cache_.insert({cache_key, ret}).first;
return it->second;
}
void function::operator()(const std::string& args, const grid_fn_ty& grid_fn, driver::stream *stream) {
runtime::kernel* fn = autotune(args, grid_fn, stream);
(*fn)(args, stream, grid_fn(fn->opt));
}