[FRONTEND] Add warmup for triton.jit() (#684)
This revives #671 , removing the static functions that may unnecessarily hold a reference to the grid and the JITFunction object Co-authored-by: Jason Ansel <jansel@jansel.net>
This commit is contained in:
@@ -150,3 +150,25 @@ def test_constexpr_not_callable() -> None:
|
||||
except BaseException:
|
||||
error = True
|
||||
assert error is True
|
||||
|
||||
|
||||
def test_jit_warmup_cache() -> None:
|
||||
@triton.jit
|
||||
def kernel_add(a, b, o, N: tl.constexpr):
|
||||
idx = tl.arange(0, N)
|
||||
tl.store(o + idx,
|
||||
tl.load(a + idx) + tl.load(b + idx))
|
||||
|
||||
args = [
|
||||
torch.randn(32, dtype=torch.float32, device="cuda"),
|
||||
torch.randn(32, dtype=torch.float32, device="cuda"),
|
||||
torch.randn(32, dtype=torch.float32, device="cuda"),
|
||||
32,
|
||||
]
|
||||
assert len(kernel_add.cache) == 0
|
||||
kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1,))
|
||||
assert len(kernel_add.cache) == 1
|
||||
kernel_add.warmup(*args, grid=(1,))
|
||||
assert len(kernel_add.cache) == 1
|
||||
kernel_add.warmup(*args, grid=(1,))
|
||||
assert len(kernel_add.cache) == 1
|
||||
|
Reference in New Issue
Block a user