some work on conv

This commit is contained in:
Philippe Tillet
2019-10-31 18:08:27 -04:00
parent 91a2fd463b
commit 739a8d9061
10 changed files with 278 additions and 24 deletions

View File

@@ -514,6 +514,7 @@ void gen_torch_make_handles(std::ostream &os,
}
void gen_torch_make_launch_function(std::ostream &os, const std::vector<ir::argument*>& args) {
os << " std::cout << 9 << std::endl;";
os << " std::function<void()> run = [&](){\n ";
os << " (*id_fn_map.at(id))({";
for(unsigned i = 0; i < args.size() ; i++){
@@ -528,6 +529,7 @@ void gen_torch_make_launch_function(std::ostream &os, const std::vector<ir::argu
os << "}, *id_grid_map.at(id), &stream);\n";
os << " };\n ";
os << " run();";
os << " std::cout << 10 << std::endl;";
os << " if(bench > 0)\n ";
os << " i64scalar_map[bench_id] = triton::tools::bench(run, &stream);\n ";
}
@@ -586,10 +588,14 @@ extern std::map<size_t, int64_t> i64scalar_map;
gen_torch_signature(oss, fn, outputs, name);
oss << " {" << std::endl;
oss << " std::cout << 1 << std::endl;";
gen_torch_init_driver(oss, fn->args());
gen_torch_make_handles(oss, fn->args());
oss << " std::cout << 2 << std::endl;";
gen_torch_make_launch_function(oss, fn->args());
oss << " std::cout << 3 << std::endl;";
gen_torch_ret(oss, outputs);
oss << " std::cout << \"done\" << std::endl;\n";
oss << "}" << std::endl;
oss << std::endl;