From 6e7593b44680fe79022cfce82b539f8f09c6895d Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Sat, 14 Aug 2021 10:58:38 -0700 Subject: [PATCH] added reset_to_zero in vector addition (#205) --- python/triton/code_gen.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index f91d0e78f..7b9d5b037 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -612,7 +612,7 @@ class Launcher: class Autotuner: - def __init__(self, kernel, arg_names, configs, key): + def __init__(self, kernel, arg_names, configs, key, reset_to_zero): if not configs: self.configs = [Config(dict(), num_warps=4, num_stages=2)] else: @@ -620,6 +620,14 @@ class Autotuner: self.key_idx = [arg_names.index(k) for k in key] self.cache = dict() self.kernel = kernel + # hook to reset all required tensor to zeros before relaunching a kernel + self.hook = lambda args: 0 + if reset_to_zero is not None: + self.reset_idx = [arg_names.index(k) for k in reset_to_zero] + def _hook(args): + for i in self.reset_idx: + args[i].zero_() + self.hook = _hook def _bench(self, *args, config, **meta): # check for conflicts, i.e. meta-parameters both provided @@ -632,7 +640,9 @@ class Autotuner: ) # augment meta-parameters with tunable ones current = dict(meta, **config.meta) - kernel_call = lambda: self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current) + def kernel_call(): + self.hook(args) + self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current) return triton.testing.do_bench(kernel_call) def __call__(self, *args, **meta): @@ -642,6 +652,7 @@ class Autotuner: timings = {config: self._bench(*args, config=config, **meta) \ for config in self.configs} self.cache[key] = builtins.min(timings, key=timings.get) + self.hook(args) config = self.cache[key] else: config = self.configs[0] @@ -709,10 +720,10 @@ class Config: self.num_stages = num_stages -def autotune(configs, key): +def autotune(configs, key, reset_to_zero=None): def decorator(fn): def wrapper(kernel): - return Autotuner(kernel, fn.arg_names, configs, key) + return Autotuner(kernel, fn.arg_names, configs, key, reset_to_zero) fn.kernel_decorators.append(wrapper) return fn