diff --git a/lib/ir/dispatch.cc b/lib/ir/dispatch.cc index 477f4dce0..c616b2fd4 100644 --- a/lib/ir/dispatch.cc +++ b/lib/ir/dispatch.cc @@ -33,21 +33,28 @@ ir::type *integer_promote(ir::type* a_ty, ir::type* b_ty){ return a_rank > b_rank ? a_ty : b_ty; } -ir::type *computation_type(ir::type* a_ty, ir::type* b_ty){ +enum class DivOrMod { NO, YES }; + +ir::type *computation_type(ir::type* a_ty, ir::type* b_ty, DivOrMod div_or_mod) { context &ctx = a_ty->get_context(); // 1) if one operand is double, the other is implicitly // converted to double - if(a_ty->is_fp64_ty() || b_ty->is_fp64_ty()) + if (a_ty->is_fp64_ty() || b_ty->is_fp64_ty()) return type::get_fp64_ty(ctx); // 2) if one operand is float, the other is implicitly // converted to float - if(a_ty->is_fp32_ty() || b_ty->is_fp32_ty()) + if (a_ty->is_fp32_ty() || b_ty->is_fp32_ty()) return type::get_fp32_ty(ctx); - // 3 ) if one operand is half, the other is implicitly - // converted to half - if(a_ty->is_fp16_ty() || b_ty->is_fp16_ty()) - return type::get_fp16_ty(ctx); - if(!a_ty->is_integer_ty() || !b_ty->is_integer_ty()) + // 3 ) if one operand is half, the other is implicitly converted to half + // unless we're doing / or %, which do not exist natively in PTX for fp16. + if (a_ty->is_fp16_ty() || b_ty->is_fp16_ty()) { + if (div_or_mod == DivOrMod::YES) { + return type::get_fp32_ty(ctx); + } else { + return type::get_fp16_ty(ctx); + } + } + if (!a_ty->is_integer_ty() || !b_ty->is_integer_ty()) throw_unreachable("augment_types"); // 4 ) both operands are integer and undergo // integer promotion @@ -78,7 +85,7 @@ void check_ptr_type(ir::type* type_a, ir::type* type_b, bool allow_ptr_a){ void binary_op_type_checking(ir::value*& lhs, ir::value*& rhs, ir::builder* builder, bool allow_lhs_ptr = false, bool allow_rhs_ptr = false, - bool arithmetic_check = true){ + bool arithmetic_check = true, DivOrMod div_or_mod = DivOrMod::NO) { // implicit broadcasting std::tie(lhs, rhs) = dispatch::broadcast(lhs, rhs, builder); // implicit typecasting @@ -86,8 +93,8 @@ void binary_op_type_checking(ir::value*& lhs, ir::value*& rhs, ir::builder* buil ir::type *rhs_sca_ty = rhs->get_type()->get_scalar_ty(); check_ptr_type(lhs_sca_ty, rhs_sca_ty, allow_lhs_ptr); check_ptr_type(rhs_sca_ty, lhs_sca_ty, allow_rhs_ptr); - if(arithmetic_check && !lhs_sca_ty->is_pointer_ty() && !rhs_sca_ty->is_pointer_ty()){ - ir::type *ret_sca_ty = computation_type(lhs_sca_ty, rhs_sca_ty); + if (arithmetic_check && !lhs_sca_ty->is_pointer_ty() && !rhs_sca_ty->is_pointer_ty()) { + ir::type *ret_sca_ty = computation_type(lhs_sca_ty, rhs_sca_ty, div_or_mod); lhs = dispatch::cast(lhs, ret_sca_ty, builder); rhs = dispatch::cast(rhs, ret_sca_ty, builder); } @@ -140,7 +147,7 @@ ir::value *dispatch::mul(ir::value *input, ir::value *other, ir::builder *builde } ir::value *dispatch::truediv(ir::value *input, ir::value *other, ir::builder *builder) { - binary_op_type_checking(input, other, builder, false, false, false); + binary_op_type_checking(input, other, builder, false, false, true, DivOrMod::YES); ir::type *input_scalar_ty = input->get_type()->get_scalar_ty(); ir::type *other_scalar_ty = other->get_type()->get_scalar_ty(); // float / int @@ -168,7 +175,7 @@ ir::value *dispatch::truediv(ir::value *input, ir::value *other, ir::builder *bu } ir::value *dispatch::floordiv(ir::value *input, ir::value *other, ir::builder *builder){ - binary_op_type_checking(input, other, builder, false, false, false); + binary_op_type_checking(input, other, builder, false, false, true, DivOrMod::YES); ir::type *input_scalar_ty = input->get_type()->get_scalar_ty(); ir::type *other_scalar_ty = other->get_type()->get_scalar_ty(); if(input_scalar_ty->is_integer_ty() && other_scalar_ty->is_integer_ty()){ @@ -181,7 +188,7 @@ ir::value *dispatch::floordiv(ir::value *input, ir::value *other, ir::builder *b } ir::value *dispatch::mod(ir::value *input, ir::value *other, ir::builder *builder) { - binary_op_type_checking(input, other, builder); + binary_op_type_checking(input, other, builder, false, false, true, DivOrMod::YES); ir::type *scalar_ty = input->get_type()->get_scalar_ty(); // float % int if (scalar_ty->is_floating_point_ty()) @@ -193,51 +200,43 @@ ir::value *dispatch::mod(ir::value *input, ir::value *other, ir::builder *builde } -void bitwise_op_type_checking(ir::value *&input, ir::value *&other, ir::builder *builder, bool force_lhs_type = false){ +void bitwise_op_type_checking(ir::value *&input, ir::value *&other, ir::builder *builder) { binary_op_type_checking(input, other, builder, false, false, false); ir::type *input_sca_ty = input->get_type()->get_scalar_ty(); ir::type *other_sca_ty = other->get_type()->get_scalar_ty(); if(!input_sca_ty->is_integer_ty() || !other_sca_ty->is_integer_ty()) throw_incompatible_types(input_sca_ty, other_sca_ty); - // for some reason pytorch assigns the result of binary op to have the type of the lhs... - if(force_lhs_type){ - if(input_sca_ty->get_integer_bitwidth() != other_sca_ty->get_integer_bitwidth()) - other = dispatch::cast(other, input_sca_ty, builder); - } - else{ - if(input_sca_ty->get_integer_bitwidth() < other_sca_ty->get_integer_bitwidth()) - input = dispatch::cast(input, other_sca_ty, builder); - else if(other_sca_ty->get_integer_bitwidth() < input_sca_ty->get_integer_bitwidth()) - other = dispatch::cast(other, input_sca_ty, builder); - } - + if(input_sca_ty->get_integer_bitwidth() < other_sca_ty->get_integer_bitwidth()) + input = dispatch::cast(input, other_sca_ty, builder); + else if(other_sca_ty->get_integer_bitwidth() < input_sca_ty->get_integer_bitwidth()) + other = dispatch::cast(other, input_sca_ty, builder); } ir::value *dispatch::and_(ir::value *input, ir::value *other, ir::builder *builder) { - bitwise_op_type_checking(input, other, builder, true); + bitwise_op_type_checking(input, other, builder); return builder->create_and(input, other); } ir::value *dispatch::or_(ir::value *input, ir::value *other, ir::builder *builder) { - bitwise_op_type_checking(input, other, builder, true); + bitwise_op_type_checking(input, other, builder); return builder->create_or(input, other); } ir::value *dispatch::xor_(ir::value *input, ir::value *other, ir::builder *builder) { - bitwise_op_type_checking(input, other, builder, true); + bitwise_op_type_checking(input, other, builder); return builder->create_xor(input, other); } ir::value *dispatch::lshr(ir::value *input, ir::value *other, ir::builder *builder) { - bitwise_op_type_checking(input, other, builder, false); + bitwise_op_type_checking(input, other, builder); return builder->create_lshr(input, other); } ir::value *dispatch::shl(ir::value *input, ir::value *other, ir::builder *builder) { - bitwise_op_type_checking(input, other, builder, false); + bitwise_op_type_checking(input, other, builder); return builder->create_shl(input, other); } diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index e85c399b3..785ca49ac 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -69,7 +69,7 @@ def _test_unary(dtype_x, expr, torch_expr=None, device='cuda'): triton.testing.assert_almost_equal(z_ref, z_tri) -def _test_binary(dtype_x, dtype_y, expr, mode_x='real', mode_y='real', device='cuda'): +def _test_binary(dtype_x, dtype_y, expr, torch_expr=None, mode_x='real', mode_y='real', device='cuda'): SIZE = 128 # define the kernel / launch-grid @triton.jit @@ -82,12 +82,12 @@ def _test_binary(dtype_x, dtype_y, expr, mode_x='real', mode_y='real', device='c kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': expr}) # inputs - x = triton.testing.random(SIZE, dtype=cvt[dtype_x], device=device) - y = triton.testing.random(SIZE, dtype=cvt[dtype_y], device=device) + x = triton.testing.random(SIZE, dtype=cvt[dtype_x], device=device, seed=17) + y = triton.testing.random(SIZE, dtype=cvt[dtype_y], device=device, seed=144) if mode_x == 'nan': x[:] = float('nan') if mode_y == 'nan': y[:] = float('nan') # reference result - z_ref = eval(expr) + z_ref = eval(expr if torch_expr is None else torch_expr) # triton result z_tri = torch.empty(SIZE, dtype=z_ref.dtype, device=device) kernel[(1, )](z_tri, x, y, SIZE=SIZE, num_warps=4) @@ -95,17 +95,56 @@ def _test_binary(dtype_x, dtype_y, expr, mode_x='real', mode_y='real', device='c triton.testing.assert_almost_equal(z_ref, z_tri, err_msg=expr) +def _fake_fmod(x, y): + """ + Triton % (for both integers and floats) has the same semantics as torch + fmod, but torch fmod doesn't work on integers until torch 1.8. + `_fake_fmod` gives the same semantics but works on all versions of torch. + """ + z = torch.remainder(x, y) + return torch.where((torch.sign(x) != torch.sign(y)) & (z != 0), z - y, z) + + +def _mod_operation_ill_conditioned(dtype_x, dtype_y) -> bool: + # The result of x % y is ill-conditioned if x % y is much smaller than x. + # pytorch/CUDA has slightly different (probably better) rounding on + # remainders than stock LLVM. We currently don't expect to match it + # bit-for-bit. + return (dtype_x, dtype_y) in [ + ('int32', 'float16'), + ('int32', 'float32'), + ('int64', 'float16'), + ('int64', 'float32'), + ('int64', 'float64'), + ] + # --------------- # test binary ops # --------------- -@pytest.mark.parametrize("dtype_x, dtype_y, expr", [ - (dtype_x, dtype_y, f' x {op} y') \ - for op in ['+', '-', '*', '/', '%'] \ - for dtype_x in dtypes \ +@pytest.mark.parametrize("dtype_x, dtype_y, op", [ + (dtype_x, dtype_y, op) + for op in ['+', '-', '*', '/', '%'] + for dtype_x in dtypes for dtype_y in dtypes ]) -def test_bin_op(dtype_x, dtype_y, expr, device='cuda'): - _test_binary(dtype_x, dtype_y, expr, device=device) +def test_bin_op(dtype_x, dtype_y, op, device='cuda'): + expr = f' x {op} y' + if op == '%' and dtype_x in int_dtypes and dtype_y in int_dtypes: + # LLVM has 'torch.fmod', not 'torch.remainder' semantics on integer remainders. + torch_expr = '_fake_fmod(x, y)' + elif op in ('/', '%') and dtype_x in ('int16', 'float16') and dtype_y in ('int16', 'float16'): + # Triton promotes 16-bit floating-point / and % to 32-bit because there + # are no native div or FRem operations on float16. Since we have to + # convert anyway, we may as well take the accuracy bump. + torch_expr = f'x.to(torch.float32) {op} y.to(torch.float32)' + else: + torch_expr = None + if op == '%' and _mod_operation_ill_conditioned(dtype_x, dtype_y): + with pytest.raises(AssertionError, match='Arrays are not almost equal'): + _test_binary(dtype_x, dtype_y, expr, torch_expr=torch_expr, device=device) + else: + _test_binary(dtype_x, dtype_y, expr, torch_expr=torch_expr, device=device) + # --------------- diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 688508265..e1091eff7 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -482,7 +482,8 @@ class CodeGenerator(ast.NodeVisitor): with warnings.catch_warnings(): # The ast library added visit_Constant and deprecated some other # methods but we can't move to that without breaking Python 3.6 and 3.7. - warnings.simplefilter("ignore", DeprecationWarning) + warnings.simplefilter("ignore", DeprecationWarning) # python 3.9 + warnings.simplefilter("ignore", PendingDeprecationWarning) # python 3.8 return super().visit(node) def generic_visit(self, node): @@ -905,7 +906,7 @@ class JITFunction: node = generator.last_node if node is None or isinstance(e, (NotImplementedError, CompilationError)): raise e - raise CompilationError(self.src, node, e) + raise CompilationError(self.src, node) from e # - when `.src` attribute is set, cache path needs # to be reinitialized diff --git a/python/triton/testing.py b/python/triton/testing.py index 08ad62580..051a8f378 100644 --- a/python/triton/testing.py +++ b/python/triton/testing.py @@ -89,14 +89,21 @@ def assert_allclose(x, y, tol=1e-2): assert allclose(x, y, tol) -def random(shape, dtype, device): - torch.manual_seed(0) +def random(shape, dtype, device, seed=0): + """ + Override the seed in tests if you're calling this function twice and don't + want the same result for both calls. + """ + torch.manual_seed(seed) if isinstance(shape, int): shape = (shape, ) if dtype == torch.bool: return torch.randint(0, 2, shape, dtype=dtype, device=device) if dtype in [torch.int8, torch.int16, torch.int32, torch.int64]: - return torch.randint(1, 32, shape, dtype=dtype, device=device) + iinfo = torch.iinfo(dtype) + x = torch.randint(iinfo.min, iinfo.max, shape, dtype=dtype, device=device) + x[x == 0] = 1 # Hack. Never return zero so tests of division don't error out. + return x if dtype in [torch.float16, torch.float32, torch.float64]: return torch.normal(0, 1, shape, dtype=dtype, device=device) raise RuntimeError(f'Unknown dtype {dtype}')