[FRONTEND] Expand __init__ * imports, add __all__ (#912)
Expand `from .foo import *` to full listings, and `__all__` sections. This reifies the module export listings, which is useful for code importing this module; without this, clients will need special `mypy` control pragmas for this library. This removes a number of `# flake8` control pragmas. Verified with `flake8`
This commit is contained in:
committed by
GitHub
parent
e057c65cf0
commit
2fa17588f7
@@ -1,15 +1,50 @@
|
|||||||
"""isort:skip_file"""
|
"""isort:skip_file"""
|
||||||
# flake8: noqa: F401
|
|
||||||
__version__ = '2.0.0'
|
__version__ = '2.0.0'
|
||||||
|
|
||||||
|
# ---------------------------------------
|
||||||
|
# Note: import order is significant here.
|
||||||
|
|
||||||
# TODO: torch needs to be imported first
|
# TODO: torch needs to be imported first
|
||||||
# or pybind11 shows `munmap_chunk(): invalid pointer`
|
# or pybind11 shows `munmap_chunk(): invalid pointer`
|
||||||
import torch
|
import torch # noqa: F401
|
||||||
|
|
||||||
# submodules
|
# submodules
|
||||||
from .utils import *
|
from .utils import (
|
||||||
from .runtime import Config, autotune, heuristics, JITFunction, KernelInterface
|
cdiv,
|
||||||
|
MockTensor,
|
||||||
|
next_power_of_2,
|
||||||
|
reinterpret,
|
||||||
|
TensorWrapper,
|
||||||
|
)
|
||||||
|
from .runtime import (
|
||||||
|
autotune,
|
||||||
|
Config,
|
||||||
|
heuristics,
|
||||||
|
JITFunction,
|
||||||
|
KernelInterface,
|
||||||
|
)
|
||||||
from .runtime.jit import jit
|
from .runtime.jit import jit
|
||||||
from .compiler import compile, CompilationError
|
from .compiler import compile, CompilationError
|
||||||
from . import language
|
from . import language
|
||||||
from . import testing
|
from . import testing
|
||||||
from . import ops
|
from . import ops
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"autotune",
|
||||||
|
"cdiv",
|
||||||
|
"CompilationError",
|
||||||
|
"compile",
|
||||||
|
"Config",
|
||||||
|
"heuristics",
|
||||||
|
"jit",
|
||||||
|
"JITFunction",
|
||||||
|
"KernelInterface",
|
||||||
|
"language",
|
||||||
|
"MockTensor",
|
||||||
|
"next_power_of_2",
|
||||||
|
"ops",
|
||||||
|
"reinterpret",
|
||||||
|
"runtime",
|
||||||
|
"TensorWrapper",
|
||||||
|
"testing",
|
||||||
|
]
|
||||||
|
@@ -1,4 +1,171 @@
|
|||||||
# flake8: noqa: F401
|
"""isort:skip_file"""
|
||||||
|
# Import order is significant here.
|
||||||
|
|
||||||
|
from triton._C.libtriton.triton import ir
|
||||||
|
|
||||||
from . import core, extern, libdevice, random
|
from . import core, extern, libdevice, random
|
||||||
from .core import *
|
from .core import (
|
||||||
from .random import *
|
abs,
|
||||||
|
arange,
|
||||||
|
argmin,
|
||||||
|
argmax,
|
||||||
|
atomic_add,
|
||||||
|
atomic_and,
|
||||||
|
atomic_cas,
|
||||||
|
atomic_max,
|
||||||
|
atomic_min,
|
||||||
|
atomic_or,
|
||||||
|
atomic_xchg,
|
||||||
|
atomic_xor,
|
||||||
|
bfloat16,
|
||||||
|
block_type,
|
||||||
|
builtin,
|
||||||
|
cat,
|
||||||
|
cdiv,
|
||||||
|
constexpr,
|
||||||
|
cos,
|
||||||
|
debug_barrier,
|
||||||
|
dot,
|
||||||
|
dtype,
|
||||||
|
exp,
|
||||||
|
fdiv,
|
||||||
|
float16,
|
||||||
|
float32,
|
||||||
|
float64,
|
||||||
|
float8,
|
||||||
|
function_type,
|
||||||
|
int1,
|
||||||
|
int16,
|
||||||
|
int32,
|
||||||
|
int64,
|
||||||
|
int8,
|
||||||
|
load,
|
||||||
|
log,
|
||||||
|
max,
|
||||||
|
max_contiguous,
|
||||||
|
maximum,
|
||||||
|
min,
|
||||||
|
minimum,
|
||||||
|
multiple_of,
|
||||||
|
num_programs,
|
||||||
|
pi32_t,
|
||||||
|
pointer_type,
|
||||||
|
printf,
|
||||||
|
program_id,
|
||||||
|
ravel,
|
||||||
|
sigmoid,
|
||||||
|
sin,
|
||||||
|
softmax,
|
||||||
|
sqrt,
|
||||||
|
store,
|
||||||
|
sum,
|
||||||
|
swizzle2d,
|
||||||
|
tensor,
|
||||||
|
trans,
|
||||||
|
triton,
|
||||||
|
uint16,
|
||||||
|
uint32,
|
||||||
|
uint64,
|
||||||
|
uint8,
|
||||||
|
umulhi,
|
||||||
|
void,
|
||||||
|
where,
|
||||||
|
xor_sum,
|
||||||
|
zeros,
|
||||||
|
zeros_like,
|
||||||
|
)
|
||||||
|
from .random import (
|
||||||
|
pair_uniform_to_normal,
|
||||||
|
philox,
|
||||||
|
philox_impl,
|
||||||
|
rand,
|
||||||
|
rand4x,
|
||||||
|
randint,
|
||||||
|
randint4x,
|
||||||
|
randn,
|
||||||
|
randn4x,
|
||||||
|
uint32_to_uniform_float,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"abs",
|
||||||
|
"arange",
|
||||||
|
"argmin",
|
||||||
|
"argmax",
|
||||||
|
"atomic_add",
|
||||||
|
"atomic_and",
|
||||||
|
"atomic_cas",
|
||||||
|
"atomic_max",
|
||||||
|
"atomic_min",
|
||||||
|
"atomic_or",
|
||||||
|
"atomic_xchg",
|
||||||
|
"atomic_xor",
|
||||||
|
"bfloat16",
|
||||||
|
"block_type",
|
||||||
|
"builtin",
|
||||||
|
"cat",
|
||||||
|
"cdiv",
|
||||||
|
"constexpr",
|
||||||
|
"cos",
|
||||||
|
"debug_barrier",
|
||||||
|
"dot",
|
||||||
|
"dtype",
|
||||||
|
"exp",
|
||||||
|
"fdiv",
|
||||||
|
"float16",
|
||||||
|
"float32",
|
||||||
|
"float64",
|
||||||
|
"float8",
|
||||||
|
"function_type",
|
||||||
|
"int1",
|
||||||
|
"int16",
|
||||||
|
"int32",
|
||||||
|
"int64",
|
||||||
|
"int8",
|
||||||
|
"ir",
|
||||||
|
"load",
|
||||||
|
"log",
|
||||||
|
"max",
|
||||||
|
"max_contiguous",
|
||||||
|
"maximum",
|
||||||
|
"min",
|
||||||
|
"minimum",
|
||||||
|
"multiple_of",
|
||||||
|
"num_programs",
|
||||||
|
"pair_uniform_to_normal",
|
||||||
|
"philox",
|
||||||
|
"philox_impl",
|
||||||
|
"pi32_t",
|
||||||
|
"pointer_type",
|
||||||
|
"printf",
|
||||||
|
"program_id",
|
||||||
|
"rand",
|
||||||
|
"rand4x",
|
||||||
|
"randint",
|
||||||
|
"randint4x",
|
||||||
|
"randn",
|
||||||
|
"randn4x",
|
||||||
|
"ravel",
|
||||||
|
"sigmoid",
|
||||||
|
"sin",
|
||||||
|
"softmax",
|
||||||
|
"sqrt",
|
||||||
|
"store",
|
||||||
|
"sum",
|
||||||
|
"swizzle2d",
|
||||||
|
"tensor",
|
||||||
|
"trans",
|
||||||
|
"triton",
|
||||||
|
"uint16",
|
||||||
|
"uint32",
|
||||||
|
"uint32_to_uniform_float",
|
||||||
|
"uint64",
|
||||||
|
"uint8",
|
||||||
|
"umulhi",
|
||||||
|
"void",
|
||||||
|
"where",
|
||||||
|
"xor_sum",
|
||||||
|
"zeros",
|
||||||
|
"zeros_like",
|
||||||
|
]
|
||||||
|
@@ -1,5 +1,12 @@
|
|||||||
# flake8: noqa: F401
|
|
||||||
# from .conv import _conv, conv
|
# from .conv import _conv, conv
|
||||||
from . import blocksparse
|
from . import blocksparse
|
||||||
from .cross_entropy import _cross_entropy, cross_entropy
|
from .cross_entropy import _cross_entropy, cross_entropy
|
||||||
from .matmul import _matmul, matmul
|
from .matmul import _matmul, matmul
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"blocksparse",
|
||||||
|
"_cross_entropy",
|
||||||
|
"cross_entropy",
|
||||||
|
"_matmul",
|
||||||
|
"matmul",
|
||||||
|
]
|
||||||
|
@@ -1,3 +1,7 @@
|
|||||||
# flake8: noqa: F401
|
|
||||||
from .matmul import matmul
|
from .matmul import matmul
|
||||||
from .softmax import softmax
|
from .softmax import softmax
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"matmul",
|
||||||
|
"softmax",
|
||||||
|
]
|
||||||
|
@@ -1,2 +1,12 @@
|
|||||||
from .autotuner import Config, Heuristics, autotune, heuristics # noqa: F401
|
from .autotuner import Config, Heuristics, autotune, heuristics
|
||||||
from .jit import JITFunction, KernelInterface, version_key # noqa: F401
|
from .jit import JITFunction, KernelInterface, version_key
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"Config",
|
||||||
|
"Heuristics",
|
||||||
|
"autotune",
|
||||||
|
"heuristics",
|
||||||
|
"JITFunction",
|
||||||
|
"KernelInterface",
|
||||||
|
"version_key",
|
||||||
|
]
|
||||||
|
Reference in New Issue
Block a user