[GENERAL] Improved caching mechanism:

* Now computing hash in libtriton
* Now only compiling a single pytorch hook per function signature
This commit is contained in:
Philippe Tillet
2020-02-20 20:09:33 -05:00
committed by Philippe Tillet
parent 30f77e9ec5
commit dfb844bf41
14 changed files with 538 additions and 435 deletions

View File

@@ -6,6 +6,7 @@
#include <regex>
#include <algorithm>
#include "triton/runtime/function.h"
#include "triton/runtime/arg.h"
#include "triton/lang/code_gen.h"
#include "triton/lang/parser.h"
#include "triton/lang/cpp.h"
@@ -40,9 +41,10 @@ void delete_grid(size_t id) {
/* Function map */
void register_fn(size_t id,
const std::string& src,
const rt::function::options_space_t& opt) {
id_fn_map[id].reset(new rt::function(src, opt));
const std::string& src,
const rt::function::options_space_t& opt,
const std::string &cache_ref) {
id_fn_map[id].reset(new rt::function(src, opt, cache_ref));
}
void delete_fn(size_t id) {
@@ -64,6 +66,7 @@ size_t make_op_id() {
return id_fn_map.size();
}
/* TF scalar wrapper */
size_t make_scalar_id() {
size_t ret = i64scalar_map.size();
@@ -423,6 +426,37 @@ inline std::string to_torch_ty(ir::type *ty) {
throw std::runtime_error("unknown type");
}
inline std::string to_torch_ty(rt::arg_type ty){
switch(ty){
case rt::INT1_T: return "int64_t";
case rt::INT8_T: return "int64_t";
case rt::INT16_T: return "int64_t";
case rt::INT32_T: return "int64_t";
case rt::INT64_T: return "int64_t";
case rt::HALF_T: return "double";
case rt::FLOAT_T: return "double";
case rt::DOUBLE_T: return "double";
case rt::BUFFER_T: return "torch::Tensor";
default: return "UNKNOWN";
}
}
inline std::string to_c_ty(rt::arg_type ty){
switch(ty){
case rt::INT1_T: return "bool";
case rt::INT8_T: return "int8_t";
case rt::INT16_T: return "int16_t";
case rt::INT32_T: return "int32_t";
case rt::INT64_T: return "int64_t";
case rt::HALF_T: return "half";
case rt::FLOAT_T: return "float";
case rt::DOUBLE_T: return "double";
case rt::BUFFER_T: return "drv::cu_buffer";
default: return "UNKNOWN";
}
}
inline std::string to_c_ty(ir::type *ty) {
if(ty->is_integer_ty(1))
return "bool";
@@ -448,33 +482,30 @@ inline std::string to_c_ty(ir::type *ty) {
void gen_torch_signature(std::ostringstream& oss,
ir::function* fn,
const std::string& name) {
const auto& args = fn->args();
const std::string& name,
const std::vector<rt::arg_type>& args) {
std::string ret_ty = "void";
oss << ret_ty << " " << name << "(";
oss << "int64_t id, ";
oss << "int64_t bench, ";
oss << "int64_t bench_id, ";
for(size_t i = 0; i < args.size(); i++) {
ir::argument* arg = args[i];
if(i > 0)
oss << ", ";
oss << to_torch_ty(arg->get_type()) << " " << arg->get_name();
oss << to_torch_ty(args[i]) << " " << "th_arg_" << i;
}
oss << ")";
}
void gen_torch_init_driver(std::ostringstream &oss,
const std::vector<ir::argument*>&args) {
ir::argument* tensor = nullptr;
for(ir::argument* arg: args)
if(arg->get_type()->is_pointer_ty()){
tensor = arg;
const std::vector<rt::arg_type>&args) {
// Find index of first buffer
size_t i;
for(i = 0; i < args.size(); i++)
if(args[i] == rt::BUFFER_T)
break;
}
oss << " // Wrap CUDA handles" << std::endl;
oss << " c10::DeviceIndex device = " << tensor->get_name() << ".storage().device().index();" << std::endl;
oss << " c10::DeviceIndex device = th_arg_" << i << ".storage().device().index();" << std::endl;
oss << " // Get stream" << std::endl;
oss << " CUstream custream = (CUstream)at::cuda::getCurrentCUDAStream(device).stream();" << std::endl;
oss << " triton::driver::cu_stream stream(custream, false);" << std::endl;
@@ -482,28 +513,28 @@ void gen_torch_init_driver(std::ostringstream &oss,
}
void gen_torch_make_handles(std::ostream &os,
const std::vector<ir::argument*>& args) {
const std::vector<rt::arg_type>& args) {
for(unsigned i = 0; i < args.size(); i++){
ir::argument *arg = args[i];
const std::string& name = arg->get_name();
ir::type* ty = arg->get_type();
if(!ty->is_pointer_ty())
os << " " << to_c_ty(ty) << " arg_" << name << " = " << name << ";" << std::endl;
rt::arg_type arg = args[i];
const std::string th_name = "th_arg_" + std::to_string(i);
const std::string name = "arg_" + std::to_string(i);
if(arg != rt::BUFFER_T)
os << " " << to_c_ty(arg) << " " << name << " = " << th_name << ";" << std::endl;
else{
os << " CHECK_INPUT(" << name << ");" << std::endl;
os << " drv::cu_buffer arg_" + name + "(ctx, " + name + ".storage().size(), "
" (CUdeviceptr)((char*)" + name + ".storage().data() + " + name + ".storage_offset() * " + name + ".itemsize()), false);" << std::endl;
os << " CHECK_INPUT(" << th_name << ");" << std::endl;
os << " drv::cu_buffer " + name + "(ctx, " + th_name + ".storage().size(), "
" (CUdeviceptr)((char*)" + th_name + ".storage().data() + " + th_name + ".storage_offset() * " + th_name + ".itemsize()), false);" << std::endl;
}
}
}
void gen_torch_make_launch_function(std::ostream &os, const std::vector<ir::argument*>& args) {
void gen_torch_make_launch_function(std::ostream &os,
const std::vector<rt::arg_type>& args) {
os << " std::function<void()> run = [&](){\n ";
os << " (*id_fn_map.at(id))({";
for(unsigned i = 0; i < args.size() ; i++){
ir::argument *arg = args[i];
std::string name = "arg_" + arg->get_name();
if(arg->get_type()->is_pointer_ty())
std::string name = "arg_" + std::to_string(i);
if(args[i] == rt::BUFFER_T)
name = "&" + name;
if(i > 0)
os << ", ";
@@ -531,15 +562,7 @@ void gen_torch_ret(std::ostream &os, const std::vector<std::string>& outputs) {
}
std::tuple<std::string,
std::string> make_torch_src(const std::string& src,
const runtime::function::options_space_t& opt) {
// triton-ir code-gen
ir::context ctx;
auto ir = std::shared_ptr<ir::module>(new ir::module("", ctx));
make_module(src, &*ir, opt);
// function
ir::function* fn = ir->get_function_list().front();
std::string name = fn->get_name();
std::string> make_torch_src(const std::string& name, std::vector<rt::arg_type> args) {
// generate framework code
std::ostringstream oss;
oss << R"(
@@ -563,11 +586,11 @@ extern std::map<size_t, int64_t> i64scalar_map;
)";
gen_torch_signature(oss, fn, name);
gen_torch_signature(oss, name, args);
oss << " {" << std::endl;
gen_torch_init_driver(oss, fn->args());
gen_torch_make_handles(oss, fn->args());
gen_torch_make_launch_function(oss, fn->args());
gen_torch_init_driver(oss, args);
gen_torch_make_handles(oss, args);
gen_torch_make_launch_function(oss, args);
//gen_torch_ret(oss);
oss << "}" << std::endl;
@@ -578,6 +601,22 @@ extern std::map<size_t, int64_t> i64scalar_map;
return {oss.str(), name};
}
/* Function signature */
std::vector<rt::arg_type> get_fn_signature(const std::string& src,
const runtime::function::options_space_t& opt) {
// triton-ir code-gen
ir::context ctx;
auto ir = std::shared_ptr<ir::module>(new ir::module("", ctx));
make_module(src, &*ir, opt);
// function
ir::function* fn = ir->get_function_list().front();
// extract signature
std::vector<rt::arg_type> ret;
ir::function_type* ty = fn->get_fn_type();
for(size_t i = 0; i < ty->get_num_params(); i++)
ret.push_back(rt::convert(ty->get_param_ty(i)));
return ret;
}
typedef triton::runtime::function::options_t options_t;
typedef triton::runtime::function::options_space_t options_space_t;
@@ -593,6 +632,17 @@ PYBIND11_MODULE(libtriton, m) {
"Creates C++ source code for a custom PyTorch op ");
// bindings for triton classes
pybind11::enum_<rt::arg_type>(m, "arg_type")
.value("int1", rt::INT1_T)
.value("int8", rt::INT8_T)
.value("int16", rt::INT16_T)
.value("int32", rt::INT32_T)
.value("int64", rt::INT64_T)
.value("half", rt::HALF_T)
.value("float", rt::FLOAT_T)
.value("double", rt::DOUBLE_T)
.value("buffer", rt::BUFFER_T);
pybind11::class_<options_t>(m, "options")
.def(pybind11::init<>())
.def("d", &options_t::D<int>)
@@ -604,6 +654,7 @@ PYBIND11_MODULE(libtriton, m) {
.def_readwrite("num_warps", &options_space_t::num_warps);
// hooks into triton constructs since frameworks may not use pybind11
m.def("get_fn_signature", &get_fn_signature);
m.def("register_grid", &register_grid);
m.def("delete_grid", &delete_grid);
m.def("register_fn", &register_fn);