[PYTHON][KERNEL] Added benchmarking functionalities for kernels
This commit is contained in:
@@ -20,13 +20,13 @@ using namespace triton;
|
||||
|
||||
namespace rt = triton::runtime;
|
||||
|
||||
|
||||
/* TF triton op properties */
|
||||
|
||||
std::map<size_t, std::shared_ptr<rt::function::grid_fn_ty>> id_grid_map;
|
||||
std::map<size_t, std::shared_ptr<rt::function>> id_fn_map;
|
||||
std::map<size_t, double> fp64scalar_map;
|
||||
std::map<size_t, int64_t> i64scalar_map;
|
||||
|
||||
/* Grid map */
|
||||
|
||||
void register_grid(size_t id,
|
||||
const rt::function::grid_fn_ty& grid_fn) {
|
||||
id_grid_map[id].reset(new rt::function::grid_fn_ty(grid_fn));
|
||||
@@ -36,6 +36,8 @@ void delete_grid(size_t id) {
|
||||
id_grid_map.erase(id);
|
||||
}
|
||||
|
||||
/* Function map */
|
||||
|
||||
void register_fn(size_t id,
|
||||
const std::string& src,
|
||||
const rt::function::options_space_t& opt) {
|
||||
@@ -56,8 +58,11 @@ size_t make_op_id() {
|
||||
return id_fn_map.size();
|
||||
}
|
||||
|
||||
/* TF scalar wrapper */
|
||||
size_t make_scalar_id() {
|
||||
return i64scalar_map.size();
|
||||
size_t ret = i64scalar_map.size();
|
||||
i64scalar_map[ret] = int64_t();
|
||||
return ret;
|
||||
}
|
||||
|
||||
bool has_scalar(size_t id) {
|
||||
@@ -135,8 +140,9 @@ void gen_make_handles(std::ostream &os, const std::vector<ir::argument*>& args)
|
||||
}
|
||||
}
|
||||
|
||||
void gen_make_launch_function(std::ostream &os, const std::vector<ir::argument*>& args) {
|
||||
os << " (*id_fn_map.at(id_))({";
|
||||
void gen_make_launch_function(std::ostream &os, int num_outputs, const std::vector<ir::argument*>& args) {
|
||||
os << " std::function<void()> run = [&](){\n ";
|
||||
os << " (*id_fn_map.at(id_))({";
|
||||
for(unsigned i = 0; i < args.size() ; i++){
|
||||
ir::argument *arg = args[i];
|
||||
std::string name = arg->get_name();
|
||||
@@ -146,7 +152,11 @@ void gen_make_launch_function(std::ostream &os, const std::vector<ir::argument*>
|
||||
os << ", ";
|
||||
os << name;
|
||||
}
|
||||
os << "}, *id_grid_map.at(id_), stream); \n";
|
||||
os << "}, *id_grid_map.at(id_), stream);\n";
|
||||
os << " };\n ";
|
||||
os << " run();";
|
||||
os << " if(bench_ > 0)\n ";
|
||||
os << " i64scalar_map[id_] = triton::tools::bench(run, stream);\n ";
|
||||
}
|
||||
|
||||
void gen_tf_register_kernel_builder(std::ostream &os, const std::string &name,
|
||||
@@ -186,7 +196,9 @@ void gen_tf_register_op(std::ostream &os, const std::string &name,
|
||||
throw std::runtime_error("unknown output");
|
||||
os << " .Output(\"out" << i << ": T" << idx << "\")\n";
|
||||
}
|
||||
os << " .Attr(\"id: int\")" << std::endl;
|
||||
os << " .Attr(\"id: int\")\n";
|
||||
os << " .Attr(\"bench: int\")\n";
|
||||
os << " .Attr(\"bench_id: int\")\n";
|
||||
os << ";\n";
|
||||
}
|
||||
|
||||
@@ -247,6 +259,7 @@ std::tuple<std::string,
|
||||
#include "triton/driver/backend.h"
|
||||
#include "triton/driver/stream.h"
|
||||
#include "triton/runtime/function.h"
|
||||
#include "triton/tools/bench.hpp"
|
||||
|
||||
#define EIGEN_USE_GPU
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
@@ -260,13 +273,15 @@ namespace drv = triton::driver;
|
||||
|
||||
extern std::map<size_t, std::shared_ptr<rt::function::grid_fn_ty>> id_grid_map;
|
||||
extern std::map<size_t, std::shared_ptr<rt::function>> id_fn_map;
|
||||
|
||||
extern std::map<size_t, int64_t> i64scalar_map;
|
||||
|
||||
class )" << opname << R"(: public OpKernel {
|
||||
public:
|
||||
explicit )" << opname << R"((OpKernelConstruction* context)
|
||||
: OpKernel(context) {
|
||||
OP_REQUIRES_OK(context, context->GetAttr("id", &id_));
|
||||
OP_REQUIRES_OK(context, context->GetAttr("bench", &bench_));
|
||||
OP_REQUIRES_OK(context, context->GetAttr("bench_id", &bench_id_));
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* context){
|
||||
@@ -291,12 +306,14 @@ oss << R"(
|
||||
oss << R"(
|
||||
// launch function
|
||||
)";
|
||||
gen_make_launch_function(oss, fn->args());
|
||||
gen_make_launch_function(oss, outputs.size(), fn->args());
|
||||
oss << R"(
|
||||
}
|
||||
|
||||
private:
|
||||
int id_;
|
||||
int bench_;
|
||||
int bench_id_;
|
||||
};
|
||||
|
||||
// register kernel builder
|
||||
@@ -379,6 +396,7 @@ void gen_torch_signature(std::ostringstream& oss,
|
||||
|
||||
oss << ret_ty << " " << name << "(";
|
||||
oss << "int64_t id, ";
|
||||
oss << "int64_t bench, ";
|
||||
for(size_t i = 0; i < args.size(); i++) {
|
||||
ir::argument* arg = args[i];
|
||||
if(i > 0)
|
||||
@@ -420,7 +438,8 @@ void gen_torch_make_handles(std::ostream &os,
|
||||
}
|
||||
|
||||
void gen_torch_make_launch_function(std::ostream &os, const std::vector<ir::argument*>& args) {
|
||||
os << " (*id_fn_map.at(id))({";
|
||||
os << " std::function<void()> run = [&](){\n ";
|
||||
os << " (*id_fn_map.at(id))({";
|
||||
for(unsigned i = 0; i < args.size() ; i++){
|
||||
ir::argument *arg = args[i];
|
||||
std::string name = "arg_" + arg->get_name();
|
||||
@@ -431,7 +450,11 @@ void gen_torch_make_launch_function(std::ostream &os, const std::vector<ir::argu
|
||||
os << name;
|
||||
}
|
||||
os << "}, *id_grid_map.at(id), &stream);\n";
|
||||
}
|
||||
os << " };\n ";
|
||||
os << " run();";
|
||||
os << " if(bench > 0)\n ";
|
||||
os << " i64scalar_map[id] = triton::tools::bench(run, stream);\n ";
|
||||
}
|
||||
|
||||
void gen_torch_ret(std::ostream &os, const std::vector<std::string>& outputs) {
|
||||
if(outputs.size() == 1){
|
||||
@@ -465,6 +488,7 @@ std::tuple<std::string,
|
||||
#include "triton/driver/backend.h"
|
||||
#include "triton/driver/stream.h"
|
||||
#include "triton/runtime/function.h"
|
||||
#include "triton/tools/bench.hpp"
|
||||
#include "torch/extension.h"
|
||||
#include "torch/script.h"
|
||||
#include "ATen/cuda/CUDAContext.h"
|
||||
@@ -479,6 +503,7 @@ namespace drv = triton::driver;
|
||||
|
||||
extern std::map<size_t, std::shared_ptr<rt::function::grid_fn_ty>> id_grid_map;
|
||||
extern std::map<size_t, std::shared_ptr<rt::function>> id_fn_map;
|
||||
extern std::map<size_t, int64_t> i64scalar_map;
|
||||
|
||||
)";
|
||||
|
||||
|
Reference in New Issue
Block a user