[RUNTIME] Restored do_not_specialize (#374)

This commit is contained in:
Philippe Tillet
2021-11-12 15:06:55 -08:00
committed by GitHub
parent e66bf76354
commit 01cc3d4503
3 changed files with 35 additions and 7 deletions

View File

@@ -4,6 +4,7 @@ from triton.code_gen import JITFunction
import triton.language as tl
import os
import shutil
import pytest
tmpdir = ".tmp"
@@ -25,6 +26,11 @@ def kernel(X, i, BLOCK: tl.constexpr):
i = function_1(i)
tl.store(X, i)
@triton.jit(do_not_specialize=["i"])
def kernel_nospec(X, i, BLOCK: tl.constexpr):
i = i + 1
i = function_1(i)
tl.store(X, i)
def apply_src_change(target, old, new):
delattr(kernel.fn, 'hash')
@@ -51,16 +57,34 @@ def test_nested1_change():
updated = apply_src_change(function_1, 'i + 1', 'i + 2')
assert baseline != updated
def reset_tmp_dir():
os.environ["TRITON_CACHE_DIR"] = tmpdir
if os.path.exists(tmpdir):
shutil.rmtree(tmpdir)
def test_reuse():
counter = 0
def inc_counter(key, binary):
nonlocal counter
counter += 1
os.environ["TRITON_CACHE_DIR"] = tmpdir
if os.path.exists(tmpdir):
shutil.rmtree(tmpdir)
JITFunction.cache_hook = inc_counter
reset_tmp_dir()
x = torch.empty(1, dtype=torch.int32, device='cuda')
for i in range(10):
kernel[(1,)](x, 43, BLOCK=1024)
assert counter == 1
@pytest.mark.parametrize('mode', ['enable', 'disable'])
def test_specialize(mode):
counter = 0
def inc_counter(key, binary):
nonlocal counter
counter += 1
JITFunction.cache_hook = inc_counter
reset_tmp_dir()
x = torch.empty(1, dtype=torch.int32, device='cuda')
function = {'enable': kernel, 'disable': kernel_nospec}[mode]
target = {'enable': 5, 'disable': 1}[mode]
for i in [1, 2, 4, 8, 16, 32]:
function[(1,)](x, i, BLOCK=512)
assert counter == target