[PYTHON][KERNEL] Enforcing shapes to be known at compile-time for

TensorFlow Graph Execution
This commit is contained in:
Philippe Tillet
2019-10-28 17:12:37 -04:00
parent e9c787ef05
commit 448f4433d9
9 changed files with 82 additions and 52 deletions

View File

@@ -156,7 +156,7 @@ void gen_make_launch_function(std::ostream &os, int num_outputs, const std::vect
os << " };\n ";
os << " run();";
os << " if(bench_ > 0)\n ";
os << " i64scalar_map[id_] = triton::tools::bench(run, stream);\n ";
os << " i64scalar_map[bench_id_] = triton::tools::bench(run, stream);\n ";
}
void gen_tf_register_kernel_builder(std::ostream &os, const std::string &name,
@@ -186,6 +186,7 @@ void gen_tf_register_op(std::ostream &os, const std::string &name,
os << " .Attr(\"T" << i << " : {bool, int8, int16, int32, int64, float16, float32, float64}\")" << std::endl;
os << " .Input(\"" << name << ": T" << i << "\")\n";
}
std::vector<int> out_idx;
for(size_t i = 0; i < outputs.size(); i++){
std::string name = outputs[i];
size_t idx;
@@ -194,11 +195,19 @@ void gen_tf_register_op(std::ostream &os, const std::string &name,
break;
if(idx == args.size())
throw std::runtime_error("unknown output");
os << " .Output(\"out" << i << ": T" << idx << "\")\n";
out_idx.push_back(idx);
}
for(size_t i = 0; i < out_idx.size(); i++)
os << " .Output(\"out" << i << ": T" << out_idx[i] << "\")\n";
os << " .Attr(\"id: int\")\n";
os << " .Attr(\"bench: int\")\n";
os << " .Attr(\"bench_id: int\")\n";
os << " .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {\n";
for(size_t i = 0; i < out_idx.size(); i++)
os << " c->set_output(" << i << ", c->input(" << out_idx[i] << "));\n";
os << " return Status::OK();\n";
os << " })\n";
os << ";\n";
}
@@ -313,7 +322,7 @@ oss << R"(
private:
int id_;
int bench_;
int bench_id_;
int64 bench_id_;
};
// register kernel builder
@@ -397,6 +406,7 @@ void gen_torch_signature(std::ostringstream& oss,
oss << ret_ty << " " << name << "(";
oss << "int64_t id, ";
oss << "int64_t bench, ";
oss << "int64_t bench_id, ";
for(size_t i = 0; i < args.size(); i++) {
ir::argument* arg = args[i];
if(i > 0)
@@ -453,7 +463,7 @@ void gen_torch_make_launch_function(std::ostream &os, const std::vector<ir::argu
os << " };\n ";
os << " run();";
os << " if(bench > 0)\n ";
os << " i64scalar_map[id] = triton::tools::bench(run, stream);\n ";
os << " i64scalar_map[bench_id] = triton::tools::bench(run, &stream);\n ";
}
void gen_torch_ret(std::ostream &os, const std::vector<std::string>& outputs) {