[FRONTEND] Made more tests pass (#805)

This commit is contained in:
Philippe Tillet
2022-10-26 17:47:33 -07:00
committed by GitHub
parent bb7008651a
commit 3e6cc6d66c
9 changed files with 303 additions and 166 deletions

View File

@@ -281,141 +281,142 @@ def test_bin_op(dtype_x, dtype_y, op, device='cuda'):
_test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device)
# @pytest.mark.parametrize("dtype_x, dtype_y",
# [(dtype_x, dtype_y) for dtype_x in int_dtypes for dtype_y in int_dtypes] +
# [(dtype_x, dtype_y) for dtype_x in uint_dtypes for dtype_y in uint_dtypes]
# )
# def test_floordiv(dtype_x, dtype_y, device='cuda'):
# # Triton has IEEE, not numpy/torch, semantics for %, and those carry
# # through to //, so we have to use a nonstandard expression to get a
# # reference result for //.
# expr = 'x // y'
# numpy_expr = '((x - np.fmod(x, y)) / y)'
# _test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device)
@pytest.mark.parametrize("dtype_x, dtype_y",
[(dtype_x, dtype_y) for dtype_x in int_dtypes for dtype_y in int_dtypes] +
[(dtype_x, dtype_y) for dtype_x in uint_dtypes for dtype_y in uint_dtypes]
)
def test_floordiv(dtype_x, dtype_y, device='cuda'):
# Triton has IEEE, not numpy/torch, semantics for %, and those carry
# through to //, so we have to use a nonstandard expression to get a
# reference result for //.
expr = 'x // y'
numpy_expr = '((x - np.fmod(x, y)) / y)'
_test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device)
# # ---------------
# # test bitwise ops
# # ---------------
# @pytest.mark.parametrize("dtype_x, dtype_y, op", [
# (dtype_x, dtype_y, op)
# for op in ['&', '|', '^']
# for dtype_x in dtypes + dtypes_with_bfloat16
# for dtype_y in dtypes + dtypes_with_bfloat16
# ])
# def test_bitwise_op(dtype_x, dtype_y, op, device='cuda'):
# expr = f'x {op} y'
# if (dtype_x in uint_dtypes and dtype_y in int_dtypes and _bitwidth(dtype_x) >= _bitwidth(dtype_y)):
# numpy_expr = f'x.astype(np.{dtype_x}) {op} y.astype(np.{dtype_x})'
# elif (dtype_y in uint_dtypes and dtype_x in int_dtypes and _bitwidth(dtype_y) >= _bitwidth(dtype_x)):
# numpy_expr = f'x.astype(np.{dtype_y}) {op} y.astype(np.{dtype_y})'
# else:
# numpy_expr = None
# if 'float' in dtype_x + dtype_y:
# with pytest.raises(triton.CompilationError) as exc_info:
# _test_binary(dtype_x, dtype_y, expr, numpy_expr='np.array([])', device=device)
# # The CompilationError must have been caused by a C++ exception with this text.
# assert re.match('invalid operands of type', str(exc_info.value.__cause__))
# else:
# _test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device)
# ---------------
# test bitwise ops
# ---------------
@pytest.mark.parametrize("dtype_x, dtype_y, op", [
(dtype_x, dtype_y, op)
for op in ['&', '|', '^']
for dtype_x in dtypes + dtypes_with_bfloat16
for dtype_y in dtypes + dtypes_with_bfloat16
])
def test_bitwise_op(dtype_x, dtype_y, op, device='cuda'):
expr = f'x {op} y'
if (dtype_x in uint_dtypes and dtype_y in int_dtypes and _bitwidth(dtype_x) >= _bitwidth(dtype_y)):
numpy_expr = f'x.astype(np.{dtype_x}) {op} y.astype(np.{dtype_x})'
elif (dtype_y in uint_dtypes and dtype_x in int_dtypes and _bitwidth(dtype_y) >= _bitwidth(dtype_x)):
numpy_expr = f'x.astype(np.{dtype_y}) {op} y.astype(np.{dtype_y})'
else:
numpy_expr = None
if 'float' in dtype_x + dtype_y:
with pytest.raises(triton.CompilationError) as exc_info:
_test_binary(dtype_x, dtype_y, expr, numpy_expr='np.array([])', device=device)
# The CompilationError must have been caused by a C++ exception with this text.
assert re.match('invalid operands of type', str(exc_info.value.__cause__))
else:
_test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device)
# @pytest.mark.parametrize("dtype_x, dtype_y, op", [
# (dtype_x, dtype_y, op)
# for op in ['<<', '>>']
# for dtype_x in int_dtypes + uint_dtypes
# for dtype_y in int_dtypes + uint_dtypes
# ])
# def test_shift_op(dtype_x, dtype_y, op, device='cuda'):
# expr = f'x {op} y'
# bw = max(_bitwidth(dtype_x), _bitwidth(dtype_y))
# dtype_z = f'uint{bw}'
# numpy_expr = f'x.astype(np.{dtype_z}) {op} y.astype(np.{dtype_z})'
# _test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device, y_low=0, y_high=65)
@pytest.mark.parametrize("dtype_x, dtype_y, op", [
(dtype_x, dtype_y, op)
for op in ['<<', '>>']
for dtype_x in int_dtypes + uint_dtypes
for dtype_y in int_dtypes + uint_dtypes
])
def test_shift_op(dtype_x, dtype_y, op, device='cuda'):
expr = f'x {op} y'
bw = max(_bitwidth(dtype_x), _bitwidth(dtype_y))
dtype_z = f'uint{bw}'
numpy_expr = f'x.astype(np.{dtype_z}) {op} y.astype(np.{dtype_z})'
_test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device, y_low=0, y_high=65)
# # ---------------
# # test compare ops
# # ---------------
# ops = ['==', '!=', '>', '<', '>=', '<=']
# ---------------
# test compare ops
# ---------------
ops = ['==', '!=', '>', '<', '>=', '<=']
# @pytest.mark.parametrize("dtype_x, dtype_y, op, mode_x, mode_y",
# # real
# [
# (dtype_x, dtype_y, op, 'real', 'real')
# for op in ops
# for dtype_x in dtypes
# for dtype_y in dtypes
# ] +
# # NaNs
# [('float32', 'float32', op, mode_x, mode_y)
# for op in ops
# for mode_x, mode_y in [('nan', 'real'),
# ('real', 'nan'),
# ('nan', 'nan')]
@pytest.mark.parametrize("dtype_x, dtype_y, op, mode_x, mode_y",
# real
[
(dtype_x, dtype_y, op, 'real', 'real')
for op in ops
for dtype_x in dtypes
for dtype_y in dtypes
] +
# NaNs
[('float32', 'float32', op, mode_x, mode_y)
for op in ops
for mode_x, mode_y in [('nan', 'real'),
('real', 'nan'),
('nan', 'nan')]
# ])
# def test_compare_op(dtype_x, dtype_y, op, mode_x, mode_y, device='cuda'):
# expr = f'x {op} y'
# if (dtype_x in uint_dtypes and dtype_y in int_dtypes and _bitwidth(dtype_x) >= _bitwidth(dtype_y)):
# numpy_expr = f'x.astype(np.{dtype_x}) {op} y.astype(np.{dtype_x})'
# elif (dtype_y in uint_dtypes and dtype_x in int_dtypes and _bitwidth(dtype_y) >= _bitwidth(dtype_x)):
# numpy_expr = f'x.astype(np.{dtype_y}) {op} y.astype(np.{dtype_y})'
# else:
# numpy_expr = None
# _test_binary(dtype_x, dtype_y, expr, numpy_expr, mode_x=mode_x, mode_y=mode_y, device=device)
])
def test_compare_op(dtype_x, dtype_y, op, mode_x, mode_y, device='cuda'):
expr = f'x {op} y'
if (dtype_x in uint_dtypes and dtype_y in int_dtypes and _bitwidth(dtype_x) >= _bitwidth(dtype_y)):
numpy_expr = f'x.astype(np.{dtype_x}) {op} y.astype(np.{dtype_x})'
elif (dtype_y in uint_dtypes and dtype_x in int_dtypes and _bitwidth(dtype_y) >= _bitwidth(dtype_x)):
numpy_expr = f'x.astype(np.{dtype_y}) {op} y.astype(np.{dtype_y})'
else:
numpy_expr = None
_test_binary(dtype_x, dtype_y, expr, numpy_expr, mode_x=mode_x, mode_y=mode_y, device=device)
# # ---------------
# # test where
# # ---------------
# @pytest.mark.parametrize("dtype", dtypes_with_bfloat16 + ["*int32"])
# def test_where(dtype):
# select_ptrs = False
# if dtype == "*int32":
# dtype = "int64"
# select_ptrs = True
# check_type_supported(dtype)
# ---------------
# test where
# ---------------
@pytest.mark.parametrize("dtype", dtypes_with_bfloat16 + ["*int32"])
def test_where(dtype):
select_ptrs = False
if dtype == "*int32":
dtype = "int64"
select_ptrs = True
check_type_supported(dtype)
# @triton.jit
# def where_kernel(cond_ptr, a_ptr, b_ptr, output_ptr, n_elements,
# BLOCK_SIZE: tl.constexpr,
# TEST_POINTERS: tl.constexpr):
# offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
# mask = offsets < n_elements
# decide = tl.load(cond_ptr + offsets, mask=mask)
# if TEST_POINTERS:
# a = tl.load(a_ptr + offsets, mask=mask).to(tl.pi32_t)
# b = tl.load(b_ptr + offsets, mask=mask).to(tl.pi32_t)
# else:
# a = tl.load(a_ptr + offsets, mask=mask)
# b = tl.load(b_ptr + offsets, mask=mask)
# output = tl.where(decide, a, b)
# tl.store(output_ptr + offsets, output, mask=mask)
@triton.jit
def where_kernel(cond_ptr, a_ptr, b_ptr, output_ptr, n_elements,
BLOCK_SIZE: tl.constexpr,
TEST_POINTERS: tl.constexpr):
offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
decide = tl.load(cond_ptr + offsets, mask=mask)
if TEST_POINTERS:
a = tl.load(a_ptr + offsets, mask=mask).to(tl.pi32_t)
b = tl.load(b_ptr + offsets, mask=mask).to(tl.pi32_t)
else:
a = tl.load(a_ptr + offsets, mask=mask)
b = tl.load(b_ptr + offsets, mask=mask)
output = tl.where(decide, a, b)
tl.store(output_ptr + offsets, output, mask=mask)
# SIZE = 1_000
# rs = RandomState(17)
# cond = numpy_random(SIZE, 'bool', rs)
# x = numpy_random(SIZE, dtype_str=dtype, rs=rs)
# y = numpy_random(SIZE, dtype_str=dtype, rs=rs)
# z = np.where(cond, x, y)
SIZE = 1_000
rs = RandomState(17)
cond = numpy_random(SIZE, 'bool', rs)
x = numpy_random(SIZE, dtype_str=dtype, rs=rs)
y = numpy_random(SIZE, dtype_str=dtype, rs=rs)
z = np.where(cond, x, y)
# cond_tri = to_triton(cond, device='cuda')
# x_tri = to_triton(x, device='cuda', dst_type=dtype)
# y_tri = to_triton(y, device='cuda', dst_type=dtype)
# z_tri = to_triton(np.empty(SIZE, dtype=z.dtype), device='cuda', dst_type=dtype)
cond_tri = to_triton(cond, device='cuda')
x_tri = to_triton(x, device='cuda', dst_type=dtype)
y_tri = to_triton(y, device='cuda', dst_type=dtype)
z_tri = to_triton(np.empty(SIZE, dtype=z.dtype), device='cuda', dst_type=dtype)
# grid = lambda meta: (triton.cdiv(SIZE, meta['BLOCK_SIZE']),)
# where_kernel[grid](cond_tri, x_tri, y_tri, z_tri, SIZE, BLOCK_SIZE=1024, TEST_POINTERS=select_ptrs)
# assert (z == to_numpy(z_tri)).all()
grid = lambda meta: (triton.cdiv(SIZE, meta['BLOCK_SIZE']),)
where_kernel[grid](cond_tri, x_tri, y_tri, z_tri, SIZE, BLOCK_SIZE=1024, TEST_POINTERS=select_ptrs)
assert (z == to_numpy(z_tri)).all()
# TODO: wrong result
# def test_where_broadcast():
# @triton.jit
# def where_kernel(cond_ptr, a_ptr, out_ptr, BLOCK_SIZE: tl.constexpr):
# xoffsets = tl.reshape(tl.arange(0, BLOCK_SIZE), [BLOCK_SIZE, 1])
# yoffsets = tl.reshape(tl.arange(0, BLOCK_SIZE), [1, BLOCK_SIZE])
# xoffsets = tl.arange(0, BLOCK_SIZE)[:, None]
# yoffsets = tl.arange(0, BLOCK_SIZE)[None, :]
# mask = tl.load(cond_ptr + yoffsets)
# vals = tl.load(a_ptr + yoffsets + BLOCK_SIZE * xoffsets)
@@ -424,8 +425,8 @@ def test_bin_op(dtype_x, dtype_y, op, device='cuda'):
# @triton.jit
# def where_scalar_condition(a_ptr, out_ptr, BLOCK_SIZE: tl.constexpr):
# xoffsets = tl.reshape(tl.arange(0, BLOCK_SIZE), [BLOCK_SIZE, 1])
# yoffsets = tl.reshape(tl.arange(0, BLOCK_SIZE), [1, BLOCK_SIZE])
# xoffsets = tl.arange(0, BLOCK_SIZE)[:, None]
# yoffsets = tl.arange(0, BLOCK_SIZE)[None, :]
# mask = 0
# vals = tl.load(a_ptr + yoffsets + BLOCK_SIZE * xoffsets)
# res = tl.where(mask, vals, 0.)
@@ -451,17 +452,19 @@ def test_bin_op(dtype_x, dtype_y, op, device='cuda'):
# # ---------------
# @pytest.mark.parametrize("dtype_x, expr", [
# (dtype_x, ' -x') for dtype_x in dtypes_with_bfloat16
# ] + [
# (dtype_x, ' ~x') for dtype_x in int_dtypes
# ])
# def test_unary_op(dtype_x, expr, device='cuda'):
# _test_unary(dtype_x, expr, device=device)
@pytest.mark.parametrize("dtype_x, expr", [
(dtype_x, ' -x') for dtype_x in dtypes_with_bfloat16
] + [
(dtype_x, ' ~x') for dtype_x in int_dtypes
])
def test_unary_op(dtype_x, expr, device='cuda'):
_test_unary(dtype_x, expr, device=device)
# # ----------------
# # test math ops
# # ----------------
# TODO: Math module
# # @pytest.mark.parametrize("expr", [
# # 'exp', 'log', 'cos', 'sin'
# # ])
@@ -479,17 +482,18 @@ def test_bin_op(dtype_x, dtype_y, op, device='cuda'):
# # ----------------
# def make_ptr_str(name, shape):
# rank = len(shape)
# offsets = []
# stride = 1
# for i in reversed(range(rank)):
# idx = ', '.join([':' if ii == i else 'None' for ii in range(rank)])
# offsets += [f'tl.arange(0, {shape[i]})[{idx}]*{stride}']
# stride *= shape[i]
# return f"{name} + {' + '.join(offsets)}"
def make_ptr_str(name, shape):
rank = len(shape)
offsets = []
stride = 1
for i in reversed(range(rank)):
idx = ', '.join([':' if ii == i else 'None' for ii in range(rank)])
offsets += [f'tl.arange(0, {shape[i]})[{idx}]*{stride}']
stride *= shape[i]
return f"{name} + {' + '.join(offsets)}"
# TODO: handle `%4 = triton_gpu.convert_layout %3 : (tensor<32xi32, #blocked0>) -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>``
# @pytest.mark.parametrize("expr, dtype_str", [
# (f'x[{s}]', d)
# for s in ['None, :', ':, None', 'None, :, :', ':, :, None']