From f81da73b6a912dfb130d9210a87d05515fcd3a3a Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Sat, 30 Jan 2021 18:09:49 -0500 Subject: [PATCH] [PYTHON] Added utility to read single Triton kernel from provided file in triton.read --- include/triton/lang/parser.h | 2 +- python/src/bindings.cc | 61 +++++++++++++++++++++++++++++++++++- python/triton/kernel.py | 7 ++++- python/triton/ops/softmax.c | 2 +- python/triton/ops/softmax.py | 15 +++++---- 5 files changed, 75 insertions(+), 12 deletions(-) diff --git a/include/triton/lang/parser.h b/include/triton/lang/parser.h index 7908da90b..b9542b40e 100644 --- a/include/triton/lang/parser.h +++ b/include/triton/lang/parser.h @@ -227,7 +227,7 @@ public: FuncDef* CurFunc() { return curFunc_; } const TokenSequence& ts() const { return ts_; } -private: +protected: static bool IsBuiltin(FuncType* type); static bool IsBuiltin(const std::string& name); static Identifier* GetBuiltin(const Token* tok); diff --git a/python/src/bindings.cc b/python/src/bindings.cc index 456860ce1..fbd41ef4c 100644 --- a/python/src/bindings.cc +++ b/python/src/bindings.cc @@ -3,6 +3,7 @@ #include #include #include +#include #include "triton/driver/stream.h" #include "triton/runtime/function.h" #include "triton/runtime/arg.h" @@ -15,6 +16,7 @@ using namespace triton; namespace rt = triton::runtime; namespace drv = triton::driver; +namespace lng = triton::lang; typedef std::pair map_key_t; @@ -114,6 +116,63 @@ pybind11::object autotune(int64_t op_id, int64_t dev_id, const std::string& args } +std::string extract_kernels(const std::string& str, const std::vector& names) { + if(names.empty()) + return str; + // search for all regex matches of kernel_regex in str + std::smatch matches; + std::regex regex(" *__global__ +void +([_a-zA-Z][_a-zA-Z0-9]{0,30})"); + std::sregex_iterator it(str.begin(), str.end(), regex); + std::sregex_iterator end; + std::vector> kernels; + for (; it != end; ++it) { + int pos = it->position(); + int len = it->length(); + std::string name = it->str(1); + kernels.push_back(std::make_tuple(name, pos, len)); + } + + for(const std::string& name: names) { + // check that str matches any string in kernels using std::any_of + auto pred = [&name](const std::tuple& t) { return std::get<0>(t) == name; }; + bool found = std::any_of(kernels.begin(), kernels.end(), pred); + if(!found) throw std::runtime_error("Unable to find kernel `" + name + "` in provided source code:\n" + str); + } + + + // extract functions + std::string ret; + for(const auto& k: kernels) { + std::string name; + int pos, len; + std::tie(name, pos, len) = k; + if(std::find(names.begin(), names.end(), name) != names.end()){ + std::string def = str.substr(pos, str.size() - pos); + int count, pos; + // skip over declaration + count = 1; + pos = def.find('('); + while(!(def[pos++] == ')' && count == 0) && pos < def.size()){ + count += def[pos] == '('; + count -= def[pos] == ')'; + } + // skip over definition + count = 1; + pos = def.find('{', pos); + while(!(def[pos++] == '}' && count == 0) && pos < def.size()){ + count += def[pos] == '{'; + count -= def[pos] == '}'; + } + ret += def.substr(0, pos); + ret += '\n'; + } + } + + return ret; +} + + + void init_superblocking(pybind11::module &m); void init_launch(pybind11::module &m); @@ -146,8 +205,8 @@ PYBIND11_MODULE(libtriton, m) { .def_readwrite("defines" , &rt::options_space_t::defines); // hooks into triton constructs since frameworks may not use pybind11 + m.def("extract_kernels", &extract_kernels); m.def("get_fn_signature", &get_fn_signature); - // m.def("get_fn_asm", &get_fn_asm); m.def("register_grid", ®ister_grid); m.def("delete_grid", &delete_grid); m.def("register_fn", ®ister_fn); diff --git a/python/triton/kernel.py b/python/triton/kernel.py index 012642245..b8c8a7373 100644 --- a/python/triton/kernel.py +++ b/python/triton/kernel.py @@ -39,15 +39,20 @@ def synchronize(device): dev_id = -1 if dev_id is None else dev_id libtriton.synchronize(dev_id) -def read(path): +def read(path, kernel_names=[]): with open(path, 'r') as f: source = f.read() + source = libtriton.extract_kernels(source, kernel_names) return source + class kernel: def __init__(self, src, device, defines = dict(), num_warps = [4]): + # check if src is empty + if src == '': + raise ValueError('Kernel source code is empty') self.src = src self.opt = libtriton.options_space() self.opt.defines = [(k, th_to_triton(v)) for k, v in defines.items()] diff --git a/python/triton/ops/softmax.c b/python/triton/ops/softmax.c index e363901c6..3070ed0ba 100644 --- a/python/triton/ops/softmax.c +++ b/python/triton/ops/softmax.c @@ -5,4 +5,4 @@ __global__ void forward(TYPE* X, TYPE* Y) { float shifted[BLOCK] = exp(x - x[max]); float sum = shifted[+]; *(Y + off) = shifted / sum; -} \ No newline at end of file +} \ No newline at end of file diff --git a/python/triton/ops/softmax.py b/python/triton/ops/softmax.py index 516ea8cc8..e00254561 100644 --- a/python/triton/ops/softmax.py +++ b/python/triton/ops/softmax.py @@ -2,23 +2,22 @@ import torch import triton import os -kernels = dict() -def get_kernel(block, dtype, device): +fwd_src = triton.read(os.path.join(os.path.dirname(__file__), 'softmax.c'), kernel_names=['forward']) +fwd_kernels = dict() +def get_fwd_kernel(block, dtype, device): key = (block, dtype, device) - if key not in kernels: - src = triton.read(os.path.join(os.path.dirname(__file__), 'softmax.c')) + if key not in fwd_kernels: defines = {'BLOCK': block, 'TYPE': dtype} - kernels[key] = triton.kernel(src, device = device, defines = defines) - return kernels[key] + fwd_kernels[key] = triton.kernel(fwd_src, device = device, defines = defines) + return fwd_kernels[key] class _softmax(torch.autograd.Function): - @staticmethod def forward(ctx, x): y = torch.empty_like(x) M, N = x.shape - kernel = get_kernel(N, x.dtype, x.device) + kernel = get_fwd_kernel(N, x.dtype, x.device) kernel(x.data_ptr(), y.data_ptr(), grid = lambda opt: [M, ]) return y