[FRONTEND] Add .warmup() for triton.jit() (#671)

This commit is contained in:
Jason Ansel
2022-09-18 23:09:34 -07:00
committed by GitHub
parent 82956e5d6b
commit 93b1adc53b
4 changed files with 90 additions and 19 deletions

View File

@@ -150,3 +150,25 @@ def test_constexpr_not_callable() -> None:
except BaseException:
error = True
assert error is True
def test_jit_warmup_cache() -> None:
@triton.jit
def kernel_add(a, b, o, N: tl.constexpr):
idx = tl.arange(0, N)
tl.store(o + idx,
tl.load(a + idx) + tl.load(b + idx))
args = [
torch.randn(32, dtype=torch.float32, device="cuda"),
torch.randn(32, dtype=torch.float32, device="cuda"),
torch.randn(32, dtype=torch.float32, device="cuda"),
32,
]
assert len(kernel_add.cache) == 0
kernel_add[(1,)].warmup(torch.float32, torch.float32, torch.float32, 32)
assert len(kernel_add.cache) == 1
kernel_add[(1,)].warmup(*args)
assert len(kernel_add.cache) == 1
kernel_add[(1,)](*args)
assert len(kernel_add.cache) == 1