This allows compiling in a subprocess. I'm not seeing a ton of speedup from this, but figure it is a good change anyway.
207 lines
5.1 KiB
Python
207 lines
5.1 KiB
Python
import multiprocessing
|
|
import os
|
|
import re
|
|
import shutil
|
|
from collections import namedtuple
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
import triton
|
|
import triton.language as tl
|
|
from triton.runtime.jit import JITFunction
|
|
|
|
tmpdir = ".tmp"
|
|
|
|
|
|
@triton.jit
|
|
def function_1(i):
|
|
i = i + 1
|
|
i = function_2(i)
|
|
return i
|
|
|
|
|
|
@triton.jit
|
|
def function_2(i):
|
|
i = i + 1
|
|
return i
|
|
|
|
|
|
@triton.jit
|
|
def kernel(X, i, BLOCK: tl.constexpr):
|
|
i = i + 1
|
|
i = function_1(i)
|
|
tl.store(X, i)
|
|
|
|
|
|
@triton.jit(do_not_specialize=["i"])
|
|
def kernel_nospec(X, i, BLOCK: tl.constexpr):
|
|
i = i + 1
|
|
i = function_1(i)
|
|
tl.store(X, i)
|
|
|
|
|
|
def apply_src_change(target, old, new):
|
|
kernel.hash = None
|
|
function_1.hash = None
|
|
function_2.hash = None
|
|
function_1.src = function_1.src.replace(old, new)
|
|
target.src = target.src.replace(old, new)
|
|
ret = target.cache_key
|
|
target.src = target.src.replace(new, old)
|
|
return ret
|
|
|
|
|
|
def test_nochange():
|
|
baseline = kernel.cache_key
|
|
updated = apply_src_change(kernel, 'i + 1', 'i + 1')
|
|
assert baseline == updated
|
|
|
|
|
|
def test_toplevel_change():
|
|
baseline = kernel.cache_key
|
|
updated = apply_src_change(kernel, 'i + 1', 'i + 2')
|
|
assert baseline != updated
|
|
|
|
|
|
def test_nested1_change():
|
|
baseline = kernel.cache_key
|
|
updated = apply_src_change(function_1, 'i + 1', 'i + 2')
|
|
assert baseline != updated
|
|
|
|
|
|
def reset_tmp_dir():
|
|
os.environ["TRITON_CACHE_DIR"] = tmpdir
|
|
if os.path.exists(tmpdir):
|
|
shutil.rmtree(tmpdir)
|
|
|
|
|
|
def test_reuse():
|
|
counter = 0
|
|
|
|
def inc_counter(*args, **kwargs):
|
|
nonlocal counter
|
|
counter += 1
|
|
JITFunction.cache_hook = inc_counter
|
|
reset_tmp_dir()
|
|
x = torch.empty(1, dtype=torch.int32, device='cuda')
|
|
for i in range(10):
|
|
kernel[(1,)](x, 1, BLOCK=1024)
|
|
assert counter == 1
|
|
|
|
|
|
@pytest.mark.parametrize('mode', ['enable', 'disable'])
|
|
def test_specialize(mode):
|
|
counter = 0
|
|
|
|
def inc_counter(*args, **kwargs):
|
|
nonlocal counter
|
|
counter += 1
|
|
JITFunction.cache_hook = inc_counter
|
|
reset_tmp_dir()
|
|
x = torch.empty(1, dtype=torch.int32, device='cuda')
|
|
function = {'enable': kernel, 'disable': kernel_nospec}[mode]
|
|
target = {'enable': 3, 'disable': 1}[mode]
|
|
for i in [1, 2, 4, 8, 16, 32]:
|
|
function[(1,)](x, i, BLOCK=512)
|
|
assert counter == target
|
|
|
|
|
|
@pytest.mark.parametrize("value, value_type", [
|
|
(-1, 'i32'), (0, 'i32'), (1, 'i32'), (-2**31, 'i32'), (2**31 - 1, 'i32'),
|
|
(2**32, 'i64'), (2**63 - 1, 'i64'), (-2**63, 'i64'),
|
|
(2**31, 'u32'), (2**32 - 1, 'u32'), (2**63, 'u64'), (2**64 - 1, 'u64')
|
|
])
|
|
def test_value_specialization(value: int, value_type: str, device='cuda') -> None:
|
|
|
|
@triton.jit
|
|
def kernel(VALUE, X):
|
|
pass
|
|
|
|
cache_str = None
|
|
|
|
def get_cache_str(*args, **kwargs):
|
|
nonlocal cache_str
|
|
cache_str = kwargs["repr"]
|
|
triton.JITFunction.cache_hook = get_cache_str
|
|
reset_tmp_dir()
|
|
x = torch.tensor([3.14159], device='cuda')
|
|
kernel[(1, )](value, x)
|
|
triton.JITFunction.cache_hook = None
|
|
|
|
cache_str_match = re.match(r".*VALUE: (\w+).*", cache_str)
|
|
spec_type = None if cache_str_match is None else cache_str_match.group(1)
|
|
assert spec_type == value_type
|
|
|
|
|
|
def test_constexpr_not_callable() -> None:
|
|
@triton.jit
|
|
def kernel(X, c: tl.constexpr):
|
|
tl.store(X, 2)
|
|
|
|
x = torch.empty(1, dtype=torch.int32, device='cuda')
|
|
error = False
|
|
try:
|
|
kernel[(1, )](x, c="str")
|
|
except BaseException:
|
|
error = True
|
|
assert error is False
|
|
# try and catch
|
|
try:
|
|
kernel[(1, )](x, c=tl.abs)
|
|
except BaseException:
|
|
error = True
|
|
assert error is True
|
|
|
|
|
|
def test_jit_warmup_cache() -> None:
|
|
@triton.jit
|
|
def kernel_add(a, b, o, N: tl.constexpr):
|
|
idx = tl.arange(0, N)
|
|
tl.store(o + idx,
|
|
tl.load(a + idx) + tl.load(b + idx))
|
|
|
|
args = [
|
|
torch.randn(32, dtype=torch.float32, device="cuda"),
|
|
torch.randn(32, dtype=torch.float32, device="cuda"),
|
|
torch.randn(32, dtype=torch.float32, device="cuda"),
|
|
32,
|
|
]
|
|
assert len(kernel_add.cache) == 0
|
|
kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1,))
|
|
assert len(kernel_add.cache) == 1
|
|
kernel_add.warmup(*args, grid=(1,))
|
|
assert len(kernel_add.cache) == 1
|
|
kernel_add.warmup(*args, grid=(1,))
|
|
assert len(kernel_add.cache) == 1
|
|
|
|
|
|
def test_compile_in_subproc() -> None:
|
|
@triton.jit
|
|
def kernel_sub(a, b, o, N: tl.constexpr):
|
|
idx = tl.arange(0, N)
|
|
tl.store(o + idx,
|
|
tl.load(a + idx) - tl.load(b + idx) * 777)
|
|
|
|
major, minor = torch.cuda.get_device_capability(0)
|
|
cc = major * 10 + minor
|
|
config = namedtuple("instance_descriptor", [
|
|
"divisible_by_16", "equal_to_1"])(
|
|
tuple(range(4)),
|
|
())
|
|
|
|
proc = multiprocessing.Process(
|
|
target=triton.compile,
|
|
kwargs=dict(
|
|
fn=kernel_sub,
|
|
signature={0: "*fp32", 1: "*fp32", 2: "*fp32"},
|
|
device=0,
|
|
constants={3: 32},
|
|
configs=[config],
|
|
warm_cache_only=True,
|
|
cc=cc,
|
|
))
|
|
proc.start()
|
|
proc.join()
|
|
assert proc.exitcode == 0
|