[FRONTEND] Extract and unify @builtin/@extern (#913)

This change attaches builtin-ness as an explicit attribute, rather than
a module prefix expectation. This permits us to source those builtins
from multiple sub-modules (useful when some builtins are part of the
true cyclic implementation core, and some are just useful library
additions); but also prevents accidental inclusion of non-builtins that
happen to be in the right library.

Once the flag exists, and the compiler is using `is_builtin()` for
decision making; the existence of the current `@extern` interface
becomes isomorphic to `@builtin`; and the interface can be unified.

Leaving `@extern` a thin-wrapper, and encouraging continued use of it,
establishes future-proofing towards adding additional extern tracing,
metric hooks, or scanning in the future.

* Add `triton.impl` package to hold the core, order dependent impl
details.
 * Extract `@builtin` and unify `@extern`; add `is_builtin()`
   * Add sense bit so that `@builtin` detection is less fragile.
 * Modify the compiler to use `is_builtin()`
This commit is contained in:
Crutcher Dunnavant
2022-12-05 14:59:41 -08:00
committed by GitHub
parent e0072d210a
commit 189491727a
8 changed files with 269 additions and 239 deletions

View File

@@ -9,6 +9,7 @@ __version__ = '2.0.0'
import torch # noqa: F401
# submodules
from . import impl
from .utils import (
cdiv,
MockTensor,
@@ -36,6 +37,7 @@ __all__ = [
"compile",
"Config",
"heuristics",
"impl",
"jit",
"JITFunction",
"KernelInterface",

View File

@@ -25,6 +25,8 @@ from filelock import FileLock
import triton
import triton._C.libtriton.triton as _triton
from . import impl
from .tools.disasm import extract
@@ -715,9 +717,8 @@ class CodeGenerator(ast.NodeVisitor):
for i in range(call_op.get_num_results()):
results.append(triton.language.tensor(call_op.get_result(i), callee_ret_type[i]))
return tuple(results)
if hasattr(fn, '__self__') and self.is_triton_tensor(fn.__self__) or \
sys.modules[fn.__module__] is triton.language.core or \
isinstance(fn, triton.language.extern.ExternalFunction):
if (hasattr(fn, '__self__') and self.is_triton_tensor(fn.__self__)) \
or impl.is_builtin(fn):
return fn(*args, _builder=self.builder, **kws)
if fn in self.builtins.values():
args = [arg.value if isinstance(arg, triton.language.constexpr) else arg

View File

@@ -0,0 +1,22 @@
"""Triton internal implementation details.
Client libraries should not import interfaces from the `triton.impl` module;
as the details are subject to change.
APIs defined in the `triton.impl` module which are public will be re-exported
in other relevant `triton` module namespaces.
"""
from triton._C.libtriton.triton import ir
from .base import (
builtin,
extern,
is_builtin,
)
__all__ = [
"builtin",
"extern",
"ir",
"is_builtin",
]

View File

@@ -0,0 +1,36 @@
from __future__ import annotations
from functools import wraps
from typing import TypeVar
T = TypeVar("T")
TRITON_BUILTIN = "__triton_builtin__"
def builtin(fn: T) -> T:
"""Mark a function as a builtin."""
assert callable(fn)
@wraps(fn)
def wrapper(*args, **kwargs):
if "_builder" not in kwargs or kwargs["_builder"] is None:
raise ValueError(
"Did you forget to add @triton.jit ? "
"(`_builder` argument must be provided outside of JIT functions.)"
)
return fn(*args, **kwargs)
setattr(wrapper, TRITON_BUILTIN, True)
return wrapper
def is_builtin(fn) -> bool:
"""Is this a registered triton builtin function?"""
return getattr(fn, TRITON_BUILTIN, False)
def extern(fn: T) -> T:
"""A decorator for external functions."""
return builtin(fn)

View File

@@ -1,8 +1,10 @@
"""isort:skip_file"""
# Import order is significant here.
from triton._C.libtriton.triton import ir
from ..impl import (
ir,
builtin,
)
from . import core, extern, libdevice, random
from .core import (
abs,

View File

@@ -1,11 +1,10 @@
from __future__ import annotations
from enum import Enum
from functools import wraps
from typing import List, Callable, TypeVar
import triton
from . import semantic
from . import builtin, semantic
from triton._C.libtriton.triton import ir
T = TypeVar('T')
@@ -34,17 +33,6 @@ def _to_tensor(x, builder):
assert False, f'cannot convert {x} to tensor'
def builtin(fn: T) -> T:
@wraps(fn)
def wrapper(*args, **kwargs):
if '_builder' not in kwargs or \
kwargs['_builder'] is None:
raise ValueError("Did you forget to add @triton.jit ? (`_builder` argument must be provided outside of JIT functions.)")
return fn(*args, **kwargs)
return wrapper
class dtype:
SINT_TYPES = ['int1', 'int8', 'int16', 'int32', 'int64']
UINT_TYPES = ['uint8', 'uint16', 'uint32', 'uint64']

View File

@@ -86,25 +86,3 @@ def elementwise(lib_name: str, lib_path: str, args: list, arg_type_symbol_dict:
ret_shape = broadcast_arg.shape
func = getattr(_builder, "create_external_elementwise")
return dispatch(func, lib_name, lib_path, dispatch_args, arg_type_symbol_dict, ret_shape, _builder)
class ExternalFunction:
'''
A wrapper for external functions
'''
def __init__(self, fn):
self.fn = fn
def __call__(self, *args, **kwargs):
if '_builder' not in kwargs or \
kwargs['_builder'] is None:
raise ValueError("Did you forget to add @triton.jit ? (`_builder` argument must be provided outside of JIT functions.)")
return self.fn(*args, **kwargs)
def extern(fn):
'''
A decorator for external functions
'''
return ExternalFunction(fn)

File diff suppressed because it is too large Load Diff