[TRITON-MLIR][FRONTEND]minor fix to run through atomic_cas test (#925)
Co-authored-by: dongdongl <dongdongl@nvidia.com>
This commit is contained in:
@@ -677,36 +677,7 @@ def test_tensor_atomic_rmw(shape, axis, device="cuda"):
|
|||||||
kernel[(1,)](z_tri, x_tri, axis, shape0, shape1)
|
kernel[(1,)](z_tri, x_tri, axis, shape0, shape1)
|
||||||
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=1e-4)
|
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=1e-4)
|
||||||
|
|
||||||
# def test_atomic_cas():
|
def test_atomic_cas():
|
||||||
# # 1. make sure that atomic_cas changes the original value (Lock)
|
|
||||||
# @triton.jit
|
|
||||||
# def change_value(Lock):
|
|
||||||
# tl.atomic_cas(Lock, 0, 1)
|
|
||||||
|
|
||||||
# Lock = torch.zeros((1,), device='cuda', dtype=torch.int32)
|
|
||||||
# change_value[(1,)](Lock)
|
|
||||||
|
|
||||||
# assert (Lock[0] == 1)
|
|
||||||
|
|
||||||
# # 2. only one block enters the critical section
|
|
||||||
# @triton.jit
|
|
||||||
# def serialized_add(data, Lock):
|
|
||||||
# ptrs = data + tl.arange(0, 128)
|
|
||||||
# while tl.atomic_cas(Lock, 0, 1) == 1:
|
|
||||||
# pass
|
|
||||||
|
|
||||||
# tl.store(ptrs, tl.load(ptrs) + 1.0)
|
|
||||||
|
|
||||||
# # release lock
|
|
||||||
# tl.atomic_xchg(Lock, 0)
|
|
||||||
|
|
||||||
# Lock = torch.zeros((1,), device='cuda', dtype=torch.int32)
|
|
||||||
# data = torch.zeros((128,), device='cuda', dtype=torch.float32)
|
|
||||||
# ref = torch.full((128,), 64.0)
|
|
||||||
# serialized_add[(64,)](data, Lock)
|
|
||||||
# triton.testing.assert_almost_equal(data, ref)
|
|
||||||
|
|
||||||
def test_simple_atomic_cas():
|
|
||||||
# 1. make sure that atomic_cas changes the original value (Lock)
|
# 1. make sure that atomic_cas changes the original value (Lock)
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def change_value(Lock):
|
def change_value(Lock):
|
||||||
@@ -717,6 +688,25 @@ def test_simple_atomic_cas():
|
|||||||
|
|
||||||
assert (Lock[0] == 1)
|
assert (Lock[0] == 1)
|
||||||
|
|
||||||
|
# 2. only one block enters the critical section
|
||||||
|
@triton.jit
|
||||||
|
def serialized_add(data, Lock):
|
||||||
|
ptrs = data + tl.arange(0, 128)
|
||||||
|
while tl.atomic_cas(Lock, 0, 1) == 1:
|
||||||
|
pass
|
||||||
|
|
||||||
|
tl.store(ptrs, tl.load(ptrs) + 1.0)
|
||||||
|
|
||||||
|
# release lock
|
||||||
|
tl.atomic_xchg(Lock, 0)
|
||||||
|
|
||||||
|
Lock = torch.zeros((1,), device='cuda', dtype=torch.int32)
|
||||||
|
data = torch.zeros((128,), device='cuda', dtype=torch.float32)
|
||||||
|
ref = torch.full((128,), 64.0)
|
||||||
|
serialized_add[(64,)](data, Lock)
|
||||||
|
triton.testing.assert_almost_equal(data, ref)
|
||||||
|
|
||||||
|
|
||||||
# # ---------------
|
# # ---------------
|
||||||
# # test cast
|
# # test cast
|
||||||
# # ---------------
|
# # ---------------
|
||||||
|
@@ -528,7 +528,6 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
[ty.to_ir(self.builder) for ty in ret_types])
|
[ty.to_ir(self.builder) for ty in ret_types])
|
||||||
loop_block.merge_block_before(after_block)
|
loop_block.merge_block_before(after_block)
|
||||||
self.builder.set_insertion_point_to_end(after_block)
|
self.builder.set_insertion_point_to_end(after_block)
|
||||||
if len(yields) > 0:
|
|
||||||
self.builder.create_yield_op([y.handle for y in yields])
|
self.builder.create_yield_op([y.handle for y in yields])
|
||||||
|
|
||||||
# update global uses in while_op
|
# update global uses in while_op
|
||||||
|
Reference in New Issue
Block a user