[CODEGEN/DRIVER] Tweaks for performance optimization (#193)
This commit is contained in:
@@ -98,6 +98,7 @@ struct dispatch{
|
|||||||
|
|
||||||
// internal (debug/optimization)
|
// internal (debug/optimization)
|
||||||
static ir::value *multiple_of(ir::value *x, int value, ir::builder *builder);
|
static ir::value *multiple_of(ir::value *x, int value, ir::builder *builder);
|
||||||
|
static ir::value *max_contiguous(ir::value *x, int value, ir::builder *builder);
|
||||||
static ir::value *debug_barrier(ir::builder *builder);
|
static ir::value *debug_barrier(ir::builder *builder);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@@ -11,7 +11,8 @@ namespace ir{
|
|||||||
class metadata{
|
class metadata{
|
||||||
public:
|
public:
|
||||||
enum kind_t{
|
enum kind_t{
|
||||||
multiple_of
|
multiple_of,
|
||||||
|
max_contiguous
|
||||||
};
|
};
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
37
include/triton/tools/sys/exec.hpp
Normal file
37
include/triton/tools/sys/exec.hpp
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
#ifndef TRITON_TOOLS_SYS_EXEC_HPP
|
||||||
|
#define TRITON_TOOLS_SYS_EXEC_HPP
|
||||||
|
|
||||||
|
#include <cstdio>
|
||||||
|
#include <iostream>
|
||||||
|
#include <memory>
|
||||||
|
#include <stdexcept>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
namespace triton
|
||||||
|
{
|
||||||
|
namespace tools
|
||||||
|
{
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
int exec(const std::string& cmd, std::string& result) {
|
||||||
|
char buffer[128];
|
||||||
|
FILE* pipe = popen(cmd.c_str(), "r");
|
||||||
|
if (!pipe)
|
||||||
|
return 0;
|
||||||
|
result.clear();
|
||||||
|
try {
|
||||||
|
while (fgets(buffer, sizeof buffer, pipe) != NULL)
|
||||||
|
result += buffer;
|
||||||
|
} catch (...) {
|
||||||
|
pclose(pipe);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
return WEXITSTATUS(pclose(pipe));
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
@@ -331,6 +331,11 @@ std::vector<unsigned> align::populate_max_contiguous_cast(ir::cast_inst* v){
|
|||||||
std::vector<unsigned> align::populate_max_contiguous(ir::value *v){
|
std::vector<unsigned> align::populate_max_contiguous(ir::value *v){
|
||||||
if(max_contiguous_.find(v) != max_contiguous_.end())
|
if(max_contiguous_.find(v) != max_contiguous_.end())
|
||||||
return max_contiguous_.at(v);
|
return max_contiguous_.at(v);
|
||||||
|
if(auto *x = dynamic_cast<ir::instruction*>(v)){
|
||||||
|
unsigned max_contiguous = x->get_metadata(ir::metadata::max_contiguous);
|
||||||
|
if(max_contiguous > 0)
|
||||||
|
return add_to_cache(x, {max_contiguous}, max_contiguous_);
|
||||||
|
}
|
||||||
if(auto *x = dynamic_cast<ir::cast_inst*>(v))
|
if(auto *x = dynamic_cast<ir::cast_inst*>(v))
|
||||||
return populate_max_contiguous_cast(x);
|
return populate_max_contiguous_cast(x);
|
||||||
if(auto *x = dynamic_cast<ir::splat_inst*>(v))
|
if(auto *x = dynamic_cast<ir::splat_inst*>(v))
|
||||||
|
@@ -29,6 +29,7 @@
|
|||||||
#include "triton/tools/sha1.hpp"
|
#include "triton/tools/sha1.hpp"
|
||||||
#include "triton/tools/sys/getenv.hpp"
|
#include "triton/tools/sys/getenv.hpp"
|
||||||
#include "triton/tools/sys/mkdir.hpp"
|
#include "triton/tools/sys/mkdir.hpp"
|
||||||
|
#include "triton/tools/sys/exec.hpp"
|
||||||
#include "llvm/IR/IRBuilder.h"
|
#include "llvm/IR/IRBuilder.h"
|
||||||
#include "llvm/IR/Verifier.h"
|
#include "llvm/IR/Verifier.h"
|
||||||
#include "llvm/IR/IRPrintingPasses.h"
|
#include "llvm/IR/IRPrintingPasses.h"
|
||||||
@@ -299,10 +300,13 @@ std::string cu_module::compile_llvm_module(llvm::Module* module, driver::device*
|
|||||||
void cu_module::init_from_ptx(const std::string& ptx, driver::cu_device* device) {
|
void cu_module::init_from_ptx(const std::string& ptx, driver::cu_device* device) {
|
||||||
// JIT compile source-code
|
// JIT compile source-code
|
||||||
try{
|
try{
|
||||||
std::string ptxas = tools::getenv("TRITON_PTXAS");
|
// use ptxas if present in PATH. Otherwise, use JIT from the driver
|
||||||
|
std::string ptxas = "ptxas";
|
||||||
|
std::string version;
|
||||||
|
int use_system_ptxas = tools::exec(ptxas + " --version 2>&1", version) == 0;
|
||||||
|
|
||||||
// Use PTXAS via system call
|
// Use PTXAS via system call
|
||||||
if(!ptxas.empty()){
|
if(use_system_ptxas){
|
||||||
// compile ptx with ptxas
|
// compile ptx with ptxas
|
||||||
char _fsrc[] = "/tmp/triton_k_XXXXXX";
|
char _fsrc[] = "/tmp/triton_k_XXXXXX";
|
||||||
char _flog[] = "/tmp/triton_l_XXXXXX";
|
char _flog[] = "/tmp/triton_l_XXXXXX";
|
||||||
@@ -316,7 +320,7 @@ void cu_module::init_from_ptx(const std::string& ptx, driver::cu_device* device)
|
|||||||
std::string cmd;
|
std::string cmd;
|
||||||
int err;
|
int err;
|
||||||
std::string cc = std::to_string(device->compute_capability());
|
std::string cc = std::to_string(device->compute_capability());
|
||||||
cmd = "ptxas -v --gpu-name=sm_" + cc + " " + fsrc + " -o " + fsrc + ".o 2> " + flog;
|
cmd = ptxas + " -v --gpu-name=sm_" + cc + " " + fsrc + " -o " + fsrc + ".o 2> " + flog;
|
||||||
err = system(cmd.c_str());
|
err = system(cmd.c_str());
|
||||||
dispatch::cuModuleLoad(&*cu_, (fsrc + ".o").c_str());
|
dispatch::cuModuleLoad(&*cu_, (fsrc + ".o").c_str());
|
||||||
unlink(_fsrc);
|
unlink(_fsrc);
|
||||||
|
@@ -711,6 +711,14 @@ ir::value *dispatch::multiple_of(ir::value *x, int value, ir::builder *){
|
|||||||
return i;
|
return i;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ir::value *dispatch::max_contiguous(ir::value *x, int value, ir::builder *){
|
||||||
|
ir::instruction* i = dynamic_cast<ir::instruction*>(x);
|
||||||
|
if(!i)
|
||||||
|
throw_unreachable("max_contiguous");
|
||||||
|
i->set_metadata(ir::metadata::max_contiguous, value);
|
||||||
|
return i;
|
||||||
|
}
|
||||||
|
|
||||||
ir::value *dispatch::debug_barrier(ir::builder *builder) {
|
ir::value *dispatch::debug_barrier(ir::builder *builder) {
|
||||||
return builder->create_barrier();
|
return builder->create_barrier();
|
||||||
}
|
}
|
||||||
|
@@ -161,6 +161,7 @@ void init_triton_frontend(py::module &&m) {
|
|||||||
m.def("sqrt", &ir::dispatch::sqrt, ret::reference);
|
m.def("sqrt", &ir::dispatch::sqrt, ret::reference);
|
||||||
// internal (debugging only)
|
// internal (debugging only)
|
||||||
m.def("multiple_of", &ir::dispatch::multiple_of, ret::reference);
|
m.def("multiple_of", &ir::dispatch::multiple_of, ret::reference);
|
||||||
|
m.def("max_contiguous", &ir::dispatch::max_contiguous, ret::reference);
|
||||||
m.def("debug_barrier", &ir::dispatch::debug_barrier, ret::reference);
|
m.def("debug_barrier", &ir::dispatch::debug_barrier, ret::reference);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -637,6 +637,14 @@ def multiple_of(input, value, builder=None):
|
|||||||
return frontend.multiple_of(input, value, builder)
|
return frontend.multiple_of(input, value, builder)
|
||||||
|
|
||||||
|
|
||||||
|
@builtin
|
||||||
|
def max_contiguous(input, value, builder=None):
|
||||||
|
"""
|
||||||
|
Let the compiler knows that the `value` first values in :code:`input` are contiguous.
|
||||||
|
"""
|
||||||
|
return frontend.max_contiguous(input, value, builder)
|
||||||
|
|
||||||
|
|
||||||
# -----------------------
|
# -----------------------
|
||||||
# Standard library
|
# Standard library
|
||||||
# -----------------------
|
# -----------------------
|
||||||
|
Reference in New Issue
Block a user