[FRONTEND] Added default arguments to non-kernel @triton.jit'd function (#379)
This commit is contained in:
@@ -634,6 +634,28 @@ def test_load_cache_modifier(cache):
|
||||
# test while
|
||||
# ---------------
|
||||
|
||||
# ---------------
|
||||
# test default
|
||||
# ---------------
|
||||
#TODO: can't be local to test_default
|
||||
@triton.jit
|
||||
def _impl(value = 10):
|
||||
return value
|
||||
|
||||
def test_default():
|
||||
value = 5
|
||||
ret0 = torch.zeros(1, dtype=torch.int32, device='cuda')
|
||||
ret1 = torch.zeros(1, dtype=torch.int32, device='cuda')
|
||||
|
||||
@triton.jit
|
||||
def _kernel(ret0, ret1, value):
|
||||
tl.store(ret0, _impl())
|
||||
tl.store(ret1, _impl(value))
|
||||
|
||||
_kernel[(1,)](ret0, ret1, value)
|
||||
assert ret0.item() == 10
|
||||
assert ret1.item() == value
|
||||
|
||||
# ---------------
|
||||
# test noop
|
||||
#----------------
|
||||
|
Reference in New Issue
Block a user