[BACKEND] Fix some bugs (atomics, a segfault...) (#577)

This should fix #558 , #573 and #574
This commit is contained in:
Philippe Tillet
2022-07-06 20:03:04 -07:00
committed by GitHub
parent 22105bc33b
commit 4a399a7e40
4 changed files with 73 additions and 50 deletions

View File

@@ -532,6 +532,29 @@ def test_atomic_rmw(op, dtype_x_str, mode, device='cuda'):
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
@pytest.mark.parametrize("axis", [0, 1])
def test_tensor_atomic_rmw(axis, device="cuda"):
shape0, shape1 = 8, 8
# triton kernel
@triton.jit
def kernel(Z, X, AXIS: tl.constexpr, SHAPE0: tl.constexpr, SHAPE1: tl.constexpr):
off0 = tl.arange(0, SHAPE0)
off1 = tl.arange(0, SHAPE1)
x = tl.load(X + off0[:, None] * SHAPE1 + off1[None, :])
z = tl.sum(x, axis=AXIS)
tl.atomic_add(Z + off0, z)
rs = RandomState(17)
x = numpy_random((shape0, shape1), dtype_str="float32", rs=rs)
# reference result
z_ref = np.sum(x, axis=axis)
# triton result
x_tri = to_triton(x, device=device)
z_tri = to_triton(np.zeros((shape0,), dtype="float32"), device=device)
kernel[(1,)](z_tri, x_tri, axis, shape0, shape1)
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=1e-4)
def test_atomic_cas():
# 1. make sure that atomic_cas changes the original value (Lock)
@triton.jit