added reset_to_zero in vector addition (#205)

This commit is contained in:
Philippe Tillet
2021-08-14 10:58:38 -07:00
committed by GitHub
parent c45c2e9684
commit 6e7593b446

View File

@@ -612,7 +612,7 @@ class Launcher:
class Autotuner:
def __init__(self, kernel, arg_names, configs, key):
def __init__(self, kernel, arg_names, configs, key, reset_to_zero):
if not configs:
self.configs = [Config(dict(), num_warps=4, num_stages=2)]
else:
@@ -620,6 +620,14 @@ class Autotuner:
self.key_idx = [arg_names.index(k) for k in key]
self.cache = dict()
self.kernel = kernel
# hook to reset all required tensor to zeros before relaunching a kernel
self.hook = lambda args: 0
if reset_to_zero is not None:
self.reset_idx = [arg_names.index(k) for k in reset_to_zero]
def _hook(args):
for i in self.reset_idx:
args[i].zero_()
self.hook = _hook
def _bench(self, *args, config, **meta):
# check for conflicts, i.e. meta-parameters both provided
@@ -632,7 +640,9 @@ class Autotuner:
)
# augment meta-parameters with tunable ones
current = dict(meta, **config.meta)
kernel_call = lambda: self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current)
def kernel_call():
self.hook(args)
self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current)
return triton.testing.do_bench(kernel_call)
def __call__(self, *args, **meta):
@@ -642,6 +652,7 @@ class Autotuner:
timings = {config: self._bench(*args, config=config, **meta) \
for config in self.configs}
self.cache[key] = builtins.min(timings, key=timings.get)
self.hook(args)
config = self.cache[key]
else:
config = self.configs[0]
@@ -709,10 +720,10 @@ class Config:
self.num_stages = num_stages
def autotune(configs, key):
def autotune(configs, key, reset_to_zero=None):
def decorator(fn):
def wrapper(kernel):
return Autotuner(kernel, fn.arg_names, configs, key)
return Autotuner(kernel, fn.arg_names, configs, key, reset_to_zero)
fn.kernel_decorators.append(wrapper)
return fn