From 298da78058bb85d2cb7d1a6ec09be24018a9b557 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Sat, 7 Aug 2021 16:41:44 -0700 Subject: [PATCH] [CODEGEN/DRIVER] Tweaks for performance optimization (#193) --- include/triton/ir/dispatch.h | 1 + include/triton/ir/metadata.h | 3 ++- include/triton/tools/sys/exec.hpp | 37 +++++++++++++++++++++++++++++++ lib/codegen/analysis/align.cc | 5 +++++ lib/driver/module.cc | 10 ++++++--- lib/ir/dispatch.cc | 8 +++++++ python/src/triton.cc | 1 + python/triton/language.py | 8 +++++++ 8 files changed, 69 insertions(+), 4 deletions(-) create mode 100644 include/triton/tools/sys/exec.hpp diff --git a/include/triton/ir/dispatch.h b/include/triton/ir/dispatch.h index 1afb58fd5..c034dc191 100644 --- a/include/triton/ir/dispatch.h +++ b/include/triton/ir/dispatch.h @@ -98,6 +98,7 @@ struct dispatch{ // internal (debug/optimization) static ir::value *multiple_of(ir::value *x, int value, ir::builder *builder); + static ir::value *max_contiguous(ir::value *x, int value, ir::builder *builder); static ir::value *debug_barrier(ir::builder *builder); }; diff --git a/include/triton/ir/metadata.h b/include/triton/ir/metadata.h index e595da36c..9d4fb1137 100644 --- a/include/triton/ir/metadata.h +++ b/include/triton/ir/metadata.h @@ -11,7 +11,8 @@ namespace ir{ class metadata{ public: enum kind_t{ - multiple_of + multiple_of, + max_contiguous }; private: diff --git a/include/triton/tools/sys/exec.hpp b/include/triton/tools/sys/exec.hpp new file mode 100644 index 000000000..63e27609c --- /dev/null +++ b/include/triton/tools/sys/exec.hpp @@ -0,0 +1,37 @@ +#ifndef TRITON_TOOLS_SYS_EXEC_HPP +#define TRITON_TOOLS_SYS_EXEC_HPP + +#include +#include +#include +#include +#include + +namespace triton +{ +namespace tools +{ + + + +int exec(const std::string& cmd, std::string& result) { + char buffer[128]; + FILE* pipe = popen(cmd.c_str(), "r"); + if (!pipe) + return 0; + result.clear(); + try { + while (fgets(buffer, sizeof buffer, pipe) != NULL) + result += buffer; + } catch (...) { + pclose(pipe); + return 0; + } + return WEXITSTATUS(pclose(pipe)); + +} + +} +} + +#endif diff --git a/lib/codegen/analysis/align.cc b/lib/codegen/analysis/align.cc index f12718a8d..b972f6185 100644 --- a/lib/codegen/analysis/align.cc +++ b/lib/codegen/analysis/align.cc @@ -331,6 +331,11 @@ std::vector align::populate_max_contiguous_cast(ir::cast_inst* v){ std::vector align::populate_max_contiguous(ir::value *v){ if(max_contiguous_.find(v) != max_contiguous_.end()) return max_contiguous_.at(v); + if(auto *x = dynamic_cast(v)){ + unsigned max_contiguous = x->get_metadata(ir::metadata::max_contiguous); + if(max_contiguous > 0) + return add_to_cache(x, {max_contiguous}, max_contiguous_); + } if(auto *x = dynamic_cast(v)) return populate_max_contiguous_cast(x); if(auto *x = dynamic_cast(v)) diff --git a/lib/driver/module.cc b/lib/driver/module.cc index 83ef7e832..497e2f029 100755 --- a/lib/driver/module.cc +++ b/lib/driver/module.cc @@ -29,6 +29,7 @@ #include "triton/tools/sha1.hpp" #include "triton/tools/sys/getenv.hpp" #include "triton/tools/sys/mkdir.hpp" +#include "triton/tools/sys/exec.hpp" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Verifier.h" #include "llvm/IR/IRPrintingPasses.h" @@ -299,10 +300,13 @@ std::string cu_module::compile_llvm_module(llvm::Module* module, driver::device* void cu_module::init_from_ptx(const std::string& ptx, driver::cu_device* device) { // JIT compile source-code try{ - std::string ptxas = tools::getenv("TRITON_PTXAS"); + // use ptxas if present in PATH. Otherwise, use JIT from the driver + std::string ptxas = "ptxas"; + std::string version; + int use_system_ptxas = tools::exec(ptxas + " --version 2>&1", version) == 0; // Use PTXAS via system call - if(!ptxas.empty()){ + if(use_system_ptxas){ // compile ptx with ptxas char _fsrc[] = "/tmp/triton_k_XXXXXX"; char _flog[] = "/tmp/triton_l_XXXXXX"; @@ -316,7 +320,7 @@ void cu_module::init_from_ptx(const std::string& ptx, driver::cu_device* device) std::string cmd; int err; std::string cc = std::to_string(device->compute_capability()); - cmd = "ptxas -v --gpu-name=sm_" + cc + " " + fsrc + " -o " + fsrc + ".o 2> " + flog; + cmd = ptxas + " -v --gpu-name=sm_" + cc + " " + fsrc + " -o " + fsrc + ".o 2> " + flog; err = system(cmd.c_str()); dispatch::cuModuleLoad(&*cu_, (fsrc + ".o").c_str()); unlink(_fsrc); diff --git a/lib/ir/dispatch.cc b/lib/ir/dispatch.cc index 63fcb9723..a55e318b5 100644 --- a/lib/ir/dispatch.cc +++ b/lib/ir/dispatch.cc @@ -711,6 +711,14 @@ ir::value *dispatch::multiple_of(ir::value *x, int value, ir::builder *){ return i; } +ir::value *dispatch::max_contiguous(ir::value *x, int value, ir::builder *){ + ir::instruction* i = dynamic_cast(x); + if(!i) + throw_unreachable("max_contiguous"); + i->set_metadata(ir::metadata::max_contiguous, value); + return i; +} + ir::value *dispatch::debug_barrier(ir::builder *builder) { return builder->create_barrier(); } diff --git a/python/src/triton.cc b/python/src/triton.cc index 47db99998..63214fa9c 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -161,6 +161,7 @@ void init_triton_frontend(py::module &&m) { m.def("sqrt", &ir::dispatch::sqrt, ret::reference); // internal (debugging only) m.def("multiple_of", &ir::dispatch::multiple_of, ret::reference); + m.def("max_contiguous", &ir::dispatch::max_contiguous, ret::reference); m.def("debug_barrier", &ir::dispatch::debug_barrier, ret::reference); } diff --git a/python/triton/language.py b/python/triton/language.py index cea27d9ff..bc14b4235 100644 --- a/python/triton/language.py +++ b/python/triton/language.py @@ -637,6 +637,14 @@ def multiple_of(input, value, builder=None): return frontend.multiple_of(input, value, builder) +@builtin +def max_contiguous(input, value, builder=None): + """ + Let the compiler knows that the `value` first values in :code:`input` are contiguous. + """ + return frontend.max_contiguous(input, value, builder) + + # ----------------------- # Standard library # -----------------------