[PYTHON][CORE] Deprecating Tensorflow support
This commit is contained in:
committed by
Philippe Tillet
parent
d7a781dd40
commit
404dd18333
@@ -449,32 +449,9 @@ inline std::string to_c_ty(ir::type *ty) {
|
||||
|
||||
void gen_torch_signature(std::ostringstream& oss,
|
||||
ir::function* fn,
|
||||
const std::vector<std::string>& outputs,
|
||||
const std::string& name) {
|
||||
const auto& args = fn->args();
|
||||
std::vector<ir::type*> out_types;
|
||||
for(const std::string& out: outputs) {
|
||||
auto it = std::find_if(args.begin(), args.end(),
|
||||
[&](ir::argument* arg) { return arg->get_name() == out; });
|
||||
if(it == args.end())
|
||||
throw std::runtime_error("unknown argument");
|
||||
out_types.push_back((*it)->get_type());
|
||||
}
|
||||
|
||||
std::string ret_ty;
|
||||
if(out_types.empty())
|
||||
ret_ty = "void";
|
||||
else{
|
||||
ir::type* ty = out_types[0];
|
||||
ret_ty = to_torch_ty(ty);
|
||||
if(out_types.size() > 1){
|
||||
for(size_t i = 1; i < out_types.size(); i++)
|
||||
if(out_types[i] != ty)
|
||||
throw std::runtime_error("outputs of different types not supported by pytorch");
|
||||
ret_ty = "std::vector<" + ret_ty + ">";
|
||||
}
|
||||
}
|
||||
|
||||
std::string ret_ty = "void";
|
||||
oss << ret_ty << " " << name << "(";
|
||||
oss << "int64_t id, ";
|
||||
oss << "int64_t bench, ";
|
||||
@@ -555,9 +532,7 @@ void gen_torch_ret(std::ostream &os, const std::vector<std::string>& outputs) {
|
||||
|
||||
std::tuple<std::string,
|
||||
std::string> make_torch_src(const std::string& src,
|
||||
const std::vector<std::string>& outputs,
|
||||
const std::vector<std::string>& tmp,
|
||||
const runtime::function::options_space_t& opt) {
|
||||
const runtime::function::options_space_t& opt) {
|
||||
// triton-ir code-gen
|
||||
ir::context ctx;
|
||||
auto ir = std::shared_ptr<ir::module>(new ir::module("", ctx));
|
||||
@@ -588,12 +563,12 @@ extern std::map<size_t, int64_t> i64scalar_map;
|
||||
|
||||
)";
|
||||
|
||||
gen_torch_signature(oss, fn, outputs, name);
|
||||
gen_torch_signature(oss, fn, name);
|
||||
oss << " {" << std::endl;
|
||||
gen_torch_init_driver(oss, fn->args());
|
||||
gen_torch_make_handles(oss, fn->args());
|
||||
gen_torch_make_launch_function(oss, fn->args());
|
||||
gen_torch_ret(oss, outputs);
|
||||
//gen_torch_ret(oss);
|
||||
oss << "}" << std::endl;
|
||||
|
||||
oss << std::endl;
|
||||
|
Reference in New Issue
Block a user