[CODEGEN/DRIVER] Tweaks for performance optimization (#193)
This commit is contained in:
		| @@ -98,6 +98,7 @@ struct dispatch{ | |||||||
|  |  | ||||||
|   // internal (debug/optimization) |   // internal (debug/optimization) | ||||||
|   static ir::value *multiple_of(ir::value *x, int value, ir::builder *builder); |   static ir::value *multiple_of(ir::value *x, int value, ir::builder *builder); | ||||||
|  |   static ir::value *max_contiguous(ir::value *x, int value, ir::builder *builder); | ||||||
|   static ir::value *debug_barrier(ir::builder *builder); |   static ir::value *debug_barrier(ir::builder *builder); | ||||||
| }; | }; | ||||||
|  |  | ||||||
|   | |||||||
| @@ -11,7 +11,8 @@ namespace ir{ | |||||||
| class metadata{ | class metadata{ | ||||||
| public: | public: | ||||||
|   enum kind_t{ |   enum kind_t{ | ||||||
|     multiple_of |     multiple_of, | ||||||
|  |     max_contiguous | ||||||
|   }; |   }; | ||||||
|  |  | ||||||
| private: | private: | ||||||
|   | |||||||
							
								
								
									
										37
									
								
								include/triton/tools/sys/exec.hpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										37
									
								
								include/triton/tools/sys/exec.hpp
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,37 @@ | |||||||
|  | #ifndef TRITON_TOOLS_SYS_EXEC_HPP | ||||||
|  | #define TRITON_TOOLS_SYS_EXEC_HPP | ||||||
|  |  | ||||||
|  | #include <cstdio> | ||||||
|  | #include <iostream> | ||||||
|  | #include <memory> | ||||||
|  | #include <stdexcept> | ||||||
|  | #include <string> | ||||||
|  |  | ||||||
|  | namespace triton | ||||||
|  | { | ||||||
|  | namespace tools | ||||||
|  | { | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  | int exec(const std::string& cmd, std::string& result) { | ||||||
|  |   char buffer[128]; | ||||||
|  |   FILE* pipe = popen(cmd.c_str(), "r"); | ||||||
|  |   if (!pipe) | ||||||
|  |     return 0; | ||||||
|  |   result.clear(); | ||||||
|  |   try { | ||||||
|  |     while (fgets(buffer, sizeof buffer, pipe) != NULL) | ||||||
|  |       result += buffer; | ||||||
|  |   } catch (...) { | ||||||
|  |     pclose(pipe); | ||||||
|  |     return 0; | ||||||
|  |   } | ||||||
|  |   return WEXITSTATUS(pclose(pipe)); | ||||||
|  |  | ||||||
|  | } | ||||||
|  |  | ||||||
|  | } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | #endif | ||||||
| @@ -331,6 +331,11 @@ std::vector<unsigned> align::populate_max_contiguous_cast(ir::cast_inst* v){ | |||||||
| std::vector<unsigned> align::populate_max_contiguous(ir::value *v){ | std::vector<unsigned> align::populate_max_contiguous(ir::value *v){ | ||||||
|   if(max_contiguous_.find(v) != max_contiguous_.end()) |   if(max_contiguous_.find(v) != max_contiguous_.end()) | ||||||
|     return max_contiguous_.at(v); |     return max_contiguous_.at(v); | ||||||
|  |   if(auto *x = dynamic_cast<ir::instruction*>(v)){ | ||||||
|  |     unsigned max_contiguous = x->get_metadata(ir::metadata::max_contiguous); | ||||||
|  |     if(max_contiguous > 0) | ||||||
|  |       return add_to_cache(x, {max_contiguous}, max_contiguous_); | ||||||
|  |   } | ||||||
|   if(auto *x = dynamic_cast<ir::cast_inst*>(v)) |   if(auto *x = dynamic_cast<ir::cast_inst*>(v)) | ||||||
|     return populate_max_contiguous_cast(x); |     return populate_max_contiguous_cast(x); | ||||||
|   if(auto *x = dynamic_cast<ir::splat_inst*>(v)) |   if(auto *x = dynamic_cast<ir::splat_inst*>(v)) | ||||||
|   | |||||||
| @@ -29,6 +29,7 @@ | |||||||
| #include "triton/tools/sha1.hpp" | #include "triton/tools/sha1.hpp" | ||||||
| #include "triton/tools/sys/getenv.hpp" | #include "triton/tools/sys/getenv.hpp" | ||||||
| #include "triton/tools/sys/mkdir.hpp" | #include "triton/tools/sys/mkdir.hpp" | ||||||
|  | #include "triton/tools/sys/exec.hpp" | ||||||
| #include "llvm/IR/IRBuilder.h" | #include "llvm/IR/IRBuilder.h" | ||||||
| #include "llvm/IR/Verifier.h" | #include "llvm/IR/Verifier.h" | ||||||
| #include "llvm/IR/IRPrintingPasses.h" | #include "llvm/IR/IRPrintingPasses.h" | ||||||
| @@ -299,10 +300,13 @@ std::string cu_module::compile_llvm_module(llvm::Module* module, driver::device* | |||||||
| void cu_module::init_from_ptx(const std::string& ptx, driver::cu_device* device) { | void cu_module::init_from_ptx(const std::string& ptx, driver::cu_device* device) { | ||||||
|   // JIT compile source-code |   // JIT compile source-code | ||||||
|   try{ |   try{ | ||||||
|     std::string ptxas = tools::getenv("TRITON_PTXAS"); |     // use ptxas if present in PATH. Otherwise, use JIT from the driver | ||||||
|  |     std::string ptxas = "ptxas"; | ||||||
|  |     std::string version; | ||||||
|  |     int use_system_ptxas = tools::exec(ptxas + " --version 2>&1", version) == 0; | ||||||
|  |  | ||||||
|     // Use PTXAS via system call |     // Use PTXAS via system call | ||||||
|     if(!ptxas.empty()){ |     if(use_system_ptxas){ | ||||||
|       // compile ptx with ptxas |       // compile ptx with ptxas | ||||||
|       char _fsrc[] = "/tmp/triton_k_XXXXXX"; |       char _fsrc[] = "/tmp/triton_k_XXXXXX"; | ||||||
|       char _flog[] = "/tmp/triton_l_XXXXXX"; |       char _flog[] = "/tmp/triton_l_XXXXXX"; | ||||||
| @@ -316,7 +320,7 @@ void cu_module::init_from_ptx(const std::string& ptx, driver::cu_device* device) | |||||||
|       std::string cmd; |       std::string cmd; | ||||||
|       int err; |       int err; | ||||||
|       std::string cc = std::to_string(device->compute_capability()); |       std::string cc = std::to_string(device->compute_capability()); | ||||||
|       cmd = "ptxas -v --gpu-name=sm_" + cc + " " + fsrc + " -o " + fsrc + ".o 2> " + flog; |       cmd = ptxas + " -v --gpu-name=sm_" + cc + " " + fsrc + " -o " + fsrc + ".o 2> " + flog; | ||||||
|       err = system(cmd.c_str()); |       err = system(cmd.c_str()); | ||||||
|       dispatch::cuModuleLoad(&*cu_, (fsrc + ".o").c_str()); |       dispatch::cuModuleLoad(&*cu_, (fsrc + ".o").c_str()); | ||||||
|       unlink(_fsrc); |       unlink(_fsrc); | ||||||
|   | |||||||
| @@ -711,6 +711,14 @@ ir::value *dispatch::multiple_of(ir::value *x, int value, ir::builder *){ | |||||||
|   return i; |   return i; | ||||||
| } | } | ||||||
|  |  | ||||||
|  | ir::value *dispatch::max_contiguous(ir::value *x, int value, ir::builder *){ | ||||||
|  |   ir::instruction* i = dynamic_cast<ir::instruction*>(x); | ||||||
|  |   if(!i) | ||||||
|  |     throw_unreachable("max_contiguous"); | ||||||
|  |   i->set_metadata(ir::metadata::max_contiguous, value); | ||||||
|  |   return i; | ||||||
|  | } | ||||||
|  |  | ||||||
| ir::value *dispatch::debug_barrier(ir::builder *builder) { | ir::value *dispatch::debug_barrier(ir::builder *builder) { | ||||||
|   return builder->create_barrier(); |   return builder->create_barrier(); | ||||||
| } | } | ||||||
|   | |||||||
| @@ -161,6 +161,7 @@ void init_triton_frontend(py::module &&m) { | |||||||
|   m.def("sqrt", &ir::dispatch::sqrt, ret::reference); |   m.def("sqrt", &ir::dispatch::sqrt, ret::reference); | ||||||
|   // internal (debugging only) |   // internal (debugging only) | ||||||
|   m.def("multiple_of", &ir::dispatch::multiple_of, ret::reference); |   m.def("multiple_of", &ir::dispatch::multiple_of, ret::reference); | ||||||
|  |   m.def("max_contiguous", &ir::dispatch::max_contiguous, ret::reference); | ||||||
|   m.def("debug_barrier", &ir::dispatch::debug_barrier, ret::reference); |   m.def("debug_barrier", &ir::dispatch::debug_barrier, ret::reference); | ||||||
| } | } | ||||||
|  |  | ||||||
|   | |||||||
| @@ -637,6 +637,14 @@ def multiple_of(input, value, builder=None): | |||||||
|     return frontend.multiple_of(input, value, builder) |     return frontend.multiple_of(input, value, builder) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @builtin | ||||||
|  | def max_contiguous(input, value, builder=None): | ||||||
|  |     """ | ||||||
|  |     Let the compiler knows that the `value` first values in :code:`input` are contiguous.  | ||||||
|  |     """ | ||||||
|  |     return frontend.max_contiguous(input, value, builder) | ||||||
|  |  | ||||||
|  |  | ||||||
| # ----------------------- | # ----------------------- | ||||||
| # Standard library | # Standard library | ||||||
| # ----------------------- | # ----------------------- | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user