[PYTHON] Added utility to read single Triton kernel from provided file

in triton.read
This commit is contained in:
Philippe Tillet
2021-01-30 18:09:49 -05:00
parent 9f9d7b8840
commit f81da73b6a
5 changed files with 75 additions and 12 deletions

View File

@@ -227,7 +227,7 @@ public:
FuncDef* CurFunc() { return curFunc_; }
const TokenSequence& ts() const { return ts_; }
private:
protected:
static bool IsBuiltin(FuncType* type);
static bool IsBuiltin(const std::string& name);
static Identifier* GetBuiltin(const Token* tok);

View File

@@ -3,6 +3,7 @@
#include <pybind11/stl.h>
#include <pybind11/functional.h>
#include <string>
#include <regex>
#include "triton/driver/stream.h"
#include "triton/runtime/function.h"
#include "triton/runtime/arg.h"
@@ -15,6 +16,7 @@
using namespace triton;
namespace rt = triton::runtime;
namespace drv = triton::driver;
namespace lng = triton::lang;
typedef std::pair<int, int> map_key_t;
@@ -114,6 +116,63 @@ pybind11::object autotune(int64_t op_id, int64_t dev_id, const std::string& args
}
std::string extract_kernels(const std::string& str, const std::vector<std::string>& names) {
if(names.empty())
return str;
// search for all regex matches of kernel_regex in str
std::smatch matches;
std::regex regex(" *__global__ +void +([_a-zA-Z][_a-zA-Z0-9]{0,30})");
std::sregex_iterator it(str.begin(), str.end(), regex);
std::sregex_iterator end;
std::vector<std::tuple<std::string, int, int>> kernels;
for (; it != end; ++it) {
int pos = it->position();
int len = it->length();
std::string name = it->str(1);
kernels.push_back(std::make_tuple(name, pos, len));
}
for(const std::string& name: names) {
// check that str matches any string in kernels using std::any_of
auto pred = [&name](const std::tuple<std::string, int, int>& t) { return std::get<0>(t) == name; };
bool found = std::any_of(kernels.begin(), kernels.end(), pred);
if(!found) throw std::runtime_error("Unable to find kernel `" + name + "` in provided source code:\n" + str);
}
// extract functions
std::string ret;
for(const auto& k: kernels) {
std::string name;
int pos, len;
std::tie(name, pos, len) = k;
if(std::find(names.begin(), names.end(), name) != names.end()){
std::string def = str.substr(pos, str.size() - pos);
int count, pos;
// skip over declaration
count = 1;
pos = def.find('(');
while(!(def[pos++] == ')' && count == 0) && pos < def.size()){
count += def[pos] == '(';
count -= def[pos] == ')';
}
// skip over definition
count = 1;
pos = def.find('{', pos);
while(!(def[pos++] == '}' && count == 0) && pos < def.size()){
count += def[pos] == '{';
count -= def[pos] == '}';
}
ret += def.substr(0, pos);
ret += '\n';
}
}
return ret;
}
void init_superblocking(pybind11::module &m);
void init_launch(pybind11::module &m);
@@ -146,8 +205,8 @@ PYBIND11_MODULE(libtriton, m) {
.def_readwrite("defines" , &rt::options_space_t::defines);
// hooks into triton constructs since frameworks may not use pybind11
m.def("extract_kernels", &extract_kernels);
m.def("get_fn_signature", &get_fn_signature);
// m.def("get_fn_asm", &get_fn_asm);
m.def("register_grid", &register_grid);
m.def("delete_grid", &delete_grid);
m.def("register_fn", &register_fn);

View File

@@ -39,15 +39,20 @@ def synchronize(device):
dev_id = -1 if dev_id is None else dev_id
libtriton.synchronize(dev_id)
def read(path):
def read(path, kernel_names=[]):
with open(path, 'r') as f:
source = f.read()
source = libtriton.extract_kernels(source, kernel_names)
return source
class kernel:
def __init__(self, src, device, defines = dict(), num_warps = [4]):
# check if src is empty
if src == '':
raise ValueError('Kernel source code is empty')
self.src = src
self.opt = libtriton.options_space()
self.opt.defines = [(k, th_to_triton(v)) for k, v in defines.items()]

View File

@@ -5,4 +5,4 @@ __global__ void forward(TYPE* X, TYPE* Y) {
float shifted[BLOCK] = exp(x - x[max]);
float sum = shifted[+];
*(Y + off) = shifted / sum;
}
}

View File

@@ -2,23 +2,22 @@ import torch
import triton
import os
kernels = dict()
def get_kernel(block, dtype, device):
fwd_src = triton.read(os.path.join(os.path.dirname(__file__), 'softmax.c'), kernel_names=['forward'])
fwd_kernels = dict()
def get_fwd_kernel(block, dtype, device):
key = (block, dtype, device)
if key not in kernels:
src = triton.read(os.path.join(os.path.dirname(__file__), 'softmax.c'))
if key not in fwd_kernels:
defines = {'BLOCK': block, 'TYPE': dtype}
kernels[key] = triton.kernel(src, device = device, defines = defines)
return kernels[key]
fwd_kernels[key] = triton.kernel(fwd_src, device = device, defines = defines)
return fwd_kernels[key]
class _softmax(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
y = torch.empty_like(x)
M, N = x.shape
kernel = get_kernel(N, x.dtype, x.device)
kernel = get_fwd_kernel(N, x.dtype, x.device)
kernel(x.data_ptr(), y.data_ptr(), grid = lambda opt: [M, ])
return y