[IR] Added IR and Codegen support for atomic_rmw (#120)
This commit is contained in:
committed by
Philippe Tillet
parent
59b0ac672a
commit
0274429429
@@ -137,6 +137,11 @@ void init_triton_frontend(py::module &&m) {
|
||||
m.def("atomic_cas", &ir::dispatch::atomic_cas, ret::reference);
|
||||
m.def("atomic_xchg", &ir::dispatch::atomic_xchg, ret::reference);
|
||||
m.def("atomic_add", &ir::dispatch::atomic_add, ret::reference);
|
||||
m.def("atomic_max", &ir::dispatch::atomic_max, ret::reference);
|
||||
m.def("atomic_min", &ir::dispatch::atomic_min, ret::reference);
|
||||
m.def("atomic_and", &ir::dispatch::atomic_and, ret::reference);
|
||||
m.def("atomic_or", &ir::dispatch::atomic_or, ret::reference);
|
||||
m.def("atomic_xor", &ir::dispatch::atomic_xor, ret::reference);
|
||||
// linear algebra
|
||||
m.def("dot", &ir::dispatch::dot, ret::reference);
|
||||
// indexing
|
||||
|
@@ -4,6 +4,7 @@ import triton.language as tl
|
||||
import copy
|
||||
import pytest
|
||||
import ast
|
||||
import itertools
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
@@ -238,8 +239,13 @@ def test_tuples():
|
||||
# ---------------
|
||||
# test atomics
|
||||
# ---------------
|
||||
@pytest.mark.parametrize("dtype_x", ['int32', 'float16', 'float32'])
|
||||
def test_atomic_add(dtype_x, device='cuda'):
|
||||
@pytest.mark.parametrize("op, dtype_x, mode", itertools.chain.from_iterable([
|
||||
[('add', 'int32', mode), ('add', 'float16', mode), ('add', 'float32', mode), \
|
||||
('max', 'int32', mode), ('max', 'float32', mode),\
|
||||
('min', 'int32', mode), ('min', 'float32', mode),\
|
||||
]
|
||||
for mode in ['all_neg', 'all_pos', 'min_neg', 'max_pos']]))
|
||||
def test_atomic_rmw(op, dtype_x, mode, device='cuda'):
|
||||
dtype_x = cvt[dtype_x]
|
||||
n_programs = 37
|
||||
|
||||
@@ -247,20 +253,39 @@ def test_atomic_add(dtype_x, device='cuda'):
|
||||
@triton.jit
|
||||
def kernel(X, Z, **meta):
|
||||
pid = tl.program_id(0)
|
||||
old = tl.atomic_add(X, pid)
|
||||
tl.store(Z + pid, old)
|
||||
x = tl.load(X + pid)
|
||||
old = GENERATE_TEST_HERE
|
||||
|
||||
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.atomic_{op}(Z, x)'})
|
||||
torch_op = {'add': torch.sum, 'max': torch.max, 'min': torch.min}[op]
|
||||
max_neutral = float('-inf') if dtype_x.is_floating_point else torch.iinfo(dtype_x).min
|
||||
min_neutral = float('inf') if dtype_x.is_floating_point else torch.iinfo(dtype_x).max
|
||||
neutral = {'add': 0, 'max': max_neutral, 'min': min_neutral}[op]
|
||||
|
||||
# triton result
|
||||
x_tri = torch.zeros((1, ), dtype=dtype_x, device=device)
|
||||
z_tri = torch.empty((n_programs, ), dtype=torch.int32, device=device)
|
||||
x_tri = triton.testing.random((n_programs, ), dtype=dtype_x, device=device)
|
||||
if mode == 'all_neg':
|
||||
x_tri = -torch.abs(x_tri)
|
||||
if mode == 'all_pos':
|
||||
x_tri = torch.abs(x_tri)
|
||||
if mode == 'min_neg':
|
||||
idx = torch.randint(n_programs, size=(1, )).item()
|
||||
x_tri[idx] = -torch.max(torch.abs(x_tri)) - 1
|
||||
if mode == 'max_pos':
|
||||
idx = torch.randint(n_programs, size=(1, )).item()
|
||||
x_tri[idx] = torch.max(torch.abs(x_tri)) + 1
|
||||
|
||||
z_tri = torch.empty([], dtype=dtype_x, device=device)
|
||||
z_tri.fill_(neutral)
|
||||
kernel[(n_programs, )](x_tri, z_tri)
|
||||
last_sum = torch.max(z_tri) + torch.argmax(z_tri)
|
||||
last_sum = last_sum.to(dtype_x)
|
||||
# torch result
|
||||
range = torch.arange(n_programs, dtype=torch.int32, device=device)
|
||||
x_ref = torch.sum(range).to(dtype_x)
|
||||
triton.testing.assert_allclose(x_ref, x_tri[0])
|
||||
triton.testing.assert_allclose(x_ref, last_sum)
|
||||
z_ref = torch_op(x_tri).to(dtype_x)
|
||||
# compare
|
||||
exact = op not in ['add']
|
||||
if exact:
|
||||
assert z_ref.item() == z_tri.item()
|
||||
else:
|
||||
triton.testing.assert_allclose(z_ref, z_tri)
|
||||
|
||||
|
||||
# ---------------
|
||||
|
@@ -89,6 +89,8 @@ float16 = dtype(ir.type.get_fp16)
|
||||
float32 = dtype(ir.type.get_fp32)
|
||||
float64 = dtype(ir.type.get_fp64)
|
||||
|
||||
pi32_t = pointer_dtype(int32)
|
||||
|
||||
|
||||
class block:
|
||||
@staticmethod
|
||||
@@ -464,6 +466,31 @@ def atomic_add(pointer, val, mask=None, builder=None):
|
||||
return frontend.atomic_add(pointer, val, mask, builder)
|
||||
|
||||
|
||||
@builtin
|
||||
def atomic_max(pointer, val, mask=None, builder=None):
|
||||
return frontend.atomic_max(pointer, val, mask, builder)
|
||||
|
||||
|
||||
@builtin
|
||||
def atomic_min(pointer, val, mask=None, builder=None):
|
||||
return frontend.atomic_min(pointer, val, mask, builder)
|
||||
|
||||
|
||||
@builtin
|
||||
def atomic_and(pointer, val, mask=None, builder=None):
|
||||
return frontend.atomic_and(pointer, val, mask, builder)
|
||||
|
||||
|
||||
@builtin
|
||||
def atomic_or(pointer, val, mask=None, builder=None):
|
||||
return frontend.atomic_or(pointer, val, mask, builder)
|
||||
|
||||
|
||||
@builtin
|
||||
def atomic_xor(pointer, val, mask=None, builder=None):
|
||||
return frontend.atomic_xor(pointer, val, mask, builder)
|
||||
|
||||
|
||||
# -----------------------
|
||||
# Conditioning
|
||||
# -----------------------
|
||||
|
@@ -76,7 +76,7 @@ def random(shape, dtype, device):
|
||||
if dtype in [torch.int8, torch.int16, torch.int32, torch.int64]:
|
||||
return torch.randint(1, 32, shape, dtype=dtype, device=device)
|
||||
if dtype in [torch.float16, torch.float32, torch.float64]:
|
||||
return torch.randn(shape, dtype=dtype, device=device)
|
||||
return torch.normal(0, 10, shape, dtype=dtype, device=device)
|
||||
raise RuntimeError(f'Unknown dtype {dtype}')
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user