[PYTHON][CORE] Deprecating Tensorflow support

This commit is contained in:
Philippe Tillet
2020-02-10 04:19:17 -05:00
committed by Philippe Tillet
parent d7a781dd40
commit 404dd18333
5 changed files with 26 additions and 108 deletions

View File

@@ -449,32 +449,9 @@ inline std::string to_c_ty(ir::type *ty) {
void gen_torch_signature(std::ostringstream& oss,
ir::function* fn,
const std::vector<std::string>& outputs,
const std::string& name) {
const auto& args = fn->args();
std::vector<ir::type*> out_types;
for(const std::string& out: outputs) {
auto it = std::find_if(args.begin(), args.end(),
[&](ir::argument* arg) { return arg->get_name() == out; });
if(it == args.end())
throw std::runtime_error("unknown argument");
out_types.push_back((*it)->get_type());
}
std::string ret_ty;
if(out_types.empty())
ret_ty = "void";
else{
ir::type* ty = out_types[0];
ret_ty = to_torch_ty(ty);
if(out_types.size() > 1){
for(size_t i = 1; i < out_types.size(); i++)
if(out_types[i] != ty)
throw std::runtime_error("outputs of different types not supported by pytorch");
ret_ty = "std::vector<" + ret_ty + ">";
}
}
std::string ret_ty = "void";
oss << ret_ty << " " << name << "(";
oss << "int64_t id, ";
oss << "int64_t bench, ";
@@ -555,9 +532,7 @@ void gen_torch_ret(std::ostream &os, const std::vector<std::string>& outputs) {
std::tuple<std::string,
std::string> make_torch_src(const std::string& src,
const std::vector<std::string>& outputs,
const std::vector<std::string>& tmp,
const runtime::function::options_space_t& opt) {
const runtime::function::options_space_t& opt) {
// triton-ir code-gen
ir::context ctx;
auto ir = std::shared_ptr<ir::module>(new ir::module("", ctx));
@@ -588,12 +563,12 @@ extern std::map<size_t, int64_t> i64scalar_map;
)";
gen_torch_signature(oss, fn, outputs, name);
gen_torch_signature(oss, fn, name);
oss << " {" << std::endl;
gen_torch_init_driver(oss, fn->args());
gen_torch_make_handles(oss, fn->args());
gen_torch_make_launch_function(oss, fn->args());
gen_torch_ret(oss, outputs);
//gen_torch_ret(oss);
oss << "}" << std::endl;
oss << std::endl;