[FRONTEND] Added default arguments to non-kernel @triton.jit'd function (#379)

This commit is contained in:
Philippe Tillet
2021-11-29 19:11:26 -08:00
committed by GitHub
parent 1296eb877b
commit c86ad9c9ab
4 changed files with 149 additions and 122 deletions

View File

@@ -634,6 +634,28 @@ def test_load_cache_modifier(cache):
# test while
# ---------------
# ---------------
# test default
# ---------------
#TODO: can't be local to test_default
@triton.jit
def _impl(value = 10):
return value
def test_default():
value = 5
ret0 = torch.zeros(1, dtype=torch.int32, device='cuda')
ret1 = torch.zeros(1, dtype=torch.int32, device='cuda')
@triton.jit
def _kernel(ret0, ret1, value):
tl.store(ret0, _impl())
tl.store(ret1, _impl(value))
_kernel[(1,)](ret0, ret1, value)
assert ret0.item() == 10
assert ret1.item() == value
# ---------------
# test noop
#----------------