[TRITON-MLIR][BACKEND]AtomicRMWOp supports scalar (#903)
AtomicRMWOp supports scalar Co-authored-by: dongdongl <dongdongl@nvidia.com>
This commit is contained in:
@@ -595,100 +595,80 @@ def test_tuples():
|
||||
assert c_tri == c_ref
|
||||
|
||||
|
||||
# # ---------------
|
||||
# # test atomics
|
||||
# # ---------------
|
||||
# @pytest.mark.parametrize("op, dtype_x_str, mode", itertools.chain.from_iterable([
|
||||
# [
|
||||
# ('add', 'float16', mode),
|
||||
# ('add', 'uint32', mode), ('add', 'int32', mode), ('add', 'float32', mode),
|
||||
# ('max', 'uint32', mode), ('max', 'int32', mode), ('max', 'float32', mode),
|
||||
# ('min', 'uint32', mode), ('min', 'int32', mode), ('min', 'float32', mode),
|
||||
# ]
|
||||
# for mode in ['all_neg', 'all_pos', 'min_neg', 'max_pos']]))
|
||||
# def test_atomic_rmw(op, dtype_x_str, mode, device='cuda'):
|
||||
# n_programs = 5
|
||||
# ---------------
|
||||
# test atomics
|
||||
# ---------------
|
||||
@pytest.mark.parametrize("op, dtype_x_str, mode", itertools.chain.from_iterable([
|
||||
[
|
||||
('add', 'float16', mode),
|
||||
('add', 'uint32', mode), ('add', 'int32', mode), ('add', 'float32', mode),
|
||||
('max', 'uint32', mode), ('max', 'int32', mode), ('max', 'float32', mode),
|
||||
('min', 'uint32', mode), ('min', 'int32', mode), ('min', 'float32', mode),
|
||||
]
|
||||
for mode in ['all_neg', 'all_pos', 'min_neg', 'max_pos']]))
|
||||
def test_atomic_rmw(op, dtype_x_str, mode, device='cuda'):
|
||||
n_programs = 5
|
||||
|
||||
# # triton kernel
|
||||
# @triton.jit
|
||||
# def kernel(X, Z):
|
||||
# pid = tl.program_id(0)
|
||||
# x = tl.load(X + pid)
|
||||
# old = GENERATE_TEST_HERE
|
||||
|
||||
# kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.atomic_{op}(Z, x)'})
|
||||
# numpy_op = {'add': np.sum, 'max': np.max, 'min': np.min}[op]
|
||||
# max_neutral = float('-inf') if dtype_x_str in float_dtypes else np.iinfo(getattr(np, dtype_x_str)).min
|
||||
# min_neutral = float('inf') if dtype_x_str in float_dtypes else np.iinfo(getattr(np, dtype_x_str)).max
|
||||
# neutral = {'add': 0, 'max': max_neutral, 'min': min_neutral}[op]
|
||||
|
||||
# # triton result
|
||||
# rs = RandomState(17)
|
||||
# x = numpy_random((n_programs, ), dtype_str=dtype_x_str, rs=rs)
|
||||
# if mode == 'all_neg':
|
||||
# x = -np.abs(x)
|
||||
# if mode == 'all_pos':
|
||||
# x = np.abs(x)
|
||||
# if mode == 'min_neg':
|
||||
# idx = rs.randint(n_programs, size=(1, )).item()
|
||||
# x[idx] = -np.max(np.abs(x)) - 1
|
||||
# if mode == 'max_pos':
|
||||
# idx = rs.randint(n_programs, size=(1, )).item()
|
||||
# x[idx] = np.max(np.abs(x)) + 1
|
||||
# x_tri = to_triton(x, device=device)
|
||||
|
||||
# z_tri = to_triton(np.array([neutral], dtype=getattr(np, dtype_x_str)), device=device)
|
||||
# kernel[(n_programs, )](x_tri, z_tri)
|
||||
# # torch result
|
||||
# z_ref = numpy_op(x).astype(getattr(np, dtype_x_str))
|
||||
# # compare
|
||||
# exact = op not in ['add']
|
||||
# if exact:
|
||||
# assert z_ref.item() == to_numpy(z_tri).item()
|
||||
# else:
|
||||
# np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
|
||||
|
||||
|
||||
# @pytest.mark.parametrize("axis", [0, 1])
|
||||
# def test_tensor_atomic_rmw(axis, device="cuda"):
|
||||
# shape0, shape1 = 8, 8
|
||||
# # triton kernel
|
||||
|
||||
# @triton.jit
|
||||
# def kernel(Z, X, AXIS: tl.constexpr, SHAPE0: tl.constexpr, SHAPE1: tl.constexpr):
|
||||
# off0 = tl.arange(0, SHAPE0)
|
||||
# off1 = tl.arange(0, SHAPE1)
|
||||
# x = tl.load(X + off0[:, None] * SHAPE1 + off1[None, :])
|
||||
# z = tl.sum(x, axis=AXIS)
|
||||
# tl.atomic_add(Z + off0, z)
|
||||
# rs = RandomState(17)
|
||||
# x = numpy_random((shape0, shape1), dtype_str="float32", rs=rs)
|
||||
# # reference result
|
||||
# z_ref = np.sum(x, axis=axis)
|
||||
# # triton result
|
||||
# x_tri = to_triton(x, device=device)
|
||||
# z_tri = to_triton(np.zeros((shape0,), dtype="float32"), device=device)
|
||||
# kernel[(1,)](z_tri, x_tri, axis, shape0, shape1)
|
||||
# np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=1e-4)
|
||||
|
||||
def test_tensor_atomic_rmw_add_elementwise(device="cuda"):
|
||||
shape0, shape1 = 2, 8
|
||||
# triton kernel
|
||||
@triton.jit
|
||||
def kernel(Z, X, SHAPE0: tl.constexpr, SHAPE1: tl.constexpr):
|
||||
def kernel(X, Z):
|
||||
pid = tl.program_id(0)
|
||||
x = tl.load(X + pid)
|
||||
old = GENERATE_TEST_HERE
|
||||
|
||||
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.atomic_{op}(Z, x)'})
|
||||
numpy_op = {'add': np.sum, 'max': np.max, 'min': np.min}[op]
|
||||
max_neutral = float('-inf') if dtype_x_str in float_dtypes else np.iinfo(getattr(np, dtype_x_str)).min
|
||||
min_neutral = float('inf') if dtype_x_str in float_dtypes else np.iinfo(getattr(np, dtype_x_str)).max
|
||||
neutral = {'add': 0, 'max': max_neutral, 'min': min_neutral}[op]
|
||||
|
||||
# triton result
|
||||
rs = RandomState(17)
|
||||
x = numpy_random((n_programs, ), dtype_str=dtype_x_str, rs=rs)
|
||||
if mode == 'all_neg':
|
||||
x = -np.abs(x)
|
||||
if mode == 'all_pos':
|
||||
x = np.abs(x)
|
||||
if mode == 'min_neg':
|
||||
idx = rs.randint(n_programs, size=(1, )).item()
|
||||
x[idx] = -np.max(np.abs(x)) - 1
|
||||
if mode == 'max_pos':
|
||||
idx = rs.randint(n_programs, size=(1, )).item()
|
||||
x[idx] = np.max(np.abs(x)) + 1
|
||||
x_tri = to_triton(x, device=device)
|
||||
|
||||
z_tri = to_triton(np.array([neutral], dtype=getattr(np, dtype_x_str)), device=device)
|
||||
kernel[(n_programs, )](x_tri, z_tri)
|
||||
# torch result
|
||||
z_ref = numpy_op(x).astype(getattr(np, dtype_x_str))
|
||||
# compare
|
||||
exact = op not in ['add']
|
||||
if exact:
|
||||
assert z_ref.item() == to_numpy(z_tri).item()
|
||||
else:
|
||||
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
|
||||
|
||||
#TODO[dongdongl]:add more cases with size of tensor less than warp size
|
||||
@pytest.mark.parametrize("axis", [0, 1])
|
||||
def test_tensor_atomic_rmw(axis, device="cuda"):
|
||||
shape0, shape1 = 8, 8
|
||||
# triton kernel
|
||||
|
||||
@triton.jit
|
||||
def kernel(Z, X, AXIS: tl.constexpr, SHAPE0: tl.constexpr, SHAPE1: tl.constexpr):
|
||||
off0 = tl.arange(0, SHAPE0)
|
||||
off1 = tl.arange(0, SHAPE1)
|
||||
x = tl.load(X + off0[:, None] * SHAPE1 + off1[None, :])
|
||||
tl.atomic_add(Z + off0[:, None] * SHAPE1 + off1[None, :], x)
|
||||
|
||||
z = tl.sum(x, axis=AXIS)
|
||||
tl.atomic_add(Z + off0, z)
|
||||
rs = RandomState(17)
|
||||
x = numpy_random((shape0, shape1), dtype_str="float32", rs=rs)
|
||||
z = numpy_random((shape0, shape1), dtype_str="float32", rs=rs)
|
||||
# reference
|
||||
z_ref = z + x
|
||||
# reference result
|
||||
z_ref = np.sum(x, axis=axis)
|
||||
# triton result
|
||||
x_tri = torch.from_numpy(x).to(device=device)
|
||||
z_tri = torch.from_numpy(z).to(device=device)
|
||||
kernel[(1,)](z_tri, x_tri, shape0, shape1)
|
||||
x_tri = to_triton(x, device=device)
|
||||
z_tri = to_triton(np.zeros((shape0,), dtype="float32"), device=device)
|
||||
kernel[(1,)](z_tri, x_tri, axis, shape0, shape1)
|
||||
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=1e-4)
|
||||
|
||||
# def test_atomic_cas():
|
||||
|
Reference in New Issue
Block a user