From 2fa17588f728b03b8a0837ad406d6ad6bb8f1459 Mon Sep 17 00:00:00 2001 From: Crutcher Dunnavant Date: Mon, 5 Dec 2022 14:22:55 -0800 Subject: [PATCH] [FRONTEND] Expand __init__ * imports, add __all__ (#912) Expand `from .foo import *` to full listings, and `__all__` sections. This reifies the module export listings, which is useful for code importing this module; without this, clients will need special `mypy` control pragmas for this library. This removes a number of `# flake8` control pragmas. Verified with `flake8` --- python/triton/__init__.py | 43 +++++- python/triton/language/__init__.py | 173 +++++++++++++++++++++- python/triton/language/core.py | 4 +- python/triton/ops/__init__.py | 11 +- python/triton/ops/blocksparse/__init__.py | 6 +- python/triton/runtime/__init__.py | 14 +- 6 files changed, 237 insertions(+), 14 deletions(-) diff --git a/python/triton/__init__.py b/python/triton/__init__.py index c620543ee..426a7e40b 100644 --- a/python/triton/__init__.py +++ b/python/triton/__init__.py @@ -1,15 +1,50 @@ """isort:skip_file""" -# flake8: noqa: F401 __version__ = '2.0.0' +# --------------------------------------- +# Note: import order is significant here. + # TODO: torch needs to be imported first # or pybind11 shows `munmap_chunk(): invalid pointer` -import torch +import torch # noqa: F401 + # submodules -from .utils import * -from .runtime import Config, autotune, heuristics, JITFunction, KernelInterface +from .utils import ( + cdiv, + MockTensor, + next_power_of_2, + reinterpret, + TensorWrapper, +) +from .runtime import ( + autotune, + Config, + heuristics, + JITFunction, + KernelInterface, +) from .runtime.jit import jit from .compiler import compile, CompilationError from . import language from . import testing from . import ops + +__all__ = [ + "autotune", + "cdiv", + "CompilationError", + "compile", + "Config", + "heuristics", + "jit", + "JITFunction", + "KernelInterface", + "language", + "MockTensor", + "next_power_of_2", + "ops", + "reinterpret", + "runtime", + "TensorWrapper", + "testing", +] diff --git a/python/triton/language/__init__.py b/python/triton/language/__init__.py index 6b0058dd5..f9acbd3dc 100644 --- a/python/triton/language/__init__.py +++ b/python/triton/language/__init__.py @@ -1,4 +1,171 @@ -# flake8: noqa: F401 +"""isort:skip_file""" +# Import order is significant here. + +from triton._C.libtriton.triton import ir + from . import core, extern, libdevice, random -from .core import * -from .random import * +from .core import ( + abs, + arange, + argmin, + argmax, + atomic_add, + atomic_and, + atomic_cas, + atomic_max, + atomic_min, + atomic_or, + atomic_xchg, + atomic_xor, + bfloat16, + block_type, + builtin, + cat, + cdiv, + constexpr, + cos, + debug_barrier, + dot, + dtype, + exp, + fdiv, + float16, + float32, + float64, + float8, + function_type, + int1, + int16, + int32, + int64, + int8, + load, + log, + max, + max_contiguous, + maximum, + min, + minimum, + multiple_of, + num_programs, + pi32_t, + pointer_type, + printf, + program_id, + ravel, + sigmoid, + sin, + softmax, + sqrt, + store, + sum, + swizzle2d, + tensor, + trans, + triton, + uint16, + uint32, + uint64, + uint8, + umulhi, + void, + where, + xor_sum, + zeros, + zeros_like, +) +from .random import ( + pair_uniform_to_normal, + philox, + philox_impl, + rand, + rand4x, + randint, + randint4x, + randn, + randn4x, + uint32_to_uniform_float, +) + + +__all__ = [ + "abs", + "arange", + "argmin", + "argmax", + "atomic_add", + "atomic_and", + "atomic_cas", + "atomic_max", + "atomic_min", + "atomic_or", + "atomic_xchg", + "atomic_xor", + "bfloat16", + "block_type", + "builtin", + "cat", + "cdiv", + "constexpr", + "cos", + "debug_barrier", + "dot", + "dtype", + "exp", + "fdiv", + "float16", + "float32", + "float64", + "float8", + "function_type", + "int1", + "int16", + "int32", + "int64", + "int8", + "ir", + "load", + "log", + "max", + "max_contiguous", + "maximum", + "min", + "minimum", + "multiple_of", + "num_programs", + "pair_uniform_to_normal", + "philox", + "philox_impl", + "pi32_t", + "pointer_type", + "printf", + "program_id", + "rand", + "rand4x", + "randint", + "randint4x", + "randn", + "randn4x", + "ravel", + "sigmoid", + "sin", + "softmax", + "sqrt", + "store", + "sum", + "swizzle2d", + "tensor", + "trans", + "triton", + "uint16", + "uint32", + "uint32_to_uniform_float", + "uint64", + "uint8", + "umulhi", + "void", + "where", + "xor_sum", + "zeros", + "zeros_like", +] diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 77f458ba7..57a7b3b52 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -409,10 +409,10 @@ class constexpr: def __neg__(self): return constexpr(-self.value) - + def __pos__(self): return constexpr(+self.value) - + def __invert__(self): return constexpr(~self.value) diff --git a/python/triton/ops/__init__.py b/python/triton/ops/__init__.py index dcaed8ccf..a55cb08d7 100644 --- a/python/triton/ops/__init__.py +++ b/python/triton/ops/__init__.py @@ -1,5 +1,12 @@ -# flake8: noqa: F401 -#from .conv import _conv, conv +# from .conv import _conv, conv from . import blocksparse from .cross_entropy import _cross_entropy, cross_entropy from .matmul import _matmul, matmul + +__all__ = [ + "blocksparse", + "_cross_entropy", + "cross_entropy", + "_matmul", + "matmul", +] diff --git a/python/triton/ops/blocksparse/__init__.py b/python/triton/ops/blocksparse/__init__.py index df3353e12..6b24b5377 100644 --- a/python/triton/ops/blocksparse/__init__.py +++ b/python/triton/ops/blocksparse/__init__.py @@ -1,3 +1,7 @@ -# flake8: noqa: F401 from .matmul import matmul from .softmax import softmax + +__all__ = [ + "matmul", + "softmax", +] diff --git a/python/triton/runtime/__init__.py b/python/triton/runtime/__init__.py index d9946c27c..bb7066eb5 100644 --- a/python/triton/runtime/__init__.py +++ b/python/triton/runtime/__init__.py @@ -1,2 +1,12 @@ -from .autotuner import Config, Heuristics, autotune, heuristics # noqa: F401 -from .jit import JITFunction, KernelInterface, version_key # noqa: F401 +from .autotuner import Config, Heuristics, autotune, heuristics +from .jit import JITFunction, KernelInterface, version_key + +__all__ = [ + "Config", + "Heuristics", + "autotune", + "heuristics", + "JITFunction", + "KernelInterface", + "version_key", +]