[FRONTEND] Better cache hook (#400)
Added an additional `repr` argument to the cache hook, which represents a human-readable string representation of the signature and argument attributes associated with the compiled binary.
This commit is contained in:
@@ -64,20 +64,21 @@ def reset_tmp_dir():
|
||||
|
||||
def test_reuse():
|
||||
counter = 0
|
||||
def inc_counter(key, binary):
|
||||
def inc_counter(key, binary, repr):
|
||||
nonlocal counter
|
||||
counter += 1
|
||||
JITFunction.cache_hook = inc_counter
|
||||
reset_tmp_dir()
|
||||
x = torch.empty(1, dtype=torch.int32, device='cuda')
|
||||
for i in range(10):
|
||||
kernel[(1,)](x, 43, BLOCK=1024)
|
||||
kernel[(1,)](x, 1, BLOCK=1024)
|
||||
assert counter == 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize('mode', ['enable', 'disable'])
|
||||
def test_specialize(mode):
|
||||
counter = 0
|
||||
def inc_counter(key, binary):
|
||||
def inc_counter(key, binary, repr):
|
||||
nonlocal counter
|
||||
counter += 1
|
||||
JITFunction.cache_hook = inc_counter
|
||||
|
Reference in New Issue
Block a user