From cb1b87a688b3c4f02eca89f559d97746f4f00b30 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Sun, 30 Oct 2022 15:32:55 -0700 Subject: [PATCH] [FRONTEND] Made test_if/test_default pass (#823) --- python/tests/test_core.py | 76 +++++++++++++++++++-------------------- python/triton/compiler.py | 11 +++--- 2 files changed, 45 insertions(+), 42 deletions(-) diff --git a/python/tests/test_core.py b/python/tests/test_core.py index 4a9b9ed10..1a6a15c82 100644 --- a/python/tests/test_core.py +++ b/python/tests/test_core.py @@ -1311,24 +1311,24 @@ def test_vectorization(N): # # TODO: can't be local to test_default -# @triton.jit -# def _impl(value=10): -# return value +@triton.jit +def _impl(value=10): + return value -# def test_default(): -# value = 5 -# ret0 = torch.zeros(1, dtype=torch.int32, device='cuda') -# ret1 = torch.zeros(1, dtype=torch.int32, device='cuda') +def test_default(): + value = 5 + ret0 = torch.zeros(1, dtype=torch.int32, device='cuda') + ret1 = torch.zeros(1, dtype=torch.int32, device='cuda') -# @triton.jit -# def _kernel(ret0, ret1, value): -# tl.store(ret0, _impl()) -# tl.store(ret1, _impl(value)) + @triton.jit + def _kernel(ret0, ret1, value): + tl.store(ret0, _impl()) + tl.store(ret1, _impl(value)) -# _kernel[(1,)](ret0, ret1, value) -# assert ret0.item() == 10 -# assert ret1.item() == value + _kernel[(1,)](ret0, ret1, value) + assert ret0.item() == 10 + assert ret1.item() == value # # --------------- # # test noop @@ -1484,36 +1484,36 @@ def test_call(): # # ------------- -# def test_if(): +def test_if(): -# @triton.jit -# def kernel(Cond, XTrue, XFalse, Ret): -# pid = tl.program_id(0) -# cond = tl.load(Cond) -# if pid % 2: -# tl.store(Ret, tl.load(XTrue)) -# else: -# tl.store(Ret, tl.load(XFalse)) + @triton.jit + def kernel(Cond, XTrue, XFalse, Ret): + pid = tl.program_id(0) + cond = tl.load(Cond) + if pid % 2: + tl.store(Ret, tl.load(XTrue)) + else: + tl.store(Ret, tl.load(XFalse)) -# cond = torch.ones(1, dtype=torch.int32, device='cuda') -# x_true = torch.tensor([3.14], dtype=torch.float32, device='cuda') -# x_false = torch.tensor([1.51], dtype=torch.float32, device='cuda') -# ret = torch.empty(1, dtype=torch.float32, device='cuda') -# kernel[(1,)](cond, x_true, x_false, ret) + cond = torch.ones(1, dtype=torch.int32, device='cuda') + x_true = torch.tensor([3.14], dtype=torch.float32, device='cuda') + x_false = torch.tensor([1.51], dtype=torch.float32, device='cuda') + ret = torch.empty(1, dtype=torch.float32, device='cuda') + kernel[(1,)](cond, x_true, x_false, ret) -# def test_num_warps_pow2(): -# dst = torch.empty(128, device='cuda') +def test_num_warps_pow2(): + dst = torch.empty(128, device='cuda') -# @triton.jit -# def _kernel(dst): -# pass + @triton.jit + def _kernel(dst): + pass -# with pytest.raises(AssertionError, match='must be a power of 2'): -# _kernel[(1,)](dst=dst, num_warps=3) -# _kernel[(1,)](dst=dst, num_warps=1) -# _kernel[(1,)](dst=dst, num_warps=2) -# _kernel[(1,)](dst=dst, num_warps=4) + with pytest.raises(AssertionError, match='must be a power of 2'): + _kernel[(1,)](dst=dst, num_warps=3) + _kernel[(1,)](dst=dst, num_warps=1) + _kernel[(1,)](dst=dst, num_warps=2) + _kernel[(1,)](dst=dst, num_warps=4) # # ------------- # # test extern diff --git a/python/triton/compiler.py b/python/triton/compiler.py index 8c93360ec..2558cd322 100644 --- a/python/triton/compiler.py +++ b/python/triton/compiler.py @@ -196,7 +196,7 @@ class CodeGenerator(ast.NodeVisitor): return tuple(ret_types) else: ret = triton.language.core._to_tensor(ret_value, self.builder) - self.builder.ret([ret_value.handle]) + self.builder.ret([ret.handle]) return ret.type def visit_FunctionDef(self, node): @@ -399,13 +399,15 @@ class CodeGenerator(ast.NodeVisitor): if_op = self.builder.create_if_op([ty.to_ir(self.builder) for ty in ret_types], cond.handle, True) then_block.merge_block_before(if_op.get_then_block()) self.builder.set_insertion_point_to_end(if_op.get_then_block()) - self.builder.create_yield_op([then_defs[n].handle for n in names]) + if len(names) > 0: + self.builder.create_yield_op([then_defs[n].handle for n in names]) if not node.orelse: else_block = if_op.get_else_block() else: else_block.merge_block_before(if_op.get_else_block()) self.builder.set_insertion_point_to_end(if_op.get_else_block()) - self.builder.create_yield_op([else_defs[n].handle for n in names]) + if len(names) > 0: + self.builder.create_yield_op([else_defs[n].handle for n in names]) else: # no else block if_op = self.builder.create_if_op([ty.to_ir(self.builder) for ty in ret_types], cond.handle, False) then_block.merge_block_before(if_op.get_then_block()) @@ -526,7 +528,8 @@ class CodeGenerator(ast.NodeVisitor): [ty.to_ir(self.builder) for ty in ret_types]) loop_block.merge_block_before(after_block) self.builder.set_insertion_point_to_end(after_block) - self.builder.create_yield_op([y.handle for y in yields]) + if len(yields) > 0: + self.builder.create_yield_op([y.handle for y in yields]) # update global uses in while_op for i, name in enumerate(names):