[FRONTEND] Add warmup for triton.jit() (#684)

This revives #671 , removing the static functions that may unnecessarily hold a reference to the grid and the JITFunction object

Co-authored-by: Jason Ansel <jansel@jansel.net>
This commit is contained in:
Philippe Tillet
2022-09-21 12:13:20 -07:00
committed by GitHub
parent 6abe813d1c
commit 677ddae618
4 changed files with 80 additions and 16 deletions

View File

@@ -150,3 +150,25 @@ def test_constexpr_not_callable() -> None:
except BaseException:
error = True
assert error is True
def test_jit_warmup_cache() -> None:
@triton.jit
def kernel_add(a, b, o, N: tl.constexpr):
idx = tl.arange(0, N)
tl.store(o + idx,
tl.load(a + idx) + tl.load(b + idx))
args = [
torch.randn(32, dtype=torch.float32, device="cuda"),
torch.randn(32, dtype=torch.float32, device="cuda"),
torch.randn(32, dtype=torch.float32, device="cuda"),
32,
]
assert len(kernel_add.cache) == 0
kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1,))
assert len(kernel_add.cache) == 1
kernel_add.warmup(*args, grid=(1,))
assert len(kernel_add.cache) == 1
kernel_add.warmup(*args, grid=(1,))
assert len(kernel_add.cache) == 1