[general] major overhaul of triton-c/triton-ir/triton-jit:
- Added alloc const - Added atomics - Pruning tuning space - Added example for dot/conv/shift - Bugfixes
This commit is contained in:
38
lib/jit.cpp
38
lib/jit.cpp
@@ -68,7 +68,7 @@ void loop_nest(std::vector<std::vector<T>> const & iterates, std::function<void(
|
||||
|
||||
|
||||
std::unique_ptr<llvm::Module> jit::make_llvm_module(ir::module &module, passes_wrapper &passes) {
|
||||
llvm::Module* result = new llvm::Module("matmul", llvm_context_);
|
||||
llvm::Module* result = new llvm::Module(module.get_name(), llvm_context_);
|
||||
passes.selection.run(module, *result);
|
||||
// launch information
|
||||
auto &launch_info_map = launch_info_map_[result->getName()];
|
||||
@@ -79,14 +79,14 @@ std::unique_ptr<llvm::Module> jit::make_llvm_module(ir::module &module, passes_w
|
||||
return std::unique_ptr<llvm::Module>(result);
|
||||
}
|
||||
|
||||
std::unique_ptr<ir::module> jit::make_triton_module(const std::string &src) {
|
||||
std::unique_ptr<ir::module> jit::make_triton_module(const std::string &name, const std::string &src) {
|
||||
// create AST from Triton-C source
|
||||
YY_BUFFER_STATE buffer = yy_scan_string(src.c_str());
|
||||
yyparse();
|
||||
yy_delete_buffer(buffer);
|
||||
translation_unit *program = ast_root;
|
||||
// create Triton-IR from AST
|
||||
ir::module* module = new ir::module("matrix", triton_context_);
|
||||
ir::module* module = new ir::module(name, triton_context_);
|
||||
program->codegen(module);
|
||||
return std::unique_ptr<ir::module>(module);
|
||||
}
|
||||
@@ -97,18 +97,20 @@ jit::jit(driver::context *context): driver_context_(context),
|
||||
}
|
||||
|
||||
|
||||
void jit::autotune(const std::string &src, benchmark_t benchmark) {
|
||||
void jit::autotune(const std::string &name, const std::string &src, benchmark_t benchmark) {
|
||||
// find metaparameters
|
||||
auto ptt_module = make_triton_module(src);
|
||||
auto ptt_module = make_triton_module(name, src);
|
||||
ir::module &tt_module = *ptt_module;
|
||||
// set parameters
|
||||
passes_wrapper passes(target_.get());
|
||||
passes.target_independent(tt_module);
|
||||
passes.tune.run(tt_module);
|
||||
auto mps = passes.tune.get_params(tt_module);
|
||||
// create parameter ranges
|
||||
std::vector<std::vector<unsigned>> ranges;
|
||||
for(ir::metaparameter *mp: mps)
|
||||
ranges.push_back(mp->get_space());
|
||||
// std::cout << ranges.size() << std::endl;
|
||||
// iterate over parameters
|
||||
unsigned i;
|
||||
double best = 0;
|
||||
@@ -117,51 +119,56 @@ void jit::autotune(const std::string &src, benchmark_t benchmark) {
|
||||
i = 0;
|
||||
for(ir::metaparameter *mp: mps)
|
||||
mp->set_value(params[i++]);
|
||||
passes.target_independent(tt_module);
|
||||
passes.tune.init(tt_module);
|
||||
if(!passes.tune.check_constraints(errors))
|
||||
return;
|
||||
// Deep copy of the module and tuner
|
||||
auto ptt_module = make_triton_module(src);
|
||||
auto ptt_module = make_triton_module(name, src);
|
||||
ir::module &tt_module = *ptt_module;
|
||||
passes_wrapper passes(target_.get());
|
||||
passes.target_independent(tt_module);
|
||||
passes.tune.run(tt_module);
|
||||
i = 0;
|
||||
for(ir::metaparameter* mp: passes.tune.get_params(tt_module)){
|
||||
mp->set_value(params[i++]);
|
||||
}
|
||||
passes.tune.init(tt_module);
|
||||
passes.init(tt_module);
|
||||
passes.target_dependent(tt_module);
|
||||
driver::device* device = driver_context_->device();
|
||||
if(passes.allocation.get_allocated_size() > device->max_shared_memory())
|
||||
if(passes.shmem_allocation.get_allocated_size() > device->max_shared_memory())
|
||||
return;
|
||||
if(passes.tune.get_num_threads() > device->max_threads_per_block())
|
||||
return;
|
||||
// Compile
|
||||
auto ll_module = make_llvm_module(tt_module, passes);
|
||||
std::unique_ptr<driver::module> module(driver::module::create(driver_context_, &*ll_module));
|
||||
std::unique_ptr<driver::kernel> kernel(driver::kernel::create(module.get(), "matmul"));
|
||||
launch_information info = launch_info_map_.at("matmul");
|
||||
std::unique_ptr<driver::kernel> kernel(driver::kernel::create(module.get(), name.c_str()));
|
||||
launch_information info = launch_info_map_.at(name.c_str());
|
||||
for(unsigned p: params)
|
||||
std::cout << p << " " << std::flush;
|
||||
// add globals
|
||||
for(auto x: tt_module.globals())
|
||||
global_ints_[x.first] = ((ir::metaparameter*)x.second)->get_value();
|
||||
modules_.push_back(module.get());
|
||||
double perf;
|
||||
perf = benchmark(kernel.get(), info);
|
||||
best = std::max(perf, best);
|
||||
std::cout << perf << " [ " << best << " ] " << std::endl;
|
||||
modules_.pop_back();
|
||||
});
|
||||
}
|
||||
|
||||
void jit::add_module(ir::module &tt_module, const std::vector<unsigned> ¶ms) {
|
||||
// set parameters
|
||||
passes_wrapper passes(target_.get());
|
||||
passes.target_independent(tt_module);
|
||||
passes.tune.run(tt_module);
|
||||
unsigned i = 0;
|
||||
for(ir::metaparameter* mp: passes.tune.get_params(tt_module))
|
||||
mp->set_value(params[i++]);
|
||||
passes.tune.init(tt_module);
|
||||
passes.init(tt_module);
|
||||
passes.target_dependent(tt_module);
|
||||
// check constraints
|
||||
std::map<ir::value*, std::vector<std::string>> errors;
|
||||
passes.tune.check_constraints(errors);
|
||||
@@ -184,8 +191,8 @@ void jit::add_module(ir::module &tt_module, const std::vector<unsigned> ¶ms)
|
||||
global_ints_[x.first] = ((ir::metaparameter*)x.second)->get_value();
|
||||
}
|
||||
|
||||
void jit::add_module(const std::string &src, const std::vector<unsigned> ¶ms) {
|
||||
auto ptt_module = make_triton_module(src);
|
||||
void jit::add_module(const std::string &name, const std::string &src, const std::vector<unsigned> ¶ms) {
|
||||
auto ptt_module = make_triton_module(name, src);
|
||||
add_module(*ptt_module, params);
|
||||
}
|
||||
|
||||
@@ -201,4 +208,9 @@ unsigned jit::get_int(const std::string &name){
|
||||
return global_ints_.at(name);
|
||||
}
|
||||
|
||||
driver::buffer *jit::get_buffer(const std::string &name){
|
||||
driver::cu_module *mod = (driver::cu_module*)modules_.front();
|
||||
return mod->symbol(name.c_str());
|
||||
}
|
||||
|
||||
}
|
||||
|
Reference in New Issue
Block a user