[python] refactoring in anticipation of pytorch support
This commit is contained in:
@@ -136,7 +136,7 @@ void gen_make_launch_function(std::ostream &os, const std::vector<ir::argument*>
|
||||
os << "}, *id_grid_map.at(id_), stream); \n";
|
||||
}
|
||||
|
||||
void gen_register_kernel_builder(std::ostream &os, const std::string &name,
|
||||
void gen_tf_register_kernel_builder(std::ostream &os, const std::string &name,
|
||||
const std::string &opname,
|
||||
const std::vector<ir::argument*>& args){
|
||||
os << "REGISTER_KERNEL_BUILDER(Name(\"" + name + "\").Device(DEVICE_GPU)";
|
||||
@@ -151,7 +151,7 @@ void gen_register_kernel_builder(std::ostream &os, const std::string &name,
|
||||
os << ", " + opname << ");\n";
|
||||
}
|
||||
|
||||
void gen_register_op(std::ostream &os, const std::string &name,
|
||||
void gen_tf_register_op(std::ostream &os, const std::string &name,
|
||||
const std::vector<ir::argument*>& args,
|
||||
const std::vector<std::string>& outputs){
|
||||
os << "REGISTER_OP(\"" << name << "\")\n";
|
||||
@@ -195,15 +195,12 @@ extern int get_program_id(int);
|
||||
)";
|
||||
}
|
||||
|
||||
std::tuple<std::string,
|
||||
std::string> make_tensorflow_src(std::string src,
|
||||
const std::vector<std::string>& outputs,
|
||||
const runtime::function::options_space_t& opt)
|
||||
{
|
||||
src = preheader() + src;
|
||||
void make_module(const std::string& src, ir::module* ir,
|
||||
const runtime::function::options_space_t& opt) {
|
||||
std::string copy = preheader() + src;
|
||||
// pre-process
|
||||
TokenSequence tokens;
|
||||
Preprocessor cpp(&src, true);
|
||||
Preprocessor cpp(©, true);
|
||||
for(auto it: opt.defines){
|
||||
cpp.AddMacro(it.first, &it.second[0]);
|
||||
}
|
||||
@@ -211,11 +208,19 @@ std::tuple<std::string,
|
||||
// parse
|
||||
Parser parser(tokens);
|
||||
parser.Parse();
|
||||
Generator gen(&parser);
|
||||
gen.Gen(ir);
|
||||
}
|
||||
|
||||
std::tuple<std::string,
|
||||
std::string> make_tensorflow_src(const std::string& src,
|
||||
const std::vector<std::string>& outputs,
|
||||
const runtime::function::options_space_t& opt)
|
||||
{
|
||||
// triton-ir code-gen
|
||||
ir::context ctx;
|
||||
auto ir = std::shared_ptr<ir::module>(new ir::module("", ctx));
|
||||
Generator gen(&parser);
|
||||
gen.Gen(&*ir);
|
||||
make_module(src, &*ir, opt);
|
||||
// function
|
||||
ir::function* fn = ir->get_function_list().front();
|
||||
std::string name = fn->get_name();
|
||||
@@ -287,16 +292,145 @@ private:
|
||||
|
||||
// register kernel builder
|
||||
)";
|
||||
gen_register_kernel_builder(oss, cc_name, opname, fn->args());
|
||||
gen_tf_register_kernel_builder(oss, cc_name, opname, fn->args());
|
||||
oss << R"(
|
||||
// register op
|
||||
)";
|
||||
gen_register_op(oss, cc_name, fn->args(), outputs);
|
||||
|
||||
gen_tf_register_op(oss, cc_name, fn->args(), outputs);
|
||||
|
||||
return {oss.str(), name};
|
||||
}
|
||||
|
||||
|
||||
inline std::string to_torch_ty(ir::type *ty) {
|
||||
if(ty->is_integer_ty(1))
|
||||
return "bool";
|
||||
if(ty->is_integer_ty(8))
|
||||
return "int8";
|
||||
if(ty->is_integer_ty(16))
|
||||
return "int16";
|
||||
if(ty->is_integer_ty(32))
|
||||
return "int32";
|
||||
if(ty->is_integer_ty(64))
|
||||
return "int64";
|
||||
if(ty->is_half_ty())
|
||||
return "float16";
|
||||
if(ty->is_float_ty())
|
||||
return "float32";
|
||||
if(ty->is_double_ty())
|
||||
return "float64";
|
||||
if(ty->is_pointer_ty())
|
||||
return "Tensor";
|
||||
throw std::runtime_error("unknown type");
|
||||
}
|
||||
|
||||
|
||||
|
||||
void gen_torch_signature(std::ostringstream& oss,
|
||||
ir::function* fn,
|
||||
const std::vector<std::string>& outputs,
|
||||
const std::string& name) {
|
||||
const auto& args = fn->args();
|
||||
std::vector<ir::type*> out_types;
|
||||
for(const std::string& out: outputs) {
|
||||
auto it = std::find_if(args.begin(), args.end(),
|
||||
[&](ir::argument* arg) { return arg->get_name() == out; });
|
||||
if(it == args.end())
|
||||
throw std::runtime_error("unknown argument");
|
||||
out_types.push_back((*it)->get_type());
|
||||
}
|
||||
|
||||
oss << "std::tuple<";
|
||||
for(size_t i = 0; i < out_types.size(); i++){
|
||||
if(i > 0)
|
||||
oss << ", ";
|
||||
oss << to_torch_ty(out_types[i]);
|
||||
}
|
||||
oss << "> ";
|
||||
oss << name << "(";
|
||||
oss << "int64 id" << std::endl;
|
||||
for(size_t i = 0; i < args.size(); i++) {
|
||||
ir::argument* arg = args[i];
|
||||
if(i > 0)
|
||||
oss << ", ";
|
||||
oss << to_torch_ty(arg->get_type()) << " " << arg->get_name();
|
||||
}
|
||||
oss << ")";
|
||||
}
|
||||
|
||||
void gen_torch_init_driver(std::ostringstream &oss) {
|
||||
oss << " // Wrap CUDA handles" << std::endl;
|
||||
oss << " c10::DeviceIndex device = torcha.storage().device().index();" << std::endl;
|
||||
oss << " // Get stream" << std::endl;
|
||||
oss << " CUstream custream = (CUstream)at::cuda::getCurrentCUDAStream(device).stream();" << std::endl;
|
||||
oss << " triton::driver::cu_stream stream(custream, false);" << std::endl;
|
||||
oss << " triton::driver::context* ctx = stream.context();" << std::endl;
|
||||
}
|
||||
|
||||
void gen_torch_make_handles(std::ostream &os,
|
||||
const std::vector<ir::argument*>& args) {
|
||||
for(unsigned i = 0; i < args.size(); i++){
|
||||
ir::argument *arg = args[i];
|
||||
if(!arg->get_type()->is_pointer_ty())
|
||||
continue;
|
||||
const std::string& name = arg->get_name();
|
||||
os << " drv::cu_buffer cu_" + name + "(ctx, " + name + ".storage().size(), (CUdeviceptr)" + name + ".storage.data(), false);\n ";
|
||||
}
|
||||
}
|
||||
|
||||
void gen_torch_make_launch_function(std::ostream &os, const std::vector<ir::argument*>& args) {
|
||||
os << " (*id_fn_map.at(id))({";
|
||||
for(unsigned i = 0; i < args.size() ; i++){
|
||||
ir::argument *arg = args[i];
|
||||
std::string name = arg->get_name();
|
||||
if(arg->get_type()->is_pointer_ty())
|
||||
name = "&cu_" + name;
|
||||
if(i > 0)
|
||||
os << ", ";
|
||||
os << name;
|
||||
}
|
||||
os << "}, *id_grid_map.at(id), stream); \n";
|
||||
}
|
||||
|
||||
|
||||
std::tuple<std::string,
|
||||
std::string> make_pytorch_src(const std::string& src,
|
||||
const std::vector<std::string>& outputs,
|
||||
const runtime::function::options_space_t& opt) {
|
||||
// triton-ir code-gen
|
||||
ir::context ctx;
|
||||
auto ir = std::shared_ptr<ir::module>(new ir::module("", ctx));
|
||||
make_module(src, &*ir, opt);
|
||||
// function
|
||||
ir::function* fn = ir->get_function_list().front();
|
||||
std::string name = fn->get_name();
|
||||
// generate framework code
|
||||
std::ostringstream oss;
|
||||
oss << R"(
|
||||
#include "triton/driver/buffer.h"
|
||||
#include "triton/driver/backend.h"
|
||||
#include "triton/driver/stream.h"
|
||||
#include "triton/runtime/function.h"
|
||||
|
||||
namespace rt = triton::runtime;
|
||||
namespace drv = triton::driver;
|
||||
|
||||
extern std::map<size_t, std::shared_ptr<rt::function::grid_fn_ty>> id_grid_map;
|
||||
extern std::map<size_t, std::shared_ptr<rt::function>> id_fn_map;
|
||||
|
||||
)";
|
||||
|
||||
gen_torch_signature(oss, fn, outputs, name);
|
||||
oss << " {" << std::endl;
|
||||
gen_torch_init_driver(oss);
|
||||
gen_torch_make_handles(oss, fn->args());
|
||||
gen_torch_make_launch_function(oss, fn->args());
|
||||
oss << std::endl << "}";
|
||||
|
||||
oss << "static auto registry = torch::jit::RegisterOperators(\"triton::" << name << "\", &" << name << ");" << std::endl;
|
||||
}
|
||||
|
||||
|
||||
typedef triton::runtime::function::options_t options_t;
|
||||
typedef triton::runtime::function::options_space_t options_space_t;
|
||||
|
||||
@@ -307,6 +441,8 @@ PYBIND11_MODULE(libtriton, m) {
|
||||
m.def("make_tensorflow_src", &make_tensorflow_src,
|
||||
"Creates C++ source code for a custom Tensorflow op "
|
||||
"corresponding to the specified Triton kernel");
|
||||
m.def("make_pytorch_src", &make_pytorch_src,
|
||||
"Creates C++ source code for a custom PyTorch op ");
|
||||
|
||||
// bindings for triton classes
|
||||
pybind11::class_<options_t>(m, "options")
|
||||
|
Reference in New Issue
Block a user