[Triton-MLIR][Backend] Fix number of warps and threads per warp when matrices are small (#917)

This commit is contained in:
Keren Zhou
2022-11-26 12:30:38 -08:00
committed by GitHub
parent f63be0e9b5
commit 35c9ec1103
7 changed files with 116 additions and 29 deletions

View File

@@ -648,10 +648,11 @@ def test_atomic_rmw(op, dtype_x_str, mode, device='cuda'):
else:
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
#TODO[dongdongl]:add more cases with size of tensor less than warp size
@pytest.mark.parametrize("axis", [0, 1])
def test_tensor_atomic_rmw(axis, device="cuda"):
shape0, shape1 = 8, 8
@pytest.mark.parametrize("shape, axis",
[(shape, axis) for shape in [(2, 2), (2, 8), (8, 2), (8, 8), (32, 32)] for axis in [0, 1]])
def test_tensor_atomic_rmw(shape, axis, device="cuda"):
shape0, shape1 = shape
# triton kernel
@triton.jit
@@ -660,14 +661,19 @@ def test_tensor_atomic_rmw(axis, device="cuda"):
off1 = tl.arange(0, SHAPE1)
x = tl.load(X + off0[:, None] * SHAPE1 + off1[None, :])
z = tl.sum(x, axis=AXIS)
tl.atomic_add(Z + off0, z)
if AXIS == 1:
tl.atomic_add(Z + off0, z)
else:
tl.atomic_add(Z + off1, z)
rs = RandomState(17)
x = numpy_random((shape0, shape1), dtype_str="float32", rs=rs)
print(x)
# reference result
z_ref = np.sum(x, axis=axis)
z_ref = np.sum(x, axis=axis, keepdims=False)
# triton result
x_tri = to_triton(x, device=device)
z_tri = to_triton(np.zeros((shape0,), dtype="float32"), device=device)
z_shape = (shape0, ) if axis == 1 else (shape1, )
z_tri = to_triton(np.zeros(z_shape, dtype="float32"), device=device)
kernel[(1,)](z_tri, x_tri, axis, shape0, shape1)
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=1e-4)