[RUNTIME] Bunch of bugfixes (#372)

This commit is contained in:
Philippe Tillet
2021-11-12 00:55:00 -08:00
committed by GitHub
parent f7ab96cfd7
commit e66bf76354
5 changed files with 128 additions and 17 deletions

View File

@@ -0,0 +1,66 @@
import torch
import triton
from triton.code_gen import JITFunction
import triton.language as tl
import os
import shutil
tmpdir = ".tmp"
@triton.jit
def function_1(i):
i = i + 1
i = function_2(i)
return i
@triton.jit
def function_2(i):
i = i + 1
return i
@triton.jit
def kernel(X, i, BLOCK: tl.constexpr):
i = i + 1
i = function_1(i)
tl.store(X, i)
def apply_src_change(target, old, new):
delattr(kernel.fn, 'hash')
delattr(function_1.fn, 'hash')
delattr(function_2.fn, 'hash')
function_1.src = function_1.src.replace(old, new)
target.src = target.src.replace(old, new)
ret = target.cache_key
target.src = target.src.replace(new, old)
return ret
def test_nochange():
baseline = kernel.cache_key
updated = apply_src_change(kernel, 'i + 1', 'i + 1')
assert baseline == updated
def test_toplevel_change():
baseline = kernel.cache_key
updated = apply_src_change(kernel, 'i + 1', 'i + 2')
assert baseline != updated
def test_nested1_change():
baseline = kernel.cache_key
updated = apply_src_change(function_1, 'i + 1', 'i + 2')
assert baseline != updated
def test_reuse():
counter = 0
def inc_counter(key, binary):
nonlocal counter
counter += 1
os.environ["TRITON_CACHE_DIR"] = tmpdir
if os.path.exists(tmpdir):
shutil.rmtree(tmpdir)
JITFunction.cache_hook = inc_counter
x = torch.empty(1, dtype=torch.int32, device='cuda')
for i in range(10):
kernel[(1,)](x, 43, BLOCK=1024)
assert counter == 1