[TRITON-MLIR][BACKEND]AtomicRMWOp supports scalar (#903)

AtomicRMWOp supports scalar

Co-authored-by: dongdongl <dongdongl@nvidia.com>
This commit is contained in:
donproc
2022-11-23 15:59:09 +08:00
committed by GitHub
parent 2e33352419
commit 8925c2cd11
5 changed files with 163 additions and 125 deletions

View File

@@ -595,100 +595,80 @@ def test_tuples():
assert c_tri == c_ref
# # ---------------
# # test atomics
# # ---------------
# @pytest.mark.parametrize("op, dtype_x_str, mode", itertools.chain.from_iterable([
# [
# ('add', 'float16', mode),
# ('add', 'uint32', mode), ('add', 'int32', mode), ('add', 'float32', mode),
# ('max', 'uint32', mode), ('max', 'int32', mode), ('max', 'float32', mode),
# ('min', 'uint32', mode), ('min', 'int32', mode), ('min', 'float32', mode),
# ]
# for mode in ['all_neg', 'all_pos', 'min_neg', 'max_pos']]))
# def test_atomic_rmw(op, dtype_x_str, mode, device='cuda'):
# n_programs = 5
# ---------------
# test atomics
# ---------------
@pytest.mark.parametrize("op, dtype_x_str, mode", itertools.chain.from_iterable([
[
('add', 'float16', mode),
('add', 'uint32', mode), ('add', 'int32', mode), ('add', 'float32', mode),
('max', 'uint32', mode), ('max', 'int32', mode), ('max', 'float32', mode),
('min', 'uint32', mode), ('min', 'int32', mode), ('min', 'float32', mode),
]
for mode in ['all_neg', 'all_pos', 'min_neg', 'max_pos']]))
def test_atomic_rmw(op, dtype_x_str, mode, device='cuda'):
n_programs = 5
# # triton kernel
# @triton.jit
# def kernel(X, Z):
# pid = tl.program_id(0)
# x = tl.load(X + pid)
# old = GENERATE_TEST_HERE
# kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.atomic_{op}(Z, x)'})
# numpy_op = {'add': np.sum, 'max': np.max, 'min': np.min}[op]
# max_neutral = float('-inf') if dtype_x_str in float_dtypes else np.iinfo(getattr(np, dtype_x_str)).min
# min_neutral = float('inf') if dtype_x_str in float_dtypes else np.iinfo(getattr(np, dtype_x_str)).max
# neutral = {'add': 0, 'max': max_neutral, 'min': min_neutral}[op]
# # triton result
# rs = RandomState(17)
# x = numpy_random((n_programs, ), dtype_str=dtype_x_str, rs=rs)
# if mode == 'all_neg':
# x = -np.abs(x)
# if mode == 'all_pos':
# x = np.abs(x)
# if mode == 'min_neg':
# idx = rs.randint(n_programs, size=(1, )).item()
# x[idx] = -np.max(np.abs(x)) - 1
# if mode == 'max_pos':
# idx = rs.randint(n_programs, size=(1, )).item()
# x[idx] = np.max(np.abs(x)) + 1
# x_tri = to_triton(x, device=device)
# z_tri = to_triton(np.array([neutral], dtype=getattr(np, dtype_x_str)), device=device)
# kernel[(n_programs, )](x_tri, z_tri)
# # torch result
# z_ref = numpy_op(x).astype(getattr(np, dtype_x_str))
# # compare
# exact = op not in ['add']
# if exact:
# assert z_ref.item() == to_numpy(z_tri).item()
# else:
# np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
# @pytest.mark.parametrize("axis", [0, 1])
# def test_tensor_atomic_rmw(axis, device="cuda"):
# shape0, shape1 = 8, 8
# # triton kernel
# @triton.jit
# def kernel(Z, X, AXIS: tl.constexpr, SHAPE0: tl.constexpr, SHAPE1: tl.constexpr):
# off0 = tl.arange(0, SHAPE0)
# off1 = tl.arange(0, SHAPE1)
# x = tl.load(X + off0[:, None] * SHAPE1 + off1[None, :])
# z = tl.sum(x, axis=AXIS)
# tl.atomic_add(Z + off0, z)
# rs = RandomState(17)
# x = numpy_random((shape0, shape1), dtype_str="float32", rs=rs)
# # reference result
# z_ref = np.sum(x, axis=axis)
# # triton result
# x_tri = to_triton(x, device=device)
# z_tri = to_triton(np.zeros((shape0,), dtype="float32"), device=device)
# kernel[(1,)](z_tri, x_tri, axis, shape0, shape1)
# np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=1e-4)
def test_tensor_atomic_rmw_add_elementwise(device="cuda"):
shape0, shape1 = 2, 8
# triton kernel
@triton.jit
def kernel(Z, X, SHAPE0: tl.constexpr, SHAPE1: tl.constexpr):
def kernel(X, Z):
pid = tl.program_id(0)
x = tl.load(X + pid)
old = GENERATE_TEST_HERE
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.atomic_{op}(Z, x)'})
numpy_op = {'add': np.sum, 'max': np.max, 'min': np.min}[op]
max_neutral = float('-inf') if dtype_x_str in float_dtypes else np.iinfo(getattr(np, dtype_x_str)).min
min_neutral = float('inf') if dtype_x_str in float_dtypes else np.iinfo(getattr(np, dtype_x_str)).max
neutral = {'add': 0, 'max': max_neutral, 'min': min_neutral}[op]
# triton result
rs = RandomState(17)
x = numpy_random((n_programs, ), dtype_str=dtype_x_str, rs=rs)
if mode == 'all_neg':
x = -np.abs(x)
if mode == 'all_pos':
x = np.abs(x)
if mode == 'min_neg':
idx = rs.randint(n_programs, size=(1, )).item()
x[idx] = -np.max(np.abs(x)) - 1
if mode == 'max_pos':
idx = rs.randint(n_programs, size=(1, )).item()
x[idx] = np.max(np.abs(x)) + 1
x_tri = to_triton(x, device=device)
z_tri = to_triton(np.array([neutral], dtype=getattr(np, dtype_x_str)), device=device)
kernel[(n_programs, )](x_tri, z_tri)
# torch result
z_ref = numpy_op(x).astype(getattr(np, dtype_x_str))
# compare
exact = op not in ['add']
if exact:
assert z_ref.item() == to_numpy(z_tri).item()
else:
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
#TODO[dongdongl]:add more cases with size of tensor less than warp size
@pytest.mark.parametrize("axis", [0, 1])
def test_tensor_atomic_rmw(axis, device="cuda"):
shape0, shape1 = 8, 8
# triton kernel
@triton.jit
def kernel(Z, X, AXIS: tl.constexpr, SHAPE0: tl.constexpr, SHAPE1: tl.constexpr):
off0 = tl.arange(0, SHAPE0)
off1 = tl.arange(0, SHAPE1)
x = tl.load(X + off0[:, None] * SHAPE1 + off1[None, :])
tl.atomic_add(Z + off0[:, None] * SHAPE1 + off1[None, :], x)
z = tl.sum(x, axis=AXIS)
tl.atomic_add(Z + off0, z)
rs = RandomState(17)
x = numpy_random((shape0, shape1), dtype_str="float32", rs=rs)
z = numpy_random((shape0, shape1), dtype_str="float32", rs=rs)
# reference
z_ref = z + x
# reference result
z_ref = np.sum(x, axis=axis)
# triton result
x_tri = torch.from_numpy(x).to(device=device)
z_tri = torch.from_numpy(z).to(device=device)
kernel[(1,)](z_tri, x_tri, shape0, shape1)
x_tri = to_triton(x, device=device)
z_tri = to_triton(np.zeros((shape0,), dtype="float32"), device=device)
kernel[(1,)](z_tri, x_tri, axis, shape0, shape1)
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=1e-4)
# def test_atomic_cas():