From 677ddae618139014ebd4b0407a767c9b900e81b5 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Wed, 21 Sep 2022 12:13:20 -0700 Subject: [PATCH] [FRONTEND] Add warmup for triton.jit() (#684) This revives #671 , removing the static functions that may unnecessarily hold a reference to the grid and the JITFunction object Co-authored-by: Jason Ansel --- python/test/unit/runtime/test_cache.py | 22 ++++++++++++++ python/triton/runtime/autotuner.py | 40 +++++++++++++++++++------- python/triton/runtime/jit.py | 16 +++++++---- python/triton/utils.py | 18 ++++++++++++ 4 files changed, 80 insertions(+), 16 deletions(-) diff --git a/python/test/unit/runtime/test_cache.py b/python/test/unit/runtime/test_cache.py index 6fad3af3d..6d6c0e131 100644 --- a/python/test/unit/runtime/test_cache.py +++ b/python/test/unit/runtime/test_cache.py @@ -150,3 +150,25 @@ def test_constexpr_not_callable() -> None: except BaseException: error = True assert error is True + + +def test_jit_warmup_cache() -> None: + @triton.jit + def kernel_add(a, b, o, N: tl.constexpr): + idx = tl.arange(0, N) + tl.store(o + idx, + tl.load(a + idx) + tl.load(b + idx)) + + args = [ + torch.randn(32, dtype=torch.float32, device="cuda"), + torch.randn(32, dtype=torch.float32, device="cuda"), + torch.randn(32, dtype=torch.float32, device="cuda"), + 32, + ] + assert len(kernel_add.cache) == 0 + kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1,)) + assert len(kernel_add.cache) == 1 + kernel_add.warmup(*args, grid=(1,)) + assert len(kernel_add.cache) == 1 + kernel_add.warmup(*args, grid=(1,)) + assert len(kernel_add.cache) == 1 diff --git a/python/triton/runtime/autotuner.py b/python/triton/runtime/autotuner.py index 2175501b6..8ec16c477 100644 --- a/python/triton/runtime/autotuner.py +++ b/python/triton/runtime/autotuner.py @@ -68,16 +68,7 @@ class Autotuner(KernelInterface): key = tuple([args[i] for i in self.key_idx]) if key not in self.cache: # prune configs - pruned_configs = self.configs - if self.early_config_prune: - pruned_configs = self.early_config_prune(self.configs, self.nargs) - if self.perf_model: - top_k = self.configs_top_k - if isinstance(top_k, float) and top_k <= 1.0: - top_k = int(len(self.configs) * top_k) - if len(pruned_configs) > top_k: - est_timing = {config: self.perf_model(**self.nargs, **kwargs, **config.kwargs, num_stages=config.num_stages, num_warps=config.num_warps) for config in pruned_configs} - pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k] + pruned_configs = self.prune_configs(kwargs) bench_start = time.time() timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs} @@ -94,6 +85,35 @@ class Autotuner(KernelInterface): config.pre_hook(self.nargs) return self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs) + def prune_configs(self, kwargs): + pruned_configs = self.configs + if self.early_config_prune: + pruned_configs = self.early_config_prune(self.configs, self.nargs) + if self.perf_model: + top_k = self.configs_top_k + if isinstance(top_k, float) and top_k <= 1.0: + top_k = int(len(self.configs) * top_k) + if len(pruned_configs) > top_k: + est_timing = { + config: self.perf_model(**self.nargs, **kwargs, **config.kwargs, num_stages=config.num_stages, + num_warps=config.num_warps) + for config in pruned_configs + } + pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k] + return pruned_configs + + def warmup(self, *args, **kwargs): + self.nargs = dict(zip(self.arg_names, args)) + for config in self.prune_configs(kwargs): + self.fn.warmup( + *args, + num_warps=config.num_warps, + num_stages=config.num_stages, + **kwargs, + **config.kwargs, + ) + self.nargs = None + class Config: """ diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py index 025f268ac..0187a7faa 100644 --- a/python/triton/runtime/jit.py +++ b/python/triton/runtime/jit.py @@ -12,6 +12,7 @@ from collections import namedtuple import torch import triton +from triton.utils import MockTensor try: from torch._C import _cuda_getCurrentRawStream as get_cuda_stream @@ -231,7 +232,7 @@ class JITFunction(KernelInterface): grid_args = ','.join([f'"{arg}": {arg}' for arg in self.arg_names]) src = f""" -def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stages=3, extern_libs=None, stream=None): +def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stages=3, extern_libs=None, stream=None, warmup=False): sig_key = {sig_keys}, constexpr_key = {f'{constexpr_keys},' if len(constexpr_keys) > 0 else tuple()} spec_key = {f'{spec_keys},' if len(spec_keys) > 0 else tuple()} @@ -247,11 +248,12 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage grid_2 = grid[2] if grid_size > 2 else 1 device = torch.cuda.current_device() torch.cuda.set_device(device) - if stream is None: + if stream is None and not warmup: stream = get_cuda_stream(device) try: bin = cache[key] - bin.c_wrapper(grid_0, grid_1, grid_2, stream, {args}) + if not warmup: + bin.c_wrapper(grid_0, grid_1, grid_2, stream, {args}) return bin # kernel not cached -- compile except KeyError: @@ -271,7 +273,8 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage device = 0 if not self._call_hook(key, signature, device, constants, num_warps, num_stages, extern_libs, configs): bin = triton.compile(self, signature, device, constants, num_warps, num_stages, extern_libs=extern_libs, configs=configs) - bin.c_wrapper(grid_0, grid_1, grid_2, stream, *args) + if not warmup: + bin.c_wrapper(grid_0, grid_1, grid_2, stream, *args) self.cache[key] = bin return bin return None @@ -317,7 +320,6 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage self.__module__ = fn.__module__ @property - @functools.lru_cache() def cache_key(self): # TODO : hash should be attribute of `self` if self.hash is None: @@ -326,6 +328,9 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage self.hash = dependencies_finder.ret + version_key() return self.hash + def warmup(self, *args, **kwargs): + return self.run(*map(MockTensor.wrap_dtype, args), **kwargs, warmup=True) + # we do not parse `src` in the constructor because # the user might want to monkey-patch self.src dynamically. # Our unit tests do this, for example. @@ -349,7 +354,6 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage # to be reinitialized if name == 'src': self.hash = None - JITFunction.cache_key.fget.cache_clear() def __repr__(self): return f"JITFunction({self.module}:{self.fn.__name__})" diff --git a/python/triton/utils.py b/python/triton/utils.py index f446dd06a..2ac84d06e 100644 --- a/python/triton/utils.py +++ b/python/triton/utils.py @@ -19,6 +19,24 @@ def next_power_of_2(n): return n +class MockTensor: + """ + Can be used in place of real tensors when calling: + kernel.warmup(MockTensor(torch.float32), ...) + """ + @staticmethod + def wrap_dtype(arg): + if isinstance(arg, torch.dtype): + return MockTensor(arg) + return arg + + def __init__(self, dtype): + self.dtype = dtype + + def data_ptr(self): + return 0 # optimistically assumes multiple of 16 + + class TensorWrapper: def __init__(self, base, dtype): self.dtype = dtype