[IR] Added IR and Codegen support for atomic_rmw (#120)

This commit is contained in:
Philippe Tillet
2021-05-25 18:31:48 -04:00
committed by Philippe Tillet
parent 59b0ac672a
commit 0274429429
17 changed files with 261 additions and 124 deletions

View File

@@ -137,6 +137,11 @@ void init_triton_frontend(py::module &&m) {
m.def("atomic_cas", &ir::dispatch::atomic_cas, ret::reference);
m.def("atomic_xchg", &ir::dispatch::atomic_xchg, ret::reference);
m.def("atomic_add", &ir::dispatch::atomic_add, ret::reference);
m.def("atomic_max", &ir::dispatch::atomic_max, ret::reference);
m.def("atomic_min", &ir::dispatch::atomic_min, ret::reference);
m.def("atomic_and", &ir::dispatch::atomic_and, ret::reference);
m.def("atomic_or", &ir::dispatch::atomic_or, ret::reference);
m.def("atomic_xor", &ir::dispatch::atomic_xor, ret::reference);
// linear algebra
m.def("dot", &ir::dispatch::dot, ret::reference);
// indexing

View File

@@ -4,6 +4,7 @@ import triton.language as tl
import copy
import pytest
import ast
import itertools
torch.manual_seed(0)
@@ -238,8 +239,13 @@ def test_tuples():
# ---------------
# test atomics
# ---------------
@pytest.mark.parametrize("dtype_x", ['int32', 'float16', 'float32'])
def test_atomic_add(dtype_x, device='cuda'):
@pytest.mark.parametrize("op, dtype_x, mode", itertools.chain.from_iterable([
[('add', 'int32', mode), ('add', 'float16', mode), ('add', 'float32', mode), \
('max', 'int32', mode), ('max', 'float32', mode),\
('min', 'int32', mode), ('min', 'float32', mode),\
]
for mode in ['all_neg', 'all_pos', 'min_neg', 'max_pos']]))
def test_atomic_rmw(op, dtype_x, mode, device='cuda'):
dtype_x = cvt[dtype_x]
n_programs = 37
@@ -247,20 +253,39 @@ def test_atomic_add(dtype_x, device='cuda'):
@triton.jit
def kernel(X, Z, **meta):
pid = tl.program_id(0)
old = tl.atomic_add(X, pid)
tl.store(Z + pid, old)
x = tl.load(X + pid)
old = GENERATE_TEST_HERE
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.atomic_{op}(Z, x)'})
torch_op = {'add': torch.sum, 'max': torch.max, 'min': torch.min}[op]
max_neutral = float('-inf') if dtype_x.is_floating_point else torch.iinfo(dtype_x).min
min_neutral = float('inf') if dtype_x.is_floating_point else torch.iinfo(dtype_x).max
neutral = {'add': 0, 'max': max_neutral, 'min': min_neutral}[op]
# triton result
x_tri = torch.zeros((1, ), dtype=dtype_x, device=device)
z_tri = torch.empty((n_programs, ), dtype=torch.int32, device=device)
x_tri = triton.testing.random((n_programs, ), dtype=dtype_x, device=device)
if mode == 'all_neg':
x_tri = -torch.abs(x_tri)
if mode == 'all_pos':
x_tri = torch.abs(x_tri)
if mode == 'min_neg':
idx = torch.randint(n_programs, size=(1, )).item()
x_tri[idx] = -torch.max(torch.abs(x_tri)) - 1
if mode == 'max_pos':
idx = torch.randint(n_programs, size=(1, )).item()
x_tri[idx] = torch.max(torch.abs(x_tri)) + 1
z_tri = torch.empty([], dtype=dtype_x, device=device)
z_tri.fill_(neutral)
kernel[(n_programs, )](x_tri, z_tri)
last_sum = torch.max(z_tri) + torch.argmax(z_tri)
last_sum = last_sum.to(dtype_x)
# torch result
range = torch.arange(n_programs, dtype=torch.int32, device=device)
x_ref = torch.sum(range).to(dtype_x)
triton.testing.assert_allclose(x_ref, x_tri[0])
triton.testing.assert_allclose(x_ref, last_sum)
z_ref = torch_op(x_tri).to(dtype_x)
# compare
exact = op not in ['add']
if exact:
assert z_ref.item() == z_tri.item()
else:
triton.testing.assert_allclose(z_ref, z_tri)
# ---------------

View File

@@ -89,6 +89,8 @@ float16 = dtype(ir.type.get_fp16)
float32 = dtype(ir.type.get_fp32)
float64 = dtype(ir.type.get_fp64)
pi32_t = pointer_dtype(int32)
class block:
@staticmethod
@@ -464,6 +466,31 @@ def atomic_add(pointer, val, mask=None, builder=None):
return frontend.atomic_add(pointer, val, mask, builder)
@builtin
def atomic_max(pointer, val, mask=None, builder=None):
return frontend.atomic_max(pointer, val, mask, builder)
@builtin
def atomic_min(pointer, val, mask=None, builder=None):
return frontend.atomic_min(pointer, val, mask, builder)
@builtin
def atomic_and(pointer, val, mask=None, builder=None):
return frontend.atomic_and(pointer, val, mask, builder)
@builtin
def atomic_or(pointer, val, mask=None, builder=None):
return frontend.atomic_or(pointer, val, mask, builder)
@builtin
def atomic_xor(pointer, val, mask=None, builder=None):
return frontend.atomic_xor(pointer, val, mask, builder)
# -----------------------
# Conditioning
# -----------------------

View File

@@ -76,7 +76,7 @@ def random(shape, dtype, device):
if dtype in [torch.int8, torch.int16, torch.int32, torch.int64]:
return torch.randint(1, 32, shape, dtype=dtype, device=device)
if dtype in [torch.float16, torch.float32, torch.float64]:
return torch.randn(shape, dtype=dtype, device=device)
return torch.normal(0, 10, shape, dtype=dtype, device=device)
raise RuntimeError(f'Unknown dtype {dtype}')