[TRITON-MLIR][FRONTEND]minor fix to run through atomic_cas test (#925)
Co-authored-by: dongdongl <dongdongl@nvidia.com>
This commit is contained in:
@@ -677,36 +677,7 @@ def test_tensor_atomic_rmw(shape, axis, device="cuda"):
|
||||
kernel[(1,)](z_tri, x_tri, axis, shape0, shape1)
|
||||
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=1e-4)
|
||||
|
||||
# def test_atomic_cas():
|
||||
# # 1. make sure that atomic_cas changes the original value (Lock)
|
||||
# @triton.jit
|
||||
# def change_value(Lock):
|
||||
# tl.atomic_cas(Lock, 0, 1)
|
||||
|
||||
# Lock = torch.zeros((1,), device='cuda', dtype=torch.int32)
|
||||
# change_value[(1,)](Lock)
|
||||
|
||||
# assert (Lock[0] == 1)
|
||||
|
||||
# # 2. only one block enters the critical section
|
||||
# @triton.jit
|
||||
# def serialized_add(data, Lock):
|
||||
# ptrs = data + tl.arange(0, 128)
|
||||
# while tl.atomic_cas(Lock, 0, 1) == 1:
|
||||
# pass
|
||||
|
||||
# tl.store(ptrs, tl.load(ptrs) + 1.0)
|
||||
|
||||
# # release lock
|
||||
# tl.atomic_xchg(Lock, 0)
|
||||
|
||||
# Lock = torch.zeros((1,), device='cuda', dtype=torch.int32)
|
||||
# data = torch.zeros((128,), device='cuda', dtype=torch.float32)
|
||||
# ref = torch.full((128,), 64.0)
|
||||
# serialized_add[(64,)](data, Lock)
|
||||
# triton.testing.assert_almost_equal(data, ref)
|
||||
|
||||
def test_simple_atomic_cas():
|
||||
def test_atomic_cas():
|
||||
# 1. make sure that atomic_cas changes the original value (Lock)
|
||||
@triton.jit
|
||||
def change_value(Lock):
|
||||
@@ -717,6 +688,25 @@ def test_simple_atomic_cas():
|
||||
|
||||
assert (Lock[0] == 1)
|
||||
|
||||
# 2. only one block enters the critical section
|
||||
@triton.jit
|
||||
def serialized_add(data, Lock):
|
||||
ptrs = data + tl.arange(0, 128)
|
||||
while tl.atomic_cas(Lock, 0, 1) == 1:
|
||||
pass
|
||||
|
||||
tl.store(ptrs, tl.load(ptrs) + 1.0)
|
||||
|
||||
# release lock
|
||||
tl.atomic_xchg(Lock, 0)
|
||||
|
||||
Lock = torch.zeros((1,), device='cuda', dtype=torch.int32)
|
||||
data = torch.zeros((128,), device='cuda', dtype=torch.float32)
|
||||
ref = torch.full((128,), 64.0)
|
||||
serialized_add[(64,)](data, Lock)
|
||||
triton.testing.assert_almost_equal(data, ref)
|
||||
|
||||
|
||||
# # ---------------
|
||||
# # test cast
|
||||
# # ---------------
|
||||
|
@@ -528,8 +528,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
[ty.to_ir(self.builder) for ty in ret_types])
|
||||
loop_block.merge_block_before(after_block)
|
||||
self.builder.set_insertion_point_to_end(after_block)
|
||||
if len(yields) > 0:
|
||||
self.builder.create_yield_op([y.handle for y in yields])
|
||||
self.builder.create_yield_op([y.handle for y in yields])
|
||||
|
||||
# update global uses in while_op
|
||||
for i, name in enumerate(names):
|
||||
|
Reference in New Issue
Block a user