[STYLE] run autopep8 and isort (#421)
Run: ``` isort ./python autopep8 -i --ignore E501,E701,E731 $(find ./python/ -name '*.py') ``` with an `.isort.cfg` and then clean up a few warts. This PR should be a no-op; the idea is that this is all boring whitespace changes, and any config file changes will be in a different change to make it easier to review.
This commit is contained in:
committed by
GitHub
parent
120cda015e
commit
8bf551ae7a
@@ -1,26 +1,26 @@
|
||||
import ast
|
||||
import builtins
|
||||
import dbm
|
||||
import functools
|
||||
import inspect
|
||||
import struct
|
||||
import sys
|
||||
import textwrap
|
||||
import hashlib
|
||||
import inspect
|
||||
import os
|
||||
import pickle
|
||||
import struct
|
||||
import subprocess
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import textwrap
|
||||
import time
|
||||
import warnings
|
||||
from .tools.disasm import extract
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
from filelock import FileLock
|
||||
|
||||
import triton
|
||||
import triton._C.libtriton.triton as _triton
|
||||
from filelock import FileLock
|
||||
import dbm
|
||||
import tempfile
|
||||
from typing import Optional, Dict
|
||||
import time
|
||||
|
||||
from .tools.disasm import extract
|
||||
|
||||
|
||||
class CodeGenerator(ast.NodeVisitor):
|
||||
@@ -100,7 +100,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
arg_names, kwarg_names = self.visit(node.args)
|
||||
# initialize defaults
|
||||
for i, default_value in enumerate(node.args.defaults):
|
||||
arg_node = node.args.args[-i-1]
|
||||
arg_node = node.args.args[-i - 1]
|
||||
annotation = arg_node.annotation
|
||||
name = arg_node.arg
|
||||
st_target = ast.Name(id=name, ctx=ast.Store())
|
||||
@@ -134,8 +134,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
fn.args[idx].name = arg_name
|
||||
arg_values.append(fn.args[idx])
|
||||
idx += 1
|
||||
|
||||
|
||||
|
||||
for arg_name, arg_value in zip(arg_names, arg_values):
|
||||
self.set_value(arg_name, arg_value)
|
||||
if inline:
|
||||
@@ -178,7 +177,6 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
# default: call visit_Assign
|
||||
return self.visit_Assign(node)
|
||||
|
||||
|
||||
def visit_Assign(self, node):
|
||||
_names = []
|
||||
for target in node.targets:
|
||||
@@ -272,7 +270,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
if else_bb:
|
||||
self.builder.set_insert_block(else_bb)
|
||||
is_terminator = self.visit_compound_statement(node.orelse)
|
||||
#TODO: last statement is a terminator?
|
||||
# TODO: last statement is a terminator?
|
||||
if not is_terminator:
|
||||
self.builder.br(endif_bb)
|
||||
self.module.seal_block(endif_bb)
|
||||
@@ -404,10 +402,10 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
pos_cond_node = ast.Compare(ld_target, [ast.Lt()], [arg_1])
|
||||
neg_cond_node = ast.Compare(ld_target, [ast.Gt()], [arg_1])
|
||||
pos_step_node = ast.Compare(arg_2, [ast.Gt()], [ast.Num(0)])
|
||||
build_cond = lambda: triton.language.where(self.visit(pos_step_node),\
|
||||
self.visit(pos_cond_node),\
|
||||
self.visit(neg_cond_node),\
|
||||
_builder=self.builder)
|
||||
build_cond = lambda: triton.language.where(self.visit(pos_step_node),
|
||||
self.visit(pos_cond_node),
|
||||
self.visit(neg_cond_node),
|
||||
_builder=self.builder)
|
||||
#cond_node = neg_cond_node
|
||||
step_node = ast.AugAssign(target=st_target, op=ast.Add(), value=arg_2)
|
||||
# code generation
|
||||
@@ -462,7 +460,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
if isinstance(fn, JITFunction):
|
||||
return fn(*args, generator=self, **kws)
|
||||
if hasattr(fn, '__self__') and self.is_triton_object(fn.__self__) or \
|
||||
sys.modules[fn.__module__] is triton.language.core:
|
||||
sys.modules[fn.__module__] is triton.language.core:
|
||||
return fn(*args, _builder=self.builder, **kws)
|
||||
return fn(*args, **kws)
|
||||
|
||||
@@ -505,10 +503,10 @@ class Binary:
|
||||
|
||||
class LoadedBinary:
|
||||
def __init__(self, device: int, bin: Binary):
|
||||
module, kernel = _triton.code_gen.load_binary(bin.backend,
|
||||
bin.name,
|
||||
bin.asm,
|
||||
bin.shared_mem,
|
||||
module, kernel = _triton.code_gen.load_binary(bin.backend,
|
||||
bin.name,
|
||||
bin.asm,
|
||||
bin.shared_mem,
|
||||
device)
|
||||
self.bin = bin
|
||||
self.asm = bin.asm
|
||||
@@ -520,8 +518,8 @@ class LoadedBinary:
|
||||
|
||||
def __call__(self, stream, args, grid_0, grid_1=1, grid_2=1):
|
||||
_triton.runtime.enqueue(self.bin.backend, stream, self.kernel,
|
||||
grid_0, grid_1, grid_2,
|
||||
self.bin.num_warps * 32, 1, 1,
|
||||
grid_0, grid_1, grid_2,
|
||||
self.bin.num_warps * 32, 1, 1,
|
||||
args, self.bin.shared_mem)
|
||||
|
||||
def get_sass(self, fun=None):
|
||||
@@ -632,10 +630,14 @@ class Kernel:
|
||||
|
||||
@staticmethod
|
||||
def pow2_divisor(N):
|
||||
if N % 16 == 0: return 16
|
||||
if N % 8 == 0: return 8
|
||||
if N % 4 == 0: return 4
|
||||
if N % 2 == 0: return 2
|
||||
if N % 16 == 0:
|
||||
return 16
|
||||
if N % 8 == 0:
|
||||
return 8
|
||||
if N % 4 == 0:
|
||||
return 4
|
||||
if N % 2 == 0:
|
||||
return 2
|
||||
return 1
|
||||
|
||||
def __init__(self, fn):
|
||||
@@ -675,7 +677,7 @@ class Kernel:
|
||||
tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')]
|
||||
# attributes
|
||||
args = [arg.data_ptr() if i in tensor_idxs else arg for i, arg in enumerate(wargs)]
|
||||
attributes = {i: Kernel.pow2_divisor(a) for i, a in enumerate(args) \
|
||||
attributes = {i: Kernel.pow2_divisor(a) for i, a in enumerate(args)
|
||||
if isinstance(a, int) and i not in self.fn.do_not_specialize}
|
||||
|
||||
# transforms ints whose value is one into constants for just-in-time compilation
|
||||
@@ -705,7 +707,7 @@ class Kernel:
|
||||
if binary is None:
|
||||
binary = self._compile(
|
||||
*wargs, device=device_idx, attributes=attributes,
|
||||
num_warps=num_warps, num_stages=num_stages,
|
||||
num_warps=num_warps, num_stages=num_stages,
|
||||
constants=constants,
|
||||
)
|
||||
if bin_cache_path:
|
||||
@@ -766,13 +768,12 @@ class Launcher:
|
||||
|
||||
def __call__(self, *wargs, **kwargs):
|
||||
return self.kernel(*wargs, **kwargs, grid=self.grid)
|
||||
|
||||
|
||||
|
||||
class Autotuner:
|
||||
def __init__(self, kernel, arg_names, configs, key, reset_to_zero, prune_configs_by: Dict=None):
|
||||
def __init__(self, kernel, arg_names, configs, key, reset_to_zero, prune_configs_by: Dict = None):
|
||||
'''
|
||||
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
|
||||
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
|
||||
'perf_model': performance model used to predicate running time with different configs, returns running time
|
||||
'top_k': number of configs to bench
|
||||
'prune_num_stages_by'(optional): a function used to prune num_stages. It take configs:List[Config] as its input, and returns pruned configs.
|
||||
@@ -788,6 +789,7 @@ class Autotuner:
|
||||
self.hook = lambda args: 0
|
||||
if reset_to_zero is not None:
|
||||
self.reset_idx = [arg_names.index(k) for k in reset_to_zero]
|
||||
|
||||
def _hook(args):
|
||||
for i in self.reset_idx:
|
||||
args[i].zero_()
|
||||
@@ -802,7 +804,7 @@ class Autotuner:
|
||||
perf_model, top_k, prune_num_stages_by = None, None, None
|
||||
self.perf_model, self.configs_top_k = perf_model, top_k
|
||||
self.prune_num_stages_by = prune_num_stages_by
|
||||
|
||||
|
||||
def _bench(self, *args, config, **meta):
|
||||
# check for conflicts, i.e. meta-parameters both provided
|
||||
# as kwargs and by the autotuner
|
||||
@@ -814,6 +816,7 @@ class Autotuner:
|
||||
)
|
||||
# augment meta-parameters with tunable ones
|
||||
current = dict(meta, **config.kwargs)
|
||||
|
||||
def kernel_call():
|
||||
if config.pre_hook:
|
||||
config.pre_hook(self.nargs)
|
||||
@@ -836,9 +839,9 @@ class Autotuner:
|
||||
top_k = int(len(self.configs) * top_k)
|
||||
if len(pruned_configs) > top_k:
|
||||
est_timing = {config: self.perf_model(**self.nargs, **kwargs, **config.kwargs, num_stages=config.num_stages, num_warps=config.num_warps) for config in pruned_configs}
|
||||
pruned_configs = sorted(est_timing.keys(), key=lambda x:est_timing[x])[:top_k]
|
||||
pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k]
|
||||
bench_start = time.time()
|
||||
timings = {config: self._bench(*args, config=config, **kwargs) \
|
||||
timings = {config: self._bench(*args, config=config, **kwargs)
|
||||
for config in pruned_configs}
|
||||
bench_end = time.time()
|
||||
self.bench_time = bench_end - bench_start
|
||||
@@ -876,7 +879,7 @@ def version_key():
|
||||
ptxas_version = ''
|
||||
return '-'.join(triton.__version__) + '-' + ptxas_version + '-' + '-'.join(contents)
|
||||
|
||||
#########################3
|
||||
# 3
|
||||
|
||||
|
||||
class DependenciesFinder(ast.NodeVisitor):
|
||||
@@ -888,7 +891,7 @@ class DependenciesFinder(ast.NodeVisitor):
|
||||
|
||||
def visit_Name(self, node):
|
||||
return self.globals.get(node.id, None)
|
||||
|
||||
|
||||
def visit_Attribute(self, node):
|
||||
lhs = self.visit(node.value)
|
||||
while isinstance(lhs, ast.Attribute):
|
||||
@@ -917,10 +920,10 @@ class DependenciesFinder(ast.NodeVisitor):
|
||||
self.ret = (self.ret + func.hash).encode("utf-8")
|
||||
self.ret = hashlib.md5(self.ret).hexdigest()
|
||||
|
||||
class JITFunction:
|
||||
|
||||
cache_hook = None
|
||||
|
||||
class JITFunction:
|
||||
|
||||
cache_hook = None
|
||||
|
||||
def __init__(self, fn, version=None, do_not_specialize=None):
|
||||
# information of wrapped function
|
||||
@@ -946,7 +949,6 @@ class JITFunction:
|
||||
# forward docs
|
||||
self.__doc__ = fn.__doc__
|
||||
|
||||
|
||||
@property
|
||||
@functools.lru_cache()
|
||||
def cache_key(self):
|
||||
@@ -1027,6 +1029,7 @@ class Config:
|
||||
:ivar pre_hook: a function that will be called before the kernel is called. Parameters of this
|
||||
function are args.
|
||||
"""
|
||||
|
||||
def __init__(self, kwargs, num_warps=4, num_stages=2, pre_hook=None):
|
||||
self.kwargs = kwargs
|
||||
self.num_warps = num_warps
|
||||
@@ -1049,19 +1052,19 @@ def autotune(configs, key, prune_configs_by=None, reset_to_zero=None):
|
||||
.. highlight:: python
|
||||
.. code-block:: python
|
||||
|
||||
@triton.autotune(configs=[
|
||||
@triton.autotune(configs=[
|
||||
triton.Config(meta={'BLOCK_SIZE': 128}, num_warps=4),
|
||||
triton.Config(meta={'BLOCK_SIZE': 1024}, num_warps=8),
|
||||
],
|
||||
],
|
||||
key=['x_size'] # the two above configs will be evaluated anytime
|
||||
# the value of x_size changes
|
||||
# the value of x_size changes
|
||||
)
|
||||
@triton.jit
|
||||
def kernel(x_ptr, x_size, **META):
|
||||
BLOCK_SIZE = META['BLOCK_SIZE']
|
||||
|
||||
|
||||
:note: When all the configurations are evaluated, the kernel will run multiple time.
|
||||
This means that whatever value the kernel updates will be updated multiple times.
|
||||
This means that whatever value the kernel updates will be updated multiple times.
|
||||
To avoid this undesired behavior, you can use the `reset_to_zero` argument, which
|
||||
reset the value of the provided tensor to `zero` before running any configuration.
|
||||
|
||||
@@ -1069,7 +1072,7 @@ def autotune(configs, key, prune_configs_by=None, reset_to_zero=None):
|
||||
:type configs: list[triton.Config]
|
||||
:param key: a list of argument names whose change in value will trigger the evaluation of all provided configs.
|
||||
:type key: list[str]
|
||||
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
|
||||
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
|
||||
'perf_model': performance model used to predicate running time with different configs, returns running time
|
||||
'top_k': number of configs to bench
|
||||
'prune_num_stages_by'(optional): a function used to prune num_stages. It take configs:List[Config] as its input, and returns pruned configs.
|
||||
@@ -1099,7 +1102,7 @@ def heuristics(values):
|
||||
def kernel(x_ptr, x_size, **META):
|
||||
BLOCK_SIZE = META['BLOCK_SIZE'] # smallest power-of-two >= x_size
|
||||
|
||||
|
||||
|
||||
.param values: a dictionary of meta-parameter names and functions that compute the value of the meta-parameter.
|
||||
each such function takes a list of positional arguments as input.
|
||||
.type values: dict[str, Callable[[list[Any]], Any]]
|
||||
@@ -1150,6 +1153,7 @@ def jit(*args, **kwargs):
|
||||
def cdiv(x, y):
|
||||
return (x + y - 1) // y
|
||||
|
||||
|
||||
def next_power_of_2(n):
|
||||
"""Return the smallest power of 2 greater than or equal to n"""
|
||||
n -= 1
|
||||
@@ -1163,13 +1167,14 @@ def next_power_of_2(n):
|
||||
|
||||
######
|
||||
|
||||
|
||||
class TensorWrapper:
|
||||
def __init__(self, base, dtype):
|
||||
self.dtype = dtype
|
||||
self.base = base
|
||||
self.base = base
|
||||
self.is_cuda = base.is_cuda
|
||||
self.device = base.device
|
||||
|
||||
|
||||
def data_ptr(self):
|
||||
return self.base.data_ptr()
|
||||
|
||||
|
Reference in New Issue
Block a user