[TRITON-MLIR][FRONTEND]minor fix to run through atomic_cas test (#925)

Co-authored-by: dongdongl <dongdongl@nvidia.com>
This commit is contained in:
donproc
2022-12-01 21:43:26 +08:00
committed by GitHub
parent 7d90a07d0b
commit 9def1bcebf
2 changed files with 21 additions and 32 deletions

View File

@@ -677,36 +677,7 @@ def test_tensor_atomic_rmw(shape, axis, device="cuda"):
kernel[(1,)](z_tri, x_tri, axis, shape0, shape1)
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=1e-4)
# def test_atomic_cas():
# # 1. make sure that atomic_cas changes the original value (Lock)
# @triton.jit
# def change_value(Lock):
# tl.atomic_cas(Lock, 0, 1)
# Lock = torch.zeros((1,), device='cuda', dtype=torch.int32)
# change_value[(1,)](Lock)
# assert (Lock[0] == 1)
# # 2. only one block enters the critical section
# @triton.jit
# def serialized_add(data, Lock):
# ptrs = data + tl.arange(0, 128)
# while tl.atomic_cas(Lock, 0, 1) == 1:
# pass
# tl.store(ptrs, tl.load(ptrs) + 1.0)
# # release lock
# tl.atomic_xchg(Lock, 0)
# Lock = torch.zeros((1,), device='cuda', dtype=torch.int32)
# data = torch.zeros((128,), device='cuda', dtype=torch.float32)
# ref = torch.full((128,), 64.0)
# serialized_add[(64,)](data, Lock)
# triton.testing.assert_almost_equal(data, ref)
def test_simple_atomic_cas():
def test_atomic_cas():
# 1. make sure that atomic_cas changes the original value (Lock)
@triton.jit
def change_value(Lock):
@@ -717,6 +688,25 @@ def test_simple_atomic_cas():
assert (Lock[0] == 1)
# 2. only one block enters the critical section
@triton.jit
def serialized_add(data, Lock):
ptrs = data + tl.arange(0, 128)
while tl.atomic_cas(Lock, 0, 1) == 1:
pass
tl.store(ptrs, tl.load(ptrs) + 1.0)
# release lock
tl.atomic_xchg(Lock, 0)
Lock = torch.zeros((1,), device='cuda', dtype=torch.int32)
data = torch.zeros((128,), device='cuda', dtype=torch.float32)
ref = torch.full((128,), 64.0)
serialized_add[(64,)](data, Lock)
triton.testing.assert_almost_equal(data, ref)
# # ---------------
# # test cast
# # ---------------

View File

@@ -528,8 +528,7 @@ class CodeGenerator(ast.NodeVisitor):
[ty.to_ir(self.builder) for ty in ret_types])
loop_block.merge_block_before(after_block)
self.builder.set_insertion_point_to_end(after_block)
if len(yields) > 0:
self.builder.create_yield_op([y.handle for y in yields])
self.builder.create_yield_op([y.handle for y in yields])
# update global uses in while_op
for i, name in enumerate(names):