[FRONTEND] Extract and unify @builtin/@extern (#913)
This change attaches builtin-ness as an explicit attribute, rather than a module prefix expectation. This permits us to source those builtins from multiple sub-modules (useful when some builtins are part of the true cyclic implementation core, and some are just useful library additions); but also prevents accidental inclusion of non-builtins that happen to be in the right library. Once the flag exists, and the compiler is using `is_builtin()` for decision making; the existence of the current `@extern` interface becomes isomorphic to `@builtin`; and the interface can be unified. Leaving `@extern` a thin-wrapper, and encouraging continued use of it, establishes future-proofing towards adding additional extern tracing, metric hooks, or scanning in the future. * Add `triton.impl` package to hold the core, order dependent impl details. * Extract `@builtin` and unify `@extern`; add `is_builtin()` * Add sense bit so that `@builtin` detection is less fragile. * Modify the compiler to use `is_builtin()`
This commit is contained in:
committed by
GitHub
parent
e0072d210a
commit
189491727a
@@ -9,6 +9,7 @@ __version__ = '2.0.0'
|
|||||||
import torch # noqa: F401
|
import torch # noqa: F401
|
||||||
|
|
||||||
# submodules
|
# submodules
|
||||||
|
from . import impl
|
||||||
from .utils import (
|
from .utils import (
|
||||||
cdiv,
|
cdiv,
|
||||||
MockTensor,
|
MockTensor,
|
||||||
@@ -36,6 +37,7 @@ __all__ = [
|
|||||||
"compile",
|
"compile",
|
||||||
"Config",
|
"Config",
|
||||||
"heuristics",
|
"heuristics",
|
||||||
|
"impl",
|
||||||
"jit",
|
"jit",
|
||||||
"JITFunction",
|
"JITFunction",
|
||||||
"KernelInterface",
|
"KernelInterface",
|
||||||
|
@@ -25,6 +25,8 @@ from filelock import FileLock
|
|||||||
|
|
||||||
import triton
|
import triton
|
||||||
import triton._C.libtriton.triton as _triton
|
import triton._C.libtriton.triton as _triton
|
||||||
|
|
||||||
|
from . import impl
|
||||||
from .tools.disasm import extract
|
from .tools.disasm import extract
|
||||||
|
|
||||||
|
|
||||||
@@ -715,9 +717,8 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
for i in range(call_op.get_num_results()):
|
for i in range(call_op.get_num_results()):
|
||||||
results.append(triton.language.tensor(call_op.get_result(i), callee_ret_type[i]))
|
results.append(triton.language.tensor(call_op.get_result(i), callee_ret_type[i]))
|
||||||
return tuple(results)
|
return tuple(results)
|
||||||
if hasattr(fn, '__self__') and self.is_triton_tensor(fn.__self__) or \
|
if (hasattr(fn, '__self__') and self.is_triton_tensor(fn.__self__)) \
|
||||||
sys.modules[fn.__module__] is triton.language.core or \
|
or impl.is_builtin(fn):
|
||||||
isinstance(fn, triton.language.extern.ExternalFunction):
|
|
||||||
return fn(*args, _builder=self.builder, **kws)
|
return fn(*args, _builder=self.builder, **kws)
|
||||||
if fn in self.builtins.values():
|
if fn in self.builtins.values():
|
||||||
args = [arg.value if isinstance(arg, triton.language.constexpr) else arg
|
args = [arg.value if isinstance(arg, triton.language.constexpr) else arg
|
||||||
|
22
python/triton/impl/__init__.py
Normal file
22
python/triton/impl/__init__.py
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
"""Triton internal implementation details.
|
||||||
|
|
||||||
|
Client libraries should not import interfaces from the `triton.impl` module;
|
||||||
|
as the details are subject to change.
|
||||||
|
|
||||||
|
APIs defined in the `triton.impl` module which are public will be re-exported
|
||||||
|
in other relevant `triton` module namespaces.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from triton._C.libtriton.triton import ir
|
||||||
|
from .base import (
|
||||||
|
builtin,
|
||||||
|
extern,
|
||||||
|
is_builtin,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"builtin",
|
||||||
|
"extern",
|
||||||
|
"ir",
|
||||||
|
"is_builtin",
|
||||||
|
]
|
36
python/triton/impl/base.py
Normal file
36
python/triton/impl/base.py
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from functools import wraps
|
||||||
|
from typing import TypeVar
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
TRITON_BUILTIN = "__triton_builtin__"
|
||||||
|
|
||||||
|
|
||||||
|
def builtin(fn: T) -> T:
|
||||||
|
"""Mark a function as a builtin."""
|
||||||
|
assert callable(fn)
|
||||||
|
|
||||||
|
@wraps(fn)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
if "_builder" not in kwargs or kwargs["_builder"] is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Did you forget to add @triton.jit ? "
|
||||||
|
"(`_builder` argument must be provided outside of JIT functions.)"
|
||||||
|
)
|
||||||
|
return fn(*args, **kwargs)
|
||||||
|
|
||||||
|
setattr(wrapper, TRITON_BUILTIN, True)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
def is_builtin(fn) -> bool:
|
||||||
|
"""Is this a registered triton builtin function?"""
|
||||||
|
return getattr(fn, TRITON_BUILTIN, False)
|
||||||
|
|
||||||
|
|
||||||
|
def extern(fn: T) -> T:
|
||||||
|
"""A decorator for external functions."""
|
||||||
|
return builtin(fn)
|
@@ -1,8 +1,10 @@
|
|||||||
"""isort:skip_file"""
|
"""isort:skip_file"""
|
||||||
# Import order is significant here.
|
# Import order is significant here.
|
||||||
|
|
||||||
from triton._C.libtriton.triton import ir
|
from ..impl import (
|
||||||
|
ir,
|
||||||
|
builtin,
|
||||||
|
)
|
||||||
from . import core, extern, libdevice, random
|
from . import core, extern, libdevice, random
|
||||||
from .core import (
|
from .core import (
|
||||||
abs,
|
abs,
|
||||||
|
@@ -1,11 +1,10 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from functools import wraps
|
|
||||||
from typing import List, Callable, TypeVar
|
from typing import List, Callable, TypeVar
|
||||||
|
|
||||||
import triton
|
import triton
|
||||||
from . import semantic
|
from . import builtin, semantic
|
||||||
from triton._C.libtriton.triton import ir
|
from triton._C.libtriton.triton import ir
|
||||||
|
|
||||||
T = TypeVar('T')
|
T = TypeVar('T')
|
||||||
@@ -34,17 +33,6 @@ def _to_tensor(x, builder):
|
|||||||
assert False, f'cannot convert {x} to tensor'
|
assert False, f'cannot convert {x} to tensor'
|
||||||
|
|
||||||
|
|
||||||
def builtin(fn: T) -> T:
|
|
||||||
@wraps(fn)
|
|
||||||
def wrapper(*args, **kwargs):
|
|
||||||
if '_builder' not in kwargs or \
|
|
||||||
kwargs['_builder'] is None:
|
|
||||||
raise ValueError("Did you forget to add @triton.jit ? (`_builder` argument must be provided outside of JIT functions.)")
|
|
||||||
return fn(*args, **kwargs)
|
|
||||||
|
|
||||||
return wrapper
|
|
||||||
|
|
||||||
|
|
||||||
class dtype:
|
class dtype:
|
||||||
SINT_TYPES = ['int1', 'int8', 'int16', 'int32', 'int64']
|
SINT_TYPES = ['int1', 'int8', 'int16', 'int32', 'int64']
|
||||||
UINT_TYPES = ['uint8', 'uint16', 'uint32', 'uint64']
|
UINT_TYPES = ['uint8', 'uint16', 'uint32', 'uint64']
|
||||||
|
@@ -86,25 +86,3 @@ def elementwise(lib_name: str, lib_path: str, args: list, arg_type_symbol_dict:
|
|||||||
ret_shape = broadcast_arg.shape
|
ret_shape = broadcast_arg.shape
|
||||||
func = getattr(_builder, "create_external_elementwise")
|
func = getattr(_builder, "create_external_elementwise")
|
||||||
return dispatch(func, lib_name, lib_path, dispatch_args, arg_type_symbol_dict, ret_shape, _builder)
|
return dispatch(func, lib_name, lib_path, dispatch_args, arg_type_symbol_dict, ret_shape, _builder)
|
||||||
|
|
||||||
|
|
||||||
class ExternalFunction:
|
|
||||||
'''
|
|
||||||
A wrapper for external functions
|
|
||||||
'''
|
|
||||||
|
|
||||||
def __init__(self, fn):
|
|
||||||
self.fn = fn
|
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
|
||||||
if '_builder' not in kwargs or \
|
|
||||||
kwargs['_builder'] is None:
|
|
||||||
raise ValueError("Did you forget to add @triton.jit ? (`_builder` argument must be provided outside of JIT functions.)")
|
|
||||||
return self.fn(*args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def extern(fn):
|
|
||||||
'''
|
|
||||||
A decorator for external functions
|
|
||||||
'''
|
|
||||||
return ExternalFunction(fn)
|
|
||||||
|
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user