added reset_to_zero in vector addition (#205)
This commit is contained in:
@@ -612,7 +612,7 @@ class Launcher:
|
||||
|
||||
|
||||
class Autotuner:
|
||||
def __init__(self, kernel, arg_names, configs, key):
|
||||
def __init__(self, kernel, arg_names, configs, key, reset_to_zero):
|
||||
if not configs:
|
||||
self.configs = [Config(dict(), num_warps=4, num_stages=2)]
|
||||
else:
|
||||
@@ -620,6 +620,14 @@ class Autotuner:
|
||||
self.key_idx = [arg_names.index(k) for k in key]
|
||||
self.cache = dict()
|
||||
self.kernel = kernel
|
||||
# hook to reset all required tensor to zeros before relaunching a kernel
|
||||
self.hook = lambda args: 0
|
||||
if reset_to_zero is not None:
|
||||
self.reset_idx = [arg_names.index(k) for k in reset_to_zero]
|
||||
def _hook(args):
|
||||
for i in self.reset_idx:
|
||||
args[i].zero_()
|
||||
self.hook = _hook
|
||||
|
||||
def _bench(self, *args, config, **meta):
|
||||
# check for conflicts, i.e. meta-parameters both provided
|
||||
@@ -632,7 +640,9 @@ class Autotuner:
|
||||
)
|
||||
# augment meta-parameters with tunable ones
|
||||
current = dict(meta, **config.meta)
|
||||
kernel_call = lambda: self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current)
|
||||
def kernel_call():
|
||||
self.hook(args)
|
||||
self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current)
|
||||
return triton.testing.do_bench(kernel_call)
|
||||
|
||||
def __call__(self, *args, **meta):
|
||||
@@ -642,6 +652,7 @@ class Autotuner:
|
||||
timings = {config: self._bench(*args, config=config, **meta) \
|
||||
for config in self.configs}
|
||||
self.cache[key] = builtins.min(timings, key=timings.get)
|
||||
self.hook(args)
|
||||
config = self.cache[key]
|
||||
else:
|
||||
config = self.configs[0]
|
||||
@@ -709,10 +720,10 @@ class Config:
|
||||
self.num_stages = num_stages
|
||||
|
||||
|
||||
def autotune(configs, key):
|
||||
def autotune(configs, key, reset_to_zero=None):
|
||||
def decorator(fn):
|
||||
def wrapper(kernel):
|
||||
return Autotuner(kernel, fn.arg_names, configs, key)
|
||||
return Autotuner(kernel, fn.arg_names, configs, key, reset_to_zero)
|
||||
|
||||
fn.kernel_decorators.append(wrapper)
|
||||
return fn
|
||||
|
Reference in New Issue
Block a user