[PYTHON] Added option to show PTX source code in Python

This commit is contained in:
Philippe Tillet
2020-11-07 02:55:48 -05:00
committed by Philippe Tillet
parent cf80ccc798
commit 8f3ee53f24
10 changed files with 68 additions and 33 deletions

View File

@@ -121,8 +121,8 @@ dot = _dot.apply
torch.manual_seed(0)
M, N, K = 2048, 2048, 2048
a = torch.rand((M, K)).cuda()
b = torch.rand((K, N)).cuda()
a = torch.rand((M, K)).cuda().half()
b = torch.rand((K, N)).cuda().half()
#a[:] = 1
#b[:] = 1

View File

@@ -23,12 +23,9 @@ __global__ void add(float* z, float* x, float* y, int N) {
@staticmethod
def forward(ctx, x, y):
z = torch.empty_like(x).cuda()
N = x.numel()
grid = lambda opt: (triton.cdiv(N, opt.d('TILE')),)
_add.kernel(z,x,y, N, grid=grid)
return z
add = _add.apply

View File

@@ -3,6 +3,7 @@
#include <pybind11/stl.h>
#include <pybind11/functional.h>
#include <string>
#include "triton/driver/stream.h"
#include "triton/runtime/function.h"
#include "triton/runtime/arg.h"
#include "triton/lang/code_gen.h"
@@ -19,6 +20,8 @@ typedef std::pair<int, int> map_key_t;
std::map<map_key_t, std::shared_ptr<rt::function::grid_fn_ty>> id_grid_map;
std::map<map_key_t, std::shared_ptr<rt::function>> id_fn_map;
CUstream torch_get_cuda_stream(int64_t dev_id);
/* Grid utilities */
void register_grid(const map_key_t& key,
@@ -34,15 +37,19 @@ void delete_grid(const map_key_t& key) {
void register_fn(const map_key_t& key,
const std::string& src,
const rt::function::options_space_t& opt,
const std::string &cache_ref) {
id_fn_map[key].reset(new rt::function(src, opt, cache_ref));
const rt::function::options_space_t& opt) {
if(id_fn_map.find(key) == id_fn_map.end())
id_fn_map[key].reset(new rt::function(src, opt, ""));
}
void delete_fn(const map_key_t& key) {
id_fn_map.erase(key);
}
std::string get_fn_ptx(const map_key_t& key, const rt::function::options_t& opt) {
triton::driver::cu_stream stream(torch_get_cuda_stream(key.second), false);
return id_fn_map[key]->ptx(&stream, opt);
}
void register_cst(const map_key_t& key, const std::string& name, pybind11::buffer& data) {
pybind11::buffer_info info = data.request();
@@ -113,7 +120,8 @@ PYBIND11_MODULE(libtriton, m) {
pybind11::class_<options_t>(m, "options")
.def(pybind11::init<>())
.def("d", &options_t::D<int>)
.def_readonly("num_warps", &options_t::num_warps);
.def_readwrite("num_warps", &options_t::num_warps)
.def_readwrite("defines" , &options_t::defines);
pybind11::class_<options_space_t>(m, "options_space")
.def(pybind11::init<>())
@@ -122,6 +130,7 @@ PYBIND11_MODULE(libtriton, m) {
// hooks into triton constructs since frameworks may not use pybind11
m.def("get_fn_signature", &get_fn_signature);
m.def("get_fn_ptx", &get_fn_ptx);
m.def("register_grid", &register_grid);
m.def("delete_grid", &delete_grid);
m.def("register_fn", &register_fn);

View File

@@ -27,6 +27,10 @@ int64_t cdiv_sum(torch::Tensor x, int64_t div){
return ret;
}
CUstream torch_get_cuda_stream(int64_t dev_id) {
return (CUstream)at::cuda::getCurrentCUDAStream(dev_id).stream();
}
void launch_kernel(int64_t op_id, int64_t dev_id, const std::string& args){
if(dev_id == -1){
if(!host_stream){
@@ -37,8 +41,7 @@ void launch_kernel(int64_t op_id, int64_t dev_id, const std::string& args){
(*id_fn_map.at({op_id, dev_id}))((void**)args.c_str(), args.size(), *id_grid_map.at({op_id, dev_id}), &*host_stream);
}
else{
CUstream custream = (CUstream)at::cuda::getCurrentCUDAStream(dev_id).stream();
triton::driver::cu_stream stream(custream, false);
triton::driver::cu_stream stream(torch_get_cuda_stream(dev_id), false);
triton::driver::context* ctx = stream.context();
(*id_fn_map.at({op_id, dev_id}))((void**)args.c_str(), args.size(), *id_grid_map.at({op_id, dev_id}), &stream);
}

View File

@@ -66,6 +66,26 @@ class kernel:
def set_constant(self, device, name, value):
libtriton.register_cst((self.op_id, device), name, value)
def ptx(self, device, **kwargs):
dev_id = device.index
libtriton.register_fn((self.op_id, dev_id), self.src, self.opt)
def _single_value_or_err(x, key):
if isinstance(x, list) and len(x) == 1:
return x[0]
if isinstance(x, list) and len(x) > 1:
if key in kwargs:
return kwargs[key]
raise ValueError(f'Parameter {key}={x} was auto-tuned during kernel creation. '
'Please supply an explicit value as a keyword argument.')
return str(x)
defines = dict()
for (D, V) in self.opt.defines:
defines[D] = _single_value_or_err(V, D)
opt = libtriton.options()
opt.num_warps = _single_value_or_err(self.opt.num_warps, 'num_warps')
opt.defines = defines
return libtriton.get_fn_ptx((self.op_id, dev_id), opt)
def __call__(self, *args, **kwargs):
for x in args:
if isinstance(x, torch.Tensor):
@@ -73,9 +93,7 @@ class kernel:
device = -1 if device is None else device
break
# lazily register function for device
if device not in self.registered:
self.registered.add(device)
libtriton.register_fn((self.op_id, device), self.src, self.opt, os.path.realpath(libtriton.__file__))
libtriton.register_fn((self.op_id, device), self.src, self.opt)
# launch grid
if 'grid' not in kwargs:
raise RuntimeError('Must provide grid for kernel launch')