From e0072d210ae647732aac4509673cf1965e279cd7 Mon Sep 17 00:00:00 2001 From: Crutcher Dunnavant Date: Mon, 5 Dec 2022 14:41:02 -0800 Subject: [PATCH] [FRONTEND] Propagate mypy types through @jit, @builtin, etc (#915) Changes to make decorated API methods no longer type-opaque. ``` $ echo 'import triton; reveal_type(triton.language.max)' | mypy /dev/stdin /dev/stdin:1: note: Revealed type is "def (input: Any, axis: Any, _builder: Any =) -> Any" Success: no issues found in 1 source file ``` --- python/triton/language/core.py | 17 ++++----- python/triton/runtime/jit.py | 64 ++++++++++++++++++++++++++-------- 2 files changed, 58 insertions(+), 23 deletions(-) diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 57a7b3b52..7fb2cf2ae 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -2,12 +2,13 @@ from __future__ import annotations from enum import Enum from functools import wraps -from typing import List +from typing import List, Callable, TypeVar import triton from . import semantic from triton._C.libtriton.triton import ir +T = TypeVar('T') def _to_tensor(x, builder): if isinstance(x, bool): @@ -33,7 +34,7 @@ def _to_tensor(x, builder): assert False, f'cannot convert {x} to tensor' -def builtin(fn): +def builtin(fn: T) -> T: @wraps(fn) def wrapper(*args, **kwargs): if '_builder' not in kwargs or \ @@ -851,9 +852,9 @@ def store(pointer, value, mask=None, _builder=None): # Atomic Memory Operations # ----------------------- -def _add_atomic_docstr(name): +def _add_atomic_docstr(name: str) -> Callable[[T], T]: - def _decorator(func): + def _decorator(func: T) -> T: docstr = """ Performs an atomic {name} at the memory location specified by :code:`pointer`. @@ -974,9 +975,9 @@ def fdiv(x, y, ieee_rounding=False, _builder=None): return semantic.fdiv(x, y, ieee_rounding, _builder) -def _add_math_1arg_docstr(name): +def _add_math_1arg_docstr(name: str) -> Callable[[T], T]: - def _decorator(func): + def _decorator(func: T) -> T: docstr = """ Computes the element-wise {name} of :code:`x` @@ -1023,9 +1024,9 @@ def sqrt(x, _builder=None): # Reductions # ----------------------- -def _add_reduction_docstr(name): +def _add_reduction_docstr(name: str) -> Callable[[T], T]: - def _decorator(func): + def _decorator(func: T) -> T: docstr = """ Returns the {name} of all elements in the :code:`input` tensor along the provided :code:`axis` diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py index 8d3704fbf..83ee33806 100644 --- a/python/triton/runtime/jit.py +++ b/python/triton/runtime/jit.py @@ -8,6 +8,7 @@ import os import subprocess import textwrap from collections import namedtuple +from typing import TypeVar, Generic, cast, Callable, overload, Optional, Iterable, Union import torch @@ -19,6 +20,9 @@ try: except ImportError: get_cuda_stream = lambda dev_idx: torch.cuda.current_stream(dev_idx).cuda_stream + +T = TypeVar('T') + # ----------------------------------------------------------------------------- # Dependencies Finder # ----------------------------------------------------------------------------- @@ -94,20 +98,20 @@ def version_key(): return '-'.join(triton.__version__) + '-' + ptxas_version + '-' + '-'.join(contents) -class KernelInterface: +class KernelInterface(Generic[T]): + run: T - def __getitem__(self, grid): + def __getitem__(self, grid) -> T: """ A JIT function is launched with: fn[grid](*args, **kwargs). Hence JITFunction.__getitem__ returns a callable proxy that memorizes the grid. """ - def launcher(*args, **kwargs): - return self.run(*args, grid=grid, **kwargs) - return launcher + return cast(T, functools.partial(cast(Callable, self.run), grid=grid)) -class JITFunction(KernelInterface): + +class JITFunction(KernelInterface[T]): cache_hook = None divisibility = 16 @@ -367,25 +371,55 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage # ----------------------------------------------------------------------------- -def jit(*args, **kwargs): +@overload +def jit(fn: T) -> JITFunction[T]: + ... + + +@overload +def jit( + *, + version=None, + do_not_specialize: Optional[Iterable[int]] = None, +) -> Callable[[T], JITFunction[T]]: + ... + + +def jit( + fn: Optional[T] = None, + *, + version=None, + do_not_specialize: Optional[Iterable[int]] = None, +) -> Union[JITFunction[T], Callable[[T], JITFunction[T]]]: """ Decorator for JIT-compiling a function using the Triton compiler. - :note: When a jit'd function is called, :code:`torch.tensor` arguments are implicitly converted to pointers using the :code:`.data_ptr()` method. + + :note: When a jit'd function is called, :code:`torch.tensor` arguments are + implicitly converted to pointers using the :code:`.data_ptr()` method. + :note: This function will be compiled and run on the GPU. It will only have access to: + * python primitives, - * objects within the triton.language package, + * builtins within the triton package, * arguments to this function, * other jit'd functions + :param fn: the function to be jit-compiled :type fn: Callable """ - if args: - assert len(args) == 1 - assert callable(args[0]) - return JITFunction(args[0], **kwargs) + + def decorator(fn: T) -> JITFunction[T]: + assert callable(fn) + return JITFunction( + fn, + version=version, + do_not_specialize=do_not_specialize, + ) + + if fn is not None: + return decorator(fn) + else: - def decorator(fn): - return JITFunction(fn, **kwargs) return decorator