[GENERAL] Improved caching mechanism:
* Now computing hash in libtriton * Now only compiling a single pytorch hook per function signature
This commit is contained in:
committed by
Philippe Tillet
parent
30f77e9ec5
commit
dfb844bf41
@@ -253,11 +253,11 @@ std::string cu_module::compile_llvm_module(std::unique_ptr<llvm::Module> module,
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
cu_module::cu_module(driver::context * context, std::unique_ptr<llvm::Module> ll_module): cu_module(context, compile_llvm_module(std::move(ll_module), context->device())) { }
|
||||
|
||||
cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){
|
||||
cu_context::context_switcher ctx(*context);
|
||||
// std::cout << source << std::endl;
|
||||
// JIT compile source-code
|
||||
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER};
|
||||
unsigned int errbufsize = 8096;
|
||||
@@ -266,10 +266,10 @@ cu_module::cu_module(driver::context * context, std::string const & source) : mo
|
||||
try{
|
||||
dispatch::cuModuleLoadDataEx(&*cu_, source_.data(), 2, opt, optval);
|
||||
}catch(exception::cuda::base const &){
|
||||
//#ifdef TRITON_LOG_PTX_ERROR
|
||||
#ifdef TRITON_LOG_PTX_ERROR
|
||||
std::cerr << "Compilation Failed! Log: " << std::endl;
|
||||
std::cerr << errbuf << std::endl;
|
||||
//#endif
|
||||
#endif
|
||||
throw;
|
||||
}
|
||||
}
|
||||
|
@@ -28,24 +28,28 @@
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/print.h"
|
||||
#include "triton/tools/bench.hpp"
|
||||
#include "triton/tools/sha1.hpp"
|
||||
#include "triton/tools/sys/getenv.hpp"
|
||||
#include "triton/tools/sys/mkdir.hpp"
|
||||
#include "llvm/IR/Module.h"
|
||||
#include <mutex>
|
||||
#include <fstream>
|
||||
|
||||
std::mutex mut;
|
||||
|
||||
namespace triton{
|
||||
namespace runtime {
|
||||
|
||||
// helpers
|
||||
void _parallel_loop_nest(std::vector<size_t> const & ranges,
|
||||
std::function<void(std::vector<size_t> const &)> const & f,
|
||||
size_t nthreads){
|
||||
/* --------------------- */
|
||||
/* HELPERS */
|
||||
/* --------------------- */
|
||||
|
||||
void _loop_nest(std::vector<size_t> const & ranges,
|
||||
std::function<void(std::vector<size_t> const &)> const & f){
|
||||
size_t D = ranges.size();
|
||||
std::vector<size_t> values(D, 0);
|
||||
// Start with innermost loop
|
||||
size_t i = D - 1;
|
||||
while(true){
|
||||
// Execute function
|
||||
f(values);
|
||||
while(values[i]++ == ranges[i] - 1){
|
||||
if(i == 0)
|
||||
@@ -56,24 +60,31 @@ void _parallel_loop_nest(std::vector<size_t> const & ranges,
|
||||
}
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void _parallel_loop_nest(std::vector<std::vector<T>> const & iterates, std::function<void(std::vector<T>)> const & f, size_t nthreads){
|
||||
//Ranges to iterate over
|
||||
std::vector<size_t> ranges;
|
||||
for(auto const & x: iterates)
|
||||
ranges.push_back(x.size());
|
||||
//Proxy function
|
||||
auto proxy = [&](std::vector<size_t> const & idx){
|
||||
std::vector<T> x(iterates.size());
|
||||
for(size_t i = 0; i < x.size(); ++i)
|
||||
x[i] = iterates[i][idx[i]];
|
||||
f(x);
|
||||
};
|
||||
//Iterate
|
||||
_parallel_loop_nest(ranges, proxy, nthreads);
|
||||
|
||||
/* --------------------- */
|
||||
/* OPTIONS */
|
||||
/* --------------------- */
|
||||
|
||||
std::string function::options_t::to_str() const{
|
||||
std::string ret = "nw-" + std::to_string(num_warps);
|
||||
for(const auto& x : defines){
|
||||
ret += '-';
|
||||
ret += x.first;
|
||||
ret += '-';
|
||||
ret += x.second;
|
||||
}
|
||||
// legalize
|
||||
for(char& x: ret){
|
||||
if(x == ' ' || x == '^' || x == ',' || x == ':')
|
||||
x = '_';
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
// caller
|
||||
|
||||
/* --------------------- */
|
||||
/* CALLER OBJECT */
|
||||
/* --------------------- */
|
||||
|
||||
arg_type convert(ir::type *ty) {
|
||||
if(ty->is_integer_ty(1))
|
||||
@@ -97,8 +108,46 @@ arg_type convert(ir::type *ty) {
|
||||
throw std::runtime_error("unknown type");
|
||||
}
|
||||
|
||||
function::caller::caller(ir::function *ir, std::shared_ptr<driver::module> parent, const options_t& opt)
|
||||
: bin_(driver::kernel::create(&*parent, ir->get_name().c_str())), parent_(parent), opt_(opt) {
|
||||
void function::caller::write(std::ofstream &ofs) {
|
||||
// write name
|
||||
ofs << name_ << std::endl;
|
||||
// write signature
|
||||
for(size_t i = 0; i < param_tys_.size(); i++)
|
||||
ofs << param_tys_[i] << " ";
|
||||
ofs << std::endl;
|
||||
// write module
|
||||
std::string source = ((driver::cu_module*)(&*parent_))->source();
|
||||
ofs << source;
|
||||
}
|
||||
|
||||
void function::caller::read(driver::context* ctx, std::ifstream &ifs) {
|
||||
// read name
|
||||
std::getline(ifs, name_);
|
||||
// read signature
|
||||
std::string line;
|
||||
std::getline(ifs, line);
|
||||
std::istringstream current(line);
|
||||
int param;
|
||||
param_tys_.clear();
|
||||
while(current >> param)
|
||||
param_tys_.push_back((arg_type)param);
|
||||
// read module
|
||||
std::string src((std::istreambuf_iterator<char>(ifs)),
|
||||
std::istreambuf_iterator<char>());
|
||||
parent_.reset(new driver::cu_module(ctx, src));
|
||||
bin_.reset(driver::kernel::create(&*parent_, name_.c_str()));
|
||||
|
||||
}
|
||||
|
||||
function::caller::caller(driver::context* ctx, std::ifstream &ifs, const options_t& opt)
|
||||
: opt_(opt) {
|
||||
read(ctx, ifs);
|
||||
}
|
||||
|
||||
function::caller::caller(ir::function *ir,
|
||||
std::shared_ptr<driver::module> parent, const options_t& opt)
|
||||
: parent_(parent), opt_(opt), name_(ir->get_name()) {
|
||||
bin_.reset(driver::kernel::create(&*parent, name_.c_str()));
|
||||
// extract signature
|
||||
ir::function_type* ty = ir->get_fn_type();
|
||||
for(size_t i = 0; i < ty->get_num_params(); i++)
|
||||
@@ -109,6 +158,7 @@ function::caller::caller(ir::function *ir, std::shared_ptr<driver::module> paren
|
||||
void function::caller::operator ()(driver::stream *stream, const grid_t& _grid, const std::vector<arg>& args) const {
|
||||
if(args.size() != param_tys_.size())
|
||||
throw std::runtime_error("invalid number of arguments");
|
||||
// set arguments
|
||||
for(size_t i = 0; i < args.size(); i++){
|
||||
arg arg_i = args.at(i);
|
||||
arg_type ty = arg_i.type();
|
||||
@@ -119,99 +169,33 @@ void function::caller::operator ()(driver::stream *stream, const grid_t& _grid,
|
||||
else
|
||||
bin_->setArg(i, size_of(ty), arg_i.data());
|
||||
}
|
||||
// sanity check
|
||||
// set grid
|
||||
if(_grid.size() > 3)
|
||||
throw std::runtime_error("grid size must be no greater than 3");
|
||||
std::array<size_t, 3> grid;
|
||||
for(size_t i = 0; i < 3; i++)
|
||||
grid[i] = (i < _grid.size()) ? _grid[i] : 1;
|
||||
// enqueue
|
||||
stream->enqueue(&*bin_, grid, {opt_.num_warps * 32, 1, 1});
|
||||
}
|
||||
|
||||
|
||||
/* --------------------- */
|
||||
/* FUNCTION */
|
||||
/* --------------------- */
|
||||
|
||||
// create Triton-IR from AST
|
||||
std::unique_ptr<ir::module> function::make_ir(Parser& parser) {
|
||||
// create Triton-IR from AST
|
||||
ir::module* module = new ir::module("", ctx_);
|
||||
Generator gen(&parser);
|
||||
gen.Gen(module);
|
||||
return std::unique_ptr<ir::module>(module);
|
||||
}
|
||||
|
||||
|
||||
function::caller function::autotune(driver::stream* stream, const grid_fn_ty& grid_fn,
|
||||
const std::vector<arg>& args) {
|
||||
|
||||
// all tuning parameters are strings
|
||||
std::vector<std::string> num_warps;
|
||||
for(size_t i: opt_space_.num_warps)
|
||||
num_warps.push_back(std::to_string(i));
|
||||
std::vector<std::vector<std::string>> space;
|
||||
space.push_back(num_warps);
|
||||
for(const auto& i: opt_space_.defines)
|
||||
space.push_back(i.second);
|
||||
|
||||
// exhaustive search
|
||||
double best_ts = INFINITY;
|
||||
std::unique_ptr<caller> ret;
|
||||
|
||||
auto benchmark = [&](std::vector<std::string> params) {
|
||||
// extract options
|
||||
options_t opt;
|
||||
unsigned i = 0;
|
||||
opt.num_warps = std::stoi(params[i++]);
|
||||
for(auto it: opt_space_.defines){
|
||||
opt.defines[it.first] = params[i++];
|
||||
}
|
||||
// pre-process
|
||||
TokenSequence tokens;
|
||||
Preprocessor cpp(&src_, true);
|
||||
for(auto it: opt_space_.defines)
|
||||
cpp.AddMacro(it.first, &opt.defines.at(it.first));
|
||||
cpp.Process(tokens);
|
||||
|
||||
// parse
|
||||
Parser parser(tokens);
|
||||
parser.Parse();
|
||||
// triton-ir code-gen
|
||||
auto ir = make_ir(parser);
|
||||
// binary code-gen
|
||||
std::unique_ptr<driver::module> bin;
|
||||
try{
|
||||
bin = make_bin(*ir, stream->context(), opt);
|
||||
}catch(const std::runtime_error& e){
|
||||
return;
|
||||
}
|
||||
// kernel uses too much resources
|
||||
if(!bin)
|
||||
return;
|
||||
// copy constants
|
||||
std::unique_ptr<driver::buffer> buffer;
|
||||
for(ir::alloc_const* alloc: ir->allocs()){
|
||||
std::string name = alloc->get_name();
|
||||
auto it = cst_.find(name);
|
||||
if(it == cst_.end())
|
||||
throw std::runtime_error("constant not set before execution");
|
||||
buffer = bin->symbol(name.c_str());
|
||||
stream->write(&*buffer, true, 0, it->second);
|
||||
}
|
||||
// benchmark
|
||||
ir::function *tmp = ir->get_function_list()[0];
|
||||
caller call(tmp, std::move(bin), opt);
|
||||
double ts = tools::bench([&]() { call(stream, grid_fn(opt), args); }, stream, true);
|
||||
// save best
|
||||
if(ts < best_ts) {
|
||||
best_ts = ts;
|
||||
ret.reset(new caller(call));
|
||||
}
|
||||
};
|
||||
_parallel_loop_nest<std::string>(space, benchmark, 1);
|
||||
if(!ret)
|
||||
throw std::runtime_error("could not find valid option in provided space");
|
||||
return *ret;
|
||||
}
|
||||
|
||||
|
||||
std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::context *context, const options_t& opt) {
|
||||
// create Binary from Triton-IR
|
||||
std::unique_ptr<driver::module> function::make_bin(ir::module &module,
|
||||
driver::context *context,
|
||||
const options_t& opt) {
|
||||
std::unique_ptr<codegen::target> target = context->device()->make_target();
|
||||
// generate llvm code
|
||||
llvm::LLVMContext ctx;
|
||||
@@ -236,8 +220,6 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
|
||||
dce.run(module);
|
||||
peephole.run(module);
|
||||
dce.run(module);
|
||||
// ir::print(module, std::cout);
|
||||
// exit(EXIT_FAILURE);
|
||||
align.run(module);
|
||||
cts.run(module);
|
||||
axes.run(module);
|
||||
@@ -258,16 +240,135 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
|
||||
return std::unique_ptr<driver::module>();
|
||||
barriers.run(module);
|
||||
isel.visit(module, *llvm);
|
||||
// return binary
|
||||
std::unique_ptr<driver::module> res(driver::module::create(context, std::move(llvm)));
|
||||
// done
|
||||
// exit(EXIT_FAILURE);
|
||||
return res;
|
||||
}
|
||||
|
||||
|
||||
// create Binary from options
|
||||
function::caller* function::make(driver::stream *stream, options_t opt) {
|
||||
// cache path
|
||||
std::string cache_path = cache_path_ + opt.to_str() + ".ptx";
|
||||
int ref_mtime = tools::mtime(cache_ref_);
|
||||
int ptx_mtime = tools::mtime(cache_path);
|
||||
// if cached ptx is newer than reference library
|
||||
if(!ref_mtime || !ptx_mtime || ref_mtime < ptx_mtime){
|
||||
std::ifstream ifs(cache_path);
|
||||
// file is empty -- invalid
|
||||
if(ifs && ifs.peek() == std::ifstream::traits_type::eof())
|
||||
return nullptr;
|
||||
// load cached caller
|
||||
if(ifs)
|
||||
return new caller(stream->context(), ifs, opt);
|
||||
}
|
||||
// pre-process
|
||||
TokenSequence tokens;
|
||||
Preprocessor cpp(&src_, true);
|
||||
for(auto it: opt.defines)
|
||||
cpp.AddMacro(it.first, &it.second);
|
||||
cpp.Process(tokens);
|
||||
// src -> ast
|
||||
Parser parser(tokens);
|
||||
parser.Parse();
|
||||
// ast -> triton-ir
|
||||
auto ir = make_ir(parser);
|
||||
// triton-ir -> binary
|
||||
std::unique_ptr<driver::module> bin;
|
||||
try{
|
||||
bin = make_bin(*ir, stream->context(), opt);
|
||||
}catch(const std::runtime_error&){
|
||||
if(!cache_path_.empty())
|
||||
std::ofstream ofs(cache_path);
|
||||
return nullptr;
|
||||
}
|
||||
// create callable
|
||||
ir::function *tmp = ir->get_function_list()[0];
|
||||
caller* ret = new caller(tmp, std::move(bin), opt);
|
||||
// serialize callable
|
||||
if(!cache_path_.empty()){
|
||||
std::ofstream ofs(cache_path);
|
||||
ret->write(ofs);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
// precompile all kernels spanned by given options space
|
||||
void function::precompile(driver::stream* stream,
|
||||
const options_space_t& space) {
|
||||
// all ranges
|
||||
std::vector<size_t> ranges;
|
||||
ranges.push_back(space.num_warps.size());
|
||||
for(const auto& x: space.defines)
|
||||
ranges.push_back(x.second.size());
|
||||
// functor for source with given option
|
||||
auto do_make = [&](std::vector<size_t> params) {
|
||||
// compilation options
|
||||
unsigned i = 0;
|
||||
options_t opt;
|
||||
opt.num_warps = space.num_warps[params[i++]];
|
||||
for(auto D: space.defines)
|
||||
opt.defines[D.first] = D.second[params[i++]];
|
||||
// compile
|
||||
caller* call = make(stream, opt);
|
||||
if(!call)
|
||||
return;
|
||||
// copy constants
|
||||
std::unique_ptr<driver::buffer> buffer;
|
||||
for(const auto& cst: cst_){
|
||||
buffer = call->parent()->symbol(cst.first.c_str());
|
||||
stream->write(&*buffer, true, 0, cst.second);
|
||||
}
|
||||
callers_[opt].reset(call);
|
||||
};
|
||||
// multi-threaded compilation
|
||||
_loop_nest(ranges, do_make);
|
||||
if(callers_.empty())
|
||||
throw std::runtime_error("could not find valid option in provided space");
|
||||
}
|
||||
|
||||
// return auto-tuning key for given function arguments
|
||||
function::cache_key_t function::get_key(driver::stream *stream, const std::vector<arg>& args) {
|
||||
cache_key_t ret;
|
||||
ret.first = stream->context()->device();
|
||||
for(size_t i = 0; i < args.size(); i++){
|
||||
arg_type ty = args.at(i).type();
|
||||
if(!is_int_type(ty))
|
||||
continue;
|
||||
long val = 0;
|
||||
std::memcpy((void*)&val, args.at(i).data(), size_of(ty));
|
||||
ret.second.push_back(val);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
// returns program with best compilation options for given parameter
|
||||
function::caller* function::autotune(driver::stream* stream, const grid_fn_ty& grid_fn,
|
||||
const std::vector<arg>& args) {
|
||||
// fast path -- no autotuning necessary
|
||||
if(callers_.size() == 1)
|
||||
return &*callers_.begin()->second;
|
||||
// slow path -- autotuning necessary
|
||||
double best_ts = INFINITY;
|
||||
caller* ret = nullptr;
|
||||
for(auto &x : callers_){
|
||||
if(x.second == nullptr)
|
||||
throw std::runtime_error("configuration not compiled");
|
||||
caller* current = &*x.second;
|
||||
double ts = tools::bench([&]() { (*current)(stream, grid_fn(x.first), args); },
|
||||
stream, true);
|
||||
ret = (ts < best_ts) ? current : ret;
|
||||
best_ts = std::min(ts, best_ts);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
// set copy host buffer "data" into constant memory buffer "name"
|
||||
void function::set_cst(const std::string& name, void* data, size_t n_bytes) {
|
||||
cst_[name] = std::vector<char>((char*)data, (char*)data + n_bytes);
|
||||
}
|
||||
|
||||
|
||||
std::string function::preheader() {
|
||||
return
|
||||
R"(
|
||||
return R"(
|
||||
#define bool _Bool
|
||||
#define true 1
|
||||
#define false 0
|
||||
@@ -297,47 +398,65 @@ typedef long int64;
|
||||
)";
|
||||
}
|
||||
|
||||
function::function(const std::string &src, const options_space_t& opt): src_(src), opt_space_(opt) {
|
||||
std::string function::get_cache_prefix() {
|
||||
//user-specified cache path
|
||||
std::string result = tools::getenv("TRITON_CACHE_PATH");
|
||||
if(!result.empty()){
|
||||
if(tools::mkpath(result)==0)
|
||||
return result;
|
||||
}
|
||||
//create in home
|
||||
result = tools::getenv("HOME");
|
||||
if(!result.empty())
|
||||
{
|
||||
result = result + "/.triton/cache/";
|
||||
if(tools::mkpath(result)==0)
|
||||
return result;
|
||||
}
|
||||
return "";
|
||||
}
|
||||
|
||||
function::function(const std::string &src,
|
||||
const options_space_t& opt,
|
||||
const std::string &cache_ref):
|
||||
src_(src), opt_(opt), cache_ref_(cache_ref) {
|
||||
// hash source code
|
||||
unsigned char hash[20];
|
||||
sha1::calc((void*)src_.data(), src_.size(), hash);
|
||||
// create cache path
|
||||
char _hex[40];
|
||||
sha1::toHexString(hash, _hex);
|
||||
std::string hex(_hex, _hex + 40);
|
||||
cache_path_ = get_cache_prefix() + hex + "/";
|
||||
tools::mkpath(cache_path_);
|
||||
// append pre-header to source
|
||||
src_ = preheader() + src_;
|
||||
}
|
||||
|
||||
void function::operator()(const std::vector<arg>& args, const grid_fn_ty& grid_fn, driver::stream *stream) {
|
||||
cache_key_t key;
|
||||
|
||||
/* figure out if the kernel should be re-tuned */
|
||||
// re-tune if device is different
|
||||
key.first = stream->context()->device();
|
||||
// re-tune if any int argument is different
|
||||
for(size_t i = 0; i < args.size(); i++){
|
||||
arg_type ty = args.at(i).type();
|
||||
if(is_int_type(ty)){
|
||||
long val = 0;
|
||||
std::memcpy((void*)&val, args.at(i).data(), size_of(ty));
|
||||
key.second.push_back(val);
|
||||
}
|
||||
}
|
||||
|
||||
/* find existing configuration */
|
||||
void function::operator()(const std::vector<arg>& args,
|
||||
const grid_fn_ty& grid_fn,
|
||||
driver::stream *stream) {
|
||||
// pre-compile kernels
|
||||
if(callers_.empty())
|
||||
precompile(stream, opt_);
|
||||
// auto-tune if necessary
|
||||
auto key = get_key(stream, args);
|
||||
auto it = cache_.find(key);
|
||||
if(it != cache_.end()){
|
||||
it->second(stream, grid_fn(it->second.opt()), args);
|
||||
return;
|
||||
}
|
||||
|
||||
/* re-tune and re-compile */
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mut);
|
||||
cache_.insert({key, autotune(stream, grid_fn, args)});
|
||||
if(it == cache_.end()){
|
||||
auto best = autotune(stream, grid_fn, args);
|
||||
it = cache_.insert({key, best}).first;
|
||||
}
|
||||
// run
|
||||
(*it->second)(stream, grid_fn(it->second->opt()), args);
|
||||
}
|
||||
|
||||
void function::operator()(const std::vector<arg>& args, const grid_t& grid, driver::stream *stream) {
|
||||
void function::operator()(const std::vector<arg>& args,
|
||||
const grid_t& grid,
|
||||
driver::stream *stream) {
|
||||
return this->operator()(args, [&grid](const options_t&){ return grid; }, stream);
|
||||
}
|
||||
|
||||
void function::set_cst(const std::string& name, void* data, size_t n_bytes) {
|
||||
cst_[name] = std::vector<char>((char*)data, (char*)data + n_bytes);
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user