From 7dc2a70edb386e14074f91ee8eeb1f78f270761c Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Tue, 20 Sep 2022 16:05:14 -0700 Subject: [PATCH] Revert "Add .warmup() for triton.jit()" (#682) Reverts openai/triton#671 It seems like for some reason this caused out-of-memory errors on some of our internal workloads. I'm reverting this so that HEAD can be used in production at OpenAI, and I will work on digging into this issue asynchronously. --- python/test/unit/runtime/test_cache.py | 22 -------------- python/triton/runtime/autotuner.py | 40 +++++++------------------- python/triton/runtime/jit.py | 29 ++++++------------- python/triton/utils.py | 18 ------------ 4 files changed, 19 insertions(+), 90 deletions(-) diff --git a/python/test/unit/runtime/test_cache.py b/python/test/unit/runtime/test_cache.py index 3540208d3..6fad3af3d 100644 --- a/python/test/unit/runtime/test_cache.py +++ b/python/test/unit/runtime/test_cache.py @@ -150,25 +150,3 @@ def test_constexpr_not_callable() -> None: except BaseException: error = True assert error is True - - -def test_jit_warmup_cache() -> None: - @triton.jit - def kernel_add(a, b, o, N: tl.constexpr): - idx = tl.arange(0, N) - tl.store(o + idx, - tl.load(a + idx) + tl.load(b + idx)) - - args = [ - torch.randn(32, dtype=torch.float32, device="cuda"), - torch.randn(32, dtype=torch.float32, device="cuda"), - torch.randn(32, dtype=torch.float32, device="cuda"), - 32, - ] - assert len(kernel_add.cache) == 0 - kernel_add[(1,)].warmup(torch.float32, torch.float32, torch.float32, 32) - assert len(kernel_add.cache) == 1 - kernel_add[(1,)].warmup(*args) - assert len(kernel_add.cache) == 1 - kernel_add[(1,)](*args) - assert len(kernel_add.cache) == 1 diff --git a/python/triton/runtime/autotuner.py b/python/triton/runtime/autotuner.py index 8ec16c477..2175501b6 100644 --- a/python/triton/runtime/autotuner.py +++ b/python/triton/runtime/autotuner.py @@ -68,7 +68,16 @@ class Autotuner(KernelInterface): key = tuple([args[i] for i in self.key_idx]) if key not in self.cache: # prune configs - pruned_configs = self.prune_configs(kwargs) + pruned_configs = self.configs + if self.early_config_prune: + pruned_configs = self.early_config_prune(self.configs, self.nargs) + if self.perf_model: + top_k = self.configs_top_k + if isinstance(top_k, float) and top_k <= 1.0: + top_k = int(len(self.configs) * top_k) + if len(pruned_configs) > top_k: + est_timing = {config: self.perf_model(**self.nargs, **kwargs, **config.kwargs, num_stages=config.num_stages, num_warps=config.num_warps) for config in pruned_configs} + pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k] bench_start = time.time() timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs} @@ -85,35 +94,6 @@ class Autotuner(KernelInterface): config.pre_hook(self.nargs) return self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs) - def prune_configs(self, kwargs): - pruned_configs = self.configs - if self.early_config_prune: - pruned_configs = self.early_config_prune(self.configs, self.nargs) - if self.perf_model: - top_k = self.configs_top_k - if isinstance(top_k, float) and top_k <= 1.0: - top_k = int(len(self.configs) * top_k) - if len(pruned_configs) > top_k: - est_timing = { - config: self.perf_model(**self.nargs, **kwargs, **config.kwargs, num_stages=config.num_stages, - num_warps=config.num_warps) - for config in pruned_configs - } - pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k] - return pruned_configs - - def warmup(self, *args, **kwargs): - self.nargs = dict(zip(self.arg_names, args)) - for config in self.prune_configs(kwargs): - self.fn.warmup( - *args, - num_warps=config.num_warps, - num_stages=config.num_stages, - **kwargs, - **config.kwargs, - ) - self.nargs = None - class Config: """ diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py index 4911c327f..025f268ac 100644 --- a/python/triton/runtime/jit.py +++ b/python/triton/runtime/jit.py @@ -12,7 +12,6 @@ from collections import namedtuple import torch import triton -from triton.utils import MockTensor try: from torch._C import _cuda_getCurrentRawStream as get_cuda_stream @@ -102,16 +101,9 @@ class KernelInterface: Hence JITFunction.__getitem__ returns a callable proxy that memorizes the grid. """ - class Launcher: - @staticmethod - def __call__(*args, **kwargs): - return self.run(*args, grid=grid, **kwargs) - - @staticmethod - def warmup(*args, **kwargs): - return self.warmup(*args, grid=grid, **kwargs) - - return Launcher() + def launcher(*args, **kwargs): + return self.run(*args, grid=grid, **kwargs) + return launcher class JITFunction(KernelInterface): @@ -239,7 +231,7 @@ class JITFunction(KernelInterface): grid_args = ','.join([f'"{arg}": {arg}' for arg in self.arg_names]) src = f""" -def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stages=3, extern_libs=None, stream=None, warmup=False): +def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stages=3, extern_libs=None, stream=None): sig_key = {sig_keys}, constexpr_key = {f'{constexpr_keys},' if len(constexpr_keys) > 0 else tuple()} spec_key = {f'{spec_keys},' if len(spec_keys) > 0 else tuple()} @@ -255,12 +247,11 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage grid_2 = grid[2] if grid_size > 2 else 1 device = torch.cuda.current_device() torch.cuda.set_device(device) - if stream is None and not warmup: + if stream is None: stream = get_cuda_stream(device) try: bin = cache[key] - if not warmup: - bin.c_wrapper(grid_0, grid_1, grid_2, stream, {args}) + bin.c_wrapper(grid_0, grid_1, grid_2, stream, {args}) return bin # kernel not cached -- compile except KeyError: @@ -280,8 +271,7 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage device = 0 if not self._call_hook(key, signature, device, constants, num_warps, num_stages, extern_libs, configs): bin = triton.compile(self, signature, device, constants, num_warps, num_stages, extern_libs=extern_libs, configs=configs) - if not warmup: - bin.c_wrapper(grid_0, grid_1, grid_2, stream, *args) + bin.c_wrapper(grid_0, grid_1, grid_2, stream, *args) self.cache[key] = bin return bin return None @@ -327,6 +317,7 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage self.__module__ = fn.__module__ @property + @functools.lru_cache() def cache_key(self): # TODO : hash should be attribute of `self` if self.hash is None: @@ -335,9 +326,6 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage self.hash = dependencies_finder.ret + version_key() return self.hash - def warmup(self, *args, **kwargs): - return self.run(*map(MockTensor.wrap_dtype, args), **kwargs, warmup=True) - # we do not parse `src` in the constructor because # the user might want to monkey-patch self.src dynamically. # Our unit tests do this, for example. @@ -361,6 +349,7 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage # to be reinitialized if name == 'src': self.hash = None + JITFunction.cache_key.fget.cache_clear() def __repr__(self): return f"JITFunction({self.module}:{self.fn.__name__})" diff --git a/python/triton/utils.py b/python/triton/utils.py index 2ac84d06e..f446dd06a 100644 --- a/python/triton/utils.py +++ b/python/triton/utils.py @@ -19,24 +19,6 @@ def next_power_of_2(n): return n -class MockTensor: - """ - Can be used in place of real tensors when calling: - kernel.warmup(MockTensor(torch.float32), ...) - """ - @staticmethod - def wrap_dtype(arg): - if isinstance(arg, torch.dtype): - return MockTensor(arg) - return arg - - def __init__(self, dtype): - self.dtype = dtype - - def data_ptr(self): - return 0 # optimistically assumes multiple of 16 - - class TensorWrapper: def __init__(self, base, dtype): self.dtype = dtype