Revert "Add .warmup() for triton.jit()" (#682)

Reverts openai/triton#671

It seems like for some reason this caused out-of-memory errors on some
of our internal workloads. I'm reverting this so that HEAD can be used
in production at OpenAI, and I will work on digging into this issue
asynchronously.
This commit is contained in:
Philippe Tillet
2022-09-20 16:05:14 -07:00
committed by GitHub
parent 48f30550f1
commit 7dc2a70edb
4 changed files with 19 additions and 90 deletions

View File

@@ -150,25 +150,3 @@ def test_constexpr_not_callable() -> None:
except BaseException:
error = True
assert error is True
def test_jit_warmup_cache() -> None:
@triton.jit
def kernel_add(a, b, o, N: tl.constexpr):
idx = tl.arange(0, N)
tl.store(o + idx,
tl.load(a + idx) + tl.load(b + idx))
args = [
torch.randn(32, dtype=torch.float32, device="cuda"),
torch.randn(32, dtype=torch.float32, device="cuda"),
torch.randn(32, dtype=torch.float32, device="cuda"),
32,
]
assert len(kernel_add.cache) == 0
kernel_add[(1,)].warmup(torch.float32, torch.float32, torch.float32, 32)
assert len(kernel_add.cache) == 1
kernel_add[(1,)].warmup(*args)
assert len(kernel_add.cache) == 1
kernel_add[(1,)](*args)
assert len(kernel_add.cache) == 1

View File

@@ -68,7 +68,16 @@ class Autotuner(KernelInterface):
key = tuple([args[i] for i in self.key_idx])
if key not in self.cache:
# prune configs
pruned_configs = self.prune_configs(kwargs)
pruned_configs = self.configs
if self.early_config_prune:
pruned_configs = self.early_config_prune(self.configs, self.nargs)
if self.perf_model:
top_k = self.configs_top_k
if isinstance(top_k, float) and top_k <= 1.0:
top_k = int(len(self.configs) * top_k)
if len(pruned_configs) > top_k:
est_timing = {config: self.perf_model(**self.nargs, **kwargs, **config.kwargs, num_stages=config.num_stages, num_warps=config.num_warps) for config in pruned_configs}
pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k]
bench_start = time.time()
timings = {config: self._bench(*args, config=config, **kwargs)
for config in pruned_configs}
@@ -85,35 +94,6 @@ class Autotuner(KernelInterface):
config.pre_hook(self.nargs)
return self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs)
def prune_configs(self, kwargs):
pruned_configs = self.configs
if self.early_config_prune:
pruned_configs = self.early_config_prune(self.configs, self.nargs)
if self.perf_model:
top_k = self.configs_top_k
if isinstance(top_k, float) and top_k <= 1.0:
top_k = int(len(self.configs) * top_k)
if len(pruned_configs) > top_k:
est_timing = {
config: self.perf_model(**self.nargs, **kwargs, **config.kwargs, num_stages=config.num_stages,
num_warps=config.num_warps)
for config in pruned_configs
}
pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k]
return pruned_configs
def warmup(self, *args, **kwargs):
self.nargs = dict(zip(self.arg_names, args))
for config in self.prune_configs(kwargs):
self.fn.warmup(
*args,
num_warps=config.num_warps,
num_stages=config.num_stages,
**kwargs,
**config.kwargs,
)
self.nargs = None
class Config:
"""

View File

@@ -12,7 +12,6 @@ from collections import namedtuple
import torch
import triton
from triton.utils import MockTensor
try:
from torch._C import _cuda_getCurrentRawStream as get_cuda_stream
@@ -102,16 +101,9 @@ class KernelInterface:
Hence JITFunction.__getitem__ returns a callable proxy that
memorizes the grid.
"""
class Launcher:
@staticmethod
def __call__(*args, **kwargs):
def launcher(*args, **kwargs):
return self.run(*args, grid=grid, **kwargs)
@staticmethod
def warmup(*args, **kwargs):
return self.warmup(*args, grid=grid, **kwargs)
return Launcher()
return launcher
class JITFunction(KernelInterface):
@@ -239,7 +231,7 @@ class JITFunction(KernelInterface):
grid_args = ','.join([f'"{arg}": {arg}' for arg in self.arg_names])
src = f"""
def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stages=3, extern_libs=None, stream=None, warmup=False):
def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stages=3, extern_libs=None, stream=None):
sig_key = {sig_keys},
constexpr_key = {f'{constexpr_keys},' if len(constexpr_keys) > 0 else tuple()}
spec_key = {f'{spec_keys},' if len(spec_keys) > 0 else tuple()}
@@ -255,11 +247,10 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage
grid_2 = grid[2] if grid_size > 2 else 1
device = torch.cuda.current_device()
torch.cuda.set_device(device)
if stream is None and not warmup:
if stream is None:
stream = get_cuda_stream(device)
try:
bin = cache[key]
if not warmup:
bin.c_wrapper(grid_0, grid_1, grid_2, stream, {args})
return bin
# kernel not cached -- compile
@@ -280,7 +271,6 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage
device = 0
if not self._call_hook(key, signature, device, constants, num_warps, num_stages, extern_libs, configs):
bin = triton.compile(self, signature, device, constants, num_warps, num_stages, extern_libs=extern_libs, configs=configs)
if not warmup:
bin.c_wrapper(grid_0, grid_1, grid_2, stream, *args)
self.cache[key] = bin
return bin
@@ -327,6 +317,7 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage
self.__module__ = fn.__module__
@property
@functools.lru_cache()
def cache_key(self):
# TODO : hash should be attribute of `self`
if self.hash is None:
@@ -335,9 +326,6 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage
self.hash = dependencies_finder.ret + version_key()
return self.hash
def warmup(self, *args, **kwargs):
return self.run(*map(MockTensor.wrap_dtype, args), **kwargs, warmup=True)
# we do not parse `src` in the constructor because
# the user might want to monkey-patch self.src dynamically.
# Our unit tests do this, for example.
@@ -361,6 +349,7 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage
# to be reinitialized
if name == 'src':
self.hash = None
JITFunction.cache_key.fget.cache_clear()
def __repr__(self):
return f"JITFunction({self.module}:{self.fn.__name__})"

View File

@@ -19,24 +19,6 @@ def next_power_of_2(n):
return n
class MockTensor:
"""
Can be used in place of real tensors when calling:
kernel.warmup(MockTensor(torch.float32), ...)
"""
@staticmethod
def wrap_dtype(arg):
if isinstance(arg, torch.dtype):
return MockTensor(arg)
return arg
def __init__(self, dtype):
self.dtype = dtype
def data_ptr(self):
return 0 # optimistically assumes multiple of 16
class TensorWrapper:
def __init__(self, base, dtype):
self.dtype = dtype