[PYTHON][TENSORFLOW] More work

This commit is contained in:
Philippe Tillet
2019-10-30 18:39:58 -04:00
parent fd09f9c99d
commit e0fe8d9058
6 changed files with 184 additions and 104 deletions

View File

@@ -112,51 +112,70 @@ inline std::string ref_to_tf_ty(ir::type *ty) {
return res;
}
std::string tf_normalize(const std::string& name) {
std::string ret = name;
auto tolower = [](char c) { return std::tolower(c);};
std::transform(ret.begin(), ret.end(), ret.begin(), tolower);
return ret;
}
void gen_extract_inputs(std::ostream &os, const std::vector<ir::argument*>& args, const std::vector<std::string>& outputs) {
struct tf_alloc_t{
enum type_t{
OUTPUT,
TEMP
};
tf_alloc_t(const std::string& _name, type_t _type)
: name(_name), type(_type), tf_name(tf_normalize(_name)){ }
std::string tf_name;
std::string name;
type_t type;
size_t shape_id;
};
typedef std::vector<tf_alloc_t> alloc_map_t;
void gen_extract_inputs(std::ostream &os, const std::vector<ir::argument*>& args, const alloc_map_t& allocs) {
for(unsigned i = 0; i < args.size(); i++){
ir::value *arg = args[i];
const std::string& name = arg->get_name();
std::string ty = to_tf_ty(arg->get_type());
if(!arg->get_type()->is_pointer_ty())
os << " " << ty << " " << name << " = context->input(" << i << ").scalar<" << ty << ">()();\n ";
else if(std::find(outputs.begin(), outputs.end(), arg->get_name()) == outputs.end())
else if(std::find_if(allocs.begin(), allocs.end(),
[&](tf_alloc_t x) {
return x.name == name;
}) == allocs.end())
os << " const Tensor* " << name << " = &context->input(" << i << ");\n ";
else
os << " Tensor* " << name << " = nullptr;\n ";
}
}
void gen_set_outputs(std::ostream &os, const std::vector<ir::argument*>& args, const std::vector<std::string>& outputs) {
for(unsigned i = 0; i < outputs.size(); i++)
os << " TensorShape shape" << i << ";\n ";
void gen_set_outputs(std::ostream &os, const std::vector<ir::argument*>& args, const alloc_map_t& allocs) {
// initialize shapes
std::vector<int> out_idx;
for(size_t i = 0; i < outputs.size(); i++){
std::string name = outputs[i];
size_t idx;
for(idx = 0; idx < args.size(); idx++)
if(args[idx]->get_name() == name)
break;
if(idx == args.size())
throw std::runtime_error("unknown output");
out_idx.push_back(idx);
}
for(unsigned i = 0; i < outputs.size(); i++)
os << " const Tensor& " << outputs[i] << "_shape = context->input(" << out_idx[i] << ");\n ";
for(unsigned i = 0; i < outputs.size(); i++)
os << " const int32* " << outputs[i] << "_shape_data = (const int32*)" << outputs[i] << "_shape.tensor_data().data();\n ";
for(unsigned i = 0; i < outputs.size(); i++)
os << " size_t " << outputs[i] << "_rank = " << outputs[i] << "_shape.dim_size(0);\n ";
for(unsigned i = 0; i < outputs.size(); i++)
os << " for(size_t d = 0; d < " << outputs[i] << "_rank ; d++) "
<< "shape" << i << ".AddDim(" << outputs[i] << "_shape_data[d]);\n ";
for(const auto& x: allocs)
os << " TensorShape " << x.name << "_shape;\n ";
for(const auto& x: allocs)
os << " const Tensor& " << x.name << "_shape_tensor = context->input(" << x.shape_id << ");\n ";
for(const auto& x: allocs)
os << " const int32* " << x.name << "_shape_data = (const int32*)" << x.name << "_shape_tensor.tensor_data().data();\n ";
for(const auto& x: allocs)
os << " size_t " << x.name << "_rank = " << x.name << "_shape_tensor.dim_size(0);\n ";
for(const auto& x: allocs)
os << " for(size_t d = 0; d < " << x.name << "_rank ; d++) "
<< x.name << "_shape.AddDim(" << x.name << "_shape_data[d]);\n ";
// allocate
for(unsigned i = 0; i < outputs.size(); i++)
os << " OP_REQUIRES_OK(context, context->allocate_output(" << i << ", shape" << i << ", &" << outputs[i] << "));\n ";
int output = 0;
for(const auto& x: allocs){
if(x.type == tf_alloc_t::OUTPUT)
os << " OP_REQUIRES_OK(context, context->allocate_output(" << output++ << ", " << x.name << "_shape, &" << x.name << "));\n ";
else
os << " OP_REQUIRES_OK(context, context->allocate_temp(" << x.name << "_type, " << x.name << "_shape, " << x.name << "));\n ";
}
}
void gen_make_handles(std::ostream &os, const std::vector<ir::argument*>& args) {
@@ -169,7 +188,7 @@ void gen_make_handles(std::ostream &os, const std::vector<ir::argument*>& args)
}
}
void gen_make_launch_function(std::ostream &os, int num_outputs, const std::vector<ir::argument*>& args) {
void gen_make_launch_function(std::ostream &os, const std::vector<ir::argument*>& args) {
os << " std::function<void()> run = [&](){\n ";
os << " (*id_fn_map.at(id_))({";
for(unsigned i = 0; i < args.size() ; i++){
@@ -181,9 +200,9 @@ void gen_make_launch_function(std::ostream &os, int num_outputs, const std::vect
os << ", ";
os << name;
}
os << "}, *id_grid_map.at(id_), stream);\n";
os << "}, *id_grid_map.at(id_), stream);\n ";
os << " };\n ";
os << " run();";
os << " run();\n ";
os << " if(bench_ > 0)\n ";
os << " i64scalar_map[bench_id_] = triton::tools::bench(run, stream);\n ";
}
@@ -191,69 +210,53 @@ void gen_make_launch_function(std::ostream &os, int num_outputs, const std::vect
void gen_tf_register_kernel_builder(std::ostream &os, const std::string &name,
const std::string &opname,
const std::vector<ir::argument*>& args,
const std::vector<std::string>& outputs){
const alloc_map_t& allocs){
auto tolower = [](char c) { return std::tolower(c);};
os << "REGISTER_KERNEL_BUILDER(Name(\"" + name + "\").Device(DEVICE_GPU)";
for(size_t i = 0; i < args.size(); i++){
ir::argument *arg = args[i];
std::string name = arg->get_name();
std::transform(name.begin(), name.end(), name.begin(), tolower);
std::string name = tf_normalize(arg->get_name());
if(!arg->get_type()->is_pointer_ty())
os << ".HostMemory(\"" + name + "\")";
}
for(size_t i = 0; i < outputs.size(); i++){
std::string name = outputs[i];
std::transform(name.begin(), name.end(), name.begin(), tolower);
os << ".HostMemory(\"" << name << "_shape\")";
}
for(const auto& x: allocs)
os << ".HostMemory(\"" << x.tf_name << "_shape\")";
os << ", " + opname << ");\n";
}
void gen_tf_register_op(std::ostream &os, const std::string &name,
const std::vector<ir::argument*>& args,
const std::vector<std::string>& outputs){
const alloc_map_t& allocs){
auto tolower = [](char c) { return std::tolower(c);};
os << "REGISTER_OP(\"" << name << "\")\n";
for(size_t i = 0; i < args.size(); i++)
os << " .Attr(\"T" << i << " : {bool, int8, int16, int32, int64, float16, float32, float64}\")" << std::endl;
for(size_t i = 0; i < args.size(); i++){
ir::argument *arg = args[i];
std::string name = arg->get_name();
std::transform(name.begin(), name.end(), name.begin(), tolower);
if(std::find(outputs.begin(), outputs.end(), arg->get_name()) == outputs.end())
std::string name = tf_normalize(arg->get_name());
if(std::find_if(allocs.begin(), allocs.end(),
[&](tf_alloc_t x) {
return name == x.tf_name;
}) == allocs.end())
os << " .Input(\"" << name << ": T" << i << "\")\n";
else
os << " .Input(\"" << name << "_shape: int32\")\n";
}
std::vector<int> out_idx;
for(size_t i = 0; i < outputs.size(); i++){
std::string name = outputs[i];
size_t idx;
for(idx = 0; idx < args.size(); idx++)
if(args[idx]->get_name() == name)
break;
if(idx == args.size())
throw std::runtime_error("unknown output");
out_idx.push_back(idx);
}
for(size_t i = 0; i < out_idx.size(); i++){
std::string name = outputs[i];
std::transform(name.begin(), name.end(), name.begin(), tolower);
os << " .Output(\"" << name << ": T" << out_idx[i] << "\")\n";
}
for(const auto& x: allocs)
if(x.type == tf_alloc_t::OUTPUT)
os << " .Output(\"" << x.tf_name << ": T" << x.shape_id << "\")\n";
os << " .Attr(\"id: int\")\n";
os << " .Attr(\"bench: int\")\n";
os << " .Attr(\"bench_id: int\")\n";
os << " .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* ctx) {\n";
for(size_t i = 0; i < out_idx.size(); i++)
os << " shape_inference::ShapeHandle handle" << i << ";\n";
for(size_t i = 0; i < out_idx.size(); i++)
os << " ctx->MakeShapeFromShapeTensor(" << out_idx[i] << ", &handle" << i << ");\n";
for(size_t i = 0; i < out_idx.size(); i++)
os << " ctx->set_output(" << i << ", handle" << i << ");\n";
size_t current = 0;
for(const auto& x: allocs)
if(x.type == tf_alloc_t::OUTPUT){
os << " shape_inference::ShapeHandle " << x.tf_name << "_handle;\n";
os << " ctx->MakeShapeFromShapeTensor(" << x.shape_id << ", &" << x.tf_name << "_handle);\n";
os << " ctx->set_output(" << current++ << ", " << x.tf_name << "_handle);\n";
}
os << " return Status::OK();\n";
os << " })\n";
@@ -280,6 +283,7 @@ void make_module(const std::string& src, ir::module* ir,
std::tuple<std::string,
std::string> make_tensorflow_src(const std::string& src,
const std::vector<std::string>& outputs,
const std::vector<std::string>& tmp,
const runtime::function::options_space_t& opt)
{
// triton-ir code-gen
@@ -289,10 +293,28 @@ std::tuple<std::string,
// function
ir::function* fn = ir->get_function_list().front();
const std::vector<ir::argument*>& args = fn->args();
std::string name = fn->get_name();
std::string cc_name = name;
cc_name[0] = static_cast<char>(std::toupper(cc_name[0]));
std::string opname = cc_name + "Op";
// allocation info
alloc_map_t allocs;
for(size_t i = 0; i < outputs.size(); i++)
allocs.push_back(tf_alloc_t(outputs[i], tf_alloc_t::OUTPUT));
for(size_t i = 0; i < tmp.size(); i++)
allocs.push_back(tf_alloc_t(tmp[i], tf_alloc_t::TEMP));
for(auto &x: allocs){
size_t idx;
for(idx = 0; idx < args.size(); idx++)
if(args[idx]->get_name() == x.name)
break;
if(idx == args.size())
throw std::runtime_error("unknown output");
x.shape_id = idx;
}
std::ostringstream oss;
oss << R"(
@@ -323,6 +345,11 @@ class )" << opname << R"(: public OpKernel {
OP_REQUIRES_OK(context, context->GetAttr("id", &id_));
OP_REQUIRES_OK(context, context->GetAttr("bench", &bench_));
OP_REQUIRES_OK(context, context->GetAttr("bench_id", &bench_id_));
)";
for(const auto& alloc: allocs)
oss << " OP_REQUIRES_OK(context, context->GetAttr(\"T" << alloc.shape_id << "\", &" << alloc.name << "_type));\n ";
oss << R"(
}
void Compute(OpKernelContext* context){
@@ -335,21 +362,21 @@ class )" << opname << R"(: public OpKernel {
// extract inputs
)";
gen_extract_inputs(oss, fn->args(), outputs);
gen_extract_inputs(oss, args, allocs);
oss << R"(
// set outputs
)";
gen_set_outputs(oss, fn->args(), outputs);
gen_set_outputs(oss, args, allocs);
oss << R"(
// wrap tensors
)";
gen_make_handles(oss, fn->args());
gen_make_handles(oss, args);
oss << R"(
)";
oss << R"(
// launch function
)";
gen_make_launch_function(oss, outputs.size(), fn->args());
gen_make_launch_function(oss, args);
oss << R"(
}
@@ -357,15 +384,20 @@ private:
int id_;
int bench_;
int64 bench_id_;
)";
for(const auto& alloc: allocs)
oss << "DataType " << alloc.name << "_type;\n ";
oss << R"(
};
// register kernel builder
)";
gen_tf_register_kernel_builder(oss, cc_name, opname, fn->args(), outputs);
gen_tf_register_kernel_builder(oss, cc_name, opname, args, allocs);
oss << R"(
// register op
)";
gen_tf_register_op(oss, cc_name, fn->args(), outputs);
gen_tf_register_op(oss, cc_name, args, allocs);
return {oss.str(), name};
}
@@ -517,6 +549,7 @@ void gen_torch_ret(std::ostream &os, const std::vector<std::string>& outputs) {
std::tuple<std::string,
std::string> make_torch_src(const std::string& src,
const std::vector<std::string>& outputs,
const std::vector<std::string>& tmp,
const runtime::function::options_space_t& opt) {
// triton-ir code-gen
ir::context ctx;