Files
triton/python/triton/kernel.py
2021-07-27 12:38:48 -07:00

196 lines
5.6 KiB
Python

# import for cache
import os
import tempfile
import shutil
import hashlib
import sysconfig
import sys
import weakref
import contextlib
import io
import torch.utils.cpp_extension
# import for just-in-time compilation
import distutils
import setuptools.command.build_ext
import setuptools
# triton
import triton.frameworks as fw
import triton.utils
import triton._C.libtriton as libtriton
import os
import time
import platform
@contextlib.contextmanager
def quiet():
old_stdout, old_stderr = sys.stdout, sys.stderr
sys.stdout, sys.stderr = io.StringIO(), io.StringIO()
try:
yield
finally:
sys.stdout, sys.stderr = old_stdout, old_stderr
def _build(src, path, name):
ccdir = os.path.join(libtriton.__file__, os.path.pardir)
ccdir = os.path.realpath(ccdir)
# include / libraries
include_dirs = [os.path.join(ccdir, 'include')]
library_dirs = [ccdir]
libraries = ['triton']
# create extension module
abi = fw.torch._C._GLIBCXX_USE_CXX11_ABI
extra_compile_args = ['-fPIC', '-Wno-deprecated-declarations', f'-D_GLIBCXX_USE_CXX11_ABI={str(int(abi))}']
extra_compile_args += ['-DTORCH_EXTENSION_NAME={}'.format(name)]
extra_compile_args += ['-DTORCH_API_INCLUDE_EXTENSION_H']
ext = torch.utils.cpp_extension.CUDAExtension(
name = name,
language = 'c++',
sources = [src],
include_dirs = include_dirs,
library_dirs = library_dirs,
libraries = libraries,
extra_compile_args = extra_compile_args,
depends = [os.path.realpath(libtriton.__file__)]
)
# build extension module
args = ['build_ext']
tmp = tempfile.mkdtemp()
args.append('--build-temp=' + tmp)
args.append('--build-lib=' + path)
args.append('-q')
args = dict(
name = name,
ext_modules = [ext],
script_args = args,
)
with quiet():
setuptools.setup(**args)
shutil.rmtree(tmp)
def _cvt_to_def_str(obj):
# bool
if isinstance(obj, bool):
return str(int(obj))
# torch type
if fw.has_torch():
if isinstance(obj, fw.torch.dtype):
return {fw.torch.int8: 'char',
fw.torch.int16: 'short',
fw.torch.int32: 'int',
fw.torch.int64: 'long',
fw.torch.float16: 'half',
fw.torch.float32: 'float',
fw.torch.float64: 'double'}[obj]
else:
assert False
# default
return str(obj)
def _encode(arg_types):
codes = {
libtriton.arg_type.int1: 'i1',
libtriton.arg_type.int8: 'i8',
libtriton.arg_type.int32: 'i32',
libtriton.arg_type.int64: 'i64',
libtriton.arg_type.half: 'f16',
libtriton.arg_type.float: 'f32',
libtriton.arg_type.double: 'f64',
libtriton.arg_type.buffer: 'buf'
}
ret = '_'.join(map(codes.get, arg_types))
return ret
def _make_framework_op(arg_types):
name = _encode(arg_types)
# path of .cpp and .so file
home = os.path.expanduser('~')
root = os.path.join(home, '.triton', 'torch', name)
try:
os.makedirs(root)
except FileExistsError:
pass
suffix = sysconfig.get_config_var('EXT_SUFFIX')
so = os.path.join(root, f'op{suffix}')
cpp = os.path.join(root, f'op.cpp')
# handle cached .so file
if os.path.exists(so) and os.stat(so).st_size > 0:
tt_mtime = os.stat(os.path.realpath(libtriton.__file__)).st_mtime
so_mtime = os.stat(so).st_mtime
# can use cached if libtriton is older than the .so
if tt_mtime < so_mtime:
fw.torch.ops.load_library(so)
return getattr(fw.torch.ops.triton, name)
# create torch source code
print('[TRITON] Compiling op...')
baton = torch.utils.file_baton.FileBaton(os.path.join(root, 'lock'))
if baton.try_acquire():
try:
src, _ = libtriton.make_torch_src(name, arg_types)
with open(cpp, 'w+') as handle:
handle.writelines(src)
ccdir = os.path.join(libtriton.__file__, os.path.pardir)
ccdir = os.path.realpath(ccdir)
_build(cpp, root, 'op')
finally:
baton.release()
else:
baton.wait()
print('[TRITON] Done compiling...')
fw.torch.ops.load_library(so)
return getattr(fw.torch.ops.triton, name)
class kernel:
def __init__(self, src, defines = dict(), num_warps = [2, 4, 8]):
self.src = src
# create constants
self.cst = dict()
# create triton op
macros = []
for k, v in defines.items():
cvt = lambda x: _cvt_to_def_str(x)
if(isinstance(v, list)):
values = list(map(cvt, v))
else:
values = [cvt(v)]
macros.append((k, values))
opt = libtriton.options_space()
opt.defines = macros
opt.num_warps = num_warps
self.op_id = libtriton.make_op_id()
self.opt = opt
self.registered = set()
# create pytorch hook
arg_types = libtriton.get_fn_signature(self.src, opt)
self.fw_op = _make_framework_op(arg_types)
def set_constant(self, name, value):
libtriton.register_cst(self.op_id, name, value)
def __call__(self, *args, **kwargs):
for x in args:
if isinstance(x, fw.torch.Tensor):
device = x.device.index
break
# lazily register function for device
if device not in self.registered:
self.registered.add(device)
libtriton.register_fn((self.op_id, device), self.src, self.opt, os.path.realpath(libtriton.__file__))
# launch options
bench = kwargs['bench'] if 'bench' in kwargs else 0
bench_id = libtriton.make_scalar_id() if bench > 0 else -1
# launch grid
if 'grid' not in kwargs:
raise RuntimeError('Must provide grid for kernel launch')
grid = kwargs['grid']
libtriton.register_grid((self.op_id, device), grid)
# launch
self.fw_op(self.op_id, device, bench, bench_id, *args)
if bench > 0:
return libtriton.retrieve_scalar(bench_id)