[FRONTEND] Add warmup for triton.jit() (#684)
This revives #671 , removing the static functions that may unnecessarily hold a reference to the grid and the JITFunction object Co-authored-by: Jason Ansel <jansel@jansel.net>
This commit is contained in:
@@ -150,3 +150,25 @@ def test_constexpr_not_callable() -> None:
|
|||||||
except BaseException:
|
except BaseException:
|
||||||
error = True
|
error = True
|
||||||
assert error is True
|
assert error is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_jit_warmup_cache() -> None:
|
||||||
|
@triton.jit
|
||||||
|
def kernel_add(a, b, o, N: tl.constexpr):
|
||||||
|
idx = tl.arange(0, N)
|
||||||
|
tl.store(o + idx,
|
||||||
|
tl.load(a + idx) + tl.load(b + idx))
|
||||||
|
|
||||||
|
args = [
|
||||||
|
torch.randn(32, dtype=torch.float32, device="cuda"),
|
||||||
|
torch.randn(32, dtype=torch.float32, device="cuda"),
|
||||||
|
torch.randn(32, dtype=torch.float32, device="cuda"),
|
||||||
|
32,
|
||||||
|
]
|
||||||
|
assert len(kernel_add.cache) == 0
|
||||||
|
kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1,))
|
||||||
|
assert len(kernel_add.cache) == 1
|
||||||
|
kernel_add.warmup(*args, grid=(1,))
|
||||||
|
assert len(kernel_add.cache) == 1
|
||||||
|
kernel_add.warmup(*args, grid=(1,))
|
||||||
|
assert len(kernel_add.cache) == 1
|
||||||
|
@@ -68,16 +68,7 @@ class Autotuner(KernelInterface):
|
|||||||
key = tuple([args[i] for i in self.key_idx])
|
key = tuple([args[i] for i in self.key_idx])
|
||||||
if key not in self.cache:
|
if key not in self.cache:
|
||||||
# prune configs
|
# prune configs
|
||||||
pruned_configs = self.configs
|
pruned_configs = self.prune_configs(kwargs)
|
||||||
if self.early_config_prune:
|
|
||||||
pruned_configs = self.early_config_prune(self.configs, self.nargs)
|
|
||||||
if self.perf_model:
|
|
||||||
top_k = self.configs_top_k
|
|
||||||
if isinstance(top_k, float) and top_k <= 1.0:
|
|
||||||
top_k = int(len(self.configs) * top_k)
|
|
||||||
if len(pruned_configs) > top_k:
|
|
||||||
est_timing = {config: self.perf_model(**self.nargs, **kwargs, **config.kwargs, num_stages=config.num_stages, num_warps=config.num_warps) for config in pruned_configs}
|
|
||||||
pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k]
|
|
||||||
bench_start = time.time()
|
bench_start = time.time()
|
||||||
timings = {config: self._bench(*args, config=config, **kwargs)
|
timings = {config: self._bench(*args, config=config, **kwargs)
|
||||||
for config in pruned_configs}
|
for config in pruned_configs}
|
||||||
@@ -94,6 +85,35 @@ class Autotuner(KernelInterface):
|
|||||||
config.pre_hook(self.nargs)
|
config.pre_hook(self.nargs)
|
||||||
return self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs)
|
return self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs)
|
||||||
|
|
||||||
|
def prune_configs(self, kwargs):
|
||||||
|
pruned_configs = self.configs
|
||||||
|
if self.early_config_prune:
|
||||||
|
pruned_configs = self.early_config_prune(self.configs, self.nargs)
|
||||||
|
if self.perf_model:
|
||||||
|
top_k = self.configs_top_k
|
||||||
|
if isinstance(top_k, float) and top_k <= 1.0:
|
||||||
|
top_k = int(len(self.configs) * top_k)
|
||||||
|
if len(pruned_configs) > top_k:
|
||||||
|
est_timing = {
|
||||||
|
config: self.perf_model(**self.nargs, **kwargs, **config.kwargs, num_stages=config.num_stages,
|
||||||
|
num_warps=config.num_warps)
|
||||||
|
for config in pruned_configs
|
||||||
|
}
|
||||||
|
pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k]
|
||||||
|
return pruned_configs
|
||||||
|
|
||||||
|
def warmup(self, *args, **kwargs):
|
||||||
|
self.nargs = dict(zip(self.arg_names, args))
|
||||||
|
for config in self.prune_configs(kwargs):
|
||||||
|
self.fn.warmup(
|
||||||
|
*args,
|
||||||
|
num_warps=config.num_warps,
|
||||||
|
num_stages=config.num_stages,
|
||||||
|
**kwargs,
|
||||||
|
**config.kwargs,
|
||||||
|
)
|
||||||
|
self.nargs = None
|
||||||
|
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
"""
|
"""
|
||||||
|
@@ -12,6 +12,7 @@ from collections import namedtuple
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
import triton
|
import triton
|
||||||
|
from triton.utils import MockTensor
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from torch._C import _cuda_getCurrentRawStream as get_cuda_stream
|
from torch._C import _cuda_getCurrentRawStream as get_cuda_stream
|
||||||
@@ -231,7 +232,7 @@ class JITFunction(KernelInterface):
|
|||||||
grid_args = ','.join([f'"{arg}": {arg}' for arg in self.arg_names])
|
grid_args = ','.join([f'"{arg}": {arg}' for arg in self.arg_names])
|
||||||
|
|
||||||
src = f"""
|
src = f"""
|
||||||
def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stages=3, extern_libs=None, stream=None):
|
def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stages=3, extern_libs=None, stream=None, warmup=False):
|
||||||
sig_key = {sig_keys},
|
sig_key = {sig_keys},
|
||||||
constexpr_key = {f'{constexpr_keys},' if len(constexpr_keys) > 0 else tuple()}
|
constexpr_key = {f'{constexpr_keys},' if len(constexpr_keys) > 0 else tuple()}
|
||||||
spec_key = {f'{spec_keys},' if len(spec_keys) > 0 else tuple()}
|
spec_key = {f'{spec_keys},' if len(spec_keys) > 0 else tuple()}
|
||||||
@@ -247,10 +248,11 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage
|
|||||||
grid_2 = grid[2] if grid_size > 2 else 1
|
grid_2 = grid[2] if grid_size > 2 else 1
|
||||||
device = torch.cuda.current_device()
|
device = torch.cuda.current_device()
|
||||||
torch.cuda.set_device(device)
|
torch.cuda.set_device(device)
|
||||||
if stream is None:
|
if stream is None and not warmup:
|
||||||
stream = get_cuda_stream(device)
|
stream = get_cuda_stream(device)
|
||||||
try:
|
try:
|
||||||
bin = cache[key]
|
bin = cache[key]
|
||||||
|
if not warmup:
|
||||||
bin.c_wrapper(grid_0, grid_1, grid_2, stream, {args})
|
bin.c_wrapper(grid_0, grid_1, grid_2, stream, {args})
|
||||||
return bin
|
return bin
|
||||||
# kernel not cached -- compile
|
# kernel not cached -- compile
|
||||||
@@ -271,6 +273,7 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage
|
|||||||
device = 0
|
device = 0
|
||||||
if not self._call_hook(key, signature, device, constants, num_warps, num_stages, extern_libs, configs):
|
if not self._call_hook(key, signature, device, constants, num_warps, num_stages, extern_libs, configs):
|
||||||
bin = triton.compile(self, signature, device, constants, num_warps, num_stages, extern_libs=extern_libs, configs=configs)
|
bin = triton.compile(self, signature, device, constants, num_warps, num_stages, extern_libs=extern_libs, configs=configs)
|
||||||
|
if not warmup:
|
||||||
bin.c_wrapper(grid_0, grid_1, grid_2, stream, *args)
|
bin.c_wrapper(grid_0, grid_1, grid_2, stream, *args)
|
||||||
self.cache[key] = bin
|
self.cache[key] = bin
|
||||||
return bin
|
return bin
|
||||||
@@ -317,7 +320,6 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage
|
|||||||
self.__module__ = fn.__module__
|
self.__module__ = fn.__module__
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@functools.lru_cache()
|
|
||||||
def cache_key(self):
|
def cache_key(self):
|
||||||
# TODO : hash should be attribute of `self`
|
# TODO : hash should be attribute of `self`
|
||||||
if self.hash is None:
|
if self.hash is None:
|
||||||
@@ -326,6 +328,9 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage
|
|||||||
self.hash = dependencies_finder.ret + version_key()
|
self.hash = dependencies_finder.ret + version_key()
|
||||||
return self.hash
|
return self.hash
|
||||||
|
|
||||||
|
def warmup(self, *args, **kwargs):
|
||||||
|
return self.run(*map(MockTensor.wrap_dtype, args), **kwargs, warmup=True)
|
||||||
|
|
||||||
# we do not parse `src` in the constructor because
|
# we do not parse `src` in the constructor because
|
||||||
# the user might want to monkey-patch self.src dynamically.
|
# the user might want to monkey-patch self.src dynamically.
|
||||||
# Our unit tests do this, for example.
|
# Our unit tests do this, for example.
|
||||||
@@ -349,7 +354,6 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage
|
|||||||
# to be reinitialized
|
# to be reinitialized
|
||||||
if name == 'src':
|
if name == 'src':
|
||||||
self.hash = None
|
self.hash = None
|
||||||
JITFunction.cache_key.fget.cache_clear()
|
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"JITFunction({self.module}:{self.fn.__name__})"
|
return f"JITFunction({self.module}:{self.fn.__name__})"
|
||||||
|
@@ -19,6 +19,24 @@ def next_power_of_2(n):
|
|||||||
return n
|
return n
|
||||||
|
|
||||||
|
|
||||||
|
class MockTensor:
|
||||||
|
"""
|
||||||
|
Can be used in place of real tensors when calling:
|
||||||
|
kernel.warmup(MockTensor(torch.float32), ...)
|
||||||
|
"""
|
||||||
|
@staticmethod
|
||||||
|
def wrap_dtype(arg):
|
||||||
|
if isinstance(arg, torch.dtype):
|
||||||
|
return MockTensor(arg)
|
||||||
|
return arg
|
||||||
|
|
||||||
|
def __init__(self, dtype):
|
||||||
|
self.dtype = dtype
|
||||||
|
|
||||||
|
def data_ptr(self):
|
||||||
|
return 0 # optimistically assumes multiple of 16
|
||||||
|
|
||||||
|
|
||||||
class TensorWrapper:
|
class TensorWrapper:
|
||||||
def __init__(self, base, dtype):
|
def __init__(self, base, dtype):
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
|
Reference in New Issue
Block a user