[FRONTEND] signed-integer math fixes and testing (#395)

- Promote 16-bit floating-point `/` and `%` to 32-bit; we have to anyway.
- Do not force result of integer binary operations to be the LHS type. There used to be a bug in pytorch that did this, which Triton matched, but that bug is fixed now.
- When testing signed integer operations, use random numbers from the full range of the type.
- Add an optional `seed` argument to `triton.testing.random` so binary operations are not tested with both sides equal when the LHS and RHS have the same type.
- Fix a bad `CompilationError` invocation.
- Fix a warning suppression that causes tests to fail if you run them with `-W error` on python 3.8.
This commit is contained in:
Madeleine Thompson
2021-12-21 09:46:05 -08:00
committed by GitHub
parent 4a8953efa3
commit 5cdb948c05
4 changed files with 93 additions and 47 deletions

View File

@@ -33,21 +33,28 @@ ir::type *integer_promote(ir::type* a_ty, ir::type* b_ty){
return a_rank > b_rank ? a_ty : b_ty; return a_rank > b_rank ? a_ty : b_ty;
} }
ir::type *computation_type(ir::type* a_ty, ir::type* b_ty){ enum class DivOrMod { NO, YES };
ir::type *computation_type(ir::type* a_ty, ir::type* b_ty, DivOrMod div_or_mod) {
context &ctx = a_ty->get_context(); context &ctx = a_ty->get_context();
// 1) if one operand is double, the other is implicitly // 1) if one operand is double, the other is implicitly
// converted to double // converted to double
if(a_ty->is_fp64_ty() || b_ty->is_fp64_ty()) if (a_ty->is_fp64_ty() || b_ty->is_fp64_ty())
return type::get_fp64_ty(ctx); return type::get_fp64_ty(ctx);
// 2) if one operand is float, the other is implicitly // 2) if one operand is float, the other is implicitly
// converted to float // converted to float
if(a_ty->is_fp32_ty() || b_ty->is_fp32_ty()) if (a_ty->is_fp32_ty() || b_ty->is_fp32_ty())
return type::get_fp32_ty(ctx); return type::get_fp32_ty(ctx);
// 3 ) if one operand is half, the other is implicitly // 3 ) if one operand is half, the other is implicitly converted to half
// converted to half // unless we're doing / or %, which do not exist natively in PTX for fp16.
if(a_ty->is_fp16_ty() || b_ty->is_fp16_ty()) if (a_ty->is_fp16_ty() || b_ty->is_fp16_ty()) {
if (div_or_mod == DivOrMod::YES) {
return type::get_fp32_ty(ctx);
} else {
return type::get_fp16_ty(ctx); return type::get_fp16_ty(ctx);
if(!a_ty->is_integer_ty() || !b_ty->is_integer_ty()) }
}
if (!a_ty->is_integer_ty() || !b_ty->is_integer_ty())
throw_unreachable("augment_types"); throw_unreachable("augment_types");
// 4 ) both operands are integer and undergo // 4 ) both operands are integer and undergo
// integer promotion // integer promotion
@@ -78,7 +85,7 @@ void check_ptr_type(ir::type* type_a, ir::type* type_b, bool allow_ptr_a){
void binary_op_type_checking(ir::value*& lhs, ir::value*& rhs, ir::builder* builder, void binary_op_type_checking(ir::value*& lhs, ir::value*& rhs, ir::builder* builder,
bool allow_lhs_ptr = false, bool allow_rhs_ptr = false, bool allow_lhs_ptr = false, bool allow_rhs_ptr = false,
bool arithmetic_check = true){ bool arithmetic_check = true, DivOrMod div_or_mod = DivOrMod::NO) {
// implicit broadcasting // implicit broadcasting
std::tie(lhs, rhs) = dispatch::broadcast(lhs, rhs, builder); std::tie(lhs, rhs) = dispatch::broadcast(lhs, rhs, builder);
// implicit typecasting // implicit typecasting
@@ -86,8 +93,8 @@ void binary_op_type_checking(ir::value*& lhs, ir::value*& rhs, ir::builder* buil
ir::type *rhs_sca_ty = rhs->get_type()->get_scalar_ty(); ir::type *rhs_sca_ty = rhs->get_type()->get_scalar_ty();
check_ptr_type(lhs_sca_ty, rhs_sca_ty, allow_lhs_ptr); check_ptr_type(lhs_sca_ty, rhs_sca_ty, allow_lhs_ptr);
check_ptr_type(rhs_sca_ty, lhs_sca_ty, allow_rhs_ptr); check_ptr_type(rhs_sca_ty, lhs_sca_ty, allow_rhs_ptr);
if(arithmetic_check && !lhs_sca_ty->is_pointer_ty() && !rhs_sca_ty->is_pointer_ty()){ if (arithmetic_check && !lhs_sca_ty->is_pointer_ty() && !rhs_sca_ty->is_pointer_ty()) {
ir::type *ret_sca_ty = computation_type(lhs_sca_ty, rhs_sca_ty); ir::type *ret_sca_ty = computation_type(lhs_sca_ty, rhs_sca_ty, div_or_mod);
lhs = dispatch::cast(lhs, ret_sca_ty, builder); lhs = dispatch::cast(lhs, ret_sca_ty, builder);
rhs = dispatch::cast(rhs, ret_sca_ty, builder); rhs = dispatch::cast(rhs, ret_sca_ty, builder);
} }
@@ -140,7 +147,7 @@ ir::value *dispatch::mul(ir::value *input, ir::value *other, ir::builder *builde
} }
ir::value *dispatch::truediv(ir::value *input, ir::value *other, ir::builder *builder) { ir::value *dispatch::truediv(ir::value *input, ir::value *other, ir::builder *builder) {
binary_op_type_checking(input, other, builder, false, false, false); binary_op_type_checking(input, other, builder, false, false, true, DivOrMod::YES);
ir::type *input_scalar_ty = input->get_type()->get_scalar_ty(); ir::type *input_scalar_ty = input->get_type()->get_scalar_ty();
ir::type *other_scalar_ty = other->get_type()->get_scalar_ty(); ir::type *other_scalar_ty = other->get_type()->get_scalar_ty();
// float / int // float / int
@@ -168,7 +175,7 @@ ir::value *dispatch::truediv(ir::value *input, ir::value *other, ir::builder *bu
} }
ir::value *dispatch::floordiv(ir::value *input, ir::value *other, ir::builder *builder){ ir::value *dispatch::floordiv(ir::value *input, ir::value *other, ir::builder *builder){
binary_op_type_checking(input, other, builder, false, false, false); binary_op_type_checking(input, other, builder, false, false, true, DivOrMod::YES);
ir::type *input_scalar_ty = input->get_type()->get_scalar_ty(); ir::type *input_scalar_ty = input->get_type()->get_scalar_ty();
ir::type *other_scalar_ty = other->get_type()->get_scalar_ty(); ir::type *other_scalar_ty = other->get_type()->get_scalar_ty();
if(input_scalar_ty->is_integer_ty() && other_scalar_ty->is_integer_ty()){ if(input_scalar_ty->is_integer_ty() && other_scalar_ty->is_integer_ty()){
@@ -181,7 +188,7 @@ ir::value *dispatch::floordiv(ir::value *input, ir::value *other, ir::builder *b
} }
ir::value *dispatch::mod(ir::value *input, ir::value *other, ir::builder *builder) { ir::value *dispatch::mod(ir::value *input, ir::value *other, ir::builder *builder) {
binary_op_type_checking(input, other, builder); binary_op_type_checking(input, other, builder, false, false, true, DivOrMod::YES);
ir::type *scalar_ty = input->get_type()->get_scalar_ty(); ir::type *scalar_ty = input->get_type()->get_scalar_ty();
// float % int // float % int
if (scalar_ty->is_floating_point_ty()) if (scalar_ty->is_floating_point_ty())
@@ -193,51 +200,43 @@ ir::value *dispatch::mod(ir::value *input, ir::value *other, ir::builder *builde
} }
void bitwise_op_type_checking(ir::value *&input, ir::value *&other, ir::builder *builder, bool force_lhs_type = false){ void bitwise_op_type_checking(ir::value *&input, ir::value *&other, ir::builder *builder) {
binary_op_type_checking(input, other, builder, false, false, false); binary_op_type_checking(input, other, builder, false, false, false);
ir::type *input_sca_ty = input->get_type()->get_scalar_ty(); ir::type *input_sca_ty = input->get_type()->get_scalar_ty();
ir::type *other_sca_ty = other->get_type()->get_scalar_ty(); ir::type *other_sca_ty = other->get_type()->get_scalar_ty();
if(!input_sca_ty->is_integer_ty() || !other_sca_ty->is_integer_ty()) if(!input_sca_ty->is_integer_ty() || !other_sca_ty->is_integer_ty())
throw_incompatible_types(input_sca_ty, other_sca_ty); throw_incompatible_types(input_sca_ty, other_sca_ty);
// for some reason pytorch assigns the result of binary op to have the type of the lhs...
if(force_lhs_type){
if(input_sca_ty->get_integer_bitwidth() != other_sca_ty->get_integer_bitwidth())
other = dispatch::cast(other, input_sca_ty, builder);
}
else{
if(input_sca_ty->get_integer_bitwidth() < other_sca_ty->get_integer_bitwidth()) if(input_sca_ty->get_integer_bitwidth() < other_sca_ty->get_integer_bitwidth())
input = dispatch::cast(input, other_sca_ty, builder); input = dispatch::cast(input, other_sca_ty, builder);
else if(other_sca_ty->get_integer_bitwidth() < input_sca_ty->get_integer_bitwidth()) else if(other_sca_ty->get_integer_bitwidth() < input_sca_ty->get_integer_bitwidth())
other = dispatch::cast(other, input_sca_ty, builder); other = dispatch::cast(other, input_sca_ty, builder);
}
} }
ir::value *dispatch::and_(ir::value *input, ir::value *other, ir::builder *builder) { ir::value *dispatch::and_(ir::value *input, ir::value *other, ir::builder *builder) {
bitwise_op_type_checking(input, other, builder, true); bitwise_op_type_checking(input, other, builder);
return builder->create_and(input, other); return builder->create_and(input, other);
} }
ir::value *dispatch::or_(ir::value *input, ir::value *other, ir::builder *builder) { ir::value *dispatch::or_(ir::value *input, ir::value *other, ir::builder *builder) {
bitwise_op_type_checking(input, other, builder, true); bitwise_op_type_checking(input, other, builder);
return builder->create_or(input, other); return builder->create_or(input, other);
} }
ir::value *dispatch::xor_(ir::value *input, ir::value *other, ir::builder *builder) { ir::value *dispatch::xor_(ir::value *input, ir::value *other, ir::builder *builder) {
bitwise_op_type_checking(input, other, builder, true); bitwise_op_type_checking(input, other, builder);
return builder->create_xor(input, other); return builder->create_xor(input, other);
} }
ir::value *dispatch::lshr(ir::value *input, ir::value *other, ir::builder *builder) { ir::value *dispatch::lshr(ir::value *input, ir::value *other, ir::builder *builder) {
bitwise_op_type_checking(input, other, builder, false); bitwise_op_type_checking(input, other, builder);
return builder->create_lshr(input, other); return builder->create_lshr(input, other);
} }
ir::value *dispatch::shl(ir::value *input, ir::value *other, ir::builder *builder) { ir::value *dispatch::shl(ir::value *input, ir::value *other, ir::builder *builder) {
bitwise_op_type_checking(input, other, builder, false); bitwise_op_type_checking(input, other, builder);
return builder->create_shl(input, other); return builder->create_shl(input, other);
} }

View File

@@ -69,7 +69,7 @@ def _test_unary(dtype_x, expr, torch_expr=None, device='cuda'):
triton.testing.assert_almost_equal(z_ref, z_tri) triton.testing.assert_almost_equal(z_ref, z_tri)
def _test_binary(dtype_x, dtype_y, expr, mode_x='real', mode_y='real', device='cuda'): def _test_binary(dtype_x, dtype_y, expr, torch_expr=None, mode_x='real', mode_y='real', device='cuda'):
SIZE = 128 SIZE = 128
# define the kernel / launch-grid # define the kernel / launch-grid
@triton.jit @triton.jit
@@ -82,12 +82,12 @@ def _test_binary(dtype_x, dtype_y, expr, mode_x='real', mode_y='real', device='c
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': expr}) kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': expr})
# inputs # inputs
x = triton.testing.random(SIZE, dtype=cvt[dtype_x], device=device) x = triton.testing.random(SIZE, dtype=cvt[dtype_x], device=device, seed=17)
y = triton.testing.random(SIZE, dtype=cvt[dtype_y], device=device) y = triton.testing.random(SIZE, dtype=cvt[dtype_y], device=device, seed=144)
if mode_x == 'nan': x[:] = float('nan') if mode_x == 'nan': x[:] = float('nan')
if mode_y == 'nan': y[:] = float('nan') if mode_y == 'nan': y[:] = float('nan')
# reference result # reference result
z_ref = eval(expr) z_ref = eval(expr if torch_expr is None else torch_expr)
# triton result # triton result
z_tri = torch.empty(SIZE, dtype=z_ref.dtype, device=device) z_tri = torch.empty(SIZE, dtype=z_ref.dtype, device=device)
kernel[(1, )](z_tri, x, y, SIZE=SIZE, num_warps=4) kernel[(1, )](z_tri, x, y, SIZE=SIZE, num_warps=4)
@@ -95,17 +95,56 @@ def _test_binary(dtype_x, dtype_y, expr, mode_x='real', mode_y='real', device='c
triton.testing.assert_almost_equal(z_ref, z_tri, err_msg=expr) triton.testing.assert_almost_equal(z_ref, z_tri, err_msg=expr)
def _fake_fmod(x, y):
"""
Triton % (for both integers and floats) has the same semantics as torch
fmod, but torch fmod doesn't work on integers until torch 1.8.
`_fake_fmod` gives the same semantics but works on all versions of torch.
"""
z = torch.remainder(x, y)
return torch.where((torch.sign(x) != torch.sign(y)) & (z != 0), z - y, z)
def _mod_operation_ill_conditioned(dtype_x, dtype_y) -> bool:
# The result of x % y is ill-conditioned if x % y is much smaller than x.
# pytorch/CUDA has slightly different (probably better) rounding on
# remainders than stock LLVM. We currently don't expect to match it
# bit-for-bit.
return (dtype_x, dtype_y) in [
('int32', 'float16'),
('int32', 'float32'),
('int64', 'float16'),
('int64', 'float32'),
('int64', 'float64'),
]
# --------------- # ---------------
# test binary ops # test binary ops
# --------------- # ---------------
@pytest.mark.parametrize("dtype_x, dtype_y, expr", [ @pytest.mark.parametrize("dtype_x, dtype_y, op", [
(dtype_x, dtype_y, f' x {op} y') \ (dtype_x, dtype_y, op)
for op in ['+', '-', '*', '/', '%'] \ for op in ['+', '-', '*', '/', '%']
for dtype_x in dtypes \ for dtype_x in dtypes
for dtype_y in dtypes for dtype_y in dtypes
]) ])
def test_bin_op(dtype_x, dtype_y, expr, device='cuda'): def test_bin_op(dtype_x, dtype_y, op, device='cuda'):
_test_binary(dtype_x, dtype_y, expr, device=device) expr = f' x {op} y'
if op == '%' and dtype_x in int_dtypes and dtype_y in int_dtypes:
# LLVM has 'torch.fmod', not 'torch.remainder' semantics on integer remainders.
torch_expr = '_fake_fmod(x, y)'
elif op in ('/', '%') and dtype_x in ('int16', 'float16') and dtype_y in ('int16', 'float16'):
# Triton promotes 16-bit floating-point / and % to 32-bit because there
# are no native div or FRem operations on float16. Since we have to
# convert anyway, we may as well take the accuracy bump.
torch_expr = f'x.to(torch.float32) {op} y.to(torch.float32)'
else:
torch_expr = None
if op == '%' and _mod_operation_ill_conditioned(dtype_x, dtype_y):
with pytest.raises(AssertionError, match='Arrays are not almost equal'):
_test_binary(dtype_x, dtype_y, expr, torch_expr=torch_expr, device=device)
else:
_test_binary(dtype_x, dtype_y, expr, torch_expr=torch_expr, device=device)
# --------------- # ---------------

View File

@@ -482,7 +482,8 @@ class CodeGenerator(ast.NodeVisitor):
with warnings.catch_warnings(): with warnings.catch_warnings():
# The ast library added visit_Constant and deprecated some other # The ast library added visit_Constant and deprecated some other
# methods but we can't move to that without breaking Python 3.6 and 3.7. # methods but we can't move to that without breaking Python 3.6 and 3.7.
warnings.simplefilter("ignore", DeprecationWarning) warnings.simplefilter("ignore", DeprecationWarning) # python 3.9
warnings.simplefilter("ignore", PendingDeprecationWarning) # python 3.8
return super().visit(node) return super().visit(node)
def generic_visit(self, node): def generic_visit(self, node):
@@ -905,7 +906,7 @@ class JITFunction:
node = generator.last_node node = generator.last_node
if node is None or isinstance(e, (NotImplementedError, CompilationError)): if node is None or isinstance(e, (NotImplementedError, CompilationError)):
raise e raise e
raise CompilationError(self.src, node, e) raise CompilationError(self.src, node) from e
# - when `.src` attribute is set, cache path needs # - when `.src` attribute is set, cache path needs
# to be reinitialized # to be reinitialized

View File

@@ -89,14 +89,21 @@ def assert_allclose(x, y, tol=1e-2):
assert allclose(x, y, tol) assert allclose(x, y, tol)
def random(shape, dtype, device): def random(shape, dtype, device, seed=0):
torch.manual_seed(0) """
Override the seed in tests if you're calling this function twice and don't
want the same result for both calls.
"""
torch.manual_seed(seed)
if isinstance(shape, int): if isinstance(shape, int):
shape = (shape, ) shape = (shape, )
if dtype == torch.bool: if dtype == torch.bool:
return torch.randint(0, 2, shape, dtype=dtype, device=device) return torch.randint(0, 2, shape, dtype=dtype, device=device)
if dtype in [torch.int8, torch.int16, torch.int32, torch.int64]: if dtype in [torch.int8, torch.int16, torch.int32, torch.int64]:
return torch.randint(1, 32, shape, dtype=dtype, device=device) iinfo = torch.iinfo(dtype)
x = torch.randint(iinfo.min, iinfo.max, shape, dtype=dtype, device=device)
x[x == 0] = 1 # Hack. Never return zero so tests of division don't error out.
return x
if dtype in [torch.float16, torch.float32, torch.float64]: if dtype in [torch.float16, torch.float32, torch.float64]:
return torch.normal(0, 1, shape, dtype=dtype, device=device) return torch.normal(0, 1, shape, dtype=dtype, device=device)
raise RuntimeError(f'Unknown dtype {dtype}') raise RuntimeError(f'Unknown dtype {dtype}')