From 9def1bcebfa3396fd9895e6c81390198ab8f7385 Mon Sep 17 00:00:00 2001 From: donproc Date: Thu, 1 Dec 2022 21:43:26 +0800 Subject: [PATCH] [TRITON-MLIR][FRONTEND]minor fix to run through atomic_cas test (#925) Co-authored-by: dongdongl --- python/tests/test_core.py | 50 ++++++++++++++++----------------------- python/triton/compiler.py | 3 +-- 2 files changed, 21 insertions(+), 32 deletions(-) diff --git a/python/tests/test_core.py b/python/tests/test_core.py index a32130b40..127250afd 100644 --- a/python/tests/test_core.py +++ b/python/tests/test_core.py @@ -677,36 +677,7 @@ def test_tensor_atomic_rmw(shape, axis, device="cuda"): kernel[(1,)](z_tri, x_tri, axis, shape0, shape1) np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=1e-4) -# def test_atomic_cas(): -# # 1. make sure that atomic_cas changes the original value (Lock) -# @triton.jit -# def change_value(Lock): -# tl.atomic_cas(Lock, 0, 1) - -# Lock = torch.zeros((1,), device='cuda', dtype=torch.int32) -# change_value[(1,)](Lock) - -# assert (Lock[0] == 1) - -# # 2. only one block enters the critical section -# @triton.jit -# def serialized_add(data, Lock): -# ptrs = data + tl.arange(0, 128) -# while tl.atomic_cas(Lock, 0, 1) == 1: -# pass - -# tl.store(ptrs, tl.load(ptrs) + 1.0) - -# # release lock -# tl.atomic_xchg(Lock, 0) - -# Lock = torch.zeros((1,), device='cuda', dtype=torch.int32) -# data = torch.zeros((128,), device='cuda', dtype=torch.float32) -# ref = torch.full((128,), 64.0) -# serialized_add[(64,)](data, Lock) -# triton.testing.assert_almost_equal(data, ref) - -def test_simple_atomic_cas(): +def test_atomic_cas(): # 1. make sure that atomic_cas changes the original value (Lock) @triton.jit def change_value(Lock): @@ -717,6 +688,25 @@ def test_simple_atomic_cas(): assert (Lock[0] == 1) + # 2. only one block enters the critical section + @triton.jit + def serialized_add(data, Lock): + ptrs = data + tl.arange(0, 128) + while tl.atomic_cas(Lock, 0, 1) == 1: + pass + + tl.store(ptrs, tl.load(ptrs) + 1.0) + + # release lock + tl.atomic_xchg(Lock, 0) + + Lock = torch.zeros((1,), device='cuda', dtype=torch.int32) + data = torch.zeros((128,), device='cuda', dtype=torch.float32) + ref = torch.full((128,), 64.0) + serialized_add[(64,)](data, Lock) + triton.testing.assert_almost_equal(data, ref) + + # # --------------- # # test cast # # --------------- diff --git a/python/triton/compiler.py b/python/triton/compiler.py index 6f3b78e2a..1df83d4fd 100644 --- a/python/triton/compiler.py +++ b/python/triton/compiler.py @@ -528,8 +528,7 @@ class CodeGenerator(ast.NodeVisitor): [ty.to_ir(self.builder) for ty in ret_types]) loop_block.merge_block_before(after_block) self.builder.set_insertion_point_to_end(after_block) - if len(yields) > 0: - self.builder.create_yield_op([y.handle for y in yields]) + self.builder.create_yield_op([y.handle for y in yields]) # update global uses in while_op for i, name in enumerate(names):