[IR] Added IR and Codegen support for atomic_rmw (#120)
This commit is contained in:
committed by
Philippe Tillet
parent
59b0ac672a
commit
0274429429
@@ -4,6 +4,7 @@ import triton.language as tl
|
||||
import copy
|
||||
import pytest
|
||||
import ast
|
||||
import itertools
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
@@ -238,8 +239,13 @@ def test_tuples():
|
||||
# ---------------
|
||||
# test atomics
|
||||
# ---------------
|
||||
@pytest.mark.parametrize("dtype_x", ['int32', 'float16', 'float32'])
|
||||
def test_atomic_add(dtype_x, device='cuda'):
|
||||
@pytest.mark.parametrize("op, dtype_x, mode", itertools.chain.from_iterable([
|
||||
[('add', 'int32', mode), ('add', 'float16', mode), ('add', 'float32', mode), \
|
||||
('max', 'int32', mode), ('max', 'float32', mode),\
|
||||
('min', 'int32', mode), ('min', 'float32', mode),\
|
||||
]
|
||||
for mode in ['all_neg', 'all_pos', 'min_neg', 'max_pos']]))
|
||||
def test_atomic_rmw(op, dtype_x, mode, device='cuda'):
|
||||
dtype_x = cvt[dtype_x]
|
||||
n_programs = 37
|
||||
|
||||
@@ -247,20 +253,39 @@ def test_atomic_add(dtype_x, device='cuda'):
|
||||
@triton.jit
|
||||
def kernel(X, Z, **meta):
|
||||
pid = tl.program_id(0)
|
||||
old = tl.atomic_add(X, pid)
|
||||
tl.store(Z + pid, old)
|
||||
x = tl.load(X + pid)
|
||||
old = GENERATE_TEST_HERE
|
||||
|
||||
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.atomic_{op}(Z, x)'})
|
||||
torch_op = {'add': torch.sum, 'max': torch.max, 'min': torch.min}[op]
|
||||
max_neutral = float('-inf') if dtype_x.is_floating_point else torch.iinfo(dtype_x).min
|
||||
min_neutral = float('inf') if dtype_x.is_floating_point else torch.iinfo(dtype_x).max
|
||||
neutral = {'add': 0, 'max': max_neutral, 'min': min_neutral}[op]
|
||||
|
||||
# triton result
|
||||
x_tri = torch.zeros((1, ), dtype=dtype_x, device=device)
|
||||
z_tri = torch.empty((n_programs, ), dtype=torch.int32, device=device)
|
||||
x_tri = triton.testing.random((n_programs, ), dtype=dtype_x, device=device)
|
||||
if mode == 'all_neg':
|
||||
x_tri = -torch.abs(x_tri)
|
||||
if mode == 'all_pos':
|
||||
x_tri = torch.abs(x_tri)
|
||||
if mode == 'min_neg':
|
||||
idx = torch.randint(n_programs, size=(1, )).item()
|
||||
x_tri[idx] = -torch.max(torch.abs(x_tri)) - 1
|
||||
if mode == 'max_pos':
|
||||
idx = torch.randint(n_programs, size=(1, )).item()
|
||||
x_tri[idx] = torch.max(torch.abs(x_tri)) + 1
|
||||
|
||||
z_tri = torch.empty([], dtype=dtype_x, device=device)
|
||||
z_tri.fill_(neutral)
|
||||
kernel[(n_programs, )](x_tri, z_tri)
|
||||
last_sum = torch.max(z_tri) + torch.argmax(z_tri)
|
||||
last_sum = last_sum.to(dtype_x)
|
||||
# torch result
|
||||
range = torch.arange(n_programs, dtype=torch.int32, device=device)
|
||||
x_ref = torch.sum(range).to(dtype_x)
|
||||
triton.testing.assert_allclose(x_ref, x_tri[0])
|
||||
triton.testing.assert_allclose(x_ref, last_sum)
|
||||
z_ref = torch_op(x_tri).to(dtype_x)
|
||||
# compare
|
||||
exact = op not in ['add']
|
||||
if exact:
|
||||
assert z_ref.item() == z_tri.item()
|
||||
else:
|
||||
triton.testing.assert_allclose(z_ref, z_tri)
|
||||
|
||||
|
||||
# ---------------
|
||||
|
Reference in New Issue
Block a user