[GENERAL] Merged v1.0alpha into master. Added features are:
- A100 support via mma.16816 - Thread swizzling for conflict-free shared memory accesses without padding - Complete overhaul of the LLVM code generation in codegen/selection/generator.cc to remove overengineering - Added debugging capabilities in the Python binding - Compilation error for kernels that spill
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
#include <string>
|
||||
#include <string>
|
||||
#include <mutex>
|
||||
#include <regex>
|
||||
#include <functional>
|
||||
@@ -9,11 +9,13 @@
|
||||
#include "triton/codegen/analysis/allocation.h"
|
||||
#include "triton/codegen/analysis/liveness.h"
|
||||
#include "triton/codegen/analysis/align.h"
|
||||
#include "triton/codegen/analysis/swizzle.h"
|
||||
#include "triton/codegen/transform/coalesce.h"
|
||||
#include "triton/codegen/transform/dce.h"
|
||||
#include "triton/codegen/transform/peephole.h"
|
||||
#include "triton/codegen/transform/membar.h"
|
||||
#include "triton/codegen/transform/reassociate.h"
|
||||
#include "triton/codegen/transform/reorder.h"
|
||||
#include "triton/codegen/transform/cts.h"
|
||||
#include "triton/codegen/transform/disassociate.h"
|
||||
#include "triton/codegen/selection/generator.h"
|
||||
@@ -29,6 +31,7 @@
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/print.h"
|
||||
#include "triton/runtime/error.h"
|
||||
#include "triton/tools/bench.hpp"
|
||||
#include "triton/tools/sha1.hpp"
|
||||
#include "triton/tools/sys/getenv.hpp"
|
||||
@@ -67,7 +70,7 @@ void _loop_nest(std::vector<size_t> const & ranges,
|
||||
/* OPTIONS */
|
||||
/* --------------------- */
|
||||
|
||||
std::string function::options_t::to_str() const{
|
||||
std::string options_t::to_str() const{
|
||||
std::string ret = "nw-" + std::to_string(num_warps);
|
||||
for(const auto& x : defines){
|
||||
ret += '-';
|
||||
@@ -110,41 +113,41 @@ arg_type convert(ir::type *ty) {
|
||||
throw std::runtime_error("unknown type");
|
||||
}
|
||||
|
||||
void function::caller::write(std::ofstream &ofs) {
|
||||
// write name
|
||||
ofs << name_ << std::endl;
|
||||
// write signature
|
||||
for(size_t i = 0; i < param_tys_.size(); i++)
|
||||
ofs << param_tys_[i] << " ";
|
||||
ofs << std::endl;
|
||||
// write module
|
||||
std::string source = ((driver::cu_module*)(&*parent_))->source();
|
||||
ofs << source;
|
||||
}
|
||||
//void function::caller::write(std::ofstream &ofs) {
|
||||
// // write name
|
||||
// ofs << name_ << std::endl;
|
||||
// // write signature
|
||||
// for(size_t i = 0; i < param_tys_.size(); i++)
|
||||
// ofs << param_tys_[i] << " ";
|
||||
// ofs << std::endl;
|
||||
// // write module
|
||||
// std::string source = ((driver::cu_module*)(&*parent_))->ptx();
|
||||
// ofs << source;
|
||||
//}
|
||||
|
||||
void function::caller::read(std::ifstream &ifs) {
|
||||
// read name
|
||||
std::getline(ifs, name_);
|
||||
// read signature
|
||||
std::string line;
|
||||
std::getline(ifs, line);
|
||||
std::istringstream current(line);
|
||||
int param;
|
||||
param_tys_.clear();
|
||||
while(current >> param)
|
||||
param_tys_.push_back((arg_type)param);
|
||||
// read module
|
||||
std::string src((std::istreambuf_iterator<char>(ifs)),
|
||||
std::istreambuf_iterator<char>());
|
||||
parent_.reset(new driver::cu_module(src));
|
||||
bin_.reset(driver::kernel::create(&*parent_, name_.c_str()));
|
||||
//void function::caller::read(driver::context* ctx, std::ifstream &ifs) {
|
||||
// // read name
|
||||
// std::getline(ifs, name_);
|
||||
// // read signature
|
||||
// std::string line;
|
||||
// std::getline(ifs, line);
|
||||
// std::istringstream current(line);
|
||||
// int param;
|
||||
// param_tys_.clear();
|
||||
// while(current >> param)
|
||||
// param_tys_.push_back((arg_type)param);
|
||||
// // read module
|
||||
// std::string src((std::istreambuf_iterator<char>(ifs)),
|
||||
// std::istreambuf_iterator<char>());
|
||||
// parent_.reset(new driver::cu_module(ctx, src));
|
||||
// bin_.reset(driver::kernel::create(&*parent_, name_.c_str()));
|
||||
|
||||
}
|
||||
//}
|
||||
|
||||
function::caller::caller(std::ifstream &ifs, const options_t& opt)
|
||||
: opt_(opt) {
|
||||
read(ifs);
|
||||
}
|
||||
//function::caller::caller(driver::context* ctx, std::ifstream &ifs, const options_t& opt)
|
||||
// : opt_(opt) {
|
||||
// read(ctx, ifs);
|
||||
//}
|
||||
|
||||
function::caller::caller(ir::function *ir,
|
||||
std::shared_ptr<driver::module> parent, const options_t& opt)
|
||||
@@ -198,20 +201,23 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::d
|
||||
// generate llvm code
|
||||
llvm::LLVMContext ctx;
|
||||
std::unique_ptr<llvm::Module> llvm(new llvm::Module(module.get_name(), ctx));
|
||||
// optimizations
|
||||
bool cts_use_async = target->as_nvidia()->sm() >= 80;
|
||||
// create passes
|
||||
codegen::analysis::align align;
|
||||
codegen::analysis::axes axes;
|
||||
codegen::transform::cts cts(cts_use_async);
|
||||
codegen::transform::disassociate disassociate;
|
||||
codegen::analysis::layouts layouts(&axes, &align, opt.num_warps, target.get());
|
||||
codegen::analysis::liveness liveness(&layouts);
|
||||
codegen::analysis::swizzle swizzle(&layouts, target.get());
|
||||
codegen::analysis::allocation allocation(&liveness);
|
||||
codegen::transform::membar barriers(&liveness, &layouts, &allocation);
|
||||
codegen::transform::dce dce;
|
||||
codegen::transform::peephole peephole;
|
||||
codegen::transform::peephole peephole(target.get());
|
||||
codegen::transform::reassociate reassociate;
|
||||
codegen::transform::coalesce coalesce(&align, &layouts);
|
||||
codegen::transform::cts cts;
|
||||
codegen::generator isel(&axes, &layouts, &align, &allocation, target.get(), opt.num_warps);
|
||||
codegen::generator isel(&axes, &layouts, &align, &allocation, &swizzle, target.get(), opt.num_warps);
|
||||
// run passes
|
||||
dce.run(module);
|
||||
disassociate.run(module);
|
||||
@@ -233,17 +239,20 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::d
|
||||
}
|
||||
peephole.run(module);
|
||||
dce.run(module);
|
||||
// ir::print(module, std::cout);
|
||||
align.run(module);
|
||||
axes.run(module);
|
||||
layouts.run(module);
|
||||
swizzle.run(module);
|
||||
liveness.run(module);
|
||||
allocation.run(module);
|
||||
if(allocation.allocated_size() > device->max_shared_memory())
|
||||
throw std::runtime_error("using too much shared memory");
|
||||
throw exception::out_of_shared_memory();
|
||||
barriers.run(module);
|
||||
// ir::print(module, std::cout);
|
||||
isel.visit(module, *llvm);
|
||||
std::unique_ptr<driver::module> res(driver::module::create(device, std::move(llvm)));
|
||||
if(res->spilled() > 256)
|
||||
throw exception::out_of_registers();
|
||||
return res;
|
||||
}
|
||||
|
||||
@@ -265,11 +274,11 @@ void function::make(driver::device *device, options_t opt) {
|
||||
auto ir = make_ir(parser);
|
||||
// triton-ir -> binary
|
||||
std::unique_ptr<driver::module> bin;
|
||||
// try{
|
||||
try{
|
||||
bin = make_bin(*ir, device, opt);
|
||||
// }catch(const std::runtime_error&){
|
||||
// return nullptr;
|
||||
// }
|
||||
}catch(const exception::base&){
|
||||
throw;
|
||||
}
|
||||
// create callable
|
||||
ir::function *tmp = ir->get_function_list()[0];
|
||||
callers_[opt].reset(new caller(tmp, std::move(bin), opt));
|
||||
@@ -283,6 +292,7 @@ void function::precompile(driver::device* device, const options_space_t& space)
|
||||
for(const auto& x: space.defines)
|
||||
ranges.push_back(x.second.size());
|
||||
// functor for source with given option
|
||||
std::map<options_t, std::string> err;
|
||||
auto do_make = [&](std::vector<size_t> params) {
|
||||
// compilation options
|
||||
unsigned i = 0;
|
||||
@@ -291,20 +301,73 @@ void function::precompile(driver::device* device, const options_space_t& space)
|
||||
for(auto D: space.defines)
|
||||
opt.defines[D.first] = D.second[params[i++]];
|
||||
// compile
|
||||
make(device, opt);
|
||||
try{
|
||||
make(device, opt);
|
||||
}catch(const exception::base& e){
|
||||
err[opt] = e.what();
|
||||
}
|
||||
};
|
||||
// multi-threaded compilation
|
||||
_loop_nest(ranges, do_make);
|
||||
if(callers_.empty())
|
||||
throw std::runtime_error("could not compile kernel");
|
||||
if(callers_.empty()){
|
||||
std::ostringstream dbg;
|
||||
dbg << "Auto-Tuner could not find any valid configuration:" << std::endl;
|
||||
for(auto x: err){
|
||||
dbg << "[ ";
|
||||
dbg << x.first.num_warps << ", ";
|
||||
dbg << "{ ";
|
||||
for(const auto& y: x.first.defines)
|
||||
dbg << '"' << y.first << "\"= \"" << y.second << "\", ";
|
||||
dbg << " } ] -> " << x.second << std::endl;
|
||||
}
|
||||
throw exception::no_valid_configuration(dbg.str());
|
||||
}
|
||||
}
|
||||
|
||||
std::string function::ptx(driver::device* device, const options_t& opt) {
|
||||
std::string function::get_asm(asm_mode_t mode, driver::device* device, const options_t& opt) {
|
||||
make(device, opt);
|
||||
const auto& fn = callers_.at(opt);
|
||||
if(!fn)
|
||||
return "";
|
||||
return ((driver::cu_module*)fn->parent())->source();
|
||||
switch(mode){
|
||||
case ASM_LLIR:{
|
||||
return fn->parent()->llir();
|
||||
}
|
||||
case ASM_NV_PTX:
|
||||
case ASM_NV_SASS:{
|
||||
std::string ptx = ((driver::cu_module*)fn->parent())->ptx();
|
||||
// SASS
|
||||
std::string input = std::tmpnam(nullptr);
|
||||
std::string output = std::tmpnam(nullptr);
|
||||
std::ofstream ofs(input);
|
||||
ofs << ptx;
|
||||
ofs.close();
|
||||
if(mode == ASM_NV_PTX)
|
||||
return ptx;
|
||||
std::string cmd;
|
||||
int err;
|
||||
// compile ptx
|
||||
driver::cu_device* cu_device = (driver::cu_device*)device;
|
||||
cmd = "ptxas --gpu-name=sm_" + std::to_string(cu_device->compute_capability()) + " " + input + " -o " + input + ".o";
|
||||
err = system(cmd.c_str());
|
||||
// disassemble
|
||||
cmd = "cuobjdump --dump-sass " + input + ".o >> " + output;
|
||||
err = system(cmd.c_str());
|
||||
std::regex comment(" *\\/\\* 0x[0-9a-f]+ \\*\\/");
|
||||
std::string to_delete = " /*";
|
||||
std::ifstream ifs(output);
|
||||
std::string line;
|
||||
std::string sass;
|
||||
while(std::getline(ifs, line))
|
||||
if(!std::regex_match(line, comment))
|
||||
sass += line + "\n";
|
||||
return sass;
|
||||
}
|
||||
default:
|
||||
return "";
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
// returns program with best compilation options for given parameter
|
||||
|
Reference in New Issue
Block a user