[PYTHON] Added utility to read single Triton kernel from provided file
in triton.read
This commit is contained in:
@@ -227,7 +227,7 @@ public:
|
||||
FuncDef* CurFunc() { return curFunc_; }
|
||||
const TokenSequence& ts() const { return ts_; }
|
||||
|
||||
private:
|
||||
protected:
|
||||
static bool IsBuiltin(FuncType* type);
|
||||
static bool IsBuiltin(const std::string& name);
|
||||
static Identifier* GetBuiltin(const Token* tok);
|
||||
|
@@ -3,6 +3,7 @@
|
||||
#include <pybind11/stl.h>
|
||||
#include <pybind11/functional.h>
|
||||
#include <string>
|
||||
#include <regex>
|
||||
#include "triton/driver/stream.h"
|
||||
#include "triton/runtime/function.h"
|
||||
#include "triton/runtime/arg.h"
|
||||
@@ -15,6 +16,7 @@
|
||||
using namespace triton;
|
||||
namespace rt = triton::runtime;
|
||||
namespace drv = triton::driver;
|
||||
namespace lng = triton::lang;
|
||||
|
||||
typedef std::pair<int, int> map_key_t;
|
||||
|
||||
@@ -114,6 +116,63 @@ pybind11::object autotune(int64_t op_id, int64_t dev_id, const std::string& args
|
||||
}
|
||||
|
||||
|
||||
std::string extract_kernels(const std::string& str, const std::vector<std::string>& names) {
|
||||
if(names.empty())
|
||||
return str;
|
||||
// search for all regex matches of kernel_regex in str
|
||||
std::smatch matches;
|
||||
std::regex regex(" *__global__ +void +([_a-zA-Z][_a-zA-Z0-9]{0,30})");
|
||||
std::sregex_iterator it(str.begin(), str.end(), regex);
|
||||
std::sregex_iterator end;
|
||||
std::vector<std::tuple<std::string, int, int>> kernels;
|
||||
for (; it != end; ++it) {
|
||||
int pos = it->position();
|
||||
int len = it->length();
|
||||
std::string name = it->str(1);
|
||||
kernels.push_back(std::make_tuple(name, pos, len));
|
||||
}
|
||||
|
||||
for(const std::string& name: names) {
|
||||
// check that str matches any string in kernels using std::any_of
|
||||
auto pred = [&name](const std::tuple<std::string, int, int>& t) { return std::get<0>(t) == name; };
|
||||
bool found = std::any_of(kernels.begin(), kernels.end(), pred);
|
||||
if(!found) throw std::runtime_error("Unable to find kernel `" + name + "` in provided source code:\n" + str);
|
||||
}
|
||||
|
||||
|
||||
// extract functions
|
||||
std::string ret;
|
||||
for(const auto& k: kernels) {
|
||||
std::string name;
|
||||
int pos, len;
|
||||
std::tie(name, pos, len) = k;
|
||||
if(std::find(names.begin(), names.end(), name) != names.end()){
|
||||
std::string def = str.substr(pos, str.size() - pos);
|
||||
int count, pos;
|
||||
// skip over declaration
|
||||
count = 1;
|
||||
pos = def.find('(');
|
||||
while(!(def[pos++] == ')' && count == 0) && pos < def.size()){
|
||||
count += def[pos] == '(';
|
||||
count -= def[pos] == ')';
|
||||
}
|
||||
// skip over definition
|
||||
count = 1;
|
||||
pos = def.find('{', pos);
|
||||
while(!(def[pos++] == '}' && count == 0) && pos < def.size()){
|
||||
count += def[pos] == '{';
|
||||
count -= def[pos] == '}';
|
||||
}
|
||||
ret += def.substr(0, pos);
|
||||
ret += '\n';
|
||||
}
|
||||
}
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void init_superblocking(pybind11::module &m);
|
||||
void init_launch(pybind11::module &m);
|
||||
|
||||
@@ -146,8 +205,8 @@ PYBIND11_MODULE(libtriton, m) {
|
||||
.def_readwrite("defines" , &rt::options_space_t::defines);
|
||||
|
||||
// hooks into triton constructs since frameworks may not use pybind11
|
||||
m.def("extract_kernels", &extract_kernels);
|
||||
m.def("get_fn_signature", &get_fn_signature);
|
||||
// m.def("get_fn_asm", &get_fn_asm);
|
||||
m.def("register_grid", ®ister_grid);
|
||||
m.def("delete_grid", &delete_grid);
|
||||
m.def("register_fn", ®ister_fn);
|
||||
|
@@ -39,15 +39,20 @@ def synchronize(device):
|
||||
dev_id = -1 if dev_id is None else dev_id
|
||||
libtriton.synchronize(dev_id)
|
||||
|
||||
def read(path):
|
||||
def read(path, kernel_names=[]):
|
||||
with open(path, 'r') as f:
|
||||
source = f.read()
|
||||
source = libtriton.extract_kernels(source, kernel_names)
|
||||
return source
|
||||
|
||||
|
||||
|
||||
class kernel:
|
||||
|
||||
def __init__(self, src, device, defines = dict(), num_warps = [4]):
|
||||
# check if src is empty
|
||||
if src == '':
|
||||
raise ValueError('Kernel source code is empty')
|
||||
self.src = src
|
||||
self.opt = libtriton.options_space()
|
||||
self.opt.defines = [(k, th_to_triton(v)) for k, v in defines.items()]
|
||||
|
@@ -5,4 +5,4 @@ __global__ void forward(TYPE* X, TYPE* Y) {
|
||||
float shifted[BLOCK] = exp(x - x[max]);
|
||||
float sum = shifted[+];
|
||||
*(Y + off) = shifted / sum;
|
||||
}
|
||||
}
|
@@ -2,23 +2,22 @@ import torch
|
||||
import triton
|
||||
import os
|
||||
|
||||
kernels = dict()
|
||||
def get_kernel(block, dtype, device):
|
||||
fwd_src = triton.read(os.path.join(os.path.dirname(__file__), 'softmax.c'), kernel_names=['forward'])
|
||||
fwd_kernels = dict()
|
||||
def get_fwd_kernel(block, dtype, device):
|
||||
key = (block, dtype, device)
|
||||
if key not in kernels:
|
||||
src = triton.read(os.path.join(os.path.dirname(__file__), 'softmax.c'))
|
||||
if key not in fwd_kernels:
|
||||
defines = {'BLOCK': block, 'TYPE': dtype}
|
||||
kernels[key] = triton.kernel(src, device = device, defines = defines)
|
||||
return kernels[key]
|
||||
fwd_kernels[key] = triton.kernel(fwd_src, device = device, defines = defines)
|
||||
return fwd_kernels[key]
|
||||
|
||||
|
||||
class _softmax(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, x):
|
||||
y = torch.empty_like(x)
|
||||
M, N = x.shape
|
||||
kernel = get_kernel(N, x.dtype, x.device)
|
||||
kernel = get_fwd_kernel(N, x.dtype, x.device)
|
||||
kernel(x.data_ptr(), y.data_ptr(), grid = lambda opt: [M, ])
|
||||
return y
|
||||
|
||||
|
Reference in New Issue
Block a user