[PYTHON] Added option to show PTX source code in Python
This commit is contained in:
committed by
Philippe Tillet
parent
cf80ccc798
commit
8f3ee53f24
@@ -40,7 +40,7 @@ public:
|
|||||||
static module* create(driver::context* ctx, std::unique_ptr<llvm::Module> src);
|
static module* create(driver::context* ctx, std::unique_ptr<llvm::Module> src);
|
||||||
driver::context* context() const;
|
driver::context* context() const;
|
||||||
void compile_llvm_module(std::unique_ptr<llvm::Module> module, const std::string& triple,
|
void compile_llvm_module(std::unique_ptr<llvm::Module> module, const std::string& triple,
|
||||||
const std::string &proc, std::string layout,
|
const std::string &proc, std::string layout,
|
||||||
llvm::SmallVectorImpl<char> &buffer,
|
llvm::SmallVectorImpl<char> &buffer,
|
||||||
const std::string &features,
|
const std::string &features,
|
||||||
file_type_t file_type);
|
file_type_t file_type);
|
||||||
|
@@ -122,7 +122,7 @@ private:
|
|||||||
triton::lang::translation_unit *make_ast(const std::string &src);
|
triton::lang::translation_unit *make_ast(const std::string &src);
|
||||||
std::unique_ptr<ir::module> make_ir(Parser &parser);
|
std::unique_ptr<ir::module> make_ir(Parser &parser);
|
||||||
std::unique_ptr<driver::module> make_bin(ir::module &function, driver::context *context, const options_t &opt);
|
std::unique_ptr<driver::module> make_bin(ir::module &function, driver::context *context, const options_t &opt);
|
||||||
caller *make(driver::stream *stream, options_t opt);
|
void make(driver::stream *stream, options_t opt);
|
||||||
void precompile(driver::stream *stream, const options_space_t& tuning_space);
|
void precompile(driver::stream *stream, const options_space_t& tuning_space);
|
||||||
// autotune
|
// autotune
|
||||||
caller* autotune(driver::stream *stream, const grid_fn_ty& grid, void **args, size_t args_size);
|
caller* autotune(driver::stream *stream, const grid_fn_ty& grid, void **args, size_t args_size);
|
||||||
@@ -135,6 +135,7 @@ public:
|
|||||||
void operator()(void** args, size_t args_size, const grid_t& grid, driver::stream* stream);
|
void operator()(void** args, size_t args_size, const grid_t& grid, driver::stream* stream);
|
||||||
void operator()(void** args, size_t args_size, const grid_fn_ty& grid, driver::stream *stream);
|
void operator()(void** args, size_t args_size, const grid_fn_ty& grid, driver::stream *stream);
|
||||||
void set_cst(const std::string& name, void* data, size_t n_bytes);
|
void set_cst(const std::string& name, void* data, size_t n_bytes);
|
||||||
|
std::string ptx(driver::stream *stream, const options_t& opt);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::map<std::string, std::vector<char>> cst_;
|
std::map<std::string, std::vector<char>> cst_;
|
||||||
|
@@ -246,7 +246,9 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module,
|
|||||||
|
|
||||||
|
|
||||||
// create Binary from options
|
// create Binary from options
|
||||||
function::caller* function::make(driver::stream *stream, options_t opt) {
|
void function::make(driver::stream *stream, options_t opt) {
|
||||||
|
if(callers_.find(opt) != callers_.end())
|
||||||
|
return;
|
||||||
// pre-process
|
// pre-process
|
||||||
TokenSequence tokens;
|
TokenSequence tokens;
|
||||||
Preprocessor cpp(&src_, true);
|
Preprocessor cpp(&src_, true);
|
||||||
@@ -267,8 +269,14 @@ function::caller* function::make(driver::stream *stream, options_t opt) {
|
|||||||
// }
|
// }
|
||||||
// create callable
|
// create callable
|
||||||
ir::function *tmp = ir->get_function_list()[0];
|
ir::function *tmp = ir->get_function_list()[0];
|
||||||
caller* ret = new caller(tmp, std::move(bin), opt);
|
callers_[opt].reset(new caller(tmp, std::move(bin), opt));
|
||||||
return ret;
|
auto& call = callers_[opt];
|
||||||
|
// copy constants
|
||||||
|
if(call)
|
||||||
|
for(const auto& cst: cst_){
|
||||||
|
std::unique_ptr<driver::buffer> buffer = call->parent()->symbol(cst.first.c_str());
|
||||||
|
stream->write(&*buffer, true, 0, cst.second);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// precompile all kernels spanned by given options space
|
// precompile all kernels spanned by given options space
|
||||||
@@ -288,16 +296,7 @@ void function::precompile(driver::stream* stream,
|
|||||||
for(auto D: space.defines)
|
for(auto D: space.defines)
|
||||||
opt.defines[D.first] = D.second[params[i++]];
|
opt.defines[D.first] = D.second[params[i++]];
|
||||||
// compile
|
// compile
|
||||||
caller* call = make(stream, opt);
|
make(stream, opt);
|
||||||
if(!call)
|
|
||||||
return;
|
|
||||||
// copy constants
|
|
||||||
std::unique_ptr<driver::buffer> buffer;
|
|
||||||
for(const auto& cst: cst_){
|
|
||||||
buffer = call->parent()->symbol(cst.first.c_str());
|
|
||||||
stream->write(&*buffer, true, 0, cst.second);
|
|
||||||
}
|
|
||||||
callers_[opt].reset(call);
|
|
||||||
};
|
};
|
||||||
// multi-threaded compilation
|
// multi-threaded compilation
|
||||||
_loop_nest(ranges, do_make);
|
_loop_nest(ranges, do_make);
|
||||||
@@ -305,6 +304,14 @@ void function::precompile(driver::stream* stream,
|
|||||||
throw std::runtime_error("could not compile kernel");
|
throw std::runtime_error("could not compile kernel");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::string function::ptx(driver::stream* stream, const options_t& opt) {
|
||||||
|
make(stream, opt);
|
||||||
|
const auto& fn = callers_.at(opt);
|
||||||
|
if(!fn)
|
||||||
|
return "";
|
||||||
|
return ((driver::cu_module*)fn->parent())->source();
|
||||||
|
}
|
||||||
|
|
||||||
// returns program with best compilation options for given parameter
|
// returns program with best compilation options for given parameter
|
||||||
function::caller* function::autotune(driver::stream* stream, const grid_fn_ty& grid_fn,
|
function::caller* function::autotune(driver::stream* stream, const grid_fn_ty& grid_fn,
|
||||||
void** args, size_t args_size) {
|
void** args, size_t args_size) {
|
||||||
|
@@ -121,8 +121,8 @@ dot = _dot.apply
|
|||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
|
|
||||||
M, N, K = 2048, 2048, 2048
|
M, N, K = 2048, 2048, 2048
|
||||||
a = torch.rand((M, K)).cuda()
|
a = torch.rand((M, K)).cuda().half()
|
||||||
b = torch.rand((K, N)).cuda()
|
b = torch.rand((K, N)).cuda().half()
|
||||||
|
|
||||||
#a[:] = 1
|
#a[:] = 1
|
||||||
#b[:] = 1
|
#b[:] = 1
|
||||||
|
@@ -23,12 +23,9 @@ __global__ void add(float* z, float* x, float* y, int N) {
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, x, y):
|
def forward(ctx, x, y):
|
||||||
z = torch.empty_like(x).cuda()
|
z = torch.empty_like(x).cuda()
|
||||||
|
|
||||||
N = x.numel()
|
N = x.numel()
|
||||||
grid = lambda opt: (triton.cdiv(N, opt.d('TILE')),)
|
grid = lambda opt: (triton.cdiv(N, opt.d('TILE')),)
|
||||||
|
|
||||||
_add.kernel(z,x,y, N, grid=grid)
|
_add.kernel(z,x,y, N, grid=grid)
|
||||||
|
|
||||||
return z
|
return z
|
||||||
|
|
||||||
add = _add.apply
|
add = _add.apply
|
||||||
|
@@ -3,6 +3,7 @@
|
|||||||
#include <pybind11/stl.h>
|
#include <pybind11/stl.h>
|
||||||
#include <pybind11/functional.h>
|
#include <pybind11/functional.h>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include "triton/driver/stream.h"
|
||||||
#include "triton/runtime/function.h"
|
#include "triton/runtime/function.h"
|
||||||
#include "triton/runtime/arg.h"
|
#include "triton/runtime/arg.h"
|
||||||
#include "triton/lang/code_gen.h"
|
#include "triton/lang/code_gen.h"
|
||||||
@@ -19,6 +20,8 @@ typedef std::pair<int, int> map_key_t;
|
|||||||
std::map<map_key_t, std::shared_ptr<rt::function::grid_fn_ty>> id_grid_map;
|
std::map<map_key_t, std::shared_ptr<rt::function::grid_fn_ty>> id_grid_map;
|
||||||
std::map<map_key_t, std::shared_ptr<rt::function>> id_fn_map;
|
std::map<map_key_t, std::shared_ptr<rt::function>> id_fn_map;
|
||||||
|
|
||||||
|
CUstream torch_get_cuda_stream(int64_t dev_id);
|
||||||
|
|
||||||
/* Grid utilities */
|
/* Grid utilities */
|
||||||
|
|
||||||
void register_grid(const map_key_t& key,
|
void register_grid(const map_key_t& key,
|
||||||
@@ -34,15 +37,19 @@ void delete_grid(const map_key_t& key) {
|
|||||||
|
|
||||||
void register_fn(const map_key_t& key,
|
void register_fn(const map_key_t& key,
|
||||||
const std::string& src,
|
const std::string& src,
|
||||||
const rt::function::options_space_t& opt,
|
const rt::function::options_space_t& opt) {
|
||||||
const std::string &cache_ref) {
|
if(id_fn_map.find(key) == id_fn_map.end())
|
||||||
id_fn_map[key].reset(new rt::function(src, opt, cache_ref));
|
id_fn_map[key].reset(new rt::function(src, opt, ""));
|
||||||
}
|
}
|
||||||
|
|
||||||
void delete_fn(const map_key_t& key) {
|
void delete_fn(const map_key_t& key) {
|
||||||
id_fn_map.erase(key);
|
id_fn_map.erase(key);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::string get_fn_ptx(const map_key_t& key, const rt::function::options_t& opt) {
|
||||||
|
triton::driver::cu_stream stream(torch_get_cuda_stream(key.second), false);
|
||||||
|
return id_fn_map[key]->ptx(&stream, opt);
|
||||||
|
}
|
||||||
|
|
||||||
void register_cst(const map_key_t& key, const std::string& name, pybind11::buffer& data) {
|
void register_cst(const map_key_t& key, const std::string& name, pybind11::buffer& data) {
|
||||||
pybind11::buffer_info info = data.request();
|
pybind11::buffer_info info = data.request();
|
||||||
@@ -113,7 +120,8 @@ PYBIND11_MODULE(libtriton, m) {
|
|||||||
pybind11::class_<options_t>(m, "options")
|
pybind11::class_<options_t>(m, "options")
|
||||||
.def(pybind11::init<>())
|
.def(pybind11::init<>())
|
||||||
.def("d", &options_t::D<int>)
|
.def("d", &options_t::D<int>)
|
||||||
.def_readonly("num_warps", &options_t::num_warps);
|
.def_readwrite("num_warps", &options_t::num_warps)
|
||||||
|
.def_readwrite("defines" , &options_t::defines);
|
||||||
|
|
||||||
pybind11::class_<options_space_t>(m, "options_space")
|
pybind11::class_<options_space_t>(m, "options_space")
|
||||||
.def(pybind11::init<>())
|
.def(pybind11::init<>())
|
||||||
@@ -122,6 +130,7 @@ PYBIND11_MODULE(libtriton, m) {
|
|||||||
|
|
||||||
// hooks into triton constructs since frameworks may not use pybind11
|
// hooks into triton constructs since frameworks may not use pybind11
|
||||||
m.def("get_fn_signature", &get_fn_signature);
|
m.def("get_fn_signature", &get_fn_signature);
|
||||||
|
m.def("get_fn_ptx", &get_fn_ptx);
|
||||||
m.def("register_grid", ®ister_grid);
|
m.def("register_grid", ®ister_grid);
|
||||||
m.def("delete_grid", &delete_grid);
|
m.def("delete_grid", &delete_grid);
|
||||||
m.def("register_fn", ®ister_fn);
|
m.def("register_fn", ®ister_fn);
|
||||||
|
@@ -27,6 +27,10 @@ int64_t cdiv_sum(torch::Tensor x, int64_t div){
|
|||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
CUstream torch_get_cuda_stream(int64_t dev_id) {
|
||||||
|
return (CUstream)at::cuda::getCurrentCUDAStream(dev_id).stream();
|
||||||
|
}
|
||||||
|
|
||||||
void launch_kernel(int64_t op_id, int64_t dev_id, const std::string& args){
|
void launch_kernel(int64_t op_id, int64_t dev_id, const std::string& args){
|
||||||
if(dev_id == -1){
|
if(dev_id == -1){
|
||||||
if(!host_stream){
|
if(!host_stream){
|
||||||
@@ -37,8 +41,7 @@ void launch_kernel(int64_t op_id, int64_t dev_id, const std::string& args){
|
|||||||
(*id_fn_map.at({op_id, dev_id}))((void**)args.c_str(), args.size(), *id_grid_map.at({op_id, dev_id}), &*host_stream);
|
(*id_fn_map.at({op_id, dev_id}))((void**)args.c_str(), args.size(), *id_grid_map.at({op_id, dev_id}), &*host_stream);
|
||||||
}
|
}
|
||||||
else{
|
else{
|
||||||
CUstream custream = (CUstream)at::cuda::getCurrentCUDAStream(dev_id).stream();
|
triton::driver::cu_stream stream(torch_get_cuda_stream(dev_id), false);
|
||||||
triton::driver::cu_stream stream(custream, false);
|
|
||||||
triton::driver::context* ctx = stream.context();
|
triton::driver::context* ctx = stream.context();
|
||||||
(*id_fn_map.at({op_id, dev_id}))((void**)args.c_str(), args.size(), *id_grid_map.at({op_id, dev_id}), &stream);
|
(*id_fn_map.at({op_id, dev_id}))((void**)args.c_str(), args.size(), *id_grid_map.at({op_id, dev_id}), &stream);
|
||||||
}
|
}
|
||||||
|
@@ -66,6 +66,26 @@ class kernel:
|
|||||||
def set_constant(self, device, name, value):
|
def set_constant(self, device, name, value):
|
||||||
libtriton.register_cst((self.op_id, device), name, value)
|
libtriton.register_cst((self.op_id, device), name, value)
|
||||||
|
|
||||||
|
def ptx(self, device, **kwargs):
|
||||||
|
dev_id = device.index
|
||||||
|
libtriton.register_fn((self.op_id, dev_id), self.src, self.opt)
|
||||||
|
def _single_value_or_err(x, key):
|
||||||
|
if isinstance(x, list) and len(x) == 1:
|
||||||
|
return x[0]
|
||||||
|
if isinstance(x, list) and len(x) > 1:
|
||||||
|
if key in kwargs:
|
||||||
|
return kwargs[key]
|
||||||
|
raise ValueError(f'Parameter {key}={x} was auto-tuned during kernel creation. '
|
||||||
|
'Please supply an explicit value as a keyword argument.')
|
||||||
|
return str(x)
|
||||||
|
defines = dict()
|
||||||
|
for (D, V) in self.opt.defines:
|
||||||
|
defines[D] = _single_value_or_err(V, D)
|
||||||
|
opt = libtriton.options()
|
||||||
|
opt.num_warps = _single_value_or_err(self.opt.num_warps, 'num_warps')
|
||||||
|
opt.defines = defines
|
||||||
|
return libtriton.get_fn_ptx((self.op_id, dev_id), opt)
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
def __call__(self, *args, **kwargs):
|
||||||
for x in args:
|
for x in args:
|
||||||
if isinstance(x, torch.Tensor):
|
if isinstance(x, torch.Tensor):
|
||||||
@@ -73,9 +93,7 @@ class kernel:
|
|||||||
device = -1 if device is None else device
|
device = -1 if device is None else device
|
||||||
break
|
break
|
||||||
# lazily register function for device
|
# lazily register function for device
|
||||||
if device not in self.registered:
|
libtriton.register_fn((self.op_id, device), self.src, self.opt)
|
||||||
self.registered.add(device)
|
|
||||||
libtriton.register_fn((self.op_id, device), self.src, self.opt, os.path.realpath(libtriton.__file__))
|
|
||||||
# launch grid
|
# launch grid
|
||||||
if 'grid' not in kwargs:
|
if 'grid' not in kwargs:
|
||||||
raise RuntimeError('Must provide grid for kernel launch')
|
raise RuntimeError('Must provide grid for kernel launch')
|
||||||
|
@@ -17,11 +17,11 @@ int main() {
|
|||||||
// config_t{ord, x[0], x[1], 384, 384, 384},
|
// config_t{ord, x[0], x[1], 384, 384, 384},
|
||||||
// config_t{ord, x[0], x[1], 512, 512, 512},
|
// config_t{ord, x[0], x[1], 512, 512, 512},
|
||||||
// config_t{ord, x[0], x[1], 768, 768, 768},
|
// config_t{ord, x[0], x[1], 768, 768, 768},
|
||||||
config_t{ord, x[0], x[1], 1024, 1024, 1024},
|
// config_t{ord, x[0], x[1], 1024, 1024, 1024},
|
||||||
// config_t{ord, x[0], x[1], 1280, 1280, 1280},
|
// config_t{ord, x[0], x[1], 1280, 1280, 1280},
|
||||||
// config_t{ord, x[0], x[1], 1536, 1536, 1536},
|
// config_t{ord, x[0], x[1], 1536, 1536, 1536},
|
||||||
// config_t{ord, x[0], x[1], 2048, 2048, 2048},
|
// config_t{ord, x[0], x[1], 2048, 2048, 2048},
|
||||||
// config_t{ord, x[0], x[1], 8192, 8192, 8192},
|
config_t{ord, x[0], x[1], 8192, 8192, 8192},
|
||||||
|
|
||||||
// config_t{ord, x[0], x[1], 256, 16, 256},
|
// config_t{ord, x[0], x[1], 256, 16, 256},
|
||||||
// config_t{ord, x[0], x[1], 512, 16, 512},
|
// config_t{ord, x[0], x[1], 512, 16, 512},
|
||||||
@@ -65,7 +65,7 @@ int main() {
|
|||||||
for(const auto& c: configs){
|
for(const auto& c: configs){
|
||||||
std::tie(ord, AT, BT, M, N, K) = c;
|
std::tie(ord, AT, BT, M, N, K) = c;
|
||||||
std::cout << "// " << c ;
|
std::cout << "// " << c ;
|
||||||
for(auto perf: bench_dot(stream, FLOAT, AT, BT, M, N, K, ord, ord))
|
for(auto perf: bench_dot(stream, HALF, AT, BT, M, N, K, ord, ord))
|
||||||
std::cout << ", " << perf << std::flush;
|
std::cout << ", " << perf << std::flush;
|
||||||
std::cout << std::endl;
|
std::cout << std::endl;
|
||||||
}
|
}
|
||||||
|
@@ -129,7 +129,7 @@ void triton_dot(drv::stream* stream, bool AT, bool BT,
|
|||||||
if(mode == BENCH) {
|
if(mode == BENCH) {
|
||||||
opt.defines.push_back({"TM", {"64", "128"}});
|
opt.defines.push_back({"TM", {"64", "128"}});
|
||||||
opt.defines.push_back({"TN", {"64", "128"}});
|
opt.defines.push_back({"TN", {"64", "128"}});
|
||||||
opt.defines.push_back({"TK", {"8"}});
|
opt.defines.push_back({"TK", {"16"}});
|
||||||
opt.defines.push_back({"TZ", {"1"}});
|
opt.defines.push_back({"TZ", {"1"}});
|
||||||
opt.num_warps = {4};
|
opt.num_warps = {4};
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user