[FRONTEND] Propagate mypy types through @jit, @builtin, etc (#915)
Changes to make decorated API methods no longer type-opaque. ``` $ echo 'import triton; reveal_type(triton.language.max)' | mypy /dev/stdin /dev/stdin:1: note: Revealed type is "def (input: Any, axis: Any, _builder: Any =) -> Any" Success: no issues found in 1 source file ```
This commit is contained in:
committed by
GitHub
parent
2fa17588f7
commit
e0072d210a
@@ -2,12 +2,13 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import List
|
from typing import List, Callable, TypeVar
|
||||||
|
|
||||||
import triton
|
import triton
|
||||||
from . import semantic
|
from . import semantic
|
||||||
from triton._C.libtriton.triton import ir
|
from triton._C.libtriton.triton import ir
|
||||||
|
|
||||||
|
T = TypeVar('T')
|
||||||
|
|
||||||
def _to_tensor(x, builder):
|
def _to_tensor(x, builder):
|
||||||
if isinstance(x, bool):
|
if isinstance(x, bool):
|
||||||
@@ -33,7 +34,7 @@ def _to_tensor(x, builder):
|
|||||||
assert False, f'cannot convert {x} to tensor'
|
assert False, f'cannot convert {x} to tensor'
|
||||||
|
|
||||||
|
|
||||||
def builtin(fn):
|
def builtin(fn: T) -> T:
|
||||||
@wraps(fn)
|
@wraps(fn)
|
||||||
def wrapper(*args, **kwargs):
|
def wrapper(*args, **kwargs):
|
||||||
if '_builder' not in kwargs or \
|
if '_builder' not in kwargs or \
|
||||||
@@ -851,9 +852,9 @@ def store(pointer, value, mask=None, _builder=None):
|
|||||||
# Atomic Memory Operations
|
# Atomic Memory Operations
|
||||||
# -----------------------
|
# -----------------------
|
||||||
|
|
||||||
def _add_atomic_docstr(name):
|
def _add_atomic_docstr(name: str) -> Callable[[T], T]:
|
||||||
|
|
||||||
def _decorator(func):
|
def _decorator(func: T) -> T:
|
||||||
docstr = """
|
docstr = """
|
||||||
Performs an atomic {name} at the memory location specified by :code:`pointer`.
|
Performs an atomic {name} at the memory location specified by :code:`pointer`.
|
||||||
|
|
||||||
@@ -974,9 +975,9 @@ def fdiv(x, y, ieee_rounding=False, _builder=None):
|
|||||||
return semantic.fdiv(x, y, ieee_rounding, _builder)
|
return semantic.fdiv(x, y, ieee_rounding, _builder)
|
||||||
|
|
||||||
|
|
||||||
def _add_math_1arg_docstr(name):
|
def _add_math_1arg_docstr(name: str) -> Callable[[T], T]:
|
||||||
|
|
||||||
def _decorator(func):
|
def _decorator(func: T) -> T:
|
||||||
docstr = """
|
docstr = """
|
||||||
Computes the element-wise {name} of :code:`x`
|
Computes the element-wise {name} of :code:`x`
|
||||||
|
|
||||||
@@ -1023,9 +1024,9 @@ def sqrt(x, _builder=None):
|
|||||||
# Reductions
|
# Reductions
|
||||||
# -----------------------
|
# -----------------------
|
||||||
|
|
||||||
def _add_reduction_docstr(name):
|
def _add_reduction_docstr(name: str) -> Callable[[T], T]:
|
||||||
|
|
||||||
def _decorator(func):
|
def _decorator(func: T) -> T:
|
||||||
docstr = """
|
docstr = """
|
||||||
Returns the {name} of all elements in the :code:`input` tensor along the provided :code:`axis`
|
Returns the {name} of all elements in the :code:`input` tensor along the provided :code:`axis`
|
||||||
|
|
||||||
|
@@ -8,6 +8,7 @@ import os
|
|||||||
import subprocess
|
import subprocess
|
||||||
import textwrap
|
import textwrap
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
|
from typing import TypeVar, Generic, cast, Callable, overload, Optional, Iterable, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@@ -19,6 +20,9 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
get_cuda_stream = lambda dev_idx: torch.cuda.current_stream(dev_idx).cuda_stream
|
get_cuda_stream = lambda dev_idx: torch.cuda.current_stream(dev_idx).cuda_stream
|
||||||
|
|
||||||
|
|
||||||
|
T = TypeVar('T')
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
# Dependencies Finder
|
# Dependencies Finder
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
@@ -94,20 +98,20 @@ def version_key():
|
|||||||
return '-'.join(triton.__version__) + '-' + ptxas_version + '-' + '-'.join(contents)
|
return '-'.join(triton.__version__) + '-' + ptxas_version + '-' + '-'.join(contents)
|
||||||
|
|
||||||
|
|
||||||
class KernelInterface:
|
class KernelInterface(Generic[T]):
|
||||||
|
run: T
|
||||||
|
|
||||||
def __getitem__(self, grid):
|
def __getitem__(self, grid) -> T:
|
||||||
"""
|
"""
|
||||||
A JIT function is launched with: fn[grid](*args, **kwargs).
|
A JIT function is launched with: fn[grid](*args, **kwargs).
|
||||||
Hence JITFunction.__getitem__ returns a callable proxy that
|
Hence JITFunction.__getitem__ returns a callable proxy that
|
||||||
memorizes the grid.
|
memorizes the grid.
|
||||||
"""
|
"""
|
||||||
def launcher(*args, **kwargs):
|
return cast(T, functools.partial(cast(Callable, self.run), grid=grid))
|
||||||
return self.run(*args, grid=grid, **kwargs)
|
|
||||||
return launcher
|
|
||||||
|
|
||||||
|
|
||||||
class JITFunction(KernelInterface):
|
|
||||||
|
class JITFunction(KernelInterface[T]):
|
||||||
|
|
||||||
cache_hook = None
|
cache_hook = None
|
||||||
divisibility = 16
|
divisibility = 16
|
||||||
@@ -367,25 +371,55 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage
|
|||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
def jit(*args, **kwargs):
|
@overload
|
||||||
|
def jit(fn: T) -> JITFunction[T]:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def jit(
|
||||||
|
*,
|
||||||
|
version=None,
|
||||||
|
do_not_specialize: Optional[Iterable[int]] = None,
|
||||||
|
) -> Callable[[T], JITFunction[T]]:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
def jit(
|
||||||
|
fn: Optional[T] = None,
|
||||||
|
*,
|
||||||
|
version=None,
|
||||||
|
do_not_specialize: Optional[Iterable[int]] = None,
|
||||||
|
) -> Union[JITFunction[T], Callable[[T], JITFunction[T]]]:
|
||||||
"""
|
"""
|
||||||
Decorator for JIT-compiling a function using the Triton compiler.
|
Decorator for JIT-compiling a function using the Triton compiler.
|
||||||
:note: When a jit'd function is called, :code:`torch.tensor` arguments are implicitly converted to pointers using the :code:`.data_ptr()` method.
|
|
||||||
|
:note: When a jit'd function is called, :code:`torch.tensor` arguments are
|
||||||
|
implicitly converted to pointers using the :code:`.data_ptr()` method.
|
||||||
|
|
||||||
:note: This function will be compiled and run on the GPU. It will only have access to:
|
:note: This function will be compiled and run on the GPU. It will only have access to:
|
||||||
|
|
||||||
* python primitives,
|
* python primitives,
|
||||||
* objects within the triton.language package,
|
* builtins within the triton package,
|
||||||
* arguments to this function,
|
* arguments to this function,
|
||||||
* other jit'd functions
|
* other jit'd functions
|
||||||
|
|
||||||
:param fn: the function to be jit-compiled
|
:param fn: the function to be jit-compiled
|
||||||
:type fn: Callable
|
:type fn: Callable
|
||||||
"""
|
"""
|
||||||
if args:
|
|
||||||
assert len(args) == 1
|
def decorator(fn: T) -> JITFunction[T]:
|
||||||
assert callable(args[0])
|
assert callable(fn)
|
||||||
return JITFunction(args[0], **kwargs)
|
return JITFunction(
|
||||||
|
fn,
|
||||||
|
version=version,
|
||||||
|
do_not_specialize=do_not_specialize,
|
||||||
|
)
|
||||||
|
|
||||||
|
if fn is not None:
|
||||||
|
return decorator(fn)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
def decorator(fn):
|
|
||||||
return JITFunction(fn, **kwargs)
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user