[FRONTEND] Propagate mypy types through @jit, @builtin, etc (#915)

Changes to make decorated API methods no longer type-opaque.

```
$ echo 'import triton; reveal_type(triton.language.max)' | mypy /dev/stdin
/dev/stdin:1: note: Revealed type is "def (input: Any, axis: Any, _builder: Any =) -> Any"
Success: no issues found in 1 source file
```
This commit is contained in:
Crutcher Dunnavant
2022-12-05 14:41:02 -08:00
committed by GitHub
parent 2fa17588f7
commit e0072d210a
2 changed files with 58 additions and 23 deletions

View File

@@ -2,12 +2,13 @@ from __future__ import annotations
from enum import Enum from enum import Enum
from functools import wraps from functools import wraps
from typing import List from typing import List, Callable, TypeVar
import triton import triton
from . import semantic from . import semantic
from triton._C.libtriton.triton import ir from triton._C.libtriton.triton import ir
T = TypeVar('T')
def _to_tensor(x, builder): def _to_tensor(x, builder):
if isinstance(x, bool): if isinstance(x, bool):
@@ -33,7 +34,7 @@ def _to_tensor(x, builder):
assert False, f'cannot convert {x} to tensor' assert False, f'cannot convert {x} to tensor'
def builtin(fn): def builtin(fn: T) -> T:
@wraps(fn) @wraps(fn)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
if '_builder' not in kwargs or \ if '_builder' not in kwargs or \
@@ -851,9 +852,9 @@ def store(pointer, value, mask=None, _builder=None):
# Atomic Memory Operations # Atomic Memory Operations
# ----------------------- # -----------------------
def _add_atomic_docstr(name): def _add_atomic_docstr(name: str) -> Callable[[T], T]:
def _decorator(func): def _decorator(func: T) -> T:
docstr = """ docstr = """
Performs an atomic {name} at the memory location specified by :code:`pointer`. Performs an atomic {name} at the memory location specified by :code:`pointer`.
@@ -974,9 +975,9 @@ def fdiv(x, y, ieee_rounding=False, _builder=None):
return semantic.fdiv(x, y, ieee_rounding, _builder) return semantic.fdiv(x, y, ieee_rounding, _builder)
def _add_math_1arg_docstr(name): def _add_math_1arg_docstr(name: str) -> Callable[[T], T]:
def _decorator(func): def _decorator(func: T) -> T:
docstr = """ docstr = """
Computes the element-wise {name} of :code:`x` Computes the element-wise {name} of :code:`x`
@@ -1023,9 +1024,9 @@ def sqrt(x, _builder=None):
# Reductions # Reductions
# ----------------------- # -----------------------
def _add_reduction_docstr(name): def _add_reduction_docstr(name: str) -> Callable[[T], T]:
def _decorator(func): def _decorator(func: T) -> T:
docstr = """ docstr = """
Returns the {name} of all elements in the :code:`input` tensor along the provided :code:`axis` Returns the {name} of all elements in the :code:`input` tensor along the provided :code:`axis`

View File

@@ -8,6 +8,7 @@ import os
import subprocess import subprocess
import textwrap import textwrap
from collections import namedtuple from collections import namedtuple
from typing import TypeVar, Generic, cast, Callable, overload, Optional, Iterable, Union
import torch import torch
@@ -19,6 +20,9 @@ try:
except ImportError: except ImportError:
get_cuda_stream = lambda dev_idx: torch.cuda.current_stream(dev_idx).cuda_stream get_cuda_stream = lambda dev_idx: torch.cuda.current_stream(dev_idx).cuda_stream
T = TypeVar('T')
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Dependencies Finder # Dependencies Finder
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -94,20 +98,20 @@ def version_key():
return '-'.join(triton.__version__) + '-' + ptxas_version + '-' + '-'.join(contents) return '-'.join(triton.__version__) + '-' + ptxas_version + '-' + '-'.join(contents)
class KernelInterface: class KernelInterface(Generic[T]):
run: T
def __getitem__(self, grid): def __getitem__(self, grid) -> T:
""" """
A JIT function is launched with: fn[grid](*args, **kwargs). A JIT function is launched with: fn[grid](*args, **kwargs).
Hence JITFunction.__getitem__ returns a callable proxy that Hence JITFunction.__getitem__ returns a callable proxy that
memorizes the grid. memorizes the grid.
""" """
def launcher(*args, **kwargs): return cast(T, functools.partial(cast(Callable, self.run), grid=grid))
return self.run(*args, grid=grid, **kwargs)
return launcher
class JITFunction(KernelInterface):
class JITFunction(KernelInterface[T]):
cache_hook = None cache_hook = None
divisibility = 16 divisibility = 16
@@ -367,25 +371,55 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def jit(*args, **kwargs): @overload
def jit(fn: T) -> JITFunction[T]:
...
@overload
def jit(
*,
version=None,
do_not_specialize: Optional[Iterable[int]] = None,
) -> Callable[[T], JITFunction[T]]:
...
def jit(
fn: Optional[T] = None,
*,
version=None,
do_not_specialize: Optional[Iterable[int]] = None,
) -> Union[JITFunction[T], Callable[[T], JITFunction[T]]]:
""" """
Decorator for JIT-compiling a function using the Triton compiler. Decorator for JIT-compiling a function using the Triton compiler.
:note: When a jit'd function is called, :code:`torch.tensor` arguments are implicitly converted to pointers using the :code:`.data_ptr()` method.
:note: When a jit'd function is called, :code:`torch.tensor` arguments are
implicitly converted to pointers using the :code:`.data_ptr()` method.
:note: This function will be compiled and run on the GPU. It will only have access to: :note: This function will be compiled and run on the GPU. It will only have access to:
* python primitives, * python primitives,
* objects within the triton.language package, * builtins within the triton package,
* arguments to this function, * arguments to this function,
* other jit'd functions * other jit'd functions
:param fn: the function to be jit-compiled :param fn: the function to be jit-compiled
:type fn: Callable :type fn: Callable
""" """
if args:
assert len(args) == 1 def decorator(fn: T) -> JITFunction[T]:
assert callable(args[0]) assert callable(fn)
return JITFunction(args[0], **kwargs) return JITFunction(
fn,
version=version,
do_not_specialize=do_not_specialize,
)
if fn is not None:
return decorator(fn)
else: else:
def decorator(fn):
return JITFunction(fn, **kwargs)
return decorator return decorator