[PYTHON] Added utility to read single Triton kernel from provided file
in triton.read
This commit is contained in:
@@ -227,7 +227,7 @@ public:
|
|||||||
FuncDef* CurFunc() { return curFunc_; }
|
FuncDef* CurFunc() { return curFunc_; }
|
||||||
const TokenSequence& ts() const { return ts_; }
|
const TokenSequence& ts() const { return ts_; }
|
||||||
|
|
||||||
private:
|
protected:
|
||||||
static bool IsBuiltin(FuncType* type);
|
static bool IsBuiltin(FuncType* type);
|
||||||
static bool IsBuiltin(const std::string& name);
|
static bool IsBuiltin(const std::string& name);
|
||||||
static Identifier* GetBuiltin(const Token* tok);
|
static Identifier* GetBuiltin(const Token* tok);
|
||||||
|
@@ -3,6 +3,7 @@
|
|||||||
#include <pybind11/stl.h>
|
#include <pybind11/stl.h>
|
||||||
#include <pybind11/functional.h>
|
#include <pybind11/functional.h>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include <regex>
|
||||||
#include "triton/driver/stream.h"
|
#include "triton/driver/stream.h"
|
||||||
#include "triton/runtime/function.h"
|
#include "triton/runtime/function.h"
|
||||||
#include "triton/runtime/arg.h"
|
#include "triton/runtime/arg.h"
|
||||||
@@ -15,6 +16,7 @@
|
|||||||
using namespace triton;
|
using namespace triton;
|
||||||
namespace rt = triton::runtime;
|
namespace rt = triton::runtime;
|
||||||
namespace drv = triton::driver;
|
namespace drv = triton::driver;
|
||||||
|
namespace lng = triton::lang;
|
||||||
|
|
||||||
typedef std::pair<int, int> map_key_t;
|
typedef std::pair<int, int> map_key_t;
|
||||||
|
|
||||||
@@ -114,6 +116,63 @@ pybind11::object autotune(int64_t op_id, int64_t dev_id, const std::string& args
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
std::string extract_kernels(const std::string& str, const std::vector<std::string>& names) {
|
||||||
|
if(names.empty())
|
||||||
|
return str;
|
||||||
|
// search for all regex matches of kernel_regex in str
|
||||||
|
std::smatch matches;
|
||||||
|
std::regex regex(" *__global__ +void +([_a-zA-Z][_a-zA-Z0-9]{0,30})");
|
||||||
|
std::sregex_iterator it(str.begin(), str.end(), regex);
|
||||||
|
std::sregex_iterator end;
|
||||||
|
std::vector<std::tuple<std::string, int, int>> kernels;
|
||||||
|
for (; it != end; ++it) {
|
||||||
|
int pos = it->position();
|
||||||
|
int len = it->length();
|
||||||
|
std::string name = it->str(1);
|
||||||
|
kernels.push_back(std::make_tuple(name, pos, len));
|
||||||
|
}
|
||||||
|
|
||||||
|
for(const std::string& name: names) {
|
||||||
|
// check that str matches any string in kernels using std::any_of
|
||||||
|
auto pred = [&name](const std::tuple<std::string, int, int>& t) { return std::get<0>(t) == name; };
|
||||||
|
bool found = std::any_of(kernels.begin(), kernels.end(), pred);
|
||||||
|
if(!found) throw std::runtime_error("Unable to find kernel `" + name + "` in provided source code:\n" + str);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// extract functions
|
||||||
|
std::string ret;
|
||||||
|
for(const auto& k: kernels) {
|
||||||
|
std::string name;
|
||||||
|
int pos, len;
|
||||||
|
std::tie(name, pos, len) = k;
|
||||||
|
if(std::find(names.begin(), names.end(), name) != names.end()){
|
||||||
|
std::string def = str.substr(pos, str.size() - pos);
|
||||||
|
int count, pos;
|
||||||
|
// skip over declaration
|
||||||
|
count = 1;
|
||||||
|
pos = def.find('(');
|
||||||
|
while(!(def[pos++] == ')' && count == 0) && pos < def.size()){
|
||||||
|
count += def[pos] == '(';
|
||||||
|
count -= def[pos] == ')';
|
||||||
|
}
|
||||||
|
// skip over definition
|
||||||
|
count = 1;
|
||||||
|
pos = def.find('{', pos);
|
||||||
|
while(!(def[pos++] == '}' && count == 0) && pos < def.size()){
|
||||||
|
count += def[pos] == '{';
|
||||||
|
count -= def[pos] == '}';
|
||||||
|
}
|
||||||
|
ret += def.substr(0, pos);
|
||||||
|
ret += '\n';
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void init_superblocking(pybind11::module &m);
|
void init_superblocking(pybind11::module &m);
|
||||||
void init_launch(pybind11::module &m);
|
void init_launch(pybind11::module &m);
|
||||||
|
|
||||||
@@ -146,8 +205,8 @@ PYBIND11_MODULE(libtriton, m) {
|
|||||||
.def_readwrite("defines" , &rt::options_space_t::defines);
|
.def_readwrite("defines" , &rt::options_space_t::defines);
|
||||||
|
|
||||||
// hooks into triton constructs since frameworks may not use pybind11
|
// hooks into triton constructs since frameworks may not use pybind11
|
||||||
|
m.def("extract_kernels", &extract_kernels);
|
||||||
m.def("get_fn_signature", &get_fn_signature);
|
m.def("get_fn_signature", &get_fn_signature);
|
||||||
// m.def("get_fn_asm", &get_fn_asm);
|
|
||||||
m.def("register_grid", ®ister_grid);
|
m.def("register_grid", ®ister_grid);
|
||||||
m.def("delete_grid", &delete_grid);
|
m.def("delete_grid", &delete_grid);
|
||||||
m.def("register_fn", ®ister_fn);
|
m.def("register_fn", ®ister_fn);
|
||||||
|
@@ -39,15 +39,20 @@ def synchronize(device):
|
|||||||
dev_id = -1 if dev_id is None else dev_id
|
dev_id = -1 if dev_id is None else dev_id
|
||||||
libtriton.synchronize(dev_id)
|
libtriton.synchronize(dev_id)
|
||||||
|
|
||||||
def read(path):
|
def read(path, kernel_names=[]):
|
||||||
with open(path, 'r') as f:
|
with open(path, 'r') as f:
|
||||||
source = f.read()
|
source = f.read()
|
||||||
|
source = libtriton.extract_kernels(source, kernel_names)
|
||||||
return source
|
return source
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class kernel:
|
class kernel:
|
||||||
|
|
||||||
def __init__(self, src, device, defines = dict(), num_warps = [4]):
|
def __init__(self, src, device, defines = dict(), num_warps = [4]):
|
||||||
|
# check if src is empty
|
||||||
|
if src == '':
|
||||||
|
raise ValueError('Kernel source code is empty')
|
||||||
self.src = src
|
self.src = src
|
||||||
self.opt = libtriton.options_space()
|
self.opt = libtriton.options_space()
|
||||||
self.opt.defines = [(k, th_to_triton(v)) for k, v in defines.items()]
|
self.opt.defines = [(k, th_to_triton(v)) for k, v in defines.items()]
|
||||||
|
@@ -5,4 +5,4 @@ __global__ void forward(TYPE* X, TYPE* Y) {
|
|||||||
float shifted[BLOCK] = exp(x - x[max]);
|
float shifted[BLOCK] = exp(x - x[max]);
|
||||||
float sum = shifted[+];
|
float sum = shifted[+];
|
||||||
*(Y + off) = shifted / sum;
|
*(Y + off) = shifted / sum;
|
||||||
}
|
}
|
@@ -2,23 +2,22 @@ import torch
|
|||||||
import triton
|
import triton
|
||||||
import os
|
import os
|
||||||
|
|
||||||
kernels = dict()
|
fwd_src = triton.read(os.path.join(os.path.dirname(__file__), 'softmax.c'), kernel_names=['forward'])
|
||||||
def get_kernel(block, dtype, device):
|
fwd_kernels = dict()
|
||||||
|
def get_fwd_kernel(block, dtype, device):
|
||||||
key = (block, dtype, device)
|
key = (block, dtype, device)
|
||||||
if key not in kernels:
|
if key not in fwd_kernels:
|
||||||
src = triton.read(os.path.join(os.path.dirname(__file__), 'softmax.c'))
|
|
||||||
defines = {'BLOCK': block, 'TYPE': dtype}
|
defines = {'BLOCK': block, 'TYPE': dtype}
|
||||||
kernels[key] = triton.kernel(src, device = device, defines = defines)
|
fwd_kernels[key] = triton.kernel(fwd_src, device = device, defines = defines)
|
||||||
return kernels[key]
|
return fwd_kernels[key]
|
||||||
|
|
||||||
|
|
||||||
class _softmax(torch.autograd.Function):
|
class _softmax(torch.autograd.Function):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, x):
|
def forward(ctx, x):
|
||||||
y = torch.empty_like(x)
|
y = torch.empty_like(x)
|
||||||
M, N = x.shape
|
M, N = x.shape
|
||||||
kernel = get_kernel(N, x.dtype, x.device)
|
kernel = get_fwd_kernel(N, x.dtype, x.device)
|
||||||
kernel(x.data_ptr(), y.data_ptr(), grid = lambda opt: [M, ])
|
kernel(x.data_ptr(), y.data_ptr(), grid = lambda opt: [M, ])
|
||||||
return y
|
return y
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user