Changes to make decorated API methods no longer type-opaque. ``` $ echo 'import triton; reveal_type(triton.language.max)' | mypy /dev/stdin /dev/stdin:1: note: Revealed type is "def (input: Any, axis: Any, _builder: Any =) -> Any" Success: no issues found in 1 source file ```
453 lines
17 KiB
Python
453 lines
17 KiB
Python
from __future__ import annotations, division
|
|
|
|
import ast
|
|
import functools
|
|
import hashlib
|
|
import inspect
|
|
import os
|
|
import subprocess
|
|
import textwrap
|
|
from collections import namedtuple
|
|
from typing import TypeVar, Generic, cast, Callable, overload, Optional, Iterable, Union
|
|
|
|
import torch
|
|
|
|
import triton
|
|
from triton.utils import MockTensor
|
|
|
|
try:
|
|
from torch._C import _cuda_getCurrentRawStream as get_cuda_stream
|
|
except ImportError:
|
|
get_cuda_stream = lambda dev_idx: torch.cuda.current_stream(dev_idx).cuda_stream
|
|
|
|
|
|
T = TypeVar('T')
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# Dependencies Finder
|
|
# -----------------------------------------------------------------------------
|
|
|
|
|
|
class DependenciesFinder(ast.NodeVisitor):
|
|
"""
|
|
This AST visitor is used to find dependencies of a JITFunction. This can
|
|
be used to invalidate a JITFunction's hash when its source code -- or
|
|
that of its dependencies -- changes.
|
|
"""
|
|
|
|
def __init__(self, globals, src) -> None:
|
|
super().__init__()
|
|
self.ret = hashlib.md5(src.encode("utf-8")).hexdigest()
|
|
self.globals = globals
|
|
|
|
def visit_Name(self, node):
|
|
return self.globals.get(node.id, None)
|
|
|
|
def visit_Attribute(self, node):
|
|
lhs = self.visit(node.value)
|
|
while isinstance(lhs, ast.Attribute):
|
|
lhs = self.visit(lhs.value)
|
|
if lhs is None or lhs is triton:
|
|
return None
|
|
return getattr(lhs, node.attr)
|
|
|
|
def visit_Call(self, node):
|
|
func = self.visit(node.func)
|
|
if func is None:
|
|
return
|
|
if inspect.isbuiltin(func):
|
|
return
|
|
if func.__module__ and func.__module__.startswith('triton.'):
|
|
return
|
|
assert isinstance(func, JITFunction)
|
|
if func.hash is None:
|
|
tree = ast.parse(func.src)
|
|
finder = DependenciesFinder(func.__globals__, func.src)
|
|
finder.visit(tree)
|
|
func.hash = finder.ret
|
|
self.ret = (self.ret + func.hash).encode("utf-8")
|
|
self.ret = hashlib.md5(self.ret).hexdigest()
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# JITFunction
|
|
# -----------------------------------------------------------------------------
|
|
|
|
|
|
@functools.lru_cache()
|
|
def version_key():
|
|
import pkgutil
|
|
contents = []
|
|
# frontend
|
|
with open(__file__, "rb") as f:
|
|
contents += [hashlib.md5(f.read()).hexdigest()]
|
|
with open(triton.compiler.__file__, "rb") as f:
|
|
contents += [hashlib.md5(f.read()).hexdigest()]
|
|
# backend
|
|
with open(triton._C.libtriton.__file__, "rb") as f:
|
|
contents += [hashlib.md5(f.read()).hexdigest()]
|
|
# language
|
|
language_path = os.path.join(*triton.__path__, 'language')
|
|
for lib in pkgutil.iter_modules([language_path]):
|
|
with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f:
|
|
contents += [hashlib.md5(f.read()).hexdigest()]
|
|
# ptxas version
|
|
try:
|
|
ptxas_version = hashlib.md5(subprocess.check_output(["ptxas", "--version"])).hexdigest()
|
|
except Exception:
|
|
ptxas_version = ''
|
|
return '-'.join(triton.__version__) + '-' + ptxas_version + '-' + '-'.join(contents)
|
|
|
|
|
|
class KernelInterface(Generic[T]):
|
|
run: T
|
|
|
|
def __getitem__(self, grid) -> T:
|
|
"""
|
|
A JIT function is launched with: fn[grid](*args, **kwargs).
|
|
Hence JITFunction.__getitem__ returns a callable proxy that
|
|
memorizes the grid.
|
|
"""
|
|
return cast(T, functools.partial(cast(Callable, self.run), grid=grid))
|
|
|
|
|
|
|
|
class JITFunction(KernelInterface[T]):
|
|
|
|
cache_hook = None
|
|
divisibility = 16
|
|
|
|
@staticmethod
|
|
def _key_of(arg):
|
|
if hasattr(arg, "dtype"):
|
|
return arg.dtype
|
|
elif isinstance(arg, bool):
|
|
return "i1"
|
|
elif isinstance(arg, int):
|
|
if -2**31 <= arg and arg <= 2**31 - 1:
|
|
return "i32"
|
|
elif 2**31 <= arg and arg <= 2**32 - 1:
|
|
return "u32"
|
|
elif 2**63 <= arg and arg <= 2**64 - 1:
|
|
return "u64"
|
|
else:
|
|
return "i64"
|
|
elif isinstance(arg, float):
|
|
return 'fp32'
|
|
elif arg is None:
|
|
return None
|
|
else:
|
|
raise TypeError(f'Unsupported type {type(arg)} for {arg}')
|
|
|
|
@staticmethod
|
|
def _spec_of(arg):
|
|
if hasattr(arg, "data_ptr"):
|
|
return (arg.data_ptr() % JITFunction.divisibility == 0)
|
|
elif isinstance(arg, int):
|
|
return (arg % 16 == 0, arg == 1)
|
|
return (arg is None, )
|
|
|
|
def _get_config(self, *args):
|
|
def is_divisible_by_16(x):
|
|
if hasattr(x, "data_ptr"):
|
|
return x.data_ptr() % JITFunction.divisibility == 0
|
|
elif isinstance(x, int):
|
|
return x % JITFunction.divisibility == 0
|
|
if x is None:
|
|
return True
|
|
return False
|
|
divisible_by_16 = {i for i, arg in enumerate(args) if is_divisible_by_16(arg) and i not in self.do_not_specialize}
|
|
equal_to_1 = {i for i, arg in enumerate(args) if isinstance(arg, int) and arg == 1 and i not in self.do_not_specialize}
|
|
return namedtuple("instance_descriptor", ["divisible_by_16", "equal_to_1"])(tuple(divisible_by_16), tuple(equal_to_1))
|
|
# return _triton.code_gen.instance_descriptor(divisible_by_16, equal_to_1)
|
|
|
|
@staticmethod
|
|
def _type_of(key):
|
|
if isinstance(key, (torch.dtype, triton.language.dtype)):
|
|
ty = {
|
|
torch.bool: 'i1',
|
|
torch.float16: 'fp16',
|
|
torch.bfloat16: 'bf16',
|
|
torch.float32: 'fp32',
|
|
torch.float64: 'fp64',
|
|
torch.uint8: 'u8',
|
|
torch.int8: 'i8',
|
|
torch.int16: 'i16',
|
|
torch.int32: 'i32',
|
|
torch.int64: 'i64',
|
|
|
|
triton.language.uint8: 'u8',
|
|
triton.language.uint16: 'u16',
|
|
triton.language.uint32: 'u32',
|
|
triton.language.uint64: 'u64',
|
|
triton.language.float8: 'fp8',
|
|
triton.language.float16: 'fp16',
|
|
triton.language.bfloat16: 'bf16',
|
|
triton.language.float32: 'fp32',
|
|
}[key]
|
|
return f'*{ty}'
|
|
if key is None:
|
|
return '*i8'
|
|
assert isinstance(key, str)
|
|
return key
|
|
|
|
def _make_signature(self, sig_key):
|
|
signature = ",".join([self._type_of(k) for i, k in enumerate(sig_key)])
|
|
return signature
|
|
|
|
def _make_constants(self, constexpr_key):
|
|
constants = {i: k for i, k in zip(self.constexprs, constexpr_key)}
|
|
return constants
|
|
|
|
def _call_hook(self, key, signature, device, constants, num_warps, num_stages, extern_libs, configs):
|
|
if JITFunction.cache_hook is None:
|
|
return False
|
|
name = self.fn.__name__
|
|
module = self.fn.__module__
|
|
arg_reprs = ', '.join([f'{name}: {ty}' for name, ty in zip(self.arg_names, key[1])])
|
|
repr = f"{name}[num_warps={num_warps}, num_stages={num_stages}]({arg_reprs})"
|
|
key = str(key)
|
|
|
|
class LegacyCompiler:
|
|
def __init__(self, module, name):
|
|
self.module = module
|
|
self.name = name
|
|
pass
|
|
|
|
kwargs = dict(signature=signature, device=device, constants=constants,
|
|
num_warps=num_warps, num_stages=num_stages, extern_libs=extern_libs,
|
|
configs=configs)
|
|
|
|
return JITFunction.cache_hook(key=key, repr=repr, fn=LegacyCompiler(module, name), compile={"key": key, **kwargs}, is_manual_warmup=False, already_compiled=False)
|
|
|
|
def _make_launcher(self):
|
|
regular_args = [f'{arg}' for i, arg in enumerate(self.arg_names) if i not in self.constexprs]
|
|
constexpr_args = [f'{arg}' for i, arg in enumerate(self.arg_names) if i in self.constexprs]
|
|
args = ', '.join(regular_args)
|
|
# cache key for regular argument type
|
|
sig_keys = ', '.join([f'_key_of({arg})' for arg in regular_args])
|
|
# cache key for constexpr argument values
|
|
constexpr_keys = ', '.join(constexpr_args)
|
|
# cache key for argument specialization
|
|
specializations = []
|
|
for i, arg in enumerate(regular_args):
|
|
if i in self.do_not_specialize:
|
|
continue
|
|
specializations += [f'({arg}.data_ptr() % {JITFunction.divisibility} == 0) if hasattr({arg}, "data_ptr") '
|
|
f'else ({arg} % {JITFunction.divisibility} == 0, {arg} == 1) if isinstance({arg}, int) '
|
|
f'else (False,)']
|
|
spec_keys = ', '.join(specializations)
|
|
grid_args = ','.join([f'"{arg}": {arg}' for arg in self.arg_names])
|
|
|
|
src = f"""
|
|
def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stages=3, extern_libs=None, stream=None, warmup=False):
|
|
sig_key = {sig_keys},
|
|
constexpr_key = {f'{constexpr_keys},' if len(constexpr_keys) > 0 else tuple()}
|
|
spec_key = {f'{spec_keys},' if len(spec_keys) > 0 else tuple()}
|
|
key = (version_key, sig_key, constexpr_key, spec_key)
|
|
if not extern_libs is None:
|
|
key = (key, tuple(extern_libs.items()))
|
|
assert num_warps > 0 and (num_warps & (num_warps - 1)) == 0, "num_warps must be a power of 2"
|
|
if callable(grid):
|
|
grid = grid({{{grid_args}}})
|
|
grid_size = len(grid)
|
|
grid_0 = grid[0]
|
|
grid_1 = grid[1] if grid_size > 1 else 1
|
|
grid_2 = grid[2] if grid_size > 2 else 1
|
|
device = torch.cuda.current_device()
|
|
torch.cuda.set_device(device)
|
|
if stream is None and not warmup:
|
|
stream = get_cuda_stream(device)
|
|
try:
|
|
bin = cache[key]
|
|
if not warmup:
|
|
bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared, stream, bin.cu_function, {args})
|
|
return bin
|
|
# kernel not cached -- compile
|
|
except KeyError:
|
|
# build dict of constant values
|
|
args = [{args}]
|
|
configs = self._get_config(*args),
|
|
constants = self._make_constants(constexpr_key)
|
|
constants.update({{i: None for i, arg in enumerate(args) if arg is None}})
|
|
constants.update({{i: 1 for i in configs[0].equal_to_1}})
|
|
# build kernel signature -- doesn't include specialized arguments
|
|
all_args = {', '.join([f'{arg}' for arg in self.arg_names])},
|
|
signature = {{ i: self._type_of(_key_of(arg)) for i, arg in enumerate(all_args) if i not in self.constexprs }}
|
|
# build stub signature -- includes arguments that are specialized
|
|
for i, arg in constants.items():
|
|
if callable(arg):
|
|
raise TypeError(f"Callable constexpr at index {i} is not supported")
|
|
device = 0
|
|
if not self._call_hook(key, signature, device, constants, num_warps, num_stages, extern_libs, configs):
|
|
bin = triton.compile(self, signature=signature, device=device, constants=constants, num_warps=num_warps, num_stages=num_stages, extern_libs=extern_libs, configs=configs)
|
|
if not warmup:
|
|
bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared, stream, bin.cu_function, *args)
|
|
self.cache[key] = bin
|
|
return bin
|
|
return None
|
|
"""
|
|
scope = {"version_key": version_key(), "get_cuda_stream": get_cuda_stream,
|
|
"self": self, "_spec_of": self._spec_of, "_key_of": self._key_of,
|
|
"cache": self.cache, "triton": triton, "torch": torch}
|
|
exec(src, scope)
|
|
return scope[self.fn.__name__]
|
|
|
|
def __init__(self, fn, version=None, do_not_specialize=None):
|
|
self.fn = fn
|
|
self.module = fn.__module__
|
|
self.version = version
|
|
# function signature information
|
|
signature = inspect.signature(fn)
|
|
self.arg_names = [v.name for v in signature.parameters.values()]
|
|
self.has_defaults = any([v.default != inspect._empty for v in signature.parameters.values()])
|
|
# specialization hints
|
|
self.do_not_specialize = [] if do_not_specialize is None else do_not_specialize
|
|
self.do_not_specialize = set([self.arg_names.index(arg) if isinstance(arg, str) else arg for arg in self.do_not_specialize])
|
|
# function source code (without decorators)
|
|
self.src = textwrap.dedent(inspect.getsource(fn))
|
|
self.src = self.src[self.src.find("def"):]
|
|
# cache of just-in-time compiled kernels
|
|
self.cache = dict()
|
|
self.hash = None
|
|
# JITFunction can be instantiated as kernel
|
|
# when called with a grid using __getitem__
|
|
self.kernel_decorators = []
|
|
self.kernel = None
|
|
# annotations
|
|
self.annotations = {self.arg_names.index(name): ty for name, ty in fn.__annotations__.items()}
|
|
self.__annotations__ = fn.__annotations__
|
|
# index of constexprs
|
|
self.constexprs = [self.arg_names.index(ann) for ann in self.__annotations__.keys()]
|
|
# launcher
|
|
self.run = self._make_launcher()
|
|
# re-use docs of wrapped function
|
|
self.__doc__ = fn.__doc__
|
|
self.__name__ = fn.__name__
|
|
self.__globals__ = fn.__globals__
|
|
self.__module__ = fn.__module__
|
|
|
|
@property
|
|
def cache_key(self):
|
|
# TODO : hash should be attribute of `self`
|
|
if self.hash is None:
|
|
dependencies_finder = DependenciesFinder(globals=self.__globals__, src=self.src)
|
|
dependencies_finder.visit(self.parse())
|
|
self.hash = dependencies_finder.ret + version_key()
|
|
return self.hash
|
|
|
|
def warmup(self, *args, **kwargs):
|
|
return self.run(*map(MockTensor.wrap_dtype, args), **kwargs, warmup=True)
|
|
|
|
# we do not parse `src` in the constructor because
|
|
# the user might want to monkey-patch self.src dynamically.
|
|
# Our unit tests do this, for example.
|
|
def parse(self):
|
|
tree = ast.parse(self.src)
|
|
assert isinstance(tree, ast.Module)
|
|
assert len(tree.body) == 1
|
|
assert isinstance(tree.body[0], ast.FunctionDef)
|
|
return tree
|
|
|
|
def __call__(self, *args, **kwargs):
|
|
raise RuntimeError("Cannot call @triton.jit'd outside of the scope of a kernel")
|
|
|
|
def __setattr__(self, name, value):
|
|
# - when kernel decorators change, cached kernel
|
|
# needs to be cleared
|
|
if name == 'kernel_decorators':
|
|
self.kernel = None
|
|
super(JITFunction, self).__setattr__(name, value)
|
|
# - when `.src` attribute is set, cache path needs
|
|
# to be reinitialized
|
|
if name == 'src':
|
|
self.hash = None
|
|
|
|
def __repr__(self):
|
|
return f"JITFunction({self.module}:{self.fn.__name__})"
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# `jit` decorator
|
|
# -----------------------------------------------------------------------------
|
|
|
|
|
|
@overload
|
|
def jit(fn: T) -> JITFunction[T]:
|
|
...
|
|
|
|
|
|
@overload
|
|
def jit(
|
|
*,
|
|
version=None,
|
|
do_not_specialize: Optional[Iterable[int]] = None,
|
|
) -> Callable[[T], JITFunction[T]]:
|
|
...
|
|
|
|
|
|
def jit(
|
|
fn: Optional[T] = None,
|
|
*,
|
|
version=None,
|
|
do_not_specialize: Optional[Iterable[int]] = None,
|
|
) -> Union[JITFunction[T], Callable[[T], JITFunction[T]]]:
|
|
"""
|
|
Decorator for JIT-compiling a function using the Triton compiler.
|
|
|
|
:note: When a jit'd function is called, :code:`torch.tensor` arguments are
|
|
implicitly converted to pointers using the :code:`.data_ptr()` method.
|
|
|
|
:note: This function will be compiled and run on the GPU. It will only have access to:
|
|
|
|
* python primitives,
|
|
* builtins within the triton package,
|
|
* arguments to this function,
|
|
* other jit'd functions
|
|
|
|
:param fn: the function to be jit-compiled
|
|
:type fn: Callable
|
|
"""
|
|
|
|
def decorator(fn: T) -> JITFunction[T]:
|
|
assert callable(fn)
|
|
return JITFunction(
|
|
fn,
|
|
version=version,
|
|
do_not_specialize=do_not_specialize,
|
|
)
|
|
|
|
if fn is not None:
|
|
return decorator(fn)
|
|
|
|
else:
|
|
return decorator
|
|
|
|
|
|
class TensorWrapper:
|
|
def __init__(self, base, dtype):
|
|
self.dtype = dtype
|
|
self.base = base
|
|
self.is_cuda = base.is_cuda
|
|
self.device = base.device
|
|
|
|
def data_ptr(self):
|
|
return self.base.data_ptr()
|
|
|
|
def __str__(self) -> str:
|
|
return f'TensorWrapper[{self.dtype}]({self.base})'
|
|
|
|
|
|
def reinterpret(tensor, dtype):
|
|
if isinstance(tensor, TensorWrapper):
|
|
if dtype == tensor.base.dtype:
|
|
# Reinterpreting to the original interpretation; return the base.
|
|
return tensor.base
|
|
else:
|
|
# Reinterpreting a wrapped tensor to a different type.
|
|
return TensorWrapper(tensor.base, dtype)
|
|
elif isinstance(tensor, torch.Tensor):
|
|
# A new wrapper is needed around an unwrapped tensor.
|
|
return TensorWrapper(tensor, dtype)
|
|
else:
|
|
raise TypeError(f'Cannot reinterpret a {type(tensor)}.')
|