[FRONTEND] Made test_if/test_default pass (#823)

This commit is contained in:
Philippe Tillet
2022-10-30 15:32:55 -07:00
committed by GitHub
parent e61dc75942
commit cb1b87a688
2 changed files with 45 additions and 42 deletions

View File

@@ -1311,24 +1311,24 @@ def test_vectorization(N):
# # TODO: can't be local to test_default # # TODO: can't be local to test_default
# @triton.jit @triton.jit
# def _impl(value=10): def _impl(value=10):
# return value return value
# def test_default(): def test_default():
# value = 5 value = 5
# ret0 = torch.zeros(1, dtype=torch.int32, device='cuda') ret0 = torch.zeros(1, dtype=torch.int32, device='cuda')
# ret1 = torch.zeros(1, dtype=torch.int32, device='cuda') ret1 = torch.zeros(1, dtype=torch.int32, device='cuda')
# @triton.jit @triton.jit
# def _kernel(ret0, ret1, value): def _kernel(ret0, ret1, value):
# tl.store(ret0, _impl()) tl.store(ret0, _impl())
# tl.store(ret1, _impl(value)) tl.store(ret1, _impl(value))
# _kernel[(1,)](ret0, ret1, value) _kernel[(1,)](ret0, ret1, value)
# assert ret0.item() == 10 assert ret0.item() == 10
# assert ret1.item() == value assert ret1.item() == value
# # --------------- # # ---------------
# # test noop # # test noop
@@ -1484,36 +1484,36 @@ def test_call():
# # ------------- # # -------------
# def test_if(): def test_if():
# @triton.jit @triton.jit
# def kernel(Cond, XTrue, XFalse, Ret): def kernel(Cond, XTrue, XFalse, Ret):
# pid = tl.program_id(0) pid = tl.program_id(0)
# cond = tl.load(Cond) cond = tl.load(Cond)
# if pid % 2: if pid % 2:
# tl.store(Ret, tl.load(XTrue)) tl.store(Ret, tl.load(XTrue))
# else: else:
# tl.store(Ret, tl.load(XFalse)) tl.store(Ret, tl.load(XFalse))
# cond = torch.ones(1, dtype=torch.int32, device='cuda') cond = torch.ones(1, dtype=torch.int32, device='cuda')
# x_true = torch.tensor([3.14], dtype=torch.float32, device='cuda') x_true = torch.tensor([3.14], dtype=torch.float32, device='cuda')
# x_false = torch.tensor([1.51], dtype=torch.float32, device='cuda') x_false = torch.tensor([1.51], dtype=torch.float32, device='cuda')
# ret = torch.empty(1, dtype=torch.float32, device='cuda') ret = torch.empty(1, dtype=torch.float32, device='cuda')
# kernel[(1,)](cond, x_true, x_false, ret) kernel[(1,)](cond, x_true, x_false, ret)
# def test_num_warps_pow2(): def test_num_warps_pow2():
# dst = torch.empty(128, device='cuda') dst = torch.empty(128, device='cuda')
# @triton.jit @triton.jit
# def _kernel(dst): def _kernel(dst):
# pass pass
# with pytest.raises(AssertionError, match='must be a power of 2'): with pytest.raises(AssertionError, match='must be a power of 2'):
# _kernel[(1,)](dst=dst, num_warps=3) _kernel[(1,)](dst=dst, num_warps=3)
# _kernel[(1,)](dst=dst, num_warps=1) _kernel[(1,)](dst=dst, num_warps=1)
# _kernel[(1,)](dst=dst, num_warps=2) _kernel[(1,)](dst=dst, num_warps=2)
# _kernel[(1,)](dst=dst, num_warps=4) _kernel[(1,)](dst=dst, num_warps=4)
# # ------------- # # -------------
# # test extern # # test extern

View File

@@ -196,7 +196,7 @@ class CodeGenerator(ast.NodeVisitor):
return tuple(ret_types) return tuple(ret_types)
else: else:
ret = triton.language.core._to_tensor(ret_value, self.builder) ret = triton.language.core._to_tensor(ret_value, self.builder)
self.builder.ret([ret_value.handle]) self.builder.ret([ret.handle])
return ret.type return ret.type
def visit_FunctionDef(self, node): def visit_FunctionDef(self, node):
@@ -399,12 +399,14 @@ class CodeGenerator(ast.NodeVisitor):
if_op = self.builder.create_if_op([ty.to_ir(self.builder) for ty in ret_types], cond.handle, True) if_op = self.builder.create_if_op([ty.to_ir(self.builder) for ty in ret_types], cond.handle, True)
then_block.merge_block_before(if_op.get_then_block()) then_block.merge_block_before(if_op.get_then_block())
self.builder.set_insertion_point_to_end(if_op.get_then_block()) self.builder.set_insertion_point_to_end(if_op.get_then_block())
if len(names) > 0:
self.builder.create_yield_op([then_defs[n].handle for n in names]) self.builder.create_yield_op([then_defs[n].handle for n in names])
if not node.orelse: if not node.orelse:
else_block = if_op.get_else_block() else_block = if_op.get_else_block()
else: else:
else_block.merge_block_before(if_op.get_else_block()) else_block.merge_block_before(if_op.get_else_block())
self.builder.set_insertion_point_to_end(if_op.get_else_block()) self.builder.set_insertion_point_to_end(if_op.get_else_block())
if len(names) > 0:
self.builder.create_yield_op([else_defs[n].handle for n in names]) self.builder.create_yield_op([else_defs[n].handle for n in names])
else: # no else block else: # no else block
if_op = self.builder.create_if_op([ty.to_ir(self.builder) for ty in ret_types], cond.handle, False) if_op = self.builder.create_if_op([ty.to_ir(self.builder) for ty in ret_types], cond.handle, False)
@@ -526,6 +528,7 @@ class CodeGenerator(ast.NodeVisitor):
[ty.to_ir(self.builder) for ty in ret_types]) [ty.to_ir(self.builder) for ty in ret_types])
loop_block.merge_block_before(after_block) loop_block.merge_block_before(after_block)
self.builder.set_insertion_point_to_end(after_block) self.builder.set_insertion_point_to_end(after_block)
if len(yields) > 0:
self.builder.create_yield_op([y.handle for y in yields]) self.builder.create_yield_op([y.handle for y in yields])
# update global uses in while_op # update global uses in while_op