[GENERAL] Improved caching mechanism:
* Now computing hash in libtriton * Now only compiling a single pytorch hook per function signature
This commit is contained in:
committed by
Philippe Tillet
parent
30f77e9ec5
commit
dfb844bf41
@@ -17,14 +17,14 @@ MNK = [
|
||||
(2048, 2048, 2048),
|
||||
#(8192, 8192, 8192),
|
||||
|
||||
# (64, 64, 64000),
|
||||
# (64, 64, 128000),
|
||||
# (256, 256, 64000),
|
||||
# (256, 256, 128000),
|
||||
(64, 64, 64000),
|
||||
(64, 64, 128000),
|
||||
(256, 256, 64000),
|
||||
(256, 256, 128000),
|
||||
|
||||
# (1536, 16, 1536),
|
||||
# (1536, 32, 1536),
|
||||
# (1536, 64, 1536),
|
||||
(1536, 16, 1536),
|
||||
(1536, 32, 1536),
|
||||
(1536, 64, 1536),
|
||||
# (1536, 128, 1536),
|
||||
# (4096, 16, 4096),
|
||||
# (4096, 32, 4096),
|
||||
@@ -33,9 +33,9 @@ MNK = [
|
||||
|
||||
# (127008, 768, 576)
|
||||
]
|
||||
#for M, N, K in MNK:
|
||||
# matmul = lambda a, b: torch.matmul(a, b)
|
||||
# configs += [([M, K], [K, N], [M, N], matmul, 'mk,kn->mn', dict())]
|
||||
for M, N, K in MNK:
|
||||
matmul = lambda a, b: torch.matmul(a, b)
|
||||
configs += [([M, K], [K, N], [M, N], matmul, 'mk,kn->mn', dict())]
|
||||
#for M, N, K in MNK:
|
||||
# matmul = lambda a, b: torch.matmul(a.t(), b)
|
||||
# configs += [([M, K], [M, N], [K, N], None, 'mk,mn->kn', dict())]
|
||||
@@ -175,15 +175,15 @@ for a_shape, b_shape, c_shape, torch_fn, expr, arrays in configs:
|
||||
a = torch.rand(*a_shape).type(dtype).cuda()
|
||||
b = torch.rand(*b_shape).type(dtype).cuda()
|
||||
# triton output
|
||||
print(a.size(), b.size())
|
||||
tc = triton.ops.einsum(expr, a, b, c_shape, arrays = arrays, bench = True)
|
||||
tc = torch.empty(c_shape, device=a.device)
|
||||
triton.ops.einsum(expr, a, b, tc, arrays = arrays, bench = True)
|
||||
# reference output
|
||||
if torch_fn:
|
||||
rc = torch_fn(a, b, **arrays)
|
||||
else:
|
||||
rc = torch.einsum(expr, a, b)
|
||||
# performance relative to equivalent matrix multiplication
|
||||
ctx = triton.ctx_registry[tc]
|
||||
ctx = triton.ops._einsum.registry[tc]
|
||||
B, M, N, K = ctx.matmul_B, ctx.matmul_M, ctx.matmul_N, ctx.matmul_K
|
||||
cmp_eqbmm = False
|
||||
if cmp_eqbmm:
|
||||
|
@@ -6,6 +6,7 @@
|
||||
#include <regex>
|
||||
#include <algorithm>
|
||||
#include "triton/runtime/function.h"
|
||||
#include "triton/runtime/arg.h"
|
||||
#include "triton/lang/code_gen.h"
|
||||
#include "triton/lang/parser.h"
|
||||
#include "triton/lang/cpp.h"
|
||||
@@ -40,9 +41,10 @@ void delete_grid(size_t id) {
|
||||
/* Function map */
|
||||
|
||||
void register_fn(size_t id,
|
||||
const std::string& src,
|
||||
const rt::function::options_space_t& opt) {
|
||||
id_fn_map[id].reset(new rt::function(src, opt));
|
||||
const std::string& src,
|
||||
const rt::function::options_space_t& opt,
|
||||
const std::string &cache_ref) {
|
||||
id_fn_map[id].reset(new rt::function(src, opt, cache_ref));
|
||||
}
|
||||
|
||||
void delete_fn(size_t id) {
|
||||
@@ -64,6 +66,7 @@ size_t make_op_id() {
|
||||
return id_fn_map.size();
|
||||
}
|
||||
|
||||
|
||||
/* TF scalar wrapper */
|
||||
size_t make_scalar_id() {
|
||||
size_t ret = i64scalar_map.size();
|
||||
@@ -423,6 +426,37 @@ inline std::string to_torch_ty(ir::type *ty) {
|
||||
throw std::runtime_error("unknown type");
|
||||
}
|
||||
|
||||
inline std::string to_torch_ty(rt::arg_type ty){
|
||||
switch(ty){
|
||||
case rt::INT1_T: return "int64_t";
|
||||
case rt::INT8_T: return "int64_t";
|
||||
case rt::INT16_T: return "int64_t";
|
||||
case rt::INT32_T: return "int64_t";
|
||||
case rt::INT64_T: return "int64_t";
|
||||
case rt::HALF_T: return "double";
|
||||
case rt::FLOAT_T: return "double";
|
||||
case rt::DOUBLE_T: return "double";
|
||||
case rt::BUFFER_T: return "torch::Tensor";
|
||||
default: return "UNKNOWN";
|
||||
}
|
||||
}
|
||||
|
||||
inline std::string to_c_ty(rt::arg_type ty){
|
||||
switch(ty){
|
||||
case rt::INT1_T: return "bool";
|
||||
case rt::INT8_T: return "int8_t";
|
||||
case rt::INT16_T: return "int16_t";
|
||||
case rt::INT32_T: return "int32_t";
|
||||
case rt::INT64_T: return "int64_t";
|
||||
case rt::HALF_T: return "half";
|
||||
case rt::FLOAT_T: return "float";
|
||||
case rt::DOUBLE_T: return "double";
|
||||
case rt::BUFFER_T: return "drv::cu_buffer";
|
||||
default: return "UNKNOWN";
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
inline std::string to_c_ty(ir::type *ty) {
|
||||
if(ty->is_integer_ty(1))
|
||||
return "bool";
|
||||
@@ -448,33 +482,30 @@ inline std::string to_c_ty(ir::type *ty) {
|
||||
|
||||
|
||||
void gen_torch_signature(std::ostringstream& oss,
|
||||
ir::function* fn,
|
||||
const std::string& name) {
|
||||
const auto& args = fn->args();
|
||||
const std::string& name,
|
||||
const std::vector<rt::arg_type>& args) {
|
||||
std::string ret_ty = "void";
|
||||
oss << ret_ty << " " << name << "(";
|
||||
oss << "int64_t id, ";
|
||||
oss << "int64_t bench, ";
|
||||
oss << "int64_t bench_id, ";
|
||||
for(size_t i = 0; i < args.size(); i++) {
|
||||
ir::argument* arg = args[i];
|
||||
if(i > 0)
|
||||
oss << ", ";
|
||||
oss << to_torch_ty(arg->get_type()) << " " << arg->get_name();
|
||||
oss << to_torch_ty(args[i]) << " " << "th_arg_" << i;
|
||||
}
|
||||
oss << ")";
|
||||
}
|
||||
|
||||
void gen_torch_init_driver(std::ostringstream &oss,
|
||||
const std::vector<ir::argument*>&args) {
|
||||
ir::argument* tensor = nullptr;
|
||||
for(ir::argument* arg: args)
|
||||
if(arg->get_type()->is_pointer_ty()){
|
||||
tensor = arg;
|
||||
const std::vector<rt::arg_type>&args) {
|
||||
// Find index of first buffer
|
||||
size_t i;
|
||||
for(i = 0; i < args.size(); i++)
|
||||
if(args[i] == rt::BUFFER_T)
|
||||
break;
|
||||
}
|
||||
oss << " // Wrap CUDA handles" << std::endl;
|
||||
oss << " c10::DeviceIndex device = " << tensor->get_name() << ".storage().device().index();" << std::endl;
|
||||
oss << " c10::DeviceIndex device = th_arg_" << i << ".storage().device().index();" << std::endl;
|
||||
oss << " // Get stream" << std::endl;
|
||||
oss << " CUstream custream = (CUstream)at::cuda::getCurrentCUDAStream(device).stream();" << std::endl;
|
||||
oss << " triton::driver::cu_stream stream(custream, false);" << std::endl;
|
||||
@@ -482,28 +513,28 @@ void gen_torch_init_driver(std::ostringstream &oss,
|
||||
}
|
||||
|
||||
void gen_torch_make_handles(std::ostream &os,
|
||||
const std::vector<ir::argument*>& args) {
|
||||
const std::vector<rt::arg_type>& args) {
|
||||
for(unsigned i = 0; i < args.size(); i++){
|
||||
ir::argument *arg = args[i];
|
||||
const std::string& name = arg->get_name();
|
||||
ir::type* ty = arg->get_type();
|
||||
if(!ty->is_pointer_ty())
|
||||
os << " " << to_c_ty(ty) << " arg_" << name << " = " << name << ";" << std::endl;
|
||||
rt::arg_type arg = args[i];
|
||||
const std::string th_name = "th_arg_" + std::to_string(i);
|
||||
const std::string name = "arg_" + std::to_string(i);
|
||||
if(arg != rt::BUFFER_T)
|
||||
os << " " << to_c_ty(arg) << " " << name << " = " << th_name << ";" << std::endl;
|
||||
else{
|
||||
os << " CHECK_INPUT(" << name << ");" << std::endl;
|
||||
os << " drv::cu_buffer arg_" + name + "(ctx, " + name + ".storage().size(), "
|
||||
" (CUdeviceptr)((char*)" + name + ".storage().data() + " + name + ".storage_offset() * " + name + ".itemsize()), false);" << std::endl;
|
||||
os << " CHECK_INPUT(" << th_name << ");" << std::endl;
|
||||
os << " drv::cu_buffer " + name + "(ctx, " + th_name + ".storage().size(), "
|
||||
" (CUdeviceptr)((char*)" + th_name + ".storage().data() + " + th_name + ".storage_offset() * " + th_name + ".itemsize()), false);" << std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void gen_torch_make_launch_function(std::ostream &os, const std::vector<ir::argument*>& args) {
|
||||
void gen_torch_make_launch_function(std::ostream &os,
|
||||
const std::vector<rt::arg_type>& args) {
|
||||
os << " std::function<void()> run = [&](){\n ";
|
||||
os << " (*id_fn_map.at(id))({";
|
||||
for(unsigned i = 0; i < args.size() ; i++){
|
||||
ir::argument *arg = args[i];
|
||||
std::string name = "arg_" + arg->get_name();
|
||||
if(arg->get_type()->is_pointer_ty())
|
||||
std::string name = "arg_" + std::to_string(i);
|
||||
if(args[i] == rt::BUFFER_T)
|
||||
name = "&" + name;
|
||||
if(i > 0)
|
||||
os << ", ";
|
||||
@@ -531,15 +562,7 @@ void gen_torch_ret(std::ostream &os, const std::vector<std::string>& outputs) {
|
||||
}
|
||||
|
||||
std::tuple<std::string,
|
||||
std::string> make_torch_src(const std::string& src,
|
||||
const runtime::function::options_space_t& opt) {
|
||||
// triton-ir code-gen
|
||||
ir::context ctx;
|
||||
auto ir = std::shared_ptr<ir::module>(new ir::module("", ctx));
|
||||
make_module(src, &*ir, opt);
|
||||
// function
|
||||
ir::function* fn = ir->get_function_list().front();
|
||||
std::string name = fn->get_name();
|
||||
std::string> make_torch_src(const std::string& name, std::vector<rt::arg_type> args) {
|
||||
// generate framework code
|
||||
std::ostringstream oss;
|
||||
oss << R"(
|
||||
@@ -563,11 +586,11 @@ extern std::map<size_t, int64_t> i64scalar_map;
|
||||
|
||||
)";
|
||||
|
||||
gen_torch_signature(oss, fn, name);
|
||||
gen_torch_signature(oss, name, args);
|
||||
oss << " {" << std::endl;
|
||||
gen_torch_init_driver(oss, fn->args());
|
||||
gen_torch_make_handles(oss, fn->args());
|
||||
gen_torch_make_launch_function(oss, fn->args());
|
||||
gen_torch_init_driver(oss, args);
|
||||
gen_torch_make_handles(oss, args);
|
||||
gen_torch_make_launch_function(oss, args);
|
||||
//gen_torch_ret(oss);
|
||||
oss << "}" << std::endl;
|
||||
|
||||
@@ -578,6 +601,22 @@ extern std::map<size_t, int64_t> i64scalar_map;
|
||||
return {oss.str(), name};
|
||||
}
|
||||
|
||||
/* Function signature */
|
||||
std::vector<rt::arg_type> get_fn_signature(const std::string& src,
|
||||
const runtime::function::options_space_t& opt) {
|
||||
// triton-ir code-gen
|
||||
ir::context ctx;
|
||||
auto ir = std::shared_ptr<ir::module>(new ir::module("", ctx));
|
||||
make_module(src, &*ir, opt);
|
||||
// function
|
||||
ir::function* fn = ir->get_function_list().front();
|
||||
// extract signature
|
||||
std::vector<rt::arg_type> ret;
|
||||
ir::function_type* ty = fn->get_fn_type();
|
||||
for(size_t i = 0; i < ty->get_num_params(); i++)
|
||||
ret.push_back(rt::convert(ty->get_param_ty(i)));
|
||||
return ret;
|
||||
}
|
||||
|
||||
typedef triton::runtime::function::options_t options_t;
|
||||
typedef triton::runtime::function::options_space_t options_space_t;
|
||||
@@ -593,6 +632,17 @@ PYBIND11_MODULE(libtriton, m) {
|
||||
"Creates C++ source code for a custom PyTorch op ");
|
||||
|
||||
// bindings for triton classes
|
||||
pybind11::enum_<rt::arg_type>(m, "arg_type")
|
||||
.value("int1", rt::INT1_T)
|
||||
.value("int8", rt::INT8_T)
|
||||
.value("int16", rt::INT16_T)
|
||||
.value("int32", rt::INT32_T)
|
||||
.value("int64", rt::INT64_T)
|
||||
.value("half", rt::HALF_T)
|
||||
.value("float", rt::FLOAT_T)
|
||||
.value("double", rt::DOUBLE_T)
|
||||
.value("buffer", rt::BUFFER_T);
|
||||
|
||||
pybind11::class_<options_t>(m, "options")
|
||||
.def(pybind11::init<>())
|
||||
.def("d", &options_t::D<int>)
|
||||
@@ -604,6 +654,7 @@ PYBIND11_MODULE(libtriton, m) {
|
||||
.def_readwrite("num_warps", &options_space_t::num_warps);
|
||||
|
||||
// hooks into triton constructs since frameworks may not use pybind11
|
||||
m.def("get_fn_signature", &get_fn_signature);
|
||||
m.def("register_grid", ®ister_grid);
|
||||
m.def("delete_grid", &delete_grid);
|
||||
m.def("register_fn", ®ister_fn);
|
||||
|
@@ -1,5 +1,4 @@
|
||||
from .kernel import *
|
||||
from .function import *
|
||||
from .utils import *
|
||||
import triton.ops
|
||||
|
||||
|
@@ -1,127 +0,0 @@
|
||||
import triton.frameworks as fw
|
||||
import triton.utils as utils
|
||||
|
||||
class OpContext(object):
|
||||
|
||||
def __init__(self):
|
||||
self.to_save = []
|
||||
|
||||
def save_for_backward(self, *tensors):
|
||||
self.to_save = [x.to_tensor() if isinstance(x, utils.tf_empty_proxy) else x
|
||||
for x in tensors]
|
||||
|
||||
@property
|
||||
def saved_tensors(self):
|
||||
return self.to_save
|
||||
|
||||
class function_meta(type):
|
||||
|
||||
def __init__(cls, name, bases, attrs):
|
||||
cls.registered = False
|
||||
return super(function_meta, cls).__init__(name, bases, attrs)
|
||||
|
||||
ctx_registry = utils.id_dict()
|
||||
|
||||
class function(metaclass = function_meta):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def apply_torch(cls, *args, **kwargs):
|
||||
class TorchFunction(fw.torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, *targs):
|
||||
y = cls.forward(ctx, *targs, **cls.torch_kwargs)
|
||||
ctx_registry[y] = ctx
|
||||
return y
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return cls.backward(ctx, grad_output)
|
||||
cls.torch_kwargs = kwargs
|
||||
return TorchFunction.apply(*args)
|
||||
torch_kwargs = 0
|
||||
|
||||
@classmethod
|
||||
def extract_tf_tensors(cls, lst, err):
|
||||
ret = []
|
||||
for x in lst:
|
||||
if x is None:
|
||||
ret += [None]
|
||||
elif isinstance(x, fw.tensorflow.Tensor):
|
||||
ret += [x]
|
||||
elif isinstance(x, utils.tf_empty_proxy):
|
||||
if x.tensor is None:
|
||||
raise ValueError('Empty tensor never filled during ' + err)
|
||||
else:
|
||||
ret += [x.tensor]
|
||||
else:
|
||||
raise ValueError('Unsupported return type', type(x))
|
||||
return ret
|
||||
|
||||
@classmethod
|
||||
def map_in_to_args(cls, op, args):
|
||||
ret = dict()
|
||||
for i, ix in enumerate(op.inputs):
|
||||
for j, jx in enumerate(args):
|
||||
if ix is jx:
|
||||
ret[j] = i
|
||||
return ret
|
||||
|
||||
@classmethod
|
||||
def map_res_to_out(cls, op, result):
|
||||
ret = []
|
||||
for i, ix in enumerate(result):
|
||||
for j, jx in enumerate(op.outputs):
|
||||
if ix is jx:
|
||||
ret.append(j)
|
||||
return ret
|
||||
|
||||
@classmethod
|
||||
def apply_tensorflow(cls, *args, **kwargs):
|
||||
ctx = OpContext()
|
||||
|
||||
# run forward pass
|
||||
result = cls.forward(ctx, *args, **kwargs)
|
||||
result = result if isinstance(result, tuple) else (result, )
|
||||
result = function.extract_tf_tensors(result, 'forward')
|
||||
|
||||
# Register backward pass
|
||||
op = result[0].op
|
||||
ctx_registry[op] = ctx
|
||||
if not cls.registered:
|
||||
remap_in = cls.map_in_to_args(op, args)
|
||||
remap_out = cls.map_res_to_out(op, result)
|
||||
@fw.tensorflow.RegisterGradient(op.op_def.name)
|
||||
def gradient(op, *dy):
|
||||
# Remap gradient inputs in the right order
|
||||
dy = [dy[i] for i in remap_out]
|
||||
dy = dy if len(dy) > 1 else dy[0]
|
||||
# Execute gradient function
|
||||
grad = cls.backward(ctx_registry[op], dy)
|
||||
grad = function.extract_tf_tensors(grad, 'backward')
|
||||
# Remap gradient in the right order
|
||||
ret = [None] * len(op.inputs)
|
||||
for i in range(len(grad)):
|
||||
if i in remap_in:
|
||||
ret[remap_in[i]] = grad[i]
|
||||
# Return
|
||||
return ret
|
||||
cls.registered = True
|
||||
|
||||
# Return tensor
|
||||
return result[0] if len(result)==1 else result
|
||||
|
||||
@classmethod
|
||||
def apply(cls, *args, **kwargs):
|
||||
if fw.has_tensorflow():
|
||||
return cls.apply_tensorflow(*args, **kwargs)
|
||||
elif fw.has_torch():
|
||||
return cls.apply_torch(*args, **kwargs)
|
||||
else:
|
||||
assert False
|
@@ -17,43 +17,6 @@ import triton.frameworks as fw
|
||||
import triton.utils
|
||||
import triton._C.libtriton as libtriton
|
||||
|
||||
def _make_framework_src(src, grid):
|
||||
if fw.has_torch:
|
||||
return libtriton.make_torch_src(src, grid)
|
||||
else:
|
||||
assert False
|
||||
|
||||
def _make_cache_path(src):
|
||||
md5 = hashlib.sha1(src.encode())
|
||||
hexhash = md5.hexdigest()
|
||||
home = os.path.expanduser('~')
|
||||
cacheroot = os.path.join(home, '.triton', 'cache')
|
||||
cachepath = os.path.join(cacheroot, str(hexhash))
|
||||
if not os.path.exists(cachepath):
|
||||
os.makedirs(cachepath)
|
||||
return cachepath
|
||||
|
||||
def _write_bindings(src, root):
|
||||
if fw.has_torch():
|
||||
name = 'torch'
|
||||
else:
|
||||
assert False
|
||||
cpp = os.path.join(root, '{name}.cpp'.format(name=name))
|
||||
suffix = sysconfig.get_config_var('EXT_SUFFIX')
|
||||
so = os.path.join(root, '{name}{suffix}'.format(name=name, suffix=suffix))
|
||||
recompile = False
|
||||
# recompile if .so does not exist
|
||||
if not os.path.exists(cpp) or not os.path.exists(so):
|
||||
recompile = True
|
||||
# recompile if cpp was modified after .so
|
||||
elif max(cpp, so, key=os.path.getctime) == cpp:
|
||||
recompile = True
|
||||
# write cpp file
|
||||
if recompile:
|
||||
with open(cpp, 'w+') as handle:
|
||||
handle.writelines(src)
|
||||
# return path of cpp file
|
||||
return (cpp, so)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def quiet():
|
||||
@@ -64,7 +27,7 @@ def quiet():
|
||||
finally:
|
||||
sys.stdout, sys.stderr = old_stdout, old_stderr
|
||||
|
||||
def _build(src, path):
|
||||
def _build(src, path, name):
|
||||
ccdir = os.path.join(libtriton.__file__, os.path.pardir)
|
||||
ccdir = os.path.realpath(ccdir)
|
||||
# include directories
|
||||
@@ -88,7 +51,6 @@ def _build(src, path):
|
||||
libraries += ['torch']
|
||||
abi = fw.torch._C._GLIBCXX_USE_CXX11_ABI
|
||||
extra_compile_args += ['-D_GLIBCXX_USE_CXX11_ABI={abi}'.format(abi=abi)]
|
||||
name = 'torch'
|
||||
else:
|
||||
assert False
|
||||
# extra arguments
|
||||
@@ -142,30 +104,47 @@ def _cvt_to_def_str(obj):
|
||||
return str(obj)
|
||||
|
||||
|
||||
def _make_framework_op(src, options):
|
||||
src, name = _make_framework_src(src, options)
|
||||
cache_path = _make_cache_path(src)
|
||||
cpp, so = _write_bindings(src, cache_path)
|
||||
_build(cpp, cache_path)
|
||||
if fw.has_torch():
|
||||
fw.torch.ops.load_library(so)
|
||||
return getattr(fw.torch.ops.triton, name)
|
||||
else:
|
||||
assert False
|
||||
def _encode(arg_types):
|
||||
codes = {
|
||||
libtriton.arg_type.int1: 'i1',
|
||||
libtriton.arg_type.int8: 'i8',
|
||||
libtriton.arg_type.int32: 'i32',
|
||||
libtriton.arg_type.int64: 'i64',
|
||||
libtriton.arg_type.half: 'f16',
|
||||
libtriton.arg_type.float: 'f32',
|
||||
libtriton.arg_type.double: 'f64',
|
||||
libtriton.arg_type.buffer: 'buf'
|
||||
}
|
||||
ret = '_'.join(map(codes.get, arg_types))
|
||||
return ret
|
||||
|
||||
def _make_grid(grid, args) :
|
||||
scalars = [x for x in args if isinstance(x, triton.utils.scalar)]
|
||||
def grid(opt):
|
||||
for x in scalars:
|
||||
x.set_assume_initialized()
|
||||
result = grid(opt)
|
||||
for x in scalars:
|
||||
x.unset_assume_initialized()
|
||||
return result
|
||||
return grid
|
||||
|
||||
|
||||
bench_registry = triton.utils.id_dict()
|
||||
def _make_framework_op(arg_types):
|
||||
name = _encode(arg_types)
|
||||
# path of .cpp and .so file
|
||||
home = os.path.expanduser('~')
|
||||
root = os.path.join(home, '.triton', 'torch', name)
|
||||
if not os.path.exists(root):
|
||||
os.makedirs(root)
|
||||
suffix = sysconfig.get_config_var('EXT_SUFFIX')
|
||||
so = os.path.join(root, f'op{suffix}')
|
||||
cpp = os.path.join(root, f'op.cpp')
|
||||
# handle cached .so file
|
||||
if os.path.exists(so):
|
||||
tt_mtime = os.stat(os.path.realpath(libtriton.__file__)).st_mtime
|
||||
so_mtime = os.stat(so).st_mtime
|
||||
# can use cached if libtriton is older than the .so
|
||||
if tt_mtime < so_mtime:
|
||||
fw.torch.ops.load_library(so)
|
||||
return getattr(fw.torch.ops.triton, name)
|
||||
# create torch source code
|
||||
src, _ = libtriton.make_torch_src(name, arg_types)
|
||||
with open(cpp, 'w+') as handle:
|
||||
handle.writelines(src)
|
||||
# compile torch source code
|
||||
_build(cpp, root, 'op')
|
||||
fw.torch.ops.load_library(so)
|
||||
return getattr(fw.torch.ops.triton, name)
|
||||
|
||||
|
||||
class kernel:
|
||||
|
||||
@@ -180,9 +159,8 @@ class kernel:
|
||||
self.cst[name] = value
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
|
||||
########################
|
||||
# keyword arguments
|
||||
# JIT Options
|
||||
########################
|
||||
num_warps = kwargs['num_warps'] if 'num_warps' in kwargs else [2, 4, 8]
|
||||
defines = kwargs['defines'] if 'defines' in kwargs else dict()
|
||||
@@ -195,7 +173,6 @@ class kernel:
|
||||
#########################
|
||||
# cache
|
||||
########################
|
||||
|
||||
# create a new framework op when defines are different
|
||||
key = '-'.join(['{key}-{val}'.format(key=key, val=val) for key, val in defines.items()])
|
||||
if key not in self.fw_id.keys():
|
||||
@@ -211,15 +188,16 @@ class kernel:
|
||||
opt = libtriton.options_space()
|
||||
opt.defines = macros
|
||||
opt.num_warps = num_warps
|
||||
# create unique id for this op
|
||||
# create triton function for this op
|
||||
op_id = libtriton.make_op_id()
|
||||
self.fw_id[key] = op_id
|
||||
# register function
|
||||
libtriton.register_fn(op_id, self.src, opt)
|
||||
libtriton.register_fn(op_id, self.src, opt, os.path.realpath(libtriton.__file__))
|
||||
for name, value in self.cst.items():
|
||||
libtriton.register_cst(op_id, name, value)
|
||||
# create pytorch hook for this op
|
||||
arg_types = libtriton.get_fn_signature(self.src, opt)
|
||||
if self.fw_op is None:
|
||||
self.fw_op = _make_framework_op(self.src, opt)
|
||||
self.fw_op = _make_framework_op(arg_types)
|
||||
|
||||
########################
|
||||
# initialize
|
||||
|
@@ -1,7 +1,8 @@
|
||||
import triton
|
||||
import torch
|
||||
import math
|
||||
|
||||
class _batchnorm(triton.function):
|
||||
class _batchnorm(torch.autograd.Function):
|
||||
|
||||
fwd_src = """
|
||||
void fwdbatchnorm(float *Y, float *M, float *V,
|
||||
|
@@ -73,7 +73,7 @@ class _einsum(torch.autograd.Function):
|
||||
stride_a_last, stride_b_last, stride_c_last,
|
||||
lut_mode_a, lut_mode_b,
|
||||
delta_a, delta_b,
|
||||
subscripted):
|
||||
subscripted, varnames):
|
||||
|
||||
use_lut_a = True
|
||||
use_lut_b = True
|
||||
@@ -123,6 +123,8 @@ __global__ void {name}(
|
||||
src += "\n"
|
||||
for ptr in subscripted:
|
||||
src += f", int* {ptr}"
|
||||
for name in varnames:
|
||||
src += f", int {name}"
|
||||
src += """) {
|
||||
|
||||
// re-order outer program ids
|
||||
@@ -274,6 +276,9 @@ __global__ void {name}(
|
||||
TYPE c[TM, TN, TB] = acc;
|
||||
|
||||
// re-materialize ranges
|
||||
pid_mn = get_program_id(0) / div_m;
|
||||
pid_n = pid_mn % grid_n;
|
||||
pid_m = (pid_mn / grid_n)*div_m + (get_program_id(0) % div_m);
|
||||
"""
|
||||
for axes, tile, off in zip([axes_m, axes_n, axes_b],
|
||||
['TM', 'TN', 'TB'],
|
||||
@@ -410,11 +415,8 @@ __global__ void {name}(
|
||||
batch = [d for d in sym_a if d in sym_b and d in sym_c]
|
||||
outer = [d for d in sym_a if d not in sym_b and d in sym_c]
|
||||
inner = [d for d in sym_a if d in sym_b and d not in sym_c]
|
||||
illegal = [d for d in sym_a if d not in sym_b and d not in sym_c]
|
||||
if illegal:
|
||||
raise ValueError(f"einsum labels {illegal} ({expr_a}) "\
|
||||
f"not present in {expr_b} or {expr_c}")
|
||||
return _einsum.uniq(batch), _einsum.uniq(outer), _einsum.uniq(inner)
|
||||
variables = [d for d in sym_a if d not in sym_b and d not in sym_c]
|
||||
return _einsum.uniq(batch), _einsum.uniq(outer), _einsum.uniq(inner), variables
|
||||
|
||||
|
||||
def replace_subscript(expr, arrays):
|
||||
@@ -467,7 +469,33 @@ __global__ void {name}(
|
||||
locks = None
|
||||
kernel_cache = dict()
|
||||
|
||||
def __init__(self, einsum, dtype, stride_a, stride_b, stride_c, shape_a, shape_b, arrays, mask, shape_c):
|
||||
@staticmethod
|
||||
def _tile(M, N, B, TMs, TNs, TBs, TZs, TK):
|
||||
smp = 15
|
||||
# occupancy estimation
|
||||
grid = lambda TM, TN, TB, TZ: \
|
||||
triton.cdiv(M, TM)* \
|
||||
triton.cdiv(N, TN)* \
|
||||
triton.cdiv(B, TB)* \
|
||||
TZ
|
||||
occupancy = lambda TM, TN, TB, TZ: \
|
||||
min(grid(TM, TN, TB, TZ), 4*smp)
|
||||
# arithmetic intensity estimation
|
||||
intensity = lambda TM, TN: \
|
||||
TM * TN * TK / (TM*TK + TK*TN)
|
||||
# occupancy/intensity for all configurations
|
||||
estimates = {(TM, TN, TB, TZ): (occupancy(TM, TN, TB, TZ), intensity(TM, TN)) \
|
||||
for TM in TMs \
|
||||
for TN in TNs \
|
||||
for TB in TBs \
|
||||
for TZ in TZs }
|
||||
# returns configuration that maximizes occupancy subject to maximizing intensity
|
||||
estimates = sorted(estimates.items(),
|
||||
key=lambda item: item[1],
|
||||
reverse=True)
|
||||
return estimates[0][0]
|
||||
|
||||
def __init__(self, einsum, dtype, stride_a, stride_b, stride_c, shape_a, shape_b, arrays, mask, shape_c, varnames):
|
||||
# parse symbols
|
||||
expr_a, expr_bc = einsum.split(",")
|
||||
expr_b, expr_c = expr_bc.split("->")
|
||||
@@ -476,9 +504,13 @@ __global__ void {name}(
|
||||
sym_b = _einsum.parse_expr(expr_b, subscripted)
|
||||
sym_c = _einsum.parse_expr(expr_c, subscripted)
|
||||
# parse axes
|
||||
axes_b, axes_m, axes_k = _einsum.parse_axes(sym_a, sym_b, sym_c, subscripted)
|
||||
_, axes_n, _ = _einsum.parse_axes(sym_b, sym_a, sym_c, subscripted)
|
||||
axes_b, axes_m, axes_k, var = _einsum.parse_axes(sym_a, sym_b, sym_c, subscripted)
|
||||
_, axes_n, _, _ = _einsum.parse_axes(sym_b, sym_a, sym_c, subscripted)
|
||||
axes = axes_b + axes_m + axes_n + axes_k
|
||||
# unresolved symbols
|
||||
unresolved = [x for x in map(str, var) if x not in varnames]
|
||||
if unresolved:
|
||||
raise ValueError(f'unresolved symbols: {unresolved}')
|
||||
# check dimensions
|
||||
dims_a = dict(zip(sym_a, shape_a))
|
||||
dims_b = dict(zip(sym_b, shape_b))
|
||||
@@ -520,7 +552,7 @@ __global__ void {name}(
|
||||
stride_a_last, stride_b_last, stride_c_last,
|
||||
lut_mode_a, lut_mode_b,
|
||||
delta_a, delta_b,
|
||||
subscripted)
|
||||
subscripted, varnames)
|
||||
self.kernel = cache[name]
|
||||
# Initialize locks
|
||||
if _einsum.instance.locks is None:
|
||||
@@ -565,19 +597,21 @@ __global__ void {name}(
|
||||
self.pos_a = 0
|
||||
self.pos_b = 1
|
||||
self.pos_c = 2
|
||||
# pre-processor macros
|
||||
TM = [16] + [x for x in [32, 64, 128] if x <= M]
|
||||
TN = [16] + [x for x in [32, 64, 128] if x <= N]
|
||||
TB = [x for x in [1, 2, 4] if x <= B]
|
||||
MAX_GZ = K // 2048
|
||||
MIN_GM = M // max(TM)
|
||||
MIN_GN = N // max(TN)
|
||||
MIN_GB = B // max(TB)
|
||||
TZ = [x for x in [1, 2, 4, 8, 16, 32] \
|
||||
if x < MAX_GZ and x*MIN_GM*MIN_GN*MIN_GB < 256]
|
||||
TZ = [1] if not TZ else [TZ[-1], TZ[-1]*2]
|
||||
TM, TN, TB, TZ = 64, 64, 1, 1
|
||||
self.macros = { 'TM': TM, 'TN': TN, 'TB': TB, 'TK': TK, 'TZ': TZ, 'TYPE': dtype}
|
||||
# user-provided variables
|
||||
self.pos_vars = len(self.args)
|
||||
self.varnames = varnames
|
||||
self.args += [None] * len(varnames)
|
||||
# tile size ranges
|
||||
MAX_GZ = triton.cdiv(K, 2048)
|
||||
TMs = [16] + [x for x in [32, 64, 128] if x <= M]
|
||||
TNs = [16] + [x for x in [32, 64, 128] if x <= N]
|
||||
TBs = [x for x in [1, 2, 4, 8] if x <= B]
|
||||
TZs = [x for x in [1, 2, 4, 8, 16, 32] if x <= MAX_GZ]
|
||||
# tile sizes
|
||||
TM, TN, TB, TZ = _einsum.instance._tile(M, N, B, TMs, TNs, TBs, TZs, TK)
|
||||
TM, TN, TB, TZ = 64, 128, 1, 1
|
||||
self.macros = {'TM': TM, 'TN': TN, 'TB': TB, 'TK': TK, 'TZ': TZ, 'TYPE': dtype}
|
||||
self.num_warps = [4]
|
||||
if mask:
|
||||
self.macros['MASK'] = '{0:#0{1}x}'.format(mask, 10)
|
||||
# save information on the operation
|
||||
@@ -589,12 +623,15 @@ __global__ void {name}(
|
||||
self.matmul_N = N
|
||||
self.matmul_K = K
|
||||
self.is_extended = any([not x.is_symbol for x in sym_a + sym_b])
|
||||
|
||||
|
||||
def run(self, a, b, c, bench):
|
||||
def run(self, a, b, c, values, bench):
|
||||
self.args[self.pos_a] = a
|
||||
self.args[self.pos_b] = b
|
||||
self.args[self.pos_c] = c
|
||||
return self.kernel(*self.args, grid=self.grid, bench=bench, defines=self.macros)
|
||||
for i, name in enumerate(self.varnames):
|
||||
self.args[self.pos_vars + i] = values[name]
|
||||
return self.kernel(*self.args, grid=self.grid, bench=bench, defines=self.macros, num_warps=self.num_warps)
|
||||
|
||||
|
||||
|
||||
@@ -604,8 +641,9 @@ __global__ void {name}(
|
||||
############################
|
||||
|
||||
instance_cache = dict()
|
||||
registry = triton.utils.id_dict()
|
||||
@staticmethod
|
||||
def forward(ctx, expr, a, b, output, mask=None, arrays=dict(), bench=False):
|
||||
def forward(ctx, expr, a, b, output, mask, arrays, bench, values):
|
||||
# compile einsum instance
|
||||
cache = _einsum.instance_cache
|
||||
key = (expr, a.dtype,
|
||||
@@ -615,10 +653,10 @@ __global__ void {name}(
|
||||
cache[key] = _einsum.instance(expr, a.dtype,
|
||||
a.stride(), b.stride(), output.stride(),
|
||||
a.shape, b.shape, arrays,
|
||||
mask, output.shape)
|
||||
mask, output.shape, values.keys())
|
||||
instance = cache[key]
|
||||
# run and mark as dirty output modified in-place
|
||||
perf = instance.run(a, b, output, bench)
|
||||
perf = instance.run(a, b, output, values, bench)
|
||||
ctx.mark_dirty(output)
|
||||
# save information in context
|
||||
ctx.is_extended = instance.is_extended
|
||||
@@ -629,8 +667,9 @@ __global__ void {name}(
|
||||
ctx.matmul_M = instance.matmul_M
|
||||
ctx.matmul_N = instance.matmul_N
|
||||
ctx.matmul_K = instance.matmul_K
|
||||
ctx.perf = perf
|
||||
ctx.forward_ms = perf
|
||||
ctx.save_for_backward(a, b)
|
||||
_einsum.registry[output] = ctx
|
||||
return output
|
||||
|
||||
|
||||
@@ -662,5 +701,5 @@ __global__ void {name}(
|
||||
|
||||
def einsum(expr, a, b, output,
|
||||
mask=None, arrays=dict(),
|
||||
bench=False):
|
||||
return _einsum.apply(expr, a, b, output, mask, arrays, bench)
|
||||
bench=False, values=dict()):
|
||||
return _einsum.apply(expr, a, b, output, mask, arrays, bench, values)
|
Reference in New Issue
Block a user