[RUNTIME] Added option to print LLVM-IR
Also includes appropriate driver code change for that
This commit is contained in:
@@ -44,11 +44,8 @@ public:
|
|||||||
const std::string &features,
|
const std::string &features,
|
||||||
file_type_t file_type);
|
file_type_t file_type);
|
||||||
virtual std::unique_ptr<buffer> symbol(const char * name) const = 0;
|
virtual std::unique_ptr<buffer> symbol(const char * name) const = 0;
|
||||||
std::string llir() const { return llir_; }
|
|
||||||
int spilled() const { return spilled_; }
|
int spilled() const { return spilled_; }
|
||||||
|
|
||||||
private:
|
|
||||||
std::string llir_;
|
|
||||||
protected:
|
protected:
|
||||||
int spilled_;
|
int spilled_;
|
||||||
};
|
};
|
||||||
@@ -63,15 +60,18 @@ public:
|
|||||||
// CUDA
|
// CUDA
|
||||||
class cu_module: public module {
|
class cu_module: public module {
|
||||||
std::string compile_llvm_module(std::unique_ptr<llvm::Module> module, driver::device* device);
|
std::string compile_llvm_module(std::unique_ptr<llvm::Module> module, driver::device* device);
|
||||||
|
void init_from_ptx(const std::string& ptx);
|
||||||
|
|
||||||
public:
|
public:
|
||||||
cu_module(driver::device* device, std::unique_ptr<llvm::Module> module);
|
cu_module(driver::device* device, std::unique_ptr<llvm::Module> module);
|
||||||
cu_module(driver::device* device, const std::string& source);
|
cu_module(driver::device* device, const std::string& source);
|
||||||
std::unique_ptr<buffer> symbol(const char * name) const;
|
std::unique_ptr<buffer> symbol(const char * name) const;
|
||||||
|
std::string llir() const { return llir_; }
|
||||||
const std::string& ptx() const { return ptx_; }
|
const std::string& ptx() const { return ptx_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::string ptx_;
|
std::string ptx_;
|
||||||
|
std::string llir_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
@@ -82,6 +82,7 @@ public:
|
|||||||
void operator()(void* args, size_t args_size, driver::stream *stream, const std::vector<size_t>& grid) const;
|
void operator()(void* args, size_t args_size, driver::stream *stream, const std::vector<size_t>& grid) const;
|
||||||
// getters
|
// getters
|
||||||
const std::vector<arg_type>& get_sig() const { return sig_; }
|
const std::vector<arg_type>& get_sig() const { return sig_; }
|
||||||
|
std::string get_asm(asm_mode_t mode);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void init_ir (const std::string &src);
|
void init_ir (const std::string &src);
|
||||||
|
@@ -99,44 +99,7 @@ void module::compile_llvm_module(std::unique_ptr<llvm::Module> module, const std
|
|||||||
llvm::SmallVectorImpl<char> &buffer,
|
llvm::SmallVectorImpl<char> &buffer,
|
||||||
const std::string& features,
|
const std::string& features,
|
||||||
file_type_t ft) {
|
file_type_t ft) {
|
||||||
init_llvm();
|
|
||||||
// // debug
|
|
||||||
llvm::legacy::PassManager pm;
|
|
||||||
std::string tmp;
|
|
||||||
// llvm::raw_string_ostream oss(llir_);
|
|
||||||
// pm.add(llvm::createPrintModulePass(llvm::outs()));
|
|
||||||
pm.add(llvm::createVerifierPass());
|
|
||||||
pm.run(*module);
|
|
||||||
// create machine
|
|
||||||
module->setTargetTriple(triple);
|
|
||||||
std::string error;
|
|
||||||
auto target = llvm::TargetRegistry::lookupTarget(module->getTargetTriple(), error);
|
|
||||||
llvm::TargetOptions opt;
|
|
||||||
opt.AllowFPOpFusion = llvm::FPOpFusion::Fast;
|
|
||||||
opt.UnsafeFPMath = false;
|
|
||||||
opt.NoInfsFPMath = false;
|
|
||||||
opt.NoNaNsFPMath = true;
|
|
||||||
llvm::TargetMachine *machine = target->createTargetMachine(module->getTargetTriple(), proc, features, opt,
|
|
||||||
llvm::Reloc::PIC_, llvm::None, llvm::CodeGenOpt::Aggressive);
|
|
||||||
// set data layout
|
|
||||||
if(layout.empty())
|
|
||||||
module->setDataLayout(machine->createDataLayout());
|
|
||||||
else
|
|
||||||
module->setDataLayout(layout);
|
|
||||||
// emit machine code
|
|
||||||
for (llvm::Function &f : module->functions())
|
|
||||||
f.addFnAttr(llvm::Attribute::AlwaysInline);
|
|
||||||
llvm::legacy::PassManager pass;
|
|
||||||
llvm::raw_svector_ostream stream(buffer);
|
|
||||||
// convert triton file type to llvm file type
|
|
||||||
auto ll_file_type = [&](module::file_type_t type){
|
|
||||||
if(type == Object)
|
|
||||||
return llvm::CodeGenFileType::CGFT_ObjectFile;
|
|
||||||
return llvm::CodeGenFileType::CGFT_AssemblyFile;
|
|
||||||
};
|
|
||||||
// emit
|
|
||||||
machine->addPassesToEmitFile(pass, stream, nullptr, ll_file_type(ft));
|
|
||||||
pass.run(*module);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -271,7 +234,41 @@ std::string cu_module::compile_llvm_module(std::unique_ptr<llvm::Module> module,
|
|||||||
int ptx_minor = ptx % 10;
|
int ptx_minor = ptx % 10;
|
||||||
// create
|
// create
|
||||||
llvm::SmallVector<char, 0> buffer;
|
llvm::SmallVector<char, 0> buffer;
|
||||||
module::compile_llvm_module(std::move(module), "nvptx64-nvidia-cuda", "sm_" + std::to_string(std::min(cc, max_nvvm_cc)), "", buffer, "+ptx" + std::to_string(std::min(ptx, max_nvvm_ptx)), Assembly);
|
std::string triple = "nvptx64-nvidia-cuda";
|
||||||
|
std::string proc = "sm_" + std::to_string(std::min(cc, max_nvvm_cc));
|
||||||
|
std::string layout = "";
|
||||||
|
std::string features = "+ptx" + std::to_string(std::min(ptx, max_nvvm_ptx));
|
||||||
|
init_llvm();
|
||||||
|
// verify and store llvm
|
||||||
|
llvm::legacy::PassManager pm;
|
||||||
|
pm.add(llvm::createVerifierPass());
|
||||||
|
pm.run(*module);
|
||||||
|
// create machine
|
||||||
|
module->setTargetTriple(triple);
|
||||||
|
std::string error;
|
||||||
|
auto target = llvm::TargetRegistry::lookupTarget(module->getTargetTriple(), error);
|
||||||
|
llvm::TargetOptions opt;
|
||||||
|
opt.AllowFPOpFusion = llvm::FPOpFusion::Fast;
|
||||||
|
opt.UnsafeFPMath = false;
|
||||||
|
opt.NoInfsFPMath = false;
|
||||||
|
opt.NoNaNsFPMath = true;
|
||||||
|
llvm::TargetMachine *machine = target->createTargetMachine(module->getTargetTriple(), proc, features, opt,
|
||||||
|
llvm::Reloc::PIC_, llvm::None, llvm::CodeGenOpt::Aggressive);
|
||||||
|
// set data layout
|
||||||
|
if(layout.empty())
|
||||||
|
module->setDataLayout(machine->createDataLayout());
|
||||||
|
else
|
||||||
|
module->setDataLayout(layout);
|
||||||
|
// emit machine code
|
||||||
|
for (llvm::Function &f : module->functions())
|
||||||
|
f.addFnAttr(llvm::Attribute::AlwaysInline);
|
||||||
|
llvm::legacy::PassManager pass;
|
||||||
|
llvm::raw_svector_ostream stream(buffer);
|
||||||
|
// emit
|
||||||
|
machine->addPassesToEmitFile(pass, stream, nullptr, llvm::CodeGenFileType::CGFT_AssemblyFile);
|
||||||
|
pass.run(*module);
|
||||||
|
|
||||||
|
// post-process
|
||||||
std::string result(buffer.begin(), buffer.end());
|
std::string result(buffer.begin(), buffer.end());
|
||||||
find_and_replace(result, ".version", "\n", ".version " + std::to_string(ptx_major) + "." + std::to_string(ptx_minor) + "\n");
|
find_and_replace(result, ".version", "\n", ".version " + std::to_string(ptx_major) + "." + std::to_string(ptx_minor) + "\n");
|
||||||
find_and_replace(result, ".target", "\n", ".target " + sm + "\n");
|
find_and_replace(result, ".target", "\n", ".target " + sm + "\n");
|
||||||
@@ -280,10 +277,7 @@ std::string cu_module::compile_llvm_module(std::unique_ptr<llvm::Module> module,
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void cu_module::init_from_ptx(const std::string& ptx) {
|
||||||
cu_module::cu_module(driver::device* device, std::unique_ptr<llvm::Module> ll_module): cu_module(device, compile_llvm_module(std::move(ll_module), device)) { }
|
|
||||||
|
|
||||||
cu_module::cu_module(driver::device* device, std::string const & source) : module(CUmodule(), true), ptx_(source){
|
|
||||||
// JIT compile source-code
|
// JIT compile source-code
|
||||||
|
|
||||||
try{
|
try{
|
||||||
@@ -295,7 +289,7 @@ cu_module::cu_module(driver::device* device, std::string const & source) : modul
|
|||||||
// std::string fsrc = _fsrc;
|
// std::string fsrc = _fsrc;
|
||||||
// std::string flog = _flog;
|
// std::string flog = _flog;
|
||||||
// std::ofstream ofs(fsrc);
|
// std::ofstream ofs(fsrc);
|
||||||
// ofs << source;
|
// ofs << ptx;
|
||||||
// ofs.close();
|
// ofs.close();
|
||||||
// std::string cmd;
|
// std::string cmd;
|
||||||
// int err;
|
// int err;
|
||||||
@@ -340,7 +334,7 @@ cu_module::cu_module(driver::device* device, std::string const & source) : modul
|
|||||||
}
|
}
|
||||||
catch(exception::cuda::invalid_ptx const &){
|
catch(exception::cuda::invalid_ptx const &){
|
||||||
//#ifdef TRITON_LOG_PTX_ERROR
|
//#ifdef TRITON_LOG_PTX_ERROR
|
||||||
std::cout << source << std::endl;
|
std::cout << ptx << std::endl;
|
||||||
std::cerr << "It appears that Triton produced invalid PTX code:" << std::endl;
|
std::cerr << "It appears that Triton produced invalid PTX code:" << std::endl;
|
||||||
// exit(1);
|
// exit(1);
|
||||||
//#endif
|
//#endif
|
||||||
@@ -348,6 +342,18 @@ cu_module::cu_module(driver::device* device, std::string const & source) : modul
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
cu_module::cu_module(driver::device* device, std::unique_ptr<llvm::Module> ll_module): module(CUmodule(), true) {
|
||||||
|
llvm::raw_string_ostream oss(llir_);
|
||||||
|
oss << *ll_module;
|
||||||
|
oss.flush();
|
||||||
|
ptx_ = compile_llvm_module(std::move(ll_module), device);
|
||||||
|
init_from_ptx(ptx_);
|
||||||
|
}
|
||||||
|
|
||||||
|
cu_module::cu_module(driver::device*, std::string const & source) : module(CUmodule(), true), ptx_(source){
|
||||||
|
init_from_ptx(ptx_);
|
||||||
|
}
|
||||||
|
|
||||||
std::unique_ptr<buffer> cu_module::symbol(const char *name) const{
|
std::unique_ptr<buffer> cu_module::symbol(const char *name) const{
|
||||||
CUdeviceptr handle;
|
CUdeviceptr handle;
|
||||||
size_t size;
|
size_t size;
|
||||||
|
@@ -224,6 +224,45 @@ void kernel::operator()(void *args, size_t args_size, driver::stream *stream, co
|
|||||||
stream->enqueue(&*ker_, grid, {opt.num_warps * 32, 1, 1}, args, args_size);
|
stream->enqueue(&*ker_, grid, {opt.num_warps * 32, 1, 1}, args, args_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::string kernel::get_asm(asm_mode_t mode) {
|
||||||
|
switch(mode){
|
||||||
|
case ASM_LLIR:{
|
||||||
|
return ((driver::cu_module*)mod_.get())->llir();
|
||||||
|
}
|
||||||
|
case ASM_NV_PTX:
|
||||||
|
case ASM_NV_SASS:{
|
||||||
|
std::string ptx = ((driver::cu_module*)mod_.get())->ptx();
|
||||||
|
// SASS
|
||||||
|
std::string input = std::tmpnam(nullptr);
|
||||||
|
std::string output = std::tmpnam(nullptr);
|
||||||
|
std::ofstream ofs(input);
|
||||||
|
ofs << ptx;
|
||||||
|
ofs.close();
|
||||||
|
if(mode == ASM_NV_PTX)
|
||||||
|
return ptx;
|
||||||
|
std::string cmd;
|
||||||
|
int err;
|
||||||
|
// compile ptx
|
||||||
|
driver::cu_device* cu_device = (driver::cu_device*)dev_;
|
||||||
|
cmd = "ptxas --gpu-name=sm_" + std::to_string(cu_device->compute_capability()) + " " + input + " -o " + input + ".o";
|
||||||
|
err = system(cmd.c_str());
|
||||||
|
// disassemble
|
||||||
|
cmd = "cuobjdump --dump-sass " + input + ".o >> " + output;
|
||||||
|
err = system(cmd.c_str());
|
||||||
|
std::regex comment(" *\\/\\* 0x[0-9a-f]+ \\*\\/");
|
||||||
|
std::string to_delete = " /*";
|
||||||
|
std::ifstream ifs(output);
|
||||||
|
std::string line;
|
||||||
|
std::string sass;
|
||||||
|
while(std::getline(ifs, line))
|
||||||
|
if(!std::regex_match(line, comment))
|
||||||
|
sass += line + "\n";
|
||||||
|
return sass;
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
}
|
||||||
/* --------------------------------- */
|
/* --------------------------------- */
|
||||||
/* --------------------------------- */
|
/* --------------------------------- */
|
||||||
/* --------------------------------- */
|
/* --------------------------------- */
|
||||||
|
@@ -131,15 +131,14 @@ template<> struct to_string<double>{
|
|||||||
};
|
};
|
||||||
|
|
||||||
template<class T>
|
template<class T>
|
||||||
void triton_dot(drv::context* context, drv::stream* stream, bool AT, bool BT,
|
float triton_dot(drv::context* context, drv::stream* stream,
|
||||||
int32_t M, int32_t N, int32_t K,
|
bool AT, bool BT,
|
||||||
const std::vector<int>& a_order, const std::vector<int>& b_order,
|
int32_t M, int32_t N, int32_t K){
|
||||||
std::vector<double>& bench, bool &test){
|
|
||||||
std::string ty = to_string<T>::value;
|
std::string ty = to_string<T>::value;
|
||||||
size_t dt_nbytes = sizeof(T);
|
size_t dt_nbytes = sizeof(T);
|
||||||
drv::device* device = context->device();
|
drv::device* device = context->device();
|
||||||
int32_t lda = (AT ^ a_order[0]==1) ? K : M;
|
int32_t lda = AT ? K : M;
|
||||||
int32_t ldb = (BT ^ b_order[0]==1) ? N : K;
|
int32_t ldb = BT ? N : K;
|
||||||
int32_t ldc = N;
|
int32_t ldc = N;
|
||||||
std::vector<std::string> sa = { "1", "lda" };
|
std::vector<std::string> sa = { "1", "lda" };
|
||||||
std::vector<std::string> sb = { "1", "ldb" };
|
std::vector<std::string> sb = { "1", "ldb" };
|
||||||
@@ -156,18 +155,16 @@ void triton_dot(drv::context* context, drv::stream* stream, bool AT, bool BT,
|
|||||||
ha[i] = (float)rand()/RAND_MAX;
|
ha[i] = (float)rand()/RAND_MAX;
|
||||||
for(size_t i = 0; i < hb.size(); i++)
|
for(size_t i = 0; i < hb.size(); i++)
|
||||||
hb[i] = (float)rand()/RAND_MAX;
|
hb[i] = (float)rand()/RAND_MAX;
|
||||||
// copy buffer
|
|
||||||
stream->write(&*da, true, 0, ha);
|
stream->write(&*da, true, 0, ha);
|
||||||
stream->write(&*db, true, 0, hb);
|
stream->write(&*db, true, 0, hb);
|
||||||
|
|
||||||
// macros
|
// macros
|
||||||
rt::options_space_t opts;
|
rt::options_space_t opts;
|
||||||
// A access patterns
|
// A access patterns
|
||||||
opts.defines.push_back({"STRIDE_AK", {AT? sa[a_order[0]] : sa[a_order[1]] }});
|
opts.defines.push_back({"STRIDE_AK", {AT? "1" : "lda" }});
|
||||||
opts.defines.push_back({"STRIDE_AM", {AT? sa[a_order[1]] : sa[a_order[0]] }});
|
opts.defines.push_back({"STRIDE_AM", {AT? "lda" : "1" }});
|
||||||
// B access patterns
|
// B access patterns
|
||||||
opts.defines.push_back({"STRIDE_BK", {BT? sb[b_order[1]] : sb[b_order[0]] }});
|
opts.defines.push_back({"STRIDE_BK", {BT? "ldb" : "1" }});
|
||||||
opts.defines.push_back({"STRIDE_BN", {BT? sb[b_order[0]] : sb[b_order[1]] }});
|
opts.defines.push_back({"STRIDE_BN", {BT? "1" : "ldb" }});
|
||||||
// data-type
|
// data-type
|
||||||
opts.defines.push_back({"TYPE", {ty}});
|
opts.defines.push_back({"TYPE", {ty}});
|
||||||
// tile sizes
|
// tile sizes
|
||||||
@@ -190,8 +187,9 @@ void triton_dot(drv::context* context, drv::stream* stream, bool AT, bool BT,
|
|||||||
rt::add_arg(oss, ldb);
|
rt::add_arg(oss, ldb);
|
||||||
rt::add_arg(oss, ldc);
|
rt::add_arg(oss, ldc);
|
||||||
rt::add_arg(oss, *dlocks->cu());
|
rt::add_arg(oss, *dlocks->cu());
|
||||||
// kernel
|
// function
|
||||||
rt::function function(src::dot, opts, device);
|
rt::function function(src::dot, opts, device);
|
||||||
|
// std::cout << function.get_kernels()[0].second->get_asm(rt::ASM_LLIR) << std::endl;
|
||||||
// grid
|
// grid
|
||||||
auto ceil = [](size_t x, size_t y) { return (x + y - 1) / y; };
|
auto ceil = [](size_t x, size_t y) { return (x + y - 1) / y; };
|
||||||
auto grid = [ceil, M, N](const rt::options_t& x) {
|
auto grid = [ceil, M, N](const rt::options_t& x) {
|
||||||
@@ -203,43 +201,37 @@ void triton_dot(drv::context* context, drv::stream* stream, bool AT, bool BT,
|
|||||||
// metrics
|
// metrics
|
||||||
auto tflops = [&](double nanosec) { return 2.*M*N*K / nanosec * 1e-3; };
|
auto tflops = [&](double nanosec) { return 2.*M*N*K / nanosec * 1e-3; };
|
||||||
double triton_ns = triton::tools::bench([&]() { function((void**)oss.str().data(), oss.str().size(), grid, stream);}, stream);
|
double triton_ns = triton::tools::bench([&]() { function((void**)oss.str().data(), oss.str().size(), grid, stream);}, stream);
|
||||||
bench.push_back(tflops(triton_ns));
|
return tflops(triton_ns);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<double> bench_dot(drv::context* context, drv::stream* stream,
|
float bench_dot(drv::context* context, drv::stream* stream,
|
||||||
dtype_t dtype, bool AT, bool BT,
|
bool AT, bool BT,
|
||||||
int32_t M, int32_t N, int32_t K,
|
int32_t M, int32_t N, int32_t K,
|
||||||
const std::vector<int>& a_order, const std::vector<int>& b_order) {
|
dtype_t dtype) {
|
||||||
std::vector<double> bench;
|
|
||||||
bool test;
|
|
||||||
switch(dtype){
|
switch(dtype){
|
||||||
case HALF: triton_dot<half_float::half>(context, stream, AT, BT, M, N, K, a_order, b_order, bench, test); break;
|
case HALF: return triton_dot<half_float::half>(context, stream, AT, BT, M, N, K);
|
||||||
case FLOAT: triton_dot<float>(context, stream, AT, BT, M, N, K, a_order, b_order, bench, test); break;
|
case FLOAT: return triton_dot<float>(context, stream, AT, BT, M, N, K);
|
||||||
case DOUBLE: triton_dot<double>(context, stream, AT, BT, M, N, K, a_order, b_order, bench, test); break;
|
case DOUBLE: return triton_dot<double>(context, stream, AT, BT, M, N, K);
|
||||||
default: break;
|
default: return 0;
|
||||||
}
|
}
|
||||||
return bench;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
int main() {
|
int main() {
|
||||||
// initialize default compute device
|
// initialize default compute device
|
||||||
auto context = triton::driver::backend::contexts::get_default();
|
auto context = triton::driver::backend::contexts::get_default();
|
||||||
triton::driver::stream* stream = triton::driver::stream::create(context->backend());
|
triton::driver::stream* stream = triton::driver::stream::create(context->backend());
|
||||||
// shapes to benchmark
|
// shapes to benchmark
|
||||||
typedef std::tuple<std::vector<int>, bool, bool, int, int, int> config_t;
|
typedef std::tuple<bool, bool, int, int, int> config_t;
|
||||||
std::vector<config_t> configs = {
|
std::vector<config_t> configs = {
|
||||||
{{1, 0}, false, false, 8192, 8192, 8192}
|
{false, false, 8192, 8192, 8192}
|
||||||
};
|
};
|
||||||
// does the work
|
// does the work
|
||||||
std::vector<int> ord;
|
|
||||||
bool AT, BT;
|
bool AT, BT;
|
||||||
int32_t M, N, K;
|
int32_t M, N, K;
|
||||||
|
dtype_t dtype = HALF;
|
||||||
for(const auto& c: configs){
|
for(const auto& c: configs){
|
||||||
std::tie(ord, AT, BT, M, N, K) = c;
|
std::tie(AT, BT, M, N, K) = c;
|
||||||
std::cout << "// " << AT << ", " << BT << ", " << M << ", " << N << ", " << K ;
|
float tflops = bench_dot(context, stream, AT, BT, M, N, K, dtype);
|
||||||
for(auto perf: bench_dot(context, stream, HALF, AT, BT, M, N, K, ord, ord))
|
std::cout << "// " << AT << ", " << BT << ", " << M << ", " << N << ", " << K << ", " << tflops << std::endl;
|
||||||
std::cout << ", " << perf << std::flush;
|
|
||||||
std::cout << std::endl;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user