[CODEGEN] Fixed atomic_add issue (#112)
* [CODEGEN] Fixed atomic_add issue * [CODEGEN] Fixed liveness analysis bug for instructions that are not DCE'd but have no users (e.g., atomic_cas)
This commit is contained in:
committed by
Philippe Tillet
parent
325ee38581
commit
f81012a8cf
@@ -189,6 +189,34 @@ def test_index1d(expr, device='cuda'):
|
||||
triton.testing.assert_allclose(z_ref, z_tri)
|
||||
|
||||
|
||||
# ---------------
|
||||
# test atomics
|
||||
# ---------------
|
||||
@pytest.mark.parametrize("dtype_x", ['int32', 'float16', 'float32'])
|
||||
def test_atomic_add(dtype_x, device='cuda'):
|
||||
dtype_x = cvt[dtype_x]
|
||||
n_programs = 37
|
||||
|
||||
# triton kernel
|
||||
@triton.jit
|
||||
def kernel(X, Z, **meta):
|
||||
pid = tl.program_id(0)
|
||||
old = tl.atomic_add(X, pid)
|
||||
tl.store(Z + pid, old)
|
||||
|
||||
# triton result
|
||||
x_tri = torch.zeros((1, ), dtype=dtype_x, device=device)
|
||||
z_tri = torch.empty((n_programs, ), dtype=torch.int32, device=device)
|
||||
kernel[(n_programs, )](x_tri, z_tri)
|
||||
last_sum = torch.max(z_tri) + torch.argmax(z_tri)
|
||||
last_sum = last_sum.to(dtype_x)
|
||||
# torch result
|
||||
range = torch.arange(n_programs, dtype=torch.int32, device=device)
|
||||
x_ref = torch.sum(range).to(dtype_x)
|
||||
triton.testing.assert_allclose(x_ref, x_tri[0])
|
||||
triton.testing.assert_allclose(x_ref, last_sum)
|
||||
|
||||
|
||||
# ---------------
|
||||
# test load
|
||||
# ---------------
|
||||
|
Reference in New Issue
Block a user