[GENERAL] Improved caching mechanism:
* Now computing hash in libtriton * Now only compiling a single pytorch hook per function signature
This commit is contained in:
committed by
Philippe Tillet
parent
30f77e9ec5
commit
dfb844bf41
@@ -6,6 +6,7 @@
|
||||
#include <regex>
|
||||
#include <algorithm>
|
||||
#include "triton/runtime/function.h"
|
||||
#include "triton/runtime/arg.h"
|
||||
#include "triton/lang/code_gen.h"
|
||||
#include "triton/lang/parser.h"
|
||||
#include "triton/lang/cpp.h"
|
||||
@@ -40,9 +41,10 @@ void delete_grid(size_t id) {
|
||||
/* Function map */
|
||||
|
||||
void register_fn(size_t id,
|
||||
const std::string& src,
|
||||
const rt::function::options_space_t& opt) {
|
||||
id_fn_map[id].reset(new rt::function(src, opt));
|
||||
const std::string& src,
|
||||
const rt::function::options_space_t& opt,
|
||||
const std::string &cache_ref) {
|
||||
id_fn_map[id].reset(new rt::function(src, opt, cache_ref));
|
||||
}
|
||||
|
||||
void delete_fn(size_t id) {
|
||||
@@ -64,6 +66,7 @@ size_t make_op_id() {
|
||||
return id_fn_map.size();
|
||||
}
|
||||
|
||||
|
||||
/* TF scalar wrapper */
|
||||
size_t make_scalar_id() {
|
||||
size_t ret = i64scalar_map.size();
|
||||
@@ -423,6 +426,37 @@ inline std::string to_torch_ty(ir::type *ty) {
|
||||
throw std::runtime_error("unknown type");
|
||||
}
|
||||
|
||||
inline std::string to_torch_ty(rt::arg_type ty){
|
||||
switch(ty){
|
||||
case rt::INT1_T: return "int64_t";
|
||||
case rt::INT8_T: return "int64_t";
|
||||
case rt::INT16_T: return "int64_t";
|
||||
case rt::INT32_T: return "int64_t";
|
||||
case rt::INT64_T: return "int64_t";
|
||||
case rt::HALF_T: return "double";
|
||||
case rt::FLOAT_T: return "double";
|
||||
case rt::DOUBLE_T: return "double";
|
||||
case rt::BUFFER_T: return "torch::Tensor";
|
||||
default: return "UNKNOWN";
|
||||
}
|
||||
}
|
||||
|
||||
inline std::string to_c_ty(rt::arg_type ty){
|
||||
switch(ty){
|
||||
case rt::INT1_T: return "bool";
|
||||
case rt::INT8_T: return "int8_t";
|
||||
case rt::INT16_T: return "int16_t";
|
||||
case rt::INT32_T: return "int32_t";
|
||||
case rt::INT64_T: return "int64_t";
|
||||
case rt::HALF_T: return "half";
|
||||
case rt::FLOAT_T: return "float";
|
||||
case rt::DOUBLE_T: return "double";
|
||||
case rt::BUFFER_T: return "drv::cu_buffer";
|
||||
default: return "UNKNOWN";
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
inline std::string to_c_ty(ir::type *ty) {
|
||||
if(ty->is_integer_ty(1))
|
||||
return "bool";
|
||||
@@ -448,33 +482,30 @@ inline std::string to_c_ty(ir::type *ty) {
|
||||
|
||||
|
||||
void gen_torch_signature(std::ostringstream& oss,
|
||||
ir::function* fn,
|
||||
const std::string& name) {
|
||||
const auto& args = fn->args();
|
||||
const std::string& name,
|
||||
const std::vector<rt::arg_type>& args) {
|
||||
std::string ret_ty = "void";
|
||||
oss << ret_ty << " " << name << "(";
|
||||
oss << "int64_t id, ";
|
||||
oss << "int64_t bench, ";
|
||||
oss << "int64_t bench_id, ";
|
||||
for(size_t i = 0; i < args.size(); i++) {
|
||||
ir::argument* arg = args[i];
|
||||
if(i > 0)
|
||||
oss << ", ";
|
||||
oss << to_torch_ty(arg->get_type()) << " " << arg->get_name();
|
||||
oss << to_torch_ty(args[i]) << " " << "th_arg_" << i;
|
||||
}
|
||||
oss << ")";
|
||||
}
|
||||
|
||||
void gen_torch_init_driver(std::ostringstream &oss,
|
||||
const std::vector<ir::argument*>&args) {
|
||||
ir::argument* tensor = nullptr;
|
||||
for(ir::argument* arg: args)
|
||||
if(arg->get_type()->is_pointer_ty()){
|
||||
tensor = arg;
|
||||
const std::vector<rt::arg_type>&args) {
|
||||
// Find index of first buffer
|
||||
size_t i;
|
||||
for(i = 0; i < args.size(); i++)
|
||||
if(args[i] == rt::BUFFER_T)
|
||||
break;
|
||||
}
|
||||
oss << " // Wrap CUDA handles" << std::endl;
|
||||
oss << " c10::DeviceIndex device = " << tensor->get_name() << ".storage().device().index();" << std::endl;
|
||||
oss << " c10::DeviceIndex device = th_arg_" << i << ".storage().device().index();" << std::endl;
|
||||
oss << " // Get stream" << std::endl;
|
||||
oss << " CUstream custream = (CUstream)at::cuda::getCurrentCUDAStream(device).stream();" << std::endl;
|
||||
oss << " triton::driver::cu_stream stream(custream, false);" << std::endl;
|
||||
@@ -482,28 +513,28 @@ void gen_torch_init_driver(std::ostringstream &oss,
|
||||
}
|
||||
|
||||
void gen_torch_make_handles(std::ostream &os,
|
||||
const std::vector<ir::argument*>& args) {
|
||||
const std::vector<rt::arg_type>& args) {
|
||||
for(unsigned i = 0; i < args.size(); i++){
|
||||
ir::argument *arg = args[i];
|
||||
const std::string& name = arg->get_name();
|
||||
ir::type* ty = arg->get_type();
|
||||
if(!ty->is_pointer_ty())
|
||||
os << " " << to_c_ty(ty) << " arg_" << name << " = " << name << ";" << std::endl;
|
||||
rt::arg_type arg = args[i];
|
||||
const std::string th_name = "th_arg_" + std::to_string(i);
|
||||
const std::string name = "arg_" + std::to_string(i);
|
||||
if(arg != rt::BUFFER_T)
|
||||
os << " " << to_c_ty(arg) << " " << name << " = " << th_name << ";" << std::endl;
|
||||
else{
|
||||
os << " CHECK_INPUT(" << name << ");" << std::endl;
|
||||
os << " drv::cu_buffer arg_" + name + "(ctx, " + name + ".storage().size(), "
|
||||
" (CUdeviceptr)((char*)" + name + ".storage().data() + " + name + ".storage_offset() * " + name + ".itemsize()), false);" << std::endl;
|
||||
os << " CHECK_INPUT(" << th_name << ");" << std::endl;
|
||||
os << " drv::cu_buffer " + name + "(ctx, " + th_name + ".storage().size(), "
|
||||
" (CUdeviceptr)((char*)" + th_name + ".storage().data() + " + th_name + ".storage_offset() * " + th_name + ".itemsize()), false);" << std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void gen_torch_make_launch_function(std::ostream &os, const std::vector<ir::argument*>& args) {
|
||||
void gen_torch_make_launch_function(std::ostream &os,
|
||||
const std::vector<rt::arg_type>& args) {
|
||||
os << " std::function<void()> run = [&](){\n ";
|
||||
os << " (*id_fn_map.at(id))({";
|
||||
for(unsigned i = 0; i < args.size() ; i++){
|
||||
ir::argument *arg = args[i];
|
||||
std::string name = "arg_" + arg->get_name();
|
||||
if(arg->get_type()->is_pointer_ty())
|
||||
std::string name = "arg_" + std::to_string(i);
|
||||
if(args[i] == rt::BUFFER_T)
|
||||
name = "&" + name;
|
||||
if(i > 0)
|
||||
os << ", ";
|
||||
@@ -531,15 +562,7 @@ void gen_torch_ret(std::ostream &os, const std::vector<std::string>& outputs) {
|
||||
}
|
||||
|
||||
std::tuple<std::string,
|
||||
std::string> make_torch_src(const std::string& src,
|
||||
const runtime::function::options_space_t& opt) {
|
||||
// triton-ir code-gen
|
||||
ir::context ctx;
|
||||
auto ir = std::shared_ptr<ir::module>(new ir::module("", ctx));
|
||||
make_module(src, &*ir, opt);
|
||||
// function
|
||||
ir::function* fn = ir->get_function_list().front();
|
||||
std::string name = fn->get_name();
|
||||
std::string> make_torch_src(const std::string& name, std::vector<rt::arg_type> args) {
|
||||
// generate framework code
|
||||
std::ostringstream oss;
|
||||
oss << R"(
|
||||
@@ -563,11 +586,11 @@ extern std::map<size_t, int64_t> i64scalar_map;
|
||||
|
||||
)";
|
||||
|
||||
gen_torch_signature(oss, fn, name);
|
||||
gen_torch_signature(oss, name, args);
|
||||
oss << " {" << std::endl;
|
||||
gen_torch_init_driver(oss, fn->args());
|
||||
gen_torch_make_handles(oss, fn->args());
|
||||
gen_torch_make_launch_function(oss, fn->args());
|
||||
gen_torch_init_driver(oss, args);
|
||||
gen_torch_make_handles(oss, args);
|
||||
gen_torch_make_launch_function(oss, args);
|
||||
//gen_torch_ret(oss);
|
||||
oss << "}" << std::endl;
|
||||
|
||||
@@ -578,6 +601,22 @@ extern std::map<size_t, int64_t> i64scalar_map;
|
||||
return {oss.str(), name};
|
||||
}
|
||||
|
||||
/* Function signature */
|
||||
std::vector<rt::arg_type> get_fn_signature(const std::string& src,
|
||||
const runtime::function::options_space_t& opt) {
|
||||
// triton-ir code-gen
|
||||
ir::context ctx;
|
||||
auto ir = std::shared_ptr<ir::module>(new ir::module("", ctx));
|
||||
make_module(src, &*ir, opt);
|
||||
// function
|
||||
ir::function* fn = ir->get_function_list().front();
|
||||
// extract signature
|
||||
std::vector<rt::arg_type> ret;
|
||||
ir::function_type* ty = fn->get_fn_type();
|
||||
for(size_t i = 0; i < ty->get_num_params(); i++)
|
||||
ret.push_back(rt::convert(ty->get_param_ty(i)));
|
||||
return ret;
|
||||
}
|
||||
|
||||
typedef triton::runtime::function::options_t options_t;
|
||||
typedef triton::runtime::function::options_space_t options_space_t;
|
||||
@@ -593,6 +632,17 @@ PYBIND11_MODULE(libtriton, m) {
|
||||
"Creates C++ source code for a custom PyTorch op ");
|
||||
|
||||
// bindings for triton classes
|
||||
pybind11::enum_<rt::arg_type>(m, "arg_type")
|
||||
.value("int1", rt::INT1_T)
|
||||
.value("int8", rt::INT8_T)
|
||||
.value("int16", rt::INT16_T)
|
||||
.value("int32", rt::INT32_T)
|
||||
.value("int64", rt::INT64_T)
|
||||
.value("half", rt::HALF_T)
|
||||
.value("float", rt::FLOAT_T)
|
||||
.value("double", rt::DOUBLE_T)
|
||||
.value("buffer", rt::BUFFER_T);
|
||||
|
||||
pybind11::class_<options_t>(m, "options")
|
||||
.def(pybind11::init<>())
|
||||
.def("d", &options_t::D<int>)
|
||||
@@ -604,6 +654,7 @@ PYBIND11_MODULE(libtriton, m) {
|
||||
.def_readwrite("num_warps", &options_space_t::num_warps);
|
||||
|
||||
// hooks into triton constructs since frameworks may not use pybind11
|
||||
m.def("get_fn_signature", &get_fn_signature);
|
||||
m.def("register_grid", ®ister_grid);
|
||||
m.def("delete_grid", &delete_grid);
|
||||
m.def("register_fn", ®ister_fn);
|
||||
|
Reference in New Issue
Block a user