[FRONTEND] Propagate mypy types through @jit, @builtin, etc (#915)

Changes to make decorated API methods no longer type-opaque.

```
$ echo 'import triton; reveal_type(triton.language.max)' | mypy /dev/stdin
/dev/stdin:1: note: Revealed type is "def (input: Any, axis: Any, _builder: Any =) -> Any"
Success: no issues found in 1 source file
```
This commit is contained in:
Crutcher Dunnavant
2022-12-05 14:41:02 -08:00
committed by GitHub
parent 2fa17588f7
commit e0072d210a
2 changed files with 58 additions and 23 deletions

View File

@@ -2,12 +2,13 @@ from __future__ import annotations
from enum import Enum
from functools import wraps
from typing import List
from typing import List, Callable, TypeVar
import triton
from . import semantic
from triton._C.libtriton.triton import ir
T = TypeVar('T')
def _to_tensor(x, builder):
if isinstance(x, bool):
@@ -33,7 +34,7 @@ def _to_tensor(x, builder):
assert False, f'cannot convert {x} to tensor'
def builtin(fn):
def builtin(fn: T) -> T:
@wraps(fn)
def wrapper(*args, **kwargs):
if '_builder' not in kwargs or \
@@ -851,9 +852,9 @@ def store(pointer, value, mask=None, _builder=None):
# Atomic Memory Operations
# -----------------------
def _add_atomic_docstr(name):
def _add_atomic_docstr(name: str) -> Callable[[T], T]:
def _decorator(func):
def _decorator(func: T) -> T:
docstr = """
Performs an atomic {name} at the memory location specified by :code:`pointer`.
@@ -974,9 +975,9 @@ def fdiv(x, y, ieee_rounding=False, _builder=None):
return semantic.fdiv(x, y, ieee_rounding, _builder)
def _add_math_1arg_docstr(name):
def _add_math_1arg_docstr(name: str) -> Callable[[T], T]:
def _decorator(func):
def _decorator(func: T) -> T:
docstr = """
Computes the element-wise {name} of :code:`x`
@@ -1023,9 +1024,9 @@ def sqrt(x, _builder=None):
# Reductions
# -----------------------
def _add_reduction_docstr(name):
def _add_reduction_docstr(name: str) -> Callable[[T], T]:
def _decorator(func):
def _decorator(func: T) -> T:
docstr = """
Returns the {name} of all elements in the :code:`input` tensor along the provided :code:`axis`

View File

@@ -8,6 +8,7 @@ import os
import subprocess
import textwrap
from collections import namedtuple
from typing import TypeVar, Generic, cast, Callable, overload, Optional, Iterable, Union
import torch
@@ -19,6 +20,9 @@ try:
except ImportError:
get_cuda_stream = lambda dev_idx: torch.cuda.current_stream(dev_idx).cuda_stream
T = TypeVar('T')
# -----------------------------------------------------------------------------
# Dependencies Finder
# -----------------------------------------------------------------------------
@@ -94,20 +98,20 @@ def version_key():
return '-'.join(triton.__version__) + '-' + ptxas_version + '-' + '-'.join(contents)
class KernelInterface:
class KernelInterface(Generic[T]):
run: T
def __getitem__(self, grid):
def __getitem__(self, grid) -> T:
"""
A JIT function is launched with: fn[grid](*args, **kwargs).
Hence JITFunction.__getitem__ returns a callable proxy that
memorizes the grid.
"""
def launcher(*args, **kwargs):
return self.run(*args, grid=grid, **kwargs)
return launcher
return cast(T, functools.partial(cast(Callable, self.run), grid=grid))
class JITFunction(KernelInterface):
class JITFunction(KernelInterface[T]):
cache_hook = None
divisibility = 16
@@ -367,25 +371,55 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage
# -----------------------------------------------------------------------------
def jit(*args, **kwargs):
@overload
def jit(fn: T) -> JITFunction[T]:
...
@overload
def jit(
*,
version=None,
do_not_specialize: Optional[Iterable[int]] = None,
) -> Callable[[T], JITFunction[T]]:
...
def jit(
fn: Optional[T] = None,
*,
version=None,
do_not_specialize: Optional[Iterable[int]] = None,
) -> Union[JITFunction[T], Callable[[T], JITFunction[T]]]:
"""
Decorator for JIT-compiling a function using the Triton compiler.
:note: When a jit'd function is called, :code:`torch.tensor` arguments are implicitly converted to pointers using the :code:`.data_ptr()` method.
:note: When a jit'd function is called, :code:`torch.tensor` arguments are
implicitly converted to pointers using the :code:`.data_ptr()` method.
:note: This function will be compiled and run on the GPU. It will only have access to:
* python primitives,
* objects within the triton.language package,
* builtins within the triton package,
* arguments to this function,
* other jit'd functions
:param fn: the function to be jit-compiled
:type fn: Callable
"""
if args:
assert len(args) == 1
assert callable(args[0])
return JITFunction(args[0], **kwargs)
def decorator(fn: T) -> JITFunction[T]:
assert callable(fn)
return JITFunction(
fn,
version=version,
do_not_specialize=do_not_specialize,
)
if fn is not None:
return decorator(fn)
else:
def decorator(fn):
return JITFunction(fn, **kwargs)
return decorator