2021-04-20 22:29:40 -04:00
|
|
|
import torch
|
|
|
|
import triton
|
2021-04-23 17:18:14 -04:00
|
|
|
import triton.language as tl
|
2021-04-20 22:29:40 -04:00
|
|
|
import copy
|
|
|
|
import pytest
|
|
|
|
import ast
|
|
|
|
|
|
|
|
torch.manual_seed(0)
|
|
|
|
|
|
|
|
# convert from string to torch.dtype
|
|
|
|
# Necessary because doesn't print torch.dtype properly
|
|
|
|
cvt = {
|
|
|
|
'bool': torch.bool,
|
|
|
|
'int8': torch.int8,
|
|
|
|
'int16': torch.int16,
|
|
|
|
'int32': torch.int32,
|
|
|
|
'int64': torch.int64,
|
|
|
|
'float16': torch.float16,
|
|
|
|
'float32': torch.float32,
|
|
|
|
'float64': torch.float64,
|
|
|
|
}
|
|
|
|
|
|
|
|
int_dtypes = ['int8', 'int16', 'int32', 'int64']
|
|
|
|
float_dtypes = ['float16', 'float32', 'float64']
|
|
|
|
dtypes = int_dtypes + float_dtypes
|
|
|
|
|
|
|
|
|
|
|
|
def patch_kernel(template, to_replace):
|
|
|
|
kernel = copy.deepcopy(template)
|
|
|
|
for key, value in to_replace.items():
|
|
|
|
kernel.src = kernel.src.replace(key, value)
|
|
|
|
return kernel
|
|
|
|
|
|
|
|
|
|
|
|
# generic test functions
|
|
|
|
def _test_unary(dtype_x, expr, device='cuda'):
|
|
|
|
SIZE = 128
|
|
|
|
# define the kernel / launch-grid
|
|
|
|
@triton.jit
|
|
|
|
def kernel(Z, X, **meta):
|
2021-04-23 17:18:14 -04:00
|
|
|
off = tl.arange(0, meta['SIZE'])
|
|
|
|
x = tl.load(X + off)
|
2021-04-20 22:29:40 -04:00
|
|
|
z = GENERATE_TEST_HERE
|
2021-04-23 17:18:14 -04:00
|
|
|
tl.store(Z + off, z)
|
2021-04-20 22:29:40 -04:00
|
|
|
|
|
|
|
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': expr})
|
|
|
|
# inputs
|
|
|
|
x = triton.testing.random(SIZE, dtype=cvt[dtype_x], device=device)
|
|
|
|
# reference result
|
|
|
|
z_ref = eval(expr)
|
|
|
|
# triton result
|
|
|
|
z_tri = torch.empty_like(z_ref)
|
|
|
|
kernel[(1, )](z_tri, x, SIZE=SIZE, num_warps=4)
|
|
|
|
# compare
|
|
|
|
triton.testing.assert_allclose(z_ref, z_tri)
|
|
|
|
|
|
|
|
|
|
|
|
def _test_binary(dtype_x, dtype_y, expr, device='cuda'):
|
|
|
|
SIZE = 128
|
|
|
|
# define the kernel / launch-grid
|
|
|
|
@triton.jit
|
|
|
|
def kernel(Z, X, Y, **meta):
|
2021-04-23 17:18:14 -04:00
|
|
|
off = tl.arange(0, meta['SIZE'])
|
|
|
|
x = tl.load(X + off)
|
|
|
|
y = tl.load(Y + off)
|
2021-04-20 22:29:40 -04:00
|
|
|
z = GENERATE_TEST_HERE
|
2021-04-23 17:18:14 -04:00
|
|
|
tl.store(Z + off, z)
|
2021-04-20 22:29:40 -04:00
|
|
|
|
|
|
|
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': expr})
|
|
|
|
# inputs
|
|
|
|
x = triton.testing.random(SIZE, dtype=cvt[dtype_x], device=device)
|
|
|
|
y = triton.testing.random(SIZE, dtype=cvt[dtype_y], device=device)
|
|
|
|
# reference result
|
|
|
|
z_ref = eval(expr)
|
|
|
|
# triton result
|
|
|
|
z_tri = torch.empty(SIZE, dtype=z_ref.dtype, device=device)
|
|
|
|
kernel[(1, )](z_tri, x, y, SIZE=SIZE, num_warps=4)
|
|
|
|
# compare
|
|
|
|
triton.testing.assert_allclose(z_ref, z_tri)
|
|
|
|
|
|
|
|
|
|
|
|
# ---------------
|
|
|
|
# test binary ops
|
|
|
|
# ---------------
|
|
|
|
@pytest.mark.parametrize("dtype_x, dtype_y, expr", [
|
|
|
|
(dtype_x, dtype_y, f' x {op} y') \
|
|
|
|
for op in ['+', '-', '*', '/', '%'] \
|
|
|
|
for dtype_x in dtypes \
|
|
|
|
for dtype_y in dtypes
|
|
|
|
])
|
|
|
|
def test_bin_op(dtype_x, dtype_y, expr, device='cuda'):
|
|
|
|
_test_binary(dtype_x, dtype_y, expr, device=device)
|
|
|
|
|
|
|
|
|
|
|
|
# ---------------
|
|
|
|
# test bitwise ops
|
|
|
|
# ---------------
|
|
|
|
@pytest.mark.parametrize("dtype_x, dtype_y, expr", [
|
|
|
|
(dtype_x, dtype_y, f' x {op} y') \
|
|
|
|
for op in ['&', '|', '^'] \
|
|
|
|
for dtype_x in dtypes \
|
|
|
|
for dtype_y in dtypes
|
|
|
|
])
|
|
|
|
def test_bitwise_op(dtype_x, dtype_y, expr, device='cuda'):
|
|
|
|
if 'float' in dtype_x + dtype_y:
|
|
|
|
with pytest.raises(RuntimeError):
|
|
|
|
_test_binary(dtype_x, dtype_y, expr, device=device)
|
|
|
|
else:
|
|
|
|
_test_binary(dtype_x, dtype_y, expr, device=device)
|
|
|
|
|
|
|
|
|
|
|
|
# ---------------
|
|
|
|
# test compare ops
|
|
|
|
# ---------------
|
|
|
|
@pytest.mark.parametrize("dtype_x, dtype_y, expr", [
|
|
|
|
(dtype_x, dtype_y, f' x {op} y') \
|
|
|
|
for op in ['==', '!=', '>', '<', '>=', '<='] \
|
|
|
|
for dtype_x in dtypes \
|
|
|
|
for dtype_y in dtypes
|
|
|
|
])
|
|
|
|
def test_compare_op(dtype_x, dtype_y, expr, device='cuda'):
|
|
|
|
_test_binary(dtype_x, dtype_y, expr, device=device)
|
|
|
|
|
|
|
|
|
|
|
|
# ---------------
|
|
|
|
# test unary ops
|
|
|
|
# ---------------
|
|
|
|
@pytest.mark.parametrize("dtype_x, expr", [
|
|
|
|
(dtype_x, f' -x') for dtype_x in float_dtypes
|
|
|
|
] + [\
|
|
|
|
(dtype_x, f' ~x') for dtype_x in int_dtypes
|
|
|
|
])
|
|
|
|
def test_unary_op(dtype_x, expr, device='cuda'):
|
|
|
|
_test_unary(dtype_x, expr, device=device)
|
|
|
|
|
|
|
|
|
|
|
|
# ----------------
|
|
|
|
# test indexing
|
|
|
|
# ----------------
|
|
|
|
|
|
|
|
|
|
|
|
def make_ptr_str(name, shape):
|
|
|
|
rank = len(shape)
|
|
|
|
offsets = []
|
|
|
|
stride = 1
|
|
|
|
for i in reversed(range(rank)):
|
|
|
|
idx = ', '.join([':' if ii == i else 'None' for ii in range(rank)])
|
2021-04-23 17:18:14 -04:00
|
|
|
offsets += [f'tl.arange(0, {shape[i]})[{idx}]*{stride}']
|
2021-04-20 22:29:40 -04:00
|
|
|
stride *= shape[i]
|
|
|
|
return f"{name} + {' + '.join(offsets)}"
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("expr", [f'x[{s}]' for s in
|
|
|
|
['None, :', ':, None',\
|
|
|
|
'None, :, :', ':, :, None']\
|
|
|
|
])
|
|
|
|
def test_index1d(expr, device='cuda'):
|
|
|
|
dtype = torch.int32
|
|
|
|
rank_x = expr.count(':')
|
|
|
|
rank_y = expr.count(',') + 1
|
|
|
|
shape_x = [32 for _ in range(rank_x)]
|
|
|
|
shape_z = [32 for _ in range(rank_y)]
|
|
|
|
|
|
|
|
# Triton kernel
|
|
|
|
@triton.jit
|
|
|
|
def kernel(Z, X, **meta):
|
|
|
|
SIZE = meta['SIZE']
|
2021-04-23 17:18:14 -04:00
|
|
|
m = tl.arange(0, SIZE)
|
|
|
|
n = tl.arange(0, SIZE)
|
|
|
|
x = tl.load(X_PTR_EXPR)
|
2021-04-20 22:29:40 -04:00
|
|
|
z = GENERATE_TEST_HERE
|
2021-04-23 17:18:14 -04:00
|
|
|
tl.store(Z_PTR_EXPR, z)
|
2021-04-20 22:29:40 -04:00
|
|
|
|
|
|
|
to_replace = {
|
|
|
|
'X_PTR_EXPR': make_ptr_str('X', shape_x),
|
|
|
|
'Z_PTR_EXPR': make_ptr_str('Z', shape_z),
|
|
|
|
'GENERATE_TEST_HERE': expr,
|
|
|
|
}
|
|
|
|
kernel = patch_kernel(kernel, to_replace)
|
|
|
|
|
|
|
|
# torch result
|
|
|
|
x = triton.testing.random(shape_x, dtype=dtype, device=device)
|
|
|
|
y = torch.zeros(shape_z, dtype=dtype, device=device)
|
|
|
|
z_ref = eval(expr) + y
|
|
|
|
# triton result
|
|
|
|
z_tri = torch.empty_like(z_ref)
|
|
|
|
kernel[(1, )](z_tri, x, num_warps=1, SIZE=shape_x[0])
|
|
|
|
# compare
|
|
|
|
triton.testing.assert_allclose(z_ref, z_tri)
|
|
|
|
|
|
|
|
|
|
|
|
# ---------------
|
|
|
|
# test load
|
|
|
|
# ---------------
|
|
|
|
|
|
|
|
# ---------------
|
|
|
|
# test store
|
|
|
|
# ---------------
|
|
|
|
|
|
|
|
# ---------------
|
|
|
|
# test if
|
|
|
|
# ---------------
|
|
|
|
|
|
|
|
# ---------------
|
|
|
|
# test for
|
|
|
|
# ---------------
|
|
|
|
|
|
|
|
# ---------------
|
|
|
|
# test while
|
|
|
|
# ---------------
|