[FRONTEND] Made test_if/test_default pass (#823)
This commit is contained in:
@@ -1311,24 +1311,24 @@ def test_vectorization(N):
|
|||||||
# # TODO: can't be local to test_default
|
# # TODO: can't be local to test_default
|
||||||
|
|
||||||
|
|
||||||
# @triton.jit
|
@triton.jit
|
||||||
# def _impl(value=10):
|
def _impl(value=10):
|
||||||
# return value
|
return value
|
||||||
|
|
||||||
|
|
||||||
# def test_default():
|
def test_default():
|
||||||
# value = 5
|
value = 5
|
||||||
# ret0 = torch.zeros(1, dtype=torch.int32, device='cuda')
|
ret0 = torch.zeros(1, dtype=torch.int32, device='cuda')
|
||||||
# ret1 = torch.zeros(1, dtype=torch.int32, device='cuda')
|
ret1 = torch.zeros(1, dtype=torch.int32, device='cuda')
|
||||||
|
|
||||||
# @triton.jit
|
@triton.jit
|
||||||
# def _kernel(ret0, ret1, value):
|
def _kernel(ret0, ret1, value):
|
||||||
# tl.store(ret0, _impl())
|
tl.store(ret0, _impl())
|
||||||
# tl.store(ret1, _impl(value))
|
tl.store(ret1, _impl(value))
|
||||||
|
|
||||||
# _kernel[(1,)](ret0, ret1, value)
|
_kernel[(1,)](ret0, ret1, value)
|
||||||
# assert ret0.item() == 10
|
assert ret0.item() == 10
|
||||||
# assert ret1.item() == value
|
assert ret1.item() == value
|
||||||
|
|
||||||
# # ---------------
|
# # ---------------
|
||||||
# # test noop
|
# # test noop
|
||||||
@@ -1484,36 +1484,36 @@ def test_call():
|
|||||||
# # -------------
|
# # -------------
|
||||||
|
|
||||||
|
|
||||||
# def test_if():
|
def test_if():
|
||||||
|
|
||||||
# @triton.jit
|
@triton.jit
|
||||||
# def kernel(Cond, XTrue, XFalse, Ret):
|
def kernel(Cond, XTrue, XFalse, Ret):
|
||||||
# pid = tl.program_id(0)
|
pid = tl.program_id(0)
|
||||||
# cond = tl.load(Cond)
|
cond = tl.load(Cond)
|
||||||
# if pid % 2:
|
if pid % 2:
|
||||||
# tl.store(Ret, tl.load(XTrue))
|
tl.store(Ret, tl.load(XTrue))
|
||||||
# else:
|
else:
|
||||||
# tl.store(Ret, tl.load(XFalse))
|
tl.store(Ret, tl.load(XFalse))
|
||||||
|
|
||||||
# cond = torch.ones(1, dtype=torch.int32, device='cuda')
|
cond = torch.ones(1, dtype=torch.int32, device='cuda')
|
||||||
# x_true = torch.tensor([3.14], dtype=torch.float32, device='cuda')
|
x_true = torch.tensor([3.14], dtype=torch.float32, device='cuda')
|
||||||
# x_false = torch.tensor([1.51], dtype=torch.float32, device='cuda')
|
x_false = torch.tensor([1.51], dtype=torch.float32, device='cuda')
|
||||||
# ret = torch.empty(1, dtype=torch.float32, device='cuda')
|
ret = torch.empty(1, dtype=torch.float32, device='cuda')
|
||||||
# kernel[(1,)](cond, x_true, x_false, ret)
|
kernel[(1,)](cond, x_true, x_false, ret)
|
||||||
|
|
||||||
|
|
||||||
# def test_num_warps_pow2():
|
def test_num_warps_pow2():
|
||||||
# dst = torch.empty(128, device='cuda')
|
dst = torch.empty(128, device='cuda')
|
||||||
|
|
||||||
# @triton.jit
|
@triton.jit
|
||||||
# def _kernel(dst):
|
def _kernel(dst):
|
||||||
# pass
|
pass
|
||||||
|
|
||||||
# with pytest.raises(AssertionError, match='must be a power of 2'):
|
with pytest.raises(AssertionError, match='must be a power of 2'):
|
||||||
# _kernel[(1,)](dst=dst, num_warps=3)
|
_kernel[(1,)](dst=dst, num_warps=3)
|
||||||
# _kernel[(1,)](dst=dst, num_warps=1)
|
_kernel[(1,)](dst=dst, num_warps=1)
|
||||||
# _kernel[(1,)](dst=dst, num_warps=2)
|
_kernel[(1,)](dst=dst, num_warps=2)
|
||||||
# _kernel[(1,)](dst=dst, num_warps=4)
|
_kernel[(1,)](dst=dst, num_warps=4)
|
||||||
|
|
||||||
# # -------------
|
# # -------------
|
||||||
# # test extern
|
# # test extern
|
||||||
|
@@ -196,7 +196,7 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
return tuple(ret_types)
|
return tuple(ret_types)
|
||||||
else:
|
else:
|
||||||
ret = triton.language.core._to_tensor(ret_value, self.builder)
|
ret = triton.language.core._to_tensor(ret_value, self.builder)
|
||||||
self.builder.ret([ret_value.handle])
|
self.builder.ret([ret.handle])
|
||||||
return ret.type
|
return ret.type
|
||||||
|
|
||||||
def visit_FunctionDef(self, node):
|
def visit_FunctionDef(self, node):
|
||||||
@@ -399,13 +399,15 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
if_op = self.builder.create_if_op([ty.to_ir(self.builder) for ty in ret_types], cond.handle, True)
|
if_op = self.builder.create_if_op([ty.to_ir(self.builder) for ty in ret_types], cond.handle, True)
|
||||||
then_block.merge_block_before(if_op.get_then_block())
|
then_block.merge_block_before(if_op.get_then_block())
|
||||||
self.builder.set_insertion_point_to_end(if_op.get_then_block())
|
self.builder.set_insertion_point_to_end(if_op.get_then_block())
|
||||||
self.builder.create_yield_op([then_defs[n].handle for n in names])
|
if len(names) > 0:
|
||||||
|
self.builder.create_yield_op([then_defs[n].handle for n in names])
|
||||||
if not node.orelse:
|
if not node.orelse:
|
||||||
else_block = if_op.get_else_block()
|
else_block = if_op.get_else_block()
|
||||||
else:
|
else:
|
||||||
else_block.merge_block_before(if_op.get_else_block())
|
else_block.merge_block_before(if_op.get_else_block())
|
||||||
self.builder.set_insertion_point_to_end(if_op.get_else_block())
|
self.builder.set_insertion_point_to_end(if_op.get_else_block())
|
||||||
self.builder.create_yield_op([else_defs[n].handle for n in names])
|
if len(names) > 0:
|
||||||
|
self.builder.create_yield_op([else_defs[n].handle for n in names])
|
||||||
else: # no else block
|
else: # no else block
|
||||||
if_op = self.builder.create_if_op([ty.to_ir(self.builder) for ty in ret_types], cond.handle, False)
|
if_op = self.builder.create_if_op([ty.to_ir(self.builder) for ty in ret_types], cond.handle, False)
|
||||||
then_block.merge_block_before(if_op.get_then_block())
|
then_block.merge_block_before(if_op.get_then_block())
|
||||||
@@ -526,7 +528,8 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
[ty.to_ir(self.builder) for ty in ret_types])
|
[ty.to_ir(self.builder) for ty in ret_types])
|
||||||
loop_block.merge_block_before(after_block)
|
loop_block.merge_block_before(after_block)
|
||||||
self.builder.set_insertion_point_to_end(after_block)
|
self.builder.set_insertion_point_to_end(after_block)
|
||||||
self.builder.create_yield_op([y.handle for y in yields])
|
if len(yields) > 0:
|
||||||
|
self.builder.create_yield_op([y.handle for y in yields])
|
||||||
|
|
||||||
# update global uses in while_op
|
# update global uses in while_op
|
||||||
for i, name in enumerate(names):
|
for i, name in enumerate(names):
|
||||||
|
Reference in New Issue
Block a user