[FRONTEND] Extract and unify @builtin/@extern (#913)
This change attaches builtin-ness as an explicit attribute, rather than a module prefix expectation. This permits us to source those builtins from multiple sub-modules (useful when some builtins are part of the true cyclic implementation core, and some are just useful library additions); but also prevents accidental inclusion of non-builtins that happen to be in the right library. Once the flag exists, and the compiler is using `is_builtin()` for decision making; the existence of the current `@extern` interface becomes isomorphic to `@builtin`; and the interface can be unified. Leaving `@extern` a thin-wrapper, and encouraging continued use of it, establishes future-proofing towards adding additional extern tracing, metric hooks, or scanning in the future. * Add `triton.impl` package to hold the core, order dependent impl details. * Extract `@builtin` and unify `@extern`; add `is_builtin()` * Add sense bit so that `@builtin` detection is less fragile. * Modify the compiler to use `is_builtin()`
This commit is contained in:
committed by
GitHub
parent
e0072d210a
commit
189491727a
@@ -9,6 +9,7 @@ __version__ = '2.0.0'
|
||||
import torch # noqa: F401
|
||||
|
||||
# submodules
|
||||
from . import impl
|
||||
from .utils import (
|
||||
cdiv,
|
||||
MockTensor,
|
||||
@@ -36,6 +37,7 @@ __all__ = [
|
||||
"compile",
|
||||
"Config",
|
||||
"heuristics",
|
||||
"impl",
|
||||
"jit",
|
||||
"JITFunction",
|
||||
"KernelInterface",
|
||||
|
@@ -25,6 +25,8 @@ from filelock import FileLock
|
||||
|
||||
import triton
|
||||
import triton._C.libtriton.triton as _triton
|
||||
|
||||
from . import impl
|
||||
from .tools.disasm import extract
|
||||
|
||||
|
||||
@@ -715,9 +717,8 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
for i in range(call_op.get_num_results()):
|
||||
results.append(triton.language.tensor(call_op.get_result(i), callee_ret_type[i]))
|
||||
return tuple(results)
|
||||
if hasattr(fn, '__self__') and self.is_triton_tensor(fn.__self__) or \
|
||||
sys.modules[fn.__module__] is triton.language.core or \
|
||||
isinstance(fn, triton.language.extern.ExternalFunction):
|
||||
if (hasattr(fn, '__self__') and self.is_triton_tensor(fn.__self__)) \
|
||||
or impl.is_builtin(fn):
|
||||
return fn(*args, _builder=self.builder, **kws)
|
||||
if fn in self.builtins.values():
|
||||
args = [arg.value if isinstance(arg, triton.language.constexpr) else arg
|
||||
|
22
python/triton/impl/__init__.py
Normal file
22
python/triton/impl/__init__.py
Normal file
@@ -0,0 +1,22 @@
|
||||
"""Triton internal implementation details.
|
||||
|
||||
Client libraries should not import interfaces from the `triton.impl` module;
|
||||
as the details are subject to change.
|
||||
|
||||
APIs defined in the `triton.impl` module which are public will be re-exported
|
||||
in other relevant `triton` module namespaces.
|
||||
"""
|
||||
|
||||
from triton._C.libtriton.triton import ir
|
||||
from .base import (
|
||||
builtin,
|
||||
extern,
|
||||
is_builtin,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"builtin",
|
||||
"extern",
|
||||
"ir",
|
||||
"is_builtin",
|
||||
]
|
36
python/triton/impl/base.py
Normal file
36
python/triton/impl/base.py
Normal file
@@ -0,0 +1,36 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import wraps
|
||||
from typing import TypeVar
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
TRITON_BUILTIN = "__triton_builtin__"
|
||||
|
||||
|
||||
def builtin(fn: T) -> T:
|
||||
"""Mark a function as a builtin."""
|
||||
assert callable(fn)
|
||||
|
||||
@wraps(fn)
|
||||
def wrapper(*args, **kwargs):
|
||||
if "_builder" not in kwargs or kwargs["_builder"] is None:
|
||||
raise ValueError(
|
||||
"Did you forget to add @triton.jit ? "
|
||||
"(`_builder` argument must be provided outside of JIT functions.)"
|
||||
)
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
setattr(wrapper, TRITON_BUILTIN, True)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def is_builtin(fn) -> bool:
|
||||
"""Is this a registered triton builtin function?"""
|
||||
return getattr(fn, TRITON_BUILTIN, False)
|
||||
|
||||
|
||||
def extern(fn: T) -> T:
|
||||
"""A decorator for external functions."""
|
||||
return builtin(fn)
|
@@ -1,8 +1,10 @@
|
||||
"""isort:skip_file"""
|
||||
# Import order is significant here.
|
||||
|
||||
from triton._C.libtriton.triton import ir
|
||||
|
||||
from ..impl import (
|
||||
ir,
|
||||
builtin,
|
||||
)
|
||||
from . import core, extern, libdevice, random
|
||||
from .core import (
|
||||
abs,
|
||||
|
@@ -1,11 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum
|
||||
from functools import wraps
|
||||
from typing import List, Callable, TypeVar
|
||||
|
||||
import triton
|
||||
from . import semantic
|
||||
from . import builtin, semantic
|
||||
from triton._C.libtriton.triton import ir
|
||||
|
||||
T = TypeVar('T')
|
||||
@@ -34,17 +33,6 @@ def _to_tensor(x, builder):
|
||||
assert False, f'cannot convert {x} to tensor'
|
||||
|
||||
|
||||
def builtin(fn: T) -> T:
|
||||
@wraps(fn)
|
||||
def wrapper(*args, **kwargs):
|
||||
if '_builder' not in kwargs or \
|
||||
kwargs['_builder'] is None:
|
||||
raise ValueError("Did you forget to add @triton.jit ? (`_builder` argument must be provided outside of JIT functions.)")
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class dtype:
|
||||
SINT_TYPES = ['int1', 'int8', 'int16', 'int32', 'int64']
|
||||
UINT_TYPES = ['uint8', 'uint16', 'uint32', 'uint64']
|
||||
|
@@ -86,25 +86,3 @@ def elementwise(lib_name: str, lib_path: str, args: list, arg_type_symbol_dict:
|
||||
ret_shape = broadcast_arg.shape
|
||||
func = getattr(_builder, "create_external_elementwise")
|
||||
return dispatch(func, lib_name, lib_path, dispatch_args, arg_type_symbol_dict, ret_shape, _builder)
|
||||
|
||||
|
||||
class ExternalFunction:
|
||||
'''
|
||||
A wrapper for external functions
|
||||
'''
|
||||
|
||||
def __init__(self, fn):
|
||||
self.fn = fn
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
if '_builder' not in kwargs or \
|
||||
kwargs['_builder'] is None:
|
||||
raise ValueError("Did you forget to add @triton.jit ? (`_builder` argument must be provided outside of JIT functions.)")
|
||||
return self.fn(*args, **kwargs)
|
||||
|
||||
|
||||
def extern(fn):
|
||||
'''
|
||||
A decorator for external functions
|
||||
'''
|
||||
return ExternalFunction(fn)
|
||||
|
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user