From 8f3ee53f24abc5b6512125b928d60d180d8a4475 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Sat, 7 Nov 2020 02:55:48 -0500 Subject: [PATCH] [PYTHON] Added option to show PTX source code in Python --- include/triton/driver/module.h | 2 +- include/triton/runtime/function.h | 3 ++- lib/runtime/function.cc | 33 +++++++++++++++++----------- python/examples/tutorials/mat_mul.py | 4 ++-- python/examples/tutorials/vec_add.py | 3 --- python/src/bindings.cc | 17 ++++++++++---- python/src/launch.cc | 7 ++++-- python/triton/kernel.py | 24 +++++++++++++++++--- tests/bench/dot.cc | 6 ++--- tests/common/dot.h | 2 +- 10 files changed, 68 insertions(+), 33 deletions(-) diff --git a/include/triton/driver/module.h b/include/triton/driver/module.h index 991af82ae..474b48cab 100755 --- a/include/triton/driver/module.h +++ b/include/triton/driver/module.h @@ -40,7 +40,7 @@ public: static module* create(driver::context* ctx, std::unique_ptr src); driver::context* context() const; void compile_llvm_module(std::unique_ptr module, const std::string& triple, - const std::string &proc, std::string layout, + const std::string &proc, std::string layout, llvm::SmallVectorImpl &buffer, const std::string &features, file_type_t file_type); diff --git a/include/triton/runtime/function.h b/include/triton/runtime/function.h index 01ebf0eb6..f0adafdd2 100644 --- a/include/triton/runtime/function.h +++ b/include/triton/runtime/function.h @@ -122,7 +122,7 @@ private: triton::lang::translation_unit *make_ast(const std::string &src); std::unique_ptr make_ir(Parser &parser); std::unique_ptr make_bin(ir::module &function, driver::context *context, const options_t &opt); - caller *make(driver::stream *stream, options_t opt); + void make(driver::stream *stream, options_t opt); void precompile(driver::stream *stream, const options_space_t& tuning_space); // autotune caller* autotune(driver::stream *stream, const grid_fn_ty& grid, void **args, size_t args_size); @@ -135,6 +135,7 @@ public: void operator()(void** args, size_t args_size, const grid_t& grid, driver::stream* stream); void operator()(void** args, size_t args_size, const grid_fn_ty& grid, driver::stream *stream); void set_cst(const std::string& name, void* data, size_t n_bytes); + std::string ptx(driver::stream *stream, const options_t& opt); private: std::map> cst_; diff --git a/lib/runtime/function.cc b/lib/runtime/function.cc index e4109b628..3fd8b5dca 100644 --- a/lib/runtime/function.cc +++ b/lib/runtime/function.cc @@ -246,7 +246,9 @@ std::unique_ptr function::make_bin(ir::module &module, // create Binary from options -function::caller* function::make(driver::stream *stream, options_t opt) { +void function::make(driver::stream *stream, options_t opt) { + if(callers_.find(opt) != callers_.end()) + return; // pre-process TokenSequence tokens; Preprocessor cpp(&src_, true); @@ -267,8 +269,14 @@ function::caller* function::make(driver::stream *stream, options_t opt) { // } // create callable ir::function *tmp = ir->get_function_list()[0]; - caller* ret = new caller(tmp, std::move(bin), opt); - return ret; + callers_[opt].reset(new caller(tmp, std::move(bin), opt)); + auto& call = callers_[opt]; + // copy constants + if(call) + for(const auto& cst: cst_){ + std::unique_ptr buffer = call->parent()->symbol(cst.first.c_str()); + stream->write(&*buffer, true, 0, cst.second); + } } // precompile all kernels spanned by given options space @@ -288,16 +296,7 @@ void function::precompile(driver::stream* stream, for(auto D: space.defines) opt.defines[D.first] = D.second[params[i++]]; // compile - caller* call = make(stream, opt); - if(!call) - return; - // copy constants - std::unique_ptr buffer; - for(const auto& cst: cst_){ - buffer = call->parent()->symbol(cst.first.c_str()); - stream->write(&*buffer, true, 0, cst.second); - } - callers_[opt].reset(call); + make(stream, opt); }; // multi-threaded compilation _loop_nest(ranges, do_make); @@ -305,6 +304,14 @@ void function::precompile(driver::stream* stream, throw std::runtime_error("could not compile kernel"); } +std::string function::ptx(driver::stream* stream, const options_t& opt) { + make(stream, opt); + const auto& fn = callers_.at(opt); + if(!fn) + return ""; + return ((driver::cu_module*)fn->parent())->source(); +} + // returns program with best compilation options for given parameter function::caller* function::autotune(driver::stream* stream, const grid_fn_ty& grid_fn, void** args, size_t args_size) { diff --git a/python/examples/tutorials/mat_mul.py b/python/examples/tutorials/mat_mul.py index 8ec788e45..6bd3a1495 100644 --- a/python/examples/tutorials/mat_mul.py +++ b/python/examples/tutorials/mat_mul.py @@ -121,8 +121,8 @@ dot = _dot.apply torch.manual_seed(0) M, N, K = 2048, 2048, 2048 -a = torch.rand((M, K)).cuda() -b = torch.rand((K, N)).cuda() +a = torch.rand((M, K)).cuda().half() +b = torch.rand((K, N)).cuda().half() #a[:] = 1 #b[:] = 1 diff --git a/python/examples/tutorials/vec_add.py b/python/examples/tutorials/vec_add.py index cbf63ff08..acce14062 100644 --- a/python/examples/tutorials/vec_add.py +++ b/python/examples/tutorials/vec_add.py @@ -23,12 +23,9 @@ __global__ void add(float* z, float* x, float* y, int N) { @staticmethod def forward(ctx, x, y): z = torch.empty_like(x).cuda() - N = x.numel() grid = lambda opt: (triton.cdiv(N, opt.d('TILE')),) - _add.kernel(z,x,y, N, grid=grid) - return z add = _add.apply diff --git a/python/src/bindings.cc b/python/src/bindings.cc index ac1c68fc9..0fcae9d31 100644 --- a/python/src/bindings.cc +++ b/python/src/bindings.cc @@ -3,6 +3,7 @@ #include #include #include +#include "triton/driver/stream.h" #include "triton/runtime/function.h" #include "triton/runtime/arg.h" #include "triton/lang/code_gen.h" @@ -19,6 +20,8 @@ typedef std::pair map_key_t; std::map> id_grid_map; std::map> id_fn_map; +CUstream torch_get_cuda_stream(int64_t dev_id); + /* Grid utilities */ void register_grid(const map_key_t& key, @@ -34,15 +37,19 @@ void delete_grid(const map_key_t& key) { void register_fn(const map_key_t& key, const std::string& src, - const rt::function::options_space_t& opt, - const std::string &cache_ref) { - id_fn_map[key].reset(new rt::function(src, opt, cache_ref)); + const rt::function::options_space_t& opt) { + if(id_fn_map.find(key) == id_fn_map.end()) + id_fn_map[key].reset(new rt::function(src, opt, "")); } void delete_fn(const map_key_t& key) { id_fn_map.erase(key); } +std::string get_fn_ptx(const map_key_t& key, const rt::function::options_t& opt) { + triton::driver::cu_stream stream(torch_get_cuda_stream(key.second), false); + return id_fn_map[key]->ptx(&stream, opt); +} void register_cst(const map_key_t& key, const std::string& name, pybind11::buffer& data) { pybind11::buffer_info info = data.request(); @@ -113,7 +120,8 @@ PYBIND11_MODULE(libtriton, m) { pybind11::class_(m, "options") .def(pybind11::init<>()) .def("d", &options_t::D) - .def_readonly("num_warps", &options_t::num_warps); + .def_readwrite("num_warps", &options_t::num_warps) + .def_readwrite("defines" , &options_t::defines); pybind11::class_(m, "options_space") .def(pybind11::init<>()) @@ -122,6 +130,7 @@ PYBIND11_MODULE(libtriton, m) { // hooks into triton constructs since frameworks may not use pybind11 m.def("get_fn_signature", &get_fn_signature); + m.def("get_fn_ptx", &get_fn_ptx); m.def("register_grid", ®ister_grid); m.def("delete_grid", &delete_grid); m.def("register_fn", ®ister_fn); diff --git a/python/src/launch.cc b/python/src/launch.cc index 6bf8edd5b..995f9e1ea 100644 --- a/python/src/launch.cc +++ b/python/src/launch.cc @@ -27,6 +27,10 @@ int64_t cdiv_sum(torch::Tensor x, int64_t div){ return ret; } +CUstream torch_get_cuda_stream(int64_t dev_id) { + return (CUstream)at::cuda::getCurrentCUDAStream(dev_id).stream(); +} + void launch_kernel(int64_t op_id, int64_t dev_id, const std::string& args){ if(dev_id == -1){ if(!host_stream){ @@ -37,8 +41,7 @@ void launch_kernel(int64_t op_id, int64_t dev_id, const std::string& args){ (*id_fn_map.at({op_id, dev_id}))((void**)args.c_str(), args.size(), *id_grid_map.at({op_id, dev_id}), &*host_stream); } else{ - CUstream custream = (CUstream)at::cuda::getCurrentCUDAStream(dev_id).stream(); - triton::driver::cu_stream stream(custream, false); + triton::driver::cu_stream stream(torch_get_cuda_stream(dev_id), false); triton::driver::context* ctx = stream.context(); (*id_fn_map.at({op_id, dev_id}))((void**)args.c_str(), args.size(), *id_grid_map.at({op_id, dev_id}), &stream); } diff --git a/python/triton/kernel.py b/python/triton/kernel.py index a406b61c3..e37f82340 100644 --- a/python/triton/kernel.py +++ b/python/triton/kernel.py @@ -66,6 +66,26 @@ class kernel: def set_constant(self, device, name, value): libtriton.register_cst((self.op_id, device), name, value) + def ptx(self, device, **kwargs): + dev_id = device.index + libtriton.register_fn((self.op_id, dev_id), self.src, self.opt) + def _single_value_or_err(x, key): + if isinstance(x, list) and len(x) == 1: + return x[0] + if isinstance(x, list) and len(x) > 1: + if key in kwargs: + return kwargs[key] + raise ValueError(f'Parameter {key}={x} was auto-tuned during kernel creation. ' + 'Please supply an explicit value as a keyword argument.') + return str(x) + defines = dict() + for (D, V) in self.opt.defines: + defines[D] = _single_value_or_err(V, D) + opt = libtriton.options() + opt.num_warps = _single_value_or_err(self.opt.num_warps, 'num_warps') + opt.defines = defines + return libtriton.get_fn_ptx((self.op_id, dev_id), opt) + def __call__(self, *args, **kwargs): for x in args: if isinstance(x, torch.Tensor): @@ -73,9 +93,7 @@ class kernel: device = -1 if device is None else device break # lazily register function for device - if device not in self.registered: - self.registered.add(device) - libtriton.register_fn((self.op_id, device), self.src, self.opt, os.path.realpath(libtriton.__file__)) + libtriton.register_fn((self.op_id, device), self.src, self.opt) # launch grid if 'grid' not in kwargs: raise RuntimeError('Must provide grid for kernel launch') diff --git a/tests/bench/dot.cc b/tests/bench/dot.cc index 9f5a01d9c..7204e288b 100644 --- a/tests/bench/dot.cc +++ b/tests/bench/dot.cc @@ -17,11 +17,11 @@ int main() { // config_t{ord, x[0], x[1], 384, 384, 384}, // config_t{ord, x[0], x[1], 512, 512, 512}, // config_t{ord, x[0], x[1], 768, 768, 768}, - config_t{ord, x[0], x[1], 1024, 1024, 1024}, + // config_t{ord, x[0], x[1], 1024, 1024, 1024}, // config_t{ord, x[0], x[1], 1280, 1280, 1280}, // config_t{ord, x[0], x[1], 1536, 1536, 1536}, // config_t{ord, x[0], x[1], 2048, 2048, 2048}, -// config_t{ord, x[0], x[1], 8192, 8192, 8192}, + config_t{ord, x[0], x[1], 8192, 8192, 8192}, // config_t{ord, x[0], x[1], 256, 16, 256}, // config_t{ord, x[0], x[1], 512, 16, 512}, @@ -65,7 +65,7 @@ int main() { for(const auto& c: configs){ std::tie(ord, AT, BT, M, N, K) = c; std::cout << "// " << c ; - for(auto perf: bench_dot(stream, FLOAT, AT, BT, M, N, K, ord, ord)) + for(auto perf: bench_dot(stream, HALF, AT, BT, M, N, K, ord, ord)) std::cout << ", " << perf << std::flush; std::cout << std::endl; } diff --git a/tests/common/dot.h b/tests/common/dot.h index 35f24395c..d4bafa22b 100644 --- a/tests/common/dot.h +++ b/tests/common/dot.h @@ -129,7 +129,7 @@ void triton_dot(drv::stream* stream, bool AT, bool BT, if(mode == BENCH) { opt.defines.push_back({"TM", {"64", "128"}}); opt.defines.push_back({"TN", {"64", "128"}}); - opt.defines.push_back({"TK", {"8"}}); + opt.defines.push_back({"TK", {"16"}}); opt.defines.push_back({"TZ", {"1"}}); opt.num_warps = {4}; }