From 0b025db2ee5d02027d2cfb1d985d03fadbd6548f Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Sun, 31 Jan 2021 01:01:32 -0500 Subject: [PATCH] [RUNTIME] Added option to print LLVM-IR Also includes appropriate driver code change for that --- include/triton/driver/module.h | 6 +- include/triton/runtime/function.h | 1 + lib/driver/module.cc | 96 ++++++++++++++++--------------- lib/runtime/function.cc | 39 +++++++++++++ tutorials/01-matmul.cc | 60 +++++++++---------- 5 files changed, 120 insertions(+), 82 deletions(-) diff --git a/include/triton/driver/module.h b/include/triton/driver/module.h index 0cdfbb84c..df98d5eb2 100755 --- a/include/triton/driver/module.h +++ b/include/triton/driver/module.h @@ -44,11 +44,8 @@ public: const std::string &features, file_type_t file_type); virtual std::unique_ptr symbol(const char * name) const = 0; - std::string llir() const { return llir_; } int spilled() const { return spilled_; } -private: - std::string llir_; protected: int spilled_; }; @@ -63,15 +60,18 @@ public: // CUDA class cu_module: public module { std::string compile_llvm_module(std::unique_ptr module, driver::device* device); + void init_from_ptx(const std::string& ptx); public: cu_module(driver::device* device, std::unique_ptr module); cu_module(driver::device* device, const std::string& source); std::unique_ptr symbol(const char * name) const; + std::string llir() const { return llir_; } const std::string& ptx() const { return ptx_; } private: std::string ptx_; + std::string llir_; }; diff --git a/include/triton/runtime/function.h b/include/triton/runtime/function.h index 319bc6fdb..f3fc41298 100644 --- a/include/triton/runtime/function.h +++ b/include/triton/runtime/function.h @@ -82,6 +82,7 @@ public: void operator()(void* args, size_t args_size, driver::stream *stream, const std::vector& grid) const; // getters const std::vector& get_sig() const { return sig_; } + std::string get_asm(asm_mode_t mode); private: void init_ir (const std::string &src); diff --git a/lib/driver/module.cc b/lib/driver/module.cc index 1ccfd4091..f0035fc84 100755 --- a/lib/driver/module.cc +++ b/lib/driver/module.cc @@ -99,44 +99,7 @@ void module::compile_llvm_module(std::unique_ptr module, const std llvm::SmallVectorImpl &buffer, const std::string& features, file_type_t ft) { - init_llvm(); -// // debug - llvm::legacy::PassManager pm; - std::string tmp; -// llvm::raw_string_ostream oss(llir_); -// pm.add(llvm::createPrintModulePass(llvm::outs())); - pm.add(llvm::createVerifierPass()); - pm.run(*module); - // create machine - module->setTargetTriple(triple); - std::string error; - auto target = llvm::TargetRegistry::lookupTarget(module->getTargetTriple(), error); - llvm::TargetOptions opt; - opt.AllowFPOpFusion = llvm::FPOpFusion::Fast; - opt.UnsafeFPMath = false; - opt.NoInfsFPMath = false; - opt.NoNaNsFPMath = true; - llvm::TargetMachine *machine = target->createTargetMachine(module->getTargetTriple(), proc, features, opt, - llvm::Reloc::PIC_, llvm::None, llvm::CodeGenOpt::Aggressive); - // set data layout - if(layout.empty()) - module->setDataLayout(machine->createDataLayout()); - else - module->setDataLayout(layout); - // emit machine code - for (llvm::Function &f : module->functions()) - f.addFnAttr(llvm::Attribute::AlwaysInline); - llvm::legacy::PassManager pass; - llvm::raw_svector_ostream stream(buffer); - // convert triton file type to llvm file type - auto ll_file_type = [&](module::file_type_t type){ - if(type == Object) - return llvm::CodeGenFileType::CGFT_ObjectFile; - return llvm::CodeGenFileType::CGFT_AssemblyFile; - }; - // emit - machine->addPassesToEmitFile(pass, stream, nullptr, ll_file_type(ft)); - pass.run(*module); + } @@ -271,7 +234,41 @@ std::string cu_module::compile_llvm_module(std::unique_ptr module, int ptx_minor = ptx % 10; // create llvm::SmallVector buffer; - module::compile_llvm_module(std::move(module), "nvptx64-nvidia-cuda", "sm_" + std::to_string(std::min(cc, max_nvvm_cc)), "", buffer, "+ptx" + std::to_string(std::min(ptx, max_nvvm_ptx)), Assembly); + std::string triple = "nvptx64-nvidia-cuda"; + std::string proc = "sm_" + std::to_string(std::min(cc, max_nvvm_cc)); + std::string layout = ""; + std::string features = "+ptx" + std::to_string(std::min(ptx, max_nvvm_ptx)); + init_llvm(); + // verify and store llvm + llvm::legacy::PassManager pm; + pm.add(llvm::createVerifierPass()); + pm.run(*module); + // create machine + module->setTargetTriple(triple); + std::string error; + auto target = llvm::TargetRegistry::lookupTarget(module->getTargetTriple(), error); + llvm::TargetOptions opt; + opt.AllowFPOpFusion = llvm::FPOpFusion::Fast; + opt.UnsafeFPMath = false; + opt.NoInfsFPMath = false; + opt.NoNaNsFPMath = true; + llvm::TargetMachine *machine = target->createTargetMachine(module->getTargetTriple(), proc, features, opt, + llvm::Reloc::PIC_, llvm::None, llvm::CodeGenOpt::Aggressive); + // set data layout + if(layout.empty()) + module->setDataLayout(machine->createDataLayout()); + else + module->setDataLayout(layout); + // emit machine code + for (llvm::Function &f : module->functions()) + f.addFnAttr(llvm::Attribute::AlwaysInline); + llvm::legacy::PassManager pass; + llvm::raw_svector_ostream stream(buffer); + // emit + machine->addPassesToEmitFile(pass, stream, nullptr, llvm::CodeGenFileType::CGFT_AssemblyFile); + pass.run(*module); + + // post-process std::string result(buffer.begin(), buffer.end()); find_and_replace(result, ".version", "\n", ".version " + std::to_string(ptx_major) + "." + std::to_string(ptx_minor) + "\n"); find_and_replace(result, ".target", "\n", ".target " + sm + "\n"); @@ -280,10 +277,7 @@ std::string cu_module::compile_llvm_module(std::unique_ptr module, return result; } - -cu_module::cu_module(driver::device* device, std::unique_ptr ll_module): cu_module(device, compile_llvm_module(std::move(ll_module), device)) { } - -cu_module::cu_module(driver::device* device, std::string const & source) : module(CUmodule(), true), ptx_(source){ +void cu_module::init_from_ptx(const std::string& ptx) { // JIT compile source-code try{ @@ -295,7 +289,7 @@ cu_module::cu_module(driver::device* device, std::string const & source) : modul // std::string fsrc = _fsrc; // std::string flog = _flog; // std::ofstream ofs(fsrc); -// ofs << source; +// ofs << ptx; // ofs.close(); // std::string cmd; // int err; @@ -340,7 +334,7 @@ cu_module::cu_module(driver::device* device, std::string const & source) : modul } catch(exception::cuda::invalid_ptx const &){ //#ifdef TRITON_LOG_PTX_ERROR - std::cout << source << std::endl; + std::cout << ptx << std::endl; std::cerr << "It appears that Triton produced invalid PTX code:" << std::endl; // exit(1); //#endif @@ -348,6 +342,18 @@ cu_module::cu_module(driver::device* device, std::string const & source) : modul } } +cu_module::cu_module(driver::device* device, std::unique_ptr ll_module): module(CUmodule(), true) { + llvm::raw_string_ostream oss(llir_); + oss << *ll_module; + oss.flush(); + ptx_ = compile_llvm_module(std::move(ll_module), device); + init_from_ptx(ptx_); +} + +cu_module::cu_module(driver::device*, std::string const & source) : module(CUmodule(), true), ptx_(source){ + init_from_ptx(ptx_); +} + std::unique_ptr cu_module::symbol(const char *name) const{ CUdeviceptr handle; size_t size; diff --git a/lib/runtime/function.cc b/lib/runtime/function.cc index f23c720c9..3f8c7cafc 100644 --- a/lib/runtime/function.cc +++ b/lib/runtime/function.cc @@ -224,6 +224,45 @@ void kernel::operator()(void *args, size_t args_size, driver::stream *stream, co stream->enqueue(&*ker_, grid, {opt.num_warps * 32, 1, 1}, args, args_size); } +std::string kernel::get_asm(asm_mode_t mode) { + switch(mode){ + case ASM_LLIR:{ + return ((driver::cu_module*)mod_.get())->llir(); + } + case ASM_NV_PTX: + case ASM_NV_SASS:{ + std::string ptx = ((driver::cu_module*)mod_.get())->ptx(); + // SASS + std::string input = std::tmpnam(nullptr); + std::string output = std::tmpnam(nullptr); + std::ofstream ofs(input); + ofs << ptx; + ofs.close(); + if(mode == ASM_NV_PTX) + return ptx; + std::string cmd; + int err; + // compile ptx + driver::cu_device* cu_device = (driver::cu_device*)dev_; + cmd = "ptxas --gpu-name=sm_" + std::to_string(cu_device->compute_capability()) + " " + input + " -o " + input + ".o"; + err = system(cmd.c_str()); + // disassemble + cmd = "cuobjdump --dump-sass " + input + ".o >> " + output; + err = system(cmd.c_str()); + std::regex comment(" *\\/\\* 0x[0-9a-f]+ \\*\\/"); + std::string to_delete = " /*"; + std::ifstream ifs(output); + std::string line; + std::string sass; + while(std::getline(ifs, line)) + if(!std::regex_match(line, comment)) + sass += line + "\n"; + return sass; + } + default: + return ""; + } +} /* --------------------------------- */ /* --------------------------------- */ /* --------------------------------- */ diff --git a/tutorials/01-matmul.cc b/tutorials/01-matmul.cc index 6606ad22f..8f4e67643 100644 --- a/tutorials/01-matmul.cc +++ b/tutorials/01-matmul.cc @@ -131,15 +131,14 @@ template<> struct to_string{ }; template -void triton_dot(drv::context* context, drv::stream* stream, bool AT, bool BT, - int32_t M, int32_t N, int32_t K, - const std::vector& a_order, const std::vector& b_order, - std::vector& bench, bool &test){ +float triton_dot(drv::context* context, drv::stream* stream, + bool AT, bool BT, + int32_t M, int32_t N, int32_t K){ std::string ty = to_string::value; size_t dt_nbytes = sizeof(T); drv::device* device = context->device(); - int32_t lda = (AT ^ a_order[0]==1) ? K : M; - int32_t ldb = (BT ^ b_order[0]==1) ? N : K; + int32_t lda = AT ? K : M; + int32_t ldb = BT ? N : K; int32_t ldc = N; std::vector sa = { "1", "lda" }; std::vector sb = { "1", "ldb" }; @@ -156,18 +155,16 @@ void triton_dot(drv::context* context, drv::stream* stream, bool AT, bool BT, ha[i] = (float)rand()/RAND_MAX; for(size_t i = 0; i < hb.size(); i++) hb[i] = (float)rand()/RAND_MAX; - // copy buffer stream->write(&*da, true, 0, ha); stream->write(&*db, true, 0, hb); - // macros rt::options_space_t opts; // A access patterns - opts.defines.push_back({"STRIDE_AK", {AT? sa[a_order[0]] : sa[a_order[1]] }}); - opts.defines.push_back({"STRIDE_AM", {AT? sa[a_order[1]] : sa[a_order[0]] }}); + opts.defines.push_back({"STRIDE_AK", {AT? "1" : "lda" }}); + opts.defines.push_back({"STRIDE_AM", {AT? "lda" : "1" }}); // B access patterns - opts.defines.push_back({"STRIDE_BK", {BT? sb[b_order[1]] : sb[b_order[0]] }}); - opts.defines.push_back({"STRIDE_BN", {BT? sb[b_order[0]] : sb[b_order[1]] }}); + opts.defines.push_back({"STRIDE_BK", {BT? "ldb" : "1" }}); + opts.defines.push_back({"STRIDE_BN", {BT? "1" : "ldb" }}); // data-type opts.defines.push_back({"TYPE", {ty}}); // tile sizes @@ -190,8 +187,9 @@ void triton_dot(drv::context* context, drv::stream* stream, bool AT, bool BT, rt::add_arg(oss, ldb); rt::add_arg(oss, ldc); rt::add_arg(oss, *dlocks->cu()); - // kernel + // function rt::function function(src::dot, opts, device); +// std::cout << function.get_kernels()[0].second->get_asm(rt::ASM_LLIR) << std::endl; // grid auto ceil = [](size_t x, size_t y) { return (x + y - 1) / y; }; auto grid = [ceil, M, N](const rt::options_t& x) { @@ -203,43 +201,37 @@ void triton_dot(drv::context* context, drv::stream* stream, bool AT, bool BT, // metrics auto tflops = [&](double nanosec) { return 2.*M*N*K / nanosec * 1e-3; }; double triton_ns = triton::tools::bench([&]() { function((void**)oss.str().data(), oss.str().size(), grid, stream);}, stream); - bench.push_back(tflops(triton_ns)); + return tflops(triton_ns); } -std::vector bench_dot(drv::context* context, drv::stream* stream, - dtype_t dtype, bool AT, bool BT, - int32_t M, int32_t N, int32_t K, - const std::vector& a_order, const std::vector& b_order) { - std::vector bench; - bool test; +float bench_dot(drv::context* context, drv::stream* stream, + bool AT, bool BT, + int32_t M, int32_t N, int32_t K, + dtype_t dtype) { switch(dtype){ - case HALF: triton_dot(context, stream, AT, BT, M, N, K, a_order, b_order, bench, test); break; - case FLOAT: triton_dot(context, stream, AT, BT, M, N, K, a_order, b_order, bench, test); break; - case DOUBLE: triton_dot(context, stream, AT, BT, M, N, K, a_order, b_order, bench, test); break; - default: break; + case HALF: return triton_dot(context, stream, AT, BT, M, N, K); + case FLOAT: return triton_dot(context, stream, AT, BT, M, N, K); + case DOUBLE: return triton_dot(context, stream, AT, BT, M, N, K); + default: return 0; } - return bench; } - int main() { // initialize default compute device auto context = triton::driver::backend::contexts::get_default(); triton::driver::stream* stream = triton::driver::stream::create(context->backend()); // shapes to benchmark - typedef std::tuple, bool, bool, int, int, int> config_t; + typedef std::tuple config_t; std::vector configs = { - {{1, 0}, false, false, 8192, 8192, 8192} + {false, false, 8192, 8192, 8192} }; // does the work - std::vector ord; bool AT, BT; int32_t M, N, K; + dtype_t dtype = HALF; for(const auto& c: configs){ - std::tie(ord, AT, BT, M, N, K) = c; - std::cout << "// " << AT << ", " << BT << ", " << M << ", " << N << ", " << K ; - for(auto perf: bench_dot(context, stream, HALF, AT, BT, M, N, K, ord, ord)) - std::cout << ", " << perf << std::flush; - std::cout << std::endl; + std::tie(AT, BT, M, N, K) = c; + float tflops = bench_dot(context, stream, AT, BT, M, N, K, dtype); + std::cout << "// " << AT << ", " << BT << ", " << M << ", " << N << ", " << K << ", " << tflops << std::endl; } }