[FRONTEND] Propagate mypy types through @jit, @builtin, etc (#915)
Changes to make decorated API methods no longer type-opaque. ``` $ echo 'import triton; reveal_type(triton.language.max)' | mypy /dev/stdin /dev/stdin:1: note: Revealed type is "def (input: Any, axis: Any, _builder: Any =) -> Any" Success: no issues found in 1 source file ```
This commit is contained in:
committed by
GitHub
parent
2fa17588f7
commit
e0072d210a
@@ -2,12 +2,13 @@ from __future__ import annotations
|
||||
|
||||
from enum import Enum
|
||||
from functools import wraps
|
||||
from typing import List
|
||||
from typing import List, Callable, TypeVar
|
||||
|
||||
import triton
|
||||
from . import semantic
|
||||
from triton._C.libtriton.triton import ir
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
def _to_tensor(x, builder):
|
||||
if isinstance(x, bool):
|
||||
@@ -33,7 +34,7 @@ def _to_tensor(x, builder):
|
||||
assert False, f'cannot convert {x} to tensor'
|
||||
|
||||
|
||||
def builtin(fn):
|
||||
def builtin(fn: T) -> T:
|
||||
@wraps(fn)
|
||||
def wrapper(*args, **kwargs):
|
||||
if '_builder' not in kwargs or \
|
||||
@@ -851,9 +852,9 @@ def store(pointer, value, mask=None, _builder=None):
|
||||
# Atomic Memory Operations
|
||||
# -----------------------
|
||||
|
||||
def _add_atomic_docstr(name):
|
||||
def _add_atomic_docstr(name: str) -> Callable[[T], T]:
|
||||
|
||||
def _decorator(func):
|
||||
def _decorator(func: T) -> T:
|
||||
docstr = """
|
||||
Performs an atomic {name} at the memory location specified by :code:`pointer`.
|
||||
|
||||
@@ -974,9 +975,9 @@ def fdiv(x, y, ieee_rounding=False, _builder=None):
|
||||
return semantic.fdiv(x, y, ieee_rounding, _builder)
|
||||
|
||||
|
||||
def _add_math_1arg_docstr(name):
|
||||
def _add_math_1arg_docstr(name: str) -> Callable[[T], T]:
|
||||
|
||||
def _decorator(func):
|
||||
def _decorator(func: T) -> T:
|
||||
docstr = """
|
||||
Computes the element-wise {name} of :code:`x`
|
||||
|
||||
@@ -1023,9 +1024,9 @@ def sqrt(x, _builder=None):
|
||||
# Reductions
|
||||
# -----------------------
|
||||
|
||||
def _add_reduction_docstr(name):
|
||||
def _add_reduction_docstr(name: str) -> Callable[[T], T]:
|
||||
|
||||
def _decorator(func):
|
||||
def _decorator(func: T) -> T:
|
||||
docstr = """
|
||||
Returns the {name} of all elements in the :code:`input` tensor along the provided :code:`axis`
|
||||
|
||||
|
@@ -8,6 +8,7 @@ import os
|
||||
import subprocess
|
||||
import textwrap
|
||||
from collections import namedtuple
|
||||
from typing import TypeVar, Generic, cast, Callable, overload, Optional, Iterable, Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -19,6 +20,9 @@ try:
|
||||
except ImportError:
|
||||
get_cuda_stream = lambda dev_idx: torch.cuda.current_stream(dev_idx).cuda_stream
|
||||
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Dependencies Finder
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -94,20 +98,20 @@ def version_key():
|
||||
return '-'.join(triton.__version__) + '-' + ptxas_version + '-' + '-'.join(contents)
|
||||
|
||||
|
||||
class KernelInterface:
|
||||
class KernelInterface(Generic[T]):
|
||||
run: T
|
||||
|
||||
def __getitem__(self, grid):
|
||||
def __getitem__(self, grid) -> T:
|
||||
"""
|
||||
A JIT function is launched with: fn[grid](*args, **kwargs).
|
||||
Hence JITFunction.__getitem__ returns a callable proxy that
|
||||
memorizes the grid.
|
||||
"""
|
||||
def launcher(*args, **kwargs):
|
||||
return self.run(*args, grid=grid, **kwargs)
|
||||
return launcher
|
||||
return cast(T, functools.partial(cast(Callable, self.run), grid=grid))
|
||||
|
||||
|
||||
class JITFunction(KernelInterface):
|
||||
|
||||
class JITFunction(KernelInterface[T]):
|
||||
|
||||
cache_hook = None
|
||||
divisibility = 16
|
||||
@@ -367,25 +371,55 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def jit(*args, **kwargs):
|
||||
@overload
|
||||
def jit(fn: T) -> JITFunction[T]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def jit(
|
||||
*,
|
||||
version=None,
|
||||
do_not_specialize: Optional[Iterable[int]] = None,
|
||||
) -> Callable[[T], JITFunction[T]]:
|
||||
...
|
||||
|
||||
|
||||
def jit(
|
||||
fn: Optional[T] = None,
|
||||
*,
|
||||
version=None,
|
||||
do_not_specialize: Optional[Iterable[int]] = None,
|
||||
) -> Union[JITFunction[T], Callable[[T], JITFunction[T]]]:
|
||||
"""
|
||||
Decorator for JIT-compiling a function using the Triton compiler.
|
||||
:note: When a jit'd function is called, :code:`torch.tensor` arguments are implicitly converted to pointers using the :code:`.data_ptr()` method.
|
||||
|
||||
:note: When a jit'd function is called, :code:`torch.tensor` arguments are
|
||||
implicitly converted to pointers using the :code:`.data_ptr()` method.
|
||||
|
||||
:note: This function will be compiled and run on the GPU. It will only have access to:
|
||||
|
||||
* python primitives,
|
||||
* objects within the triton.language package,
|
||||
* builtins within the triton package,
|
||||
* arguments to this function,
|
||||
* other jit'd functions
|
||||
|
||||
:param fn: the function to be jit-compiled
|
||||
:type fn: Callable
|
||||
"""
|
||||
if args:
|
||||
assert len(args) == 1
|
||||
assert callable(args[0])
|
||||
return JITFunction(args[0], **kwargs)
|
||||
|
||||
def decorator(fn: T) -> JITFunction[T]:
|
||||
assert callable(fn)
|
||||
return JITFunction(
|
||||
fn,
|
||||
version=version,
|
||||
do_not_specialize=do_not_specialize,
|
||||
)
|
||||
|
||||
if fn is not None:
|
||||
return decorator(fn)
|
||||
|
||||
else:
|
||||
def decorator(fn):
|
||||
return JITFunction(fn, **kwargs)
|
||||
return decorator
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user