[general] major overhaul of triton-c/triton-ir/triton-jit:

- Added alloc const
- Added atomics
- Pruning tuning space
- Added example for dot/conv/shift
- Bugfixes
This commit is contained in:
Philippe Tillet
2019-04-25 16:17:36 -04:00
parent 0c607c9392
commit 3413aad582
50 changed files with 2051 additions and 570 deletions

View File

@@ -68,7 +68,7 @@ void loop_nest(std::vector<std::vector<T>> const & iterates, std::function<void(
std::unique_ptr<llvm::Module> jit::make_llvm_module(ir::module &module, passes_wrapper &passes) {
llvm::Module* result = new llvm::Module("matmul", llvm_context_);
llvm::Module* result = new llvm::Module(module.get_name(), llvm_context_);
passes.selection.run(module, *result);
// launch information
auto &launch_info_map = launch_info_map_[result->getName()];
@@ -79,14 +79,14 @@ std::unique_ptr<llvm::Module> jit::make_llvm_module(ir::module &module, passes_w
return std::unique_ptr<llvm::Module>(result);
}
std::unique_ptr<ir::module> jit::make_triton_module(const std::string &src) {
std::unique_ptr<ir::module> jit::make_triton_module(const std::string &name, const std::string &src) {
// create AST from Triton-C source
YY_BUFFER_STATE buffer = yy_scan_string(src.c_str());
yyparse();
yy_delete_buffer(buffer);
translation_unit *program = ast_root;
// create Triton-IR from AST
ir::module* module = new ir::module("matrix", triton_context_);
ir::module* module = new ir::module(name, triton_context_);
program->codegen(module);
return std::unique_ptr<ir::module>(module);
}
@@ -97,18 +97,20 @@ jit::jit(driver::context *context): driver_context_(context),
}
void jit::autotune(const std::string &src, benchmark_t benchmark) {
void jit::autotune(const std::string &name, const std::string &src, benchmark_t benchmark) {
// find metaparameters
auto ptt_module = make_triton_module(src);
auto ptt_module = make_triton_module(name, src);
ir::module &tt_module = *ptt_module;
// set parameters
passes_wrapper passes(target_.get());
passes.target_independent(tt_module);
passes.tune.run(tt_module);
auto mps = passes.tune.get_params(tt_module);
// create parameter ranges
std::vector<std::vector<unsigned>> ranges;
for(ir::metaparameter *mp: mps)
ranges.push_back(mp->get_space());
// std::cout << ranges.size() << std::endl;
// iterate over parameters
unsigned i;
double best = 0;
@@ -117,51 +119,56 @@ void jit::autotune(const std::string &src, benchmark_t benchmark) {
i = 0;
for(ir::metaparameter *mp: mps)
mp->set_value(params[i++]);
passes.target_independent(tt_module);
passes.tune.init(tt_module);
if(!passes.tune.check_constraints(errors))
return;
// Deep copy of the module and tuner
auto ptt_module = make_triton_module(src);
auto ptt_module = make_triton_module(name, src);
ir::module &tt_module = *ptt_module;
passes_wrapper passes(target_.get());
passes.target_independent(tt_module);
passes.tune.run(tt_module);
i = 0;
for(ir::metaparameter* mp: passes.tune.get_params(tt_module)){
mp->set_value(params[i++]);
}
passes.tune.init(tt_module);
passes.init(tt_module);
passes.target_dependent(tt_module);
driver::device* device = driver_context_->device();
if(passes.allocation.get_allocated_size() > device->max_shared_memory())
if(passes.shmem_allocation.get_allocated_size() > device->max_shared_memory())
return;
if(passes.tune.get_num_threads() > device->max_threads_per_block())
return;
// Compile
auto ll_module = make_llvm_module(tt_module, passes);
std::unique_ptr<driver::module> module(driver::module::create(driver_context_, &*ll_module));
std::unique_ptr<driver::kernel> kernel(driver::kernel::create(module.get(), "matmul"));
launch_information info = launch_info_map_.at("matmul");
std::unique_ptr<driver::kernel> kernel(driver::kernel::create(module.get(), name.c_str()));
launch_information info = launch_info_map_.at(name.c_str());
for(unsigned p: params)
std::cout << p << " " << std::flush;
// add globals
for(auto x: tt_module.globals())
global_ints_[x.first] = ((ir::metaparameter*)x.second)->get_value();
modules_.push_back(module.get());
double perf;
perf = benchmark(kernel.get(), info);
best = std::max(perf, best);
std::cout << perf << " [ " << best << " ] " << std::endl;
modules_.pop_back();
});
}
void jit::add_module(ir::module &tt_module, const std::vector<unsigned> &params) {
// set parameters
passes_wrapper passes(target_.get());
passes.target_independent(tt_module);
passes.tune.run(tt_module);
unsigned i = 0;
for(ir::metaparameter* mp: passes.tune.get_params(tt_module))
mp->set_value(params[i++]);
passes.tune.init(tt_module);
passes.init(tt_module);
passes.target_dependent(tt_module);
// check constraints
std::map<ir::value*, std::vector<std::string>> errors;
passes.tune.check_constraints(errors);
@@ -184,8 +191,8 @@ void jit::add_module(ir::module &tt_module, const std::vector<unsigned> &params)
global_ints_[x.first] = ((ir::metaparameter*)x.second)->get_value();
}
void jit::add_module(const std::string &src, const std::vector<unsigned> &params) {
auto ptt_module = make_triton_module(src);
void jit::add_module(const std::string &name, const std::string &src, const std::vector<unsigned> &params) {
auto ptt_module = make_triton_module(name, src);
add_module(*ptt_module, params);
}
@@ -201,4 +208,9 @@ unsigned jit::get_int(const std::string &name){
return global_ints_.at(name);
}
driver::buffer *jit::get_buffer(const std::string &name){
driver::cu_module *mod = (driver::cu_module*)modules_.front();
return mod->symbol(name.c_str());
}
}