[FRONTEND] signed-integer math fixes and testing (#395)
- Promote 16-bit floating-point `/` and `%` to 32-bit; we have to anyway. - Do not force result of integer binary operations to be the LHS type. There used to be a bug in pytorch that did this, which Triton matched, but that bug is fixed now. - When testing signed integer operations, use random numbers from the full range of the type. - Add an optional `seed` argument to `triton.testing.random` so binary operations are not tested with both sides equal when the LHS and RHS have the same type. - Fix a bad `CompilationError` invocation. - Fix a warning suppression that causes tests to fail if you run them with `-W error` on python 3.8.
This commit is contained in:
committed by
GitHub
parent
4a8953efa3
commit
5cdb948c05
@@ -33,21 +33,28 @@ ir::type *integer_promote(ir::type* a_ty, ir::type* b_ty){
|
|||||||
return a_rank > b_rank ? a_ty : b_ty;
|
return a_rank > b_rank ? a_ty : b_ty;
|
||||||
}
|
}
|
||||||
|
|
||||||
ir::type *computation_type(ir::type* a_ty, ir::type* b_ty){
|
enum class DivOrMod { NO, YES };
|
||||||
|
|
||||||
|
ir::type *computation_type(ir::type* a_ty, ir::type* b_ty, DivOrMod div_or_mod) {
|
||||||
context &ctx = a_ty->get_context();
|
context &ctx = a_ty->get_context();
|
||||||
// 1) if one operand is double, the other is implicitly
|
// 1) if one operand is double, the other is implicitly
|
||||||
// converted to double
|
// converted to double
|
||||||
if(a_ty->is_fp64_ty() || b_ty->is_fp64_ty())
|
if (a_ty->is_fp64_ty() || b_ty->is_fp64_ty())
|
||||||
return type::get_fp64_ty(ctx);
|
return type::get_fp64_ty(ctx);
|
||||||
// 2) if one operand is float, the other is implicitly
|
// 2) if one operand is float, the other is implicitly
|
||||||
// converted to float
|
// converted to float
|
||||||
if(a_ty->is_fp32_ty() || b_ty->is_fp32_ty())
|
if (a_ty->is_fp32_ty() || b_ty->is_fp32_ty())
|
||||||
return type::get_fp32_ty(ctx);
|
return type::get_fp32_ty(ctx);
|
||||||
// 3 ) if one operand is half, the other is implicitly
|
// 3 ) if one operand is half, the other is implicitly converted to half
|
||||||
// converted to half
|
// unless we're doing / or %, which do not exist natively in PTX for fp16.
|
||||||
if(a_ty->is_fp16_ty() || b_ty->is_fp16_ty())
|
if (a_ty->is_fp16_ty() || b_ty->is_fp16_ty()) {
|
||||||
|
if (div_or_mod == DivOrMod::YES) {
|
||||||
|
return type::get_fp32_ty(ctx);
|
||||||
|
} else {
|
||||||
return type::get_fp16_ty(ctx);
|
return type::get_fp16_ty(ctx);
|
||||||
if(!a_ty->is_integer_ty() || !b_ty->is_integer_ty())
|
}
|
||||||
|
}
|
||||||
|
if (!a_ty->is_integer_ty() || !b_ty->is_integer_ty())
|
||||||
throw_unreachable("augment_types");
|
throw_unreachable("augment_types");
|
||||||
// 4 ) both operands are integer and undergo
|
// 4 ) both operands are integer and undergo
|
||||||
// integer promotion
|
// integer promotion
|
||||||
@@ -78,7 +85,7 @@ void check_ptr_type(ir::type* type_a, ir::type* type_b, bool allow_ptr_a){
|
|||||||
|
|
||||||
void binary_op_type_checking(ir::value*& lhs, ir::value*& rhs, ir::builder* builder,
|
void binary_op_type_checking(ir::value*& lhs, ir::value*& rhs, ir::builder* builder,
|
||||||
bool allow_lhs_ptr = false, bool allow_rhs_ptr = false,
|
bool allow_lhs_ptr = false, bool allow_rhs_ptr = false,
|
||||||
bool arithmetic_check = true){
|
bool arithmetic_check = true, DivOrMod div_or_mod = DivOrMod::NO) {
|
||||||
// implicit broadcasting
|
// implicit broadcasting
|
||||||
std::tie(lhs, rhs) = dispatch::broadcast(lhs, rhs, builder);
|
std::tie(lhs, rhs) = dispatch::broadcast(lhs, rhs, builder);
|
||||||
// implicit typecasting
|
// implicit typecasting
|
||||||
@@ -86,8 +93,8 @@ void binary_op_type_checking(ir::value*& lhs, ir::value*& rhs, ir::builder* buil
|
|||||||
ir::type *rhs_sca_ty = rhs->get_type()->get_scalar_ty();
|
ir::type *rhs_sca_ty = rhs->get_type()->get_scalar_ty();
|
||||||
check_ptr_type(lhs_sca_ty, rhs_sca_ty, allow_lhs_ptr);
|
check_ptr_type(lhs_sca_ty, rhs_sca_ty, allow_lhs_ptr);
|
||||||
check_ptr_type(rhs_sca_ty, lhs_sca_ty, allow_rhs_ptr);
|
check_ptr_type(rhs_sca_ty, lhs_sca_ty, allow_rhs_ptr);
|
||||||
if(arithmetic_check && !lhs_sca_ty->is_pointer_ty() && !rhs_sca_ty->is_pointer_ty()){
|
if (arithmetic_check && !lhs_sca_ty->is_pointer_ty() && !rhs_sca_ty->is_pointer_ty()) {
|
||||||
ir::type *ret_sca_ty = computation_type(lhs_sca_ty, rhs_sca_ty);
|
ir::type *ret_sca_ty = computation_type(lhs_sca_ty, rhs_sca_ty, div_or_mod);
|
||||||
lhs = dispatch::cast(lhs, ret_sca_ty, builder);
|
lhs = dispatch::cast(lhs, ret_sca_ty, builder);
|
||||||
rhs = dispatch::cast(rhs, ret_sca_ty, builder);
|
rhs = dispatch::cast(rhs, ret_sca_ty, builder);
|
||||||
}
|
}
|
||||||
@@ -140,7 +147,7 @@ ir::value *dispatch::mul(ir::value *input, ir::value *other, ir::builder *builde
|
|||||||
}
|
}
|
||||||
|
|
||||||
ir::value *dispatch::truediv(ir::value *input, ir::value *other, ir::builder *builder) {
|
ir::value *dispatch::truediv(ir::value *input, ir::value *other, ir::builder *builder) {
|
||||||
binary_op_type_checking(input, other, builder, false, false, false);
|
binary_op_type_checking(input, other, builder, false, false, true, DivOrMod::YES);
|
||||||
ir::type *input_scalar_ty = input->get_type()->get_scalar_ty();
|
ir::type *input_scalar_ty = input->get_type()->get_scalar_ty();
|
||||||
ir::type *other_scalar_ty = other->get_type()->get_scalar_ty();
|
ir::type *other_scalar_ty = other->get_type()->get_scalar_ty();
|
||||||
// float / int
|
// float / int
|
||||||
@@ -168,7 +175,7 @@ ir::value *dispatch::truediv(ir::value *input, ir::value *other, ir::builder *bu
|
|||||||
}
|
}
|
||||||
|
|
||||||
ir::value *dispatch::floordiv(ir::value *input, ir::value *other, ir::builder *builder){
|
ir::value *dispatch::floordiv(ir::value *input, ir::value *other, ir::builder *builder){
|
||||||
binary_op_type_checking(input, other, builder, false, false, false);
|
binary_op_type_checking(input, other, builder, false, false, true, DivOrMod::YES);
|
||||||
ir::type *input_scalar_ty = input->get_type()->get_scalar_ty();
|
ir::type *input_scalar_ty = input->get_type()->get_scalar_ty();
|
||||||
ir::type *other_scalar_ty = other->get_type()->get_scalar_ty();
|
ir::type *other_scalar_ty = other->get_type()->get_scalar_ty();
|
||||||
if(input_scalar_ty->is_integer_ty() && other_scalar_ty->is_integer_ty()){
|
if(input_scalar_ty->is_integer_ty() && other_scalar_ty->is_integer_ty()){
|
||||||
@@ -181,7 +188,7 @@ ir::value *dispatch::floordiv(ir::value *input, ir::value *other, ir::builder *b
|
|||||||
}
|
}
|
||||||
|
|
||||||
ir::value *dispatch::mod(ir::value *input, ir::value *other, ir::builder *builder) {
|
ir::value *dispatch::mod(ir::value *input, ir::value *other, ir::builder *builder) {
|
||||||
binary_op_type_checking(input, other, builder);
|
binary_op_type_checking(input, other, builder, false, false, true, DivOrMod::YES);
|
||||||
ir::type *scalar_ty = input->get_type()->get_scalar_ty();
|
ir::type *scalar_ty = input->get_type()->get_scalar_ty();
|
||||||
// float % int
|
// float % int
|
||||||
if (scalar_ty->is_floating_point_ty())
|
if (scalar_ty->is_floating_point_ty())
|
||||||
@@ -193,51 +200,43 @@ ir::value *dispatch::mod(ir::value *input, ir::value *other, ir::builder *builde
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
void bitwise_op_type_checking(ir::value *&input, ir::value *&other, ir::builder *builder, bool force_lhs_type = false){
|
void bitwise_op_type_checking(ir::value *&input, ir::value *&other, ir::builder *builder) {
|
||||||
binary_op_type_checking(input, other, builder, false, false, false);
|
binary_op_type_checking(input, other, builder, false, false, false);
|
||||||
ir::type *input_sca_ty = input->get_type()->get_scalar_ty();
|
ir::type *input_sca_ty = input->get_type()->get_scalar_ty();
|
||||||
ir::type *other_sca_ty = other->get_type()->get_scalar_ty();
|
ir::type *other_sca_ty = other->get_type()->get_scalar_ty();
|
||||||
if(!input_sca_ty->is_integer_ty() || !other_sca_ty->is_integer_ty())
|
if(!input_sca_ty->is_integer_ty() || !other_sca_ty->is_integer_ty())
|
||||||
throw_incompatible_types(input_sca_ty, other_sca_ty);
|
throw_incompatible_types(input_sca_ty, other_sca_ty);
|
||||||
// for some reason pytorch assigns the result of binary op to have the type of the lhs...
|
|
||||||
if(force_lhs_type){
|
|
||||||
if(input_sca_ty->get_integer_bitwidth() != other_sca_ty->get_integer_bitwidth())
|
|
||||||
other = dispatch::cast(other, input_sca_ty, builder);
|
|
||||||
}
|
|
||||||
else{
|
|
||||||
if(input_sca_ty->get_integer_bitwidth() < other_sca_ty->get_integer_bitwidth())
|
if(input_sca_ty->get_integer_bitwidth() < other_sca_ty->get_integer_bitwidth())
|
||||||
input = dispatch::cast(input, other_sca_ty, builder);
|
input = dispatch::cast(input, other_sca_ty, builder);
|
||||||
else if(other_sca_ty->get_integer_bitwidth() < input_sca_ty->get_integer_bitwidth())
|
else if(other_sca_ty->get_integer_bitwidth() < input_sca_ty->get_integer_bitwidth())
|
||||||
other = dispatch::cast(other, input_sca_ty, builder);
|
other = dispatch::cast(other, input_sca_ty, builder);
|
||||||
}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ir::value *dispatch::and_(ir::value *input, ir::value *other, ir::builder *builder) {
|
ir::value *dispatch::and_(ir::value *input, ir::value *other, ir::builder *builder) {
|
||||||
bitwise_op_type_checking(input, other, builder, true);
|
bitwise_op_type_checking(input, other, builder);
|
||||||
return builder->create_and(input, other);
|
return builder->create_and(input, other);
|
||||||
}
|
}
|
||||||
|
|
||||||
ir::value *dispatch::or_(ir::value *input, ir::value *other, ir::builder *builder) {
|
ir::value *dispatch::or_(ir::value *input, ir::value *other, ir::builder *builder) {
|
||||||
bitwise_op_type_checking(input, other, builder, true);
|
bitwise_op_type_checking(input, other, builder);
|
||||||
return builder->create_or(input, other);
|
return builder->create_or(input, other);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
ir::value *dispatch::xor_(ir::value *input, ir::value *other, ir::builder *builder) {
|
ir::value *dispatch::xor_(ir::value *input, ir::value *other, ir::builder *builder) {
|
||||||
bitwise_op_type_checking(input, other, builder, true);
|
bitwise_op_type_checking(input, other, builder);
|
||||||
return builder->create_xor(input, other);
|
return builder->create_xor(input, other);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
ir::value *dispatch::lshr(ir::value *input, ir::value *other, ir::builder *builder) {
|
ir::value *dispatch::lshr(ir::value *input, ir::value *other, ir::builder *builder) {
|
||||||
bitwise_op_type_checking(input, other, builder, false);
|
bitwise_op_type_checking(input, other, builder);
|
||||||
return builder->create_lshr(input, other);
|
return builder->create_lshr(input, other);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
ir::value *dispatch::shl(ir::value *input, ir::value *other, ir::builder *builder) {
|
ir::value *dispatch::shl(ir::value *input, ir::value *other, ir::builder *builder) {
|
||||||
bitwise_op_type_checking(input, other, builder, false);
|
bitwise_op_type_checking(input, other, builder);
|
||||||
return builder->create_shl(input, other);
|
return builder->create_shl(input, other);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -69,7 +69,7 @@ def _test_unary(dtype_x, expr, torch_expr=None, device='cuda'):
|
|||||||
triton.testing.assert_almost_equal(z_ref, z_tri)
|
triton.testing.assert_almost_equal(z_ref, z_tri)
|
||||||
|
|
||||||
|
|
||||||
def _test_binary(dtype_x, dtype_y, expr, mode_x='real', mode_y='real', device='cuda'):
|
def _test_binary(dtype_x, dtype_y, expr, torch_expr=None, mode_x='real', mode_y='real', device='cuda'):
|
||||||
SIZE = 128
|
SIZE = 128
|
||||||
# define the kernel / launch-grid
|
# define the kernel / launch-grid
|
||||||
@triton.jit
|
@triton.jit
|
||||||
@@ -82,12 +82,12 @@ def _test_binary(dtype_x, dtype_y, expr, mode_x='real', mode_y='real', device='c
|
|||||||
|
|
||||||
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': expr})
|
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': expr})
|
||||||
# inputs
|
# inputs
|
||||||
x = triton.testing.random(SIZE, dtype=cvt[dtype_x], device=device)
|
x = triton.testing.random(SIZE, dtype=cvt[dtype_x], device=device, seed=17)
|
||||||
y = triton.testing.random(SIZE, dtype=cvt[dtype_y], device=device)
|
y = triton.testing.random(SIZE, dtype=cvt[dtype_y], device=device, seed=144)
|
||||||
if mode_x == 'nan': x[:] = float('nan')
|
if mode_x == 'nan': x[:] = float('nan')
|
||||||
if mode_y == 'nan': y[:] = float('nan')
|
if mode_y == 'nan': y[:] = float('nan')
|
||||||
# reference result
|
# reference result
|
||||||
z_ref = eval(expr)
|
z_ref = eval(expr if torch_expr is None else torch_expr)
|
||||||
# triton result
|
# triton result
|
||||||
z_tri = torch.empty(SIZE, dtype=z_ref.dtype, device=device)
|
z_tri = torch.empty(SIZE, dtype=z_ref.dtype, device=device)
|
||||||
kernel[(1, )](z_tri, x, y, SIZE=SIZE, num_warps=4)
|
kernel[(1, )](z_tri, x, y, SIZE=SIZE, num_warps=4)
|
||||||
@@ -95,17 +95,56 @@ def _test_binary(dtype_x, dtype_y, expr, mode_x='real', mode_y='real', device='c
|
|||||||
triton.testing.assert_almost_equal(z_ref, z_tri, err_msg=expr)
|
triton.testing.assert_almost_equal(z_ref, z_tri, err_msg=expr)
|
||||||
|
|
||||||
|
|
||||||
|
def _fake_fmod(x, y):
|
||||||
|
"""
|
||||||
|
Triton % (for both integers and floats) has the same semantics as torch
|
||||||
|
fmod, but torch fmod doesn't work on integers until torch 1.8.
|
||||||
|
`_fake_fmod` gives the same semantics but works on all versions of torch.
|
||||||
|
"""
|
||||||
|
z = torch.remainder(x, y)
|
||||||
|
return torch.where((torch.sign(x) != torch.sign(y)) & (z != 0), z - y, z)
|
||||||
|
|
||||||
|
|
||||||
|
def _mod_operation_ill_conditioned(dtype_x, dtype_y) -> bool:
|
||||||
|
# The result of x % y is ill-conditioned if x % y is much smaller than x.
|
||||||
|
# pytorch/CUDA has slightly different (probably better) rounding on
|
||||||
|
# remainders than stock LLVM. We currently don't expect to match it
|
||||||
|
# bit-for-bit.
|
||||||
|
return (dtype_x, dtype_y) in [
|
||||||
|
('int32', 'float16'),
|
||||||
|
('int32', 'float32'),
|
||||||
|
('int64', 'float16'),
|
||||||
|
('int64', 'float32'),
|
||||||
|
('int64', 'float64'),
|
||||||
|
]
|
||||||
|
|
||||||
# ---------------
|
# ---------------
|
||||||
# test binary ops
|
# test binary ops
|
||||||
# ---------------
|
# ---------------
|
||||||
@pytest.mark.parametrize("dtype_x, dtype_y, expr", [
|
@pytest.mark.parametrize("dtype_x, dtype_y, op", [
|
||||||
(dtype_x, dtype_y, f' x {op} y') \
|
(dtype_x, dtype_y, op)
|
||||||
for op in ['+', '-', '*', '/', '%'] \
|
for op in ['+', '-', '*', '/', '%']
|
||||||
for dtype_x in dtypes \
|
for dtype_x in dtypes
|
||||||
for dtype_y in dtypes
|
for dtype_y in dtypes
|
||||||
])
|
])
|
||||||
def test_bin_op(dtype_x, dtype_y, expr, device='cuda'):
|
def test_bin_op(dtype_x, dtype_y, op, device='cuda'):
|
||||||
_test_binary(dtype_x, dtype_y, expr, device=device)
|
expr = f' x {op} y'
|
||||||
|
if op == '%' and dtype_x in int_dtypes and dtype_y in int_dtypes:
|
||||||
|
# LLVM has 'torch.fmod', not 'torch.remainder' semantics on integer remainders.
|
||||||
|
torch_expr = '_fake_fmod(x, y)'
|
||||||
|
elif op in ('/', '%') and dtype_x in ('int16', 'float16') and dtype_y in ('int16', 'float16'):
|
||||||
|
# Triton promotes 16-bit floating-point / and % to 32-bit because there
|
||||||
|
# are no native div or FRem operations on float16. Since we have to
|
||||||
|
# convert anyway, we may as well take the accuracy bump.
|
||||||
|
torch_expr = f'x.to(torch.float32) {op} y.to(torch.float32)'
|
||||||
|
else:
|
||||||
|
torch_expr = None
|
||||||
|
if op == '%' and _mod_operation_ill_conditioned(dtype_x, dtype_y):
|
||||||
|
with pytest.raises(AssertionError, match='Arrays are not almost equal'):
|
||||||
|
_test_binary(dtype_x, dtype_y, expr, torch_expr=torch_expr, device=device)
|
||||||
|
else:
|
||||||
|
_test_binary(dtype_x, dtype_y, expr, torch_expr=torch_expr, device=device)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------
|
# ---------------
|
||||||
|
@@ -482,7 +482,8 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
# The ast library added visit_Constant and deprecated some other
|
# The ast library added visit_Constant and deprecated some other
|
||||||
# methods but we can't move to that without breaking Python 3.6 and 3.7.
|
# methods but we can't move to that without breaking Python 3.6 and 3.7.
|
||||||
warnings.simplefilter("ignore", DeprecationWarning)
|
warnings.simplefilter("ignore", DeprecationWarning) # python 3.9
|
||||||
|
warnings.simplefilter("ignore", PendingDeprecationWarning) # python 3.8
|
||||||
return super().visit(node)
|
return super().visit(node)
|
||||||
|
|
||||||
def generic_visit(self, node):
|
def generic_visit(self, node):
|
||||||
@@ -905,7 +906,7 @@ class JITFunction:
|
|||||||
node = generator.last_node
|
node = generator.last_node
|
||||||
if node is None or isinstance(e, (NotImplementedError, CompilationError)):
|
if node is None or isinstance(e, (NotImplementedError, CompilationError)):
|
||||||
raise e
|
raise e
|
||||||
raise CompilationError(self.src, node, e)
|
raise CompilationError(self.src, node) from e
|
||||||
|
|
||||||
# - when `.src` attribute is set, cache path needs
|
# - when `.src` attribute is set, cache path needs
|
||||||
# to be reinitialized
|
# to be reinitialized
|
||||||
|
@@ -89,14 +89,21 @@ def assert_allclose(x, y, tol=1e-2):
|
|||||||
assert allclose(x, y, tol)
|
assert allclose(x, y, tol)
|
||||||
|
|
||||||
|
|
||||||
def random(shape, dtype, device):
|
def random(shape, dtype, device, seed=0):
|
||||||
torch.manual_seed(0)
|
"""
|
||||||
|
Override the seed in tests if you're calling this function twice and don't
|
||||||
|
want the same result for both calls.
|
||||||
|
"""
|
||||||
|
torch.manual_seed(seed)
|
||||||
if isinstance(shape, int):
|
if isinstance(shape, int):
|
||||||
shape = (shape, )
|
shape = (shape, )
|
||||||
if dtype == torch.bool:
|
if dtype == torch.bool:
|
||||||
return torch.randint(0, 2, shape, dtype=dtype, device=device)
|
return torch.randint(0, 2, shape, dtype=dtype, device=device)
|
||||||
if dtype in [torch.int8, torch.int16, torch.int32, torch.int64]:
|
if dtype in [torch.int8, torch.int16, torch.int32, torch.int64]:
|
||||||
return torch.randint(1, 32, shape, dtype=dtype, device=device)
|
iinfo = torch.iinfo(dtype)
|
||||||
|
x = torch.randint(iinfo.min, iinfo.max, shape, dtype=dtype, device=device)
|
||||||
|
x[x == 0] = 1 # Hack. Never return zero so tests of division don't error out.
|
||||||
|
return x
|
||||||
if dtype in [torch.float16, torch.float32, torch.float64]:
|
if dtype in [torch.float16, torch.float32, torch.float64]:
|
||||||
return torch.normal(0, 1, shape, dtype=dtype, device=device)
|
return torch.normal(0, 1, shape, dtype=dtype, device=device)
|
||||||
raise RuntimeError(f'Unknown dtype {dtype}')
|
raise RuntimeError(f'Unknown dtype {dtype}')
|
||||||
|
Reference in New Issue
Block a user