[FRONTEND] Extract and unify @builtin/@extern (#913)

This change attaches builtin-ness as an explicit attribute, rather than
a module prefix expectation. This permits us to source those builtins
from multiple sub-modules (useful when some builtins are part of the
true cyclic implementation core, and some are just useful library
additions); but also prevents accidental inclusion of non-builtins that
happen to be in the right library.

Once the flag exists, and the compiler is using `is_builtin()` for
decision making; the existence of the current `@extern` interface
becomes isomorphic to `@builtin`; and the interface can be unified.

Leaving `@extern` a thin-wrapper, and encouraging continued use of it,
establishes future-proofing towards adding additional extern tracing,
metric hooks, or scanning in the future.

* Add `triton.impl` package to hold the core, order dependent impl
details.
 * Extract `@builtin` and unify `@extern`; add `is_builtin()`
   * Add sense bit so that `@builtin` detection is less fragile.
 * Modify the compiler to use `is_builtin()`
This commit is contained in:
Crutcher Dunnavant
2022-12-05 14:59:41 -08:00
committed by GitHub
parent e0072d210a
commit 189491727a
8 changed files with 269 additions and 239 deletions

View File

@@ -9,6 +9,7 @@ __version__ = '2.0.0'
import torch # noqa: F401 import torch # noqa: F401
# submodules # submodules
from . import impl
from .utils import ( from .utils import (
cdiv, cdiv,
MockTensor, MockTensor,
@@ -36,6 +37,7 @@ __all__ = [
"compile", "compile",
"Config", "Config",
"heuristics", "heuristics",
"impl",
"jit", "jit",
"JITFunction", "JITFunction",
"KernelInterface", "KernelInterface",

View File

@@ -25,6 +25,8 @@ from filelock import FileLock
import triton import triton
import triton._C.libtriton.triton as _triton import triton._C.libtriton.triton as _triton
from . import impl
from .tools.disasm import extract from .tools.disasm import extract
@@ -715,9 +717,8 @@ class CodeGenerator(ast.NodeVisitor):
for i in range(call_op.get_num_results()): for i in range(call_op.get_num_results()):
results.append(triton.language.tensor(call_op.get_result(i), callee_ret_type[i])) results.append(triton.language.tensor(call_op.get_result(i), callee_ret_type[i]))
return tuple(results) return tuple(results)
if hasattr(fn, '__self__') and self.is_triton_tensor(fn.__self__) or \ if (hasattr(fn, '__self__') and self.is_triton_tensor(fn.__self__)) \
sys.modules[fn.__module__] is triton.language.core or \ or impl.is_builtin(fn):
isinstance(fn, triton.language.extern.ExternalFunction):
return fn(*args, _builder=self.builder, **kws) return fn(*args, _builder=self.builder, **kws)
if fn in self.builtins.values(): if fn in self.builtins.values():
args = [arg.value if isinstance(arg, triton.language.constexpr) else arg args = [arg.value if isinstance(arg, triton.language.constexpr) else arg

View File

@@ -0,0 +1,22 @@
"""Triton internal implementation details.
Client libraries should not import interfaces from the `triton.impl` module;
as the details are subject to change.
APIs defined in the `triton.impl` module which are public will be re-exported
in other relevant `triton` module namespaces.
"""
from triton._C.libtriton.triton import ir
from .base import (
builtin,
extern,
is_builtin,
)
__all__ = [
"builtin",
"extern",
"ir",
"is_builtin",
]

View File

@@ -0,0 +1,36 @@
from __future__ import annotations
from functools import wraps
from typing import TypeVar
T = TypeVar("T")
TRITON_BUILTIN = "__triton_builtin__"
def builtin(fn: T) -> T:
"""Mark a function as a builtin."""
assert callable(fn)
@wraps(fn)
def wrapper(*args, **kwargs):
if "_builder" not in kwargs or kwargs["_builder"] is None:
raise ValueError(
"Did you forget to add @triton.jit ? "
"(`_builder` argument must be provided outside of JIT functions.)"
)
return fn(*args, **kwargs)
setattr(wrapper, TRITON_BUILTIN, True)
return wrapper
def is_builtin(fn) -> bool:
"""Is this a registered triton builtin function?"""
return getattr(fn, TRITON_BUILTIN, False)
def extern(fn: T) -> T:
"""A decorator for external functions."""
return builtin(fn)

View File

@@ -1,8 +1,10 @@
"""isort:skip_file""" """isort:skip_file"""
# Import order is significant here. # Import order is significant here.
from triton._C.libtriton.triton import ir from ..impl import (
ir,
builtin,
)
from . import core, extern, libdevice, random from . import core, extern, libdevice, random
from .core import ( from .core import (
abs, abs,

View File

@@ -1,11 +1,10 @@
from __future__ import annotations from __future__ import annotations
from enum import Enum from enum import Enum
from functools import wraps
from typing import List, Callable, TypeVar from typing import List, Callable, TypeVar
import triton import triton
from . import semantic from . import builtin, semantic
from triton._C.libtriton.triton import ir from triton._C.libtriton.triton import ir
T = TypeVar('T') T = TypeVar('T')
@@ -34,17 +33,6 @@ def _to_tensor(x, builder):
assert False, f'cannot convert {x} to tensor' assert False, f'cannot convert {x} to tensor'
def builtin(fn: T) -> T:
@wraps(fn)
def wrapper(*args, **kwargs):
if '_builder' not in kwargs or \
kwargs['_builder'] is None:
raise ValueError("Did you forget to add @triton.jit ? (`_builder` argument must be provided outside of JIT functions.)")
return fn(*args, **kwargs)
return wrapper
class dtype: class dtype:
SINT_TYPES = ['int1', 'int8', 'int16', 'int32', 'int64'] SINT_TYPES = ['int1', 'int8', 'int16', 'int32', 'int64']
UINT_TYPES = ['uint8', 'uint16', 'uint32', 'uint64'] UINT_TYPES = ['uint8', 'uint16', 'uint32', 'uint64']

View File

@@ -86,25 +86,3 @@ def elementwise(lib_name: str, lib_path: str, args: list, arg_type_symbol_dict:
ret_shape = broadcast_arg.shape ret_shape = broadcast_arg.shape
func = getattr(_builder, "create_external_elementwise") func = getattr(_builder, "create_external_elementwise")
return dispatch(func, lib_name, lib_path, dispatch_args, arg_type_symbol_dict, ret_shape, _builder) return dispatch(func, lib_name, lib_path, dispatch_args, arg_type_symbol_dict, ret_shape, _builder)
class ExternalFunction:
'''
A wrapper for external functions
'''
def __init__(self, fn):
self.fn = fn
def __call__(self, *args, **kwargs):
if '_builder' not in kwargs or \
kwargs['_builder'] is None:
raise ValueError("Did you forget to add @triton.jit ? (`_builder` argument must be provided outside of JIT functions.)")
return self.fn(*args, **kwargs)
def extern(fn):
'''
A decorator for external functions
'''
return ExternalFunction(fn)

File diff suppressed because it is too large Load Diff