[FRONTEND] Add warmup for triton.jit() (#684)

This revives #671 , removing the static functions that may unnecessarily hold a reference to the grid and the JITFunction object

Co-authored-by: Jason Ansel <jansel@jansel.net>
This commit is contained in:
Philippe Tillet
2022-09-21 12:13:20 -07:00
committed by GitHub
parent 6abe813d1c
commit 677ddae618
4 changed files with 80 additions and 16 deletions

View File

@@ -150,3 +150,25 @@ def test_constexpr_not_callable() -> None:
except BaseException: except BaseException:
error = True error = True
assert error is True assert error is True
def test_jit_warmup_cache() -> None:
@triton.jit
def kernel_add(a, b, o, N: tl.constexpr):
idx = tl.arange(0, N)
tl.store(o + idx,
tl.load(a + idx) + tl.load(b + idx))
args = [
torch.randn(32, dtype=torch.float32, device="cuda"),
torch.randn(32, dtype=torch.float32, device="cuda"),
torch.randn(32, dtype=torch.float32, device="cuda"),
32,
]
assert len(kernel_add.cache) == 0
kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1,))
assert len(kernel_add.cache) == 1
kernel_add.warmup(*args, grid=(1,))
assert len(kernel_add.cache) == 1
kernel_add.warmup(*args, grid=(1,))
assert len(kernel_add.cache) == 1

View File

@@ -68,16 +68,7 @@ class Autotuner(KernelInterface):
key = tuple([args[i] for i in self.key_idx]) key = tuple([args[i] for i in self.key_idx])
if key not in self.cache: if key not in self.cache:
# prune configs # prune configs
pruned_configs = self.configs pruned_configs = self.prune_configs(kwargs)
if self.early_config_prune:
pruned_configs = self.early_config_prune(self.configs, self.nargs)
if self.perf_model:
top_k = self.configs_top_k
if isinstance(top_k, float) and top_k <= 1.0:
top_k = int(len(self.configs) * top_k)
if len(pruned_configs) > top_k:
est_timing = {config: self.perf_model(**self.nargs, **kwargs, **config.kwargs, num_stages=config.num_stages, num_warps=config.num_warps) for config in pruned_configs}
pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k]
bench_start = time.time() bench_start = time.time()
timings = {config: self._bench(*args, config=config, **kwargs) timings = {config: self._bench(*args, config=config, **kwargs)
for config in pruned_configs} for config in pruned_configs}
@@ -94,6 +85,35 @@ class Autotuner(KernelInterface):
config.pre_hook(self.nargs) config.pre_hook(self.nargs)
return self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs) return self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs)
def prune_configs(self, kwargs):
pruned_configs = self.configs
if self.early_config_prune:
pruned_configs = self.early_config_prune(self.configs, self.nargs)
if self.perf_model:
top_k = self.configs_top_k
if isinstance(top_k, float) and top_k <= 1.0:
top_k = int(len(self.configs) * top_k)
if len(pruned_configs) > top_k:
est_timing = {
config: self.perf_model(**self.nargs, **kwargs, **config.kwargs, num_stages=config.num_stages,
num_warps=config.num_warps)
for config in pruned_configs
}
pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k]
return pruned_configs
def warmup(self, *args, **kwargs):
self.nargs = dict(zip(self.arg_names, args))
for config in self.prune_configs(kwargs):
self.fn.warmup(
*args,
num_warps=config.num_warps,
num_stages=config.num_stages,
**kwargs,
**config.kwargs,
)
self.nargs = None
class Config: class Config:
""" """

View File

@@ -12,6 +12,7 @@ from collections import namedtuple
import torch import torch
import triton import triton
from triton.utils import MockTensor
try: try:
from torch._C import _cuda_getCurrentRawStream as get_cuda_stream from torch._C import _cuda_getCurrentRawStream as get_cuda_stream
@@ -231,7 +232,7 @@ class JITFunction(KernelInterface):
grid_args = ','.join([f'"{arg}": {arg}' for arg in self.arg_names]) grid_args = ','.join([f'"{arg}": {arg}' for arg in self.arg_names])
src = f""" src = f"""
def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stages=3, extern_libs=None, stream=None): def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stages=3, extern_libs=None, stream=None, warmup=False):
sig_key = {sig_keys}, sig_key = {sig_keys},
constexpr_key = {f'{constexpr_keys},' if len(constexpr_keys) > 0 else tuple()} constexpr_key = {f'{constexpr_keys},' if len(constexpr_keys) > 0 else tuple()}
spec_key = {f'{spec_keys},' if len(spec_keys) > 0 else tuple()} spec_key = {f'{spec_keys},' if len(spec_keys) > 0 else tuple()}
@@ -247,11 +248,12 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage
grid_2 = grid[2] if grid_size > 2 else 1 grid_2 = grid[2] if grid_size > 2 else 1
device = torch.cuda.current_device() device = torch.cuda.current_device()
torch.cuda.set_device(device) torch.cuda.set_device(device)
if stream is None: if stream is None and not warmup:
stream = get_cuda_stream(device) stream = get_cuda_stream(device)
try: try:
bin = cache[key] bin = cache[key]
bin.c_wrapper(grid_0, grid_1, grid_2, stream, {args}) if not warmup:
bin.c_wrapper(grid_0, grid_1, grid_2, stream, {args})
return bin return bin
# kernel not cached -- compile # kernel not cached -- compile
except KeyError: except KeyError:
@@ -271,7 +273,8 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage
device = 0 device = 0
if not self._call_hook(key, signature, device, constants, num_warps, num_stages, extern_libs, configs): if not self._call_hook(key, signature, device, constants, num_warps, num_stages, extern_libs, configs):
bin = triton.compile(self, signature, device, constants, num_warps, num_stages, extern_libs=extern_libs, configs=configs) bin = triton.compile(self, signature, device, constants, num_warps, num_stages, extern_libs=extern_libs, configs=configs)
bin.c_wrapper(grid_0, grid_1, grid_2, stream, *args) if not warmup:
bin.c_wrapper(grid_0, grid_1, grid_2, stream, *args)
self.cache[key] = bin self.cache[key] = bin
return bin return bin
return None return None
@@ -317,7 +320,6 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage
self.__module__ = fn.__module__ self.__module__ = fn.__module__
@property @property
@functools.lru_cache()
def cache_key(self): def cache_key(self):
# TODO : hash should be attribute of `self` # TODO : hash should be attribute of `self`
if self.hash is None: if self.hash is None:
@@ -326,6 +328,9 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage
self.hash = dependencies_finder.ret + version_key() self.hash = dependencies_finder.ret + version_key()
return self.hash return self.hash
def warmup(self, *args, **kwargs):
return self.run(*map(MockTensor.wrap_dtype, args), **kwargs, warmup=True)
# we do not parse `src` in the constructor because # we do not parse `src` in the constructor because
# the user might want to monkey-patch self.src dynamically. # the user might want to monkey-patch self.src dynamically.
# Our unit tests do this, for example. # Our unit tests do this, for example.
@@ -349,7 +354,6 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage
# to be reinitialized # to be reinitialized
if name == 'src': if name == 'src':
self.hash = None self.hash = None
JITFunction.cache_key.fget.cache_clear()
def __repr__(self): def __repr__(self):
return f"JITFunction({self.module}:{self.fn.__name__})" return f"JITFunction({self.module}:{self.fn.__name__})"

View File

@@ -19,6 +19,24 @@ def next_power_of_2(n):
return n return n
class MockTensor:
"""
Can be used in place of real tensors when calling:
kernel.warmup(MockTensor(torch.float32), ...)
"""
@staticmethod
def wrap_dtype(arg):
if isinstance(arg, torch.dtype):
return MockTensor(arg)
return arg
def __init__(self, dtype):
self.dtype = dtype
def data_ptr(self):
return 0 # optimistically assumes multiple of 16
class TensorWrapper: class TensorWrapper:
def __init__(self, base, dtype): def __init__(self, base, dtype):
self.dtype = dtype self.dtype = dtype