[FRONTEND] Made test_if/test_default pass (#823)
This commit is contained in:
@@ -1311,24 +1311,24 @@ def test_vectorization(N):
|
||||
# # TODO: can't be local to test_default
|
||||
|
||||
|
||||
# @triton.jit
|
||||
# def _impl(value=10):
|
||||
# return value
|
||||
@triton.jit
|
||||
def _impl(value=10):
|
||||
return value
|
||||
|
||||
|
||||
# def test_default():
|
||||
# value = 5
|
||||
# ret0 = torch.zeros(1, dtype=torch.int32, device='cuda')
|
||||
# ret1 = torch.zeros(1, dtype=torch.int32, device='cuda')
|
||||
def test_default():
|
||||
value = 5
|
||||
ret0 = torch.zeros(1, dtype=torch.int32, device='cuda')
|
||||
ret1 = torch.zeros(1, dtype=torch.int32, device='cuda')
|
||||
|
||||
# @triton.jit
|
||||
# def _kernel(ret0, ret1, value):
|
||||
# tl.store(ret0, _impl())
|
||||
# tl.store(ret1, _impl(value))
|
||||
@triton.jit
|
||||
def _kernel(ret0, ret1, value):
|
||||
tl.store(ret0, _impl())
|
||||
tl.store(ret1, _impl(value))
|
||||
|
||||
# _kernel[(1,)](ret0, ret1, value)
|
||||
# assert ret0.item() == 10
|
||||
# assert ret1.item() == value
|
||||
_kernel[(1,)](ret0, ret1, value)
|
||||
assert ret0.item() == 10
|
||||
assert ret1.item() == value
|
||||
|
||||
# # ---------------
|
||||
# # test noop
|
||||
@@ -1484,36 +1484,36 @@ def test_call():
|
||||
# # -------------
|
||||
|
||||
|
||||
# def test_if():
|
||||
def test_if():
|
||||
|
||||
# @triton.jit
|
||||
# def kernel(Cond, XTrue, XFalse, Ret):
|
||||
# pid = tl.program_id(0)
|
||||
# cond = tl.load(Cond)
|
||||
# if pid % 2:
|
||||
# tl.store(Ret, tl.load(XTrue))
|
||||
# else:
|
||||
# tl.store(Ret, tl.load(XFalse))
|
||||
@triton.jit
|
||||
def kernel(Cond, XTrue, XFalse, Ret):
|
||||
pid = tl.program_id(0)
|
||||
cond = tl.load(Cond)
|
||||
if pid % 2:
|
||||
tl.store(Ret, tl.load(XTrue))
|
||||
else:
|
||||
tl.store(Ret, tl.load(XFalse))
|
||||
|
||||
# cond = torch.ones(1, dtype=torch.int32, device='cuda')
|
||||
# x_true = torch.tensor([3.14], dtype=torch.float32, device='cuda')
|
||||
# x_false = torch.tensor([1.51], dtype=torch.float32, device='cuda')
|
||||
# ret = torch.empty(1, dtype=torch.float32, device='cuda')
|
||||
# kernel[(1,)](cond, x_true, x_false, ret)
|
||||
cond = torch.ones(1, dtype=torch.int32, device='cuda')
|
||||
x_true = torch.tensor([3.14], dtype=torch.float32, device='cuda')
|
||||
x_false = torch.tensor([1.51], dtype=torch.float32, device='cuda')
|
||||
ret = torch.empty(1, dtype=torch.float32, device='cuda')
|
||||
kernel[(1,)](cond, x_true, x_false, ret)
|
||||
|
||||
|
||||
# def test_num_warps_pow2():
|
||||
# dst = torch.empty(128, device='cuda')
|
||||
def test_num_warps_pow2():
|
||||
dst = torch.empty(128, device='cuda')
|
||||
|
||||
# @triton.jit
|
||||
# def _kernel(dst):
|
||||
# pass
|
||||
@triton.jit
|
||||
def _kernel(dst):
|
||||
pass
|
||||
|
||||
# with pytest.raises(AssertionError, match='must be a power of 2'):
|
||||
# _kernel[(1,)](dst=dst, num_warps=3)
|
||||
# _kernel[(1,)](dst=dst, num_warps=1)
|
||||
# _kernel[(1,)](dst=dst, num_warps=2)
|
||||
# _kernel[(1,)](dst=dst, num_warps=4)
|
||||
with pytest.raises(AssertionError, match='must be a power of 2'):
|
||||
_kernel[(1,)](dst=dst, num_warps=3)
|
||||
_kernel[(1,)](dst=dst, num_warps=1)
|
||||
_kernel[(1,)](dst=dst, num_warps=2)
|
||||
_kernel[(1,)](dst=dst, num_warps=4)
|
||||
|
||||
# # -------------
|
||||
# # test extern
|
||||
|
Reference in New Issue
Block a user