[PYTHON] Added option to show PTX source code in Python
This commit is contained in:
@@ -3,6 +3,7 @@
|
||||
#include <pybind11/stl.h>
|
||||
#include <pybind11/functional.h>
|
||||
#include <string>
|
||||
#include "triton/driver/stream.h"
|
||||
#include "triton/runtime/function.h"
|
||||
#include "triton/runtime/arg.h"
|
||||
#include "triton/lang/code_gen.h"
|
||||
@@ -19,6 +20,8 @@ typedef std::pair<int, int> map_key_t;
|
||||
std::map<map_key_t, std::shared_ptr<rt::function::grid_fn_ty>> id_grid_map;
|
||||
std::map<map_key_t, std::shared_ptr<rt::function>> id_fn_map;
|
||||
|
||||
CUstream torch_get_cuda_stream(int64_t dev_id);
|
||||
|
||||
/* Grid utilities */
|
||||
|
||||
void register_grid(const map_key_t& key,
|
||||
@@ -34,15 +37,19 @@ void delete_grid(const map_key_t& key) {
|
||||
|
||||
void register_fn(const map_key_t& key,
|
||||
const std::string& src,
|
||||
const rt::function::options_space_t& opt,
|
||||
const std::string &cache_ref) {
|
||||
id_fn_map[key].reset(new rt::function(src, opt, cache_ref));
|
||||
const rt::function::options_space_t& opt) {
|
||||
if(id_fn_map.find(key) == id_fn_map.end())
|
||||
id_fn_map[key].reset(new rt::function(src, opt, ""));
|
||||
}
|
||||
|
||||
void delete_fn(const map_key_t& key) {
|
||||
id_fn_map.erase(key);
|
||||
}
|
||||
|
||||
std::string get_fn_ptx(const map_key_t& key, const rt::function::options_t& opt) {
|
||||
triton::driver::cu_stream stream(torch_get_cuda_stream(key.second), false);
|
||||
return id_fn_map[key]->ptx(&stream, opt);
|
||||
}
|
||||
|
||||
void register_cst(const map_key_t& key, const std::string& name, pybind11::buffer& data) {
|
||||
pybind11::buffer_info info = data.request();
|
||||
@@ -113,7 +120,8 @@ PYBIND11_MODULE(libtriton, m) {
|
||||
pybind11::class_<options_t>(m, "options")
|
||||
.def(pybind11::init<>())
|
||||
.def("d", &options_t::D<int>)
|
||||
.def_readonly("num_warps", &options_t::num_warps);
|
||||
.def_readwrite("num_warps", &options_t::num_warps)
|
||||
.def_readwrite("defines" , &options_t::defines);
|
||||
|
||||
pybind11::class_<options_space_t>(m, "options_space")
|
||||
.def(pybind11::init<>())
|
||||
@@ -122,6 +130,7 @@ PYBIND11_MODULE(libtriton, m) {
|
||||
|
||||
// hooks into triton constructs since frameworks may not use pybind11
|
||||
m.def("get_fn_signature", &get_fn_signature);
|
||||
m.def("get_fn_ptx", &get_fn_ptx);
|
||||
m.def("register_grid", ®ister_grid);
|
||||
m.def("delete_grid", &delete_grid);
|
||||
m.def("register_fn", ®ister_fn);
|
||||
|
Reference in New Issue
Block a user