added reset_to_zero in vector addition (#205)
This commit is contained in:
@@ -612,7 +612,7 @@ class Launcher:
|
|||||||
|
|
||||||
|
|
||||||
class Autotuner:
|
class Autotuner:
|
||||||
def __init__(self, kernel, arg_names, configs, key):
|
def __init__(self, kernel, arg_names, configs, key, reset_to_zero):
|
||||||
if not configs:
|
if not configs:
|
||||||
self.configs = [Config(dict(), num_warps=4, num_stages=2)]
|
self.configs = [Config(dict(), num_warps=4, num_stages=2)]
|
||||||
else:
|
else:
|
||||||
@@ -620,6 +620,14 @@ class Autotuner:
|
|||||||
self.key_idx = [arg_names.index(k) for k in key]
|
self.key_idx = [arg_names.index(k) for k in key]
|
||||||
self.cache = dict()
|
self.cache = dict()
|
||||||
self.kernel = kernel
|
self.kernel = kernel
|
||||||
|
# hook to reset all required tensor to zeros before relaunching a kernel
|
||||||
|
self.hook = lambda args: 0
|
||||||
|
if reset_to_zero is not None:
|
||||||
|
self.reset_idx = [arg_names.index(k) for k in reset_to_zero]
|
||||||
|
def _hook(args):
|
||||||
|
for i in self.reset_idx:
|
||||||
|
args[i].zero_()
|
||||||
|
self.hook = _hook
|
||||||
|
|
||||||
def _bench(self, *args, config, **meta):
|
def _bench(self, *args, config, **meta):
|
||||||
# check for conflicts, i.e. meta-parameters both provided
|
# check for conflicts, i.e. meta-parameters both provided
|
||||||
@@ -632,7 +640,9 @@ class Autotuner:
|
|||||||
)
|
)
|
||||||
# augment meta-parameters with tunable ones
|
# augment meta-parameters with tunable ones
|
||||||
current = dict(meta, **config.meta)
|
current = dict(meta, **config.meta)
|
||||||
kernel_call = lambda: self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current)
|
def kernel_call():
|
||||||
|
self.hook(args)
|
||||||
|
self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current)
|
||||||
return triton.testing.do_bench(kernel_call)
|
return triton.testing.do_bench(kernel_call)
|
||||||
|
|
||||||
def __call__(self, *args, **meta):
|
def __call__(self, *args, **meta):
|
||||||
@@ -642,6 +652,7 @@ class Autotuner:
|
|||||||
timings = {config: self._bench(*args, config=config, **meta) \
|
timings = {config: self._bench(*args, config=config, **meta) \
|
||||||
for config in self.configs}
|
for config in self.configs}
|
||||||
self.cache[key] = builtins.min(timings, key=timings.get)
|
self.cache[key] = builtins.min(timings, key=timings.get)
|
||||||
|
self.hook(args)
|
||||||
config = self.cache[key]
|
config = self.cache[key]
|
||||||
else:
|
else:
|
||||||
config = self.configs[0]
|
config = self.configs[0]
|
||||||
@@ -709,10 +720,10 @@ class Config:
|
|||||||
self.num_stages = num_stages
|
self.num_stages = num_stages
|
||||||
|
|
||||||
|
|
||||||
def autotune(configs, key):
|
def autotune(configs, key, reset_to_zero=None):
|
||||||
def decorator(fn):
|
def decorator(fn):
|
||||||
def wrapper(kernel):
|
def wrapper(kernel):
|
||||||
return Autotuner(kernel, fn.arg_names, configs, key)
|
return Autotuner(kernel, fn.arg_names, configs, key, reset_to_zero)
|
||||||
|
|
||||||
fn.kernel_decorators.append(wrapper)
|
fn.kernel_decorators.append(wrapper)
|
||||||
return fn
|
return fn
|
||||||
|
Reference in New Issue
Block a user