[FRONTEND] Expand __init__ * imports, add __all__ (#912)

Expand `from .foo import *` to full listings, and `__all__` sections.

This reifies the module export listings, which is useful for code
importing this module; without this, clients will need special `mypy`
control pragmas for this library.

This removes a number of `# flake8` control pragmas.

Verified with `flake8`
This commit is contained in:
Crutcher Dunnavant
2022-12-05 14:22:55 -08:00
committed by GitHub
parent e057c65cf0
commit 2fa17588f7
6 changed files with 237 additions and 14 deletions

View File

@@ -1,15 +1,50 @@
"""isort:skip_file"""
# flake8: noqa: F401
__version__ = '2.0.0'
# ---------------------------------------
# Note: import order is significant here.
# TODO: torch needs to be imported first
# or pybind11 shows `munmap_chunk(): invalid pointer`
import torch
import torch # noqa: F401
# submodules
from .utils import *
from .runtime import Config, autotune, heuristics, JITFunction, KernelInterface
from .utils import (
cdiv,
MockTensor,
next_power_of_2,
reinterpret,
TensorWrapper,
)
from .runtime import (
autotune,
Config,
heuristics,
JITFunction,
KernelInterface,
)
from .runtime.jit import jit
from .compiler import compile, CompilationError
from . import language
from . import testing
from . import ops
__all__ = [
"autotune",
"cdiv",
"CompilationError",
"compile",
"Config",
"heuristics",
"jit",
"JITFunction",
"KernelInterface",
"language",
"MockTensor",
"next_power_of_2",
"ops",
"reinterpret",
"runtime",
"TensorWrapper",
"testing",
]

View File

@@ -1,4 +1,171 @@
# flake8: noqa: F401
"""isort:skip_file"""
# Import order is significant here.
from triton._C.libtriton.triton import ir
from . import core, extern, libdevice, random
from .core import *
from .random import *
from .core import (
abs,
arange,
argmin,
argmax,
atomic_add,
atomic_and,
atomic_cas,
atomic_max,
atomic_min,
atomic_or,
atomic_xchg,
atomic_xor,
bfloat16,
block_type,
builtin,
cat,
cdiv,
constexpr,
cos,
debug_barrier,
dot,
dtype,
exp,
fdiv,
float16,
float32,
float64,
float8,
function_type,
int1,
int16,
int32,
int64,
int8,
load,
log,
max,
max_contiguous,
maximum,
min,
minimum,
multiple_of,
num_programs,
pi32_t,
pointer_type,
printf,
program_id,
ravel,
sigmoid,
sin,
softmax,
sqrt,
store,
sum,
swizzle2d,
tensor,
trans,
triton,
uint16,
uint32,
uint64,
uint8,
umulhi,
void,
where,
xor_sum,
zeros,
zeros_like,
)
from .random import (
pair_uniform_to_normal,
philox,
philox_impl,
rand,
rand4x,
randint,
randint4x,
randn,
randn4x,
uint32_to_uniform_float,
)
__all__ = [
"abs",
"arange",
"argmin",
"argmax",
"atomic_add",
"atomic_and",
"atomic_cas",
"atomic_max",
"atomic_min",
"atomic_or",
"atomic_xchg",
"atomic_xor",
"bfloat16",
"block_type",
"builtin",
"cat",
"cdiv",
"constexpr",
"cos",
"debug_barrier",
"dot",
"dtype",
"exp",
"fdiv",
"float16",
"float32",
"float64",
"float8",
"function_type",
"int1",
"int16",
"int32",
"int64",
"int8",
"ir",
"load",
"log",
"max",
"max_contiguous",
"maximum",
"min",
"minimum",
"multiple_of",
"num_programs",
"pair_uniform_to_normal",
"philox",
"philox_impl",
"pi32_t",
"pointer_type",
"printf",
"program_id",
"rand",
"rand4x",
"randint",
"randint4x",
"randn",
"randn4x",
"ravel",
"sigmoid",
"sin",
"softmax",
"sqrt",
"store",
"sum",
"swizzle2d",
"tensor",
"trans",
"triton",
"uint16",
"uint32",
"uint32_to_uniform_float",
"uint64",
"uint8",
"umulhi",
"void",
"where",
"xor_sum",
"zeros",
"zeros_like",
]

View File

@@ -409,10 +409,10 @@ class constexpr:
def __neg__(self):
return constexpr(-self.value)
def __pos__(self):
return constexpr(+self.value)
def __invert__(self):
return constexpr(~self.value)

View File

@@ -1,5 +1,12 @@
# flake8: noqa: F401
#from .conv import _conv, conv
# from .conv import _conv, conv
from . import blocksparse
from .cross_entropy import _cross_entropy, cross_entropy
from .matmul import _matmul, matmul
__all__ = [
"blocksparse",
"_cross_entropy",
"cross_entropy",
"_matmul",
"matmul",
]

View File

@@ -1,3 +1,7 @@
# flake8: noqa: F401
from .matmul import matmul
from .softmax import softmax
__all__ = [
"matmul",
"softmax",
]

View File

@@ -1,2 +1,12 @@
from .autotuner import Config, Heuristics, autotune, heuristics # noqa: F401
from .jit import JITFunction, KernelInterface, version_key # noqa: F401
from .autotuner import Config, Heuristics, autotune, heuristics
from .jit import JITFunction, KernelInterface, version_key
__all__ = [
"Config",
"Heuristics",
"autotune",
"heuristics",
"JITFunction",
"KernelInterface",
"version_key",
]