[FRONTEND] signed-integer math fixes and testing (#395)
- Promote 16-bit floating-point `/` and `%` to 32-bit; we have to anyway. - Do not force result of integer binary operations to be the LHS type. There used to be a bug in pytorch that did this, which Triton matched, but that bug is fixed now. - When testing signed integer operations, use random numbers from the full range of the type. - Add an optional `seed` argument to `triton.testing.random` so binary operations are not tested with both sides equal when the LHS and RHS have the same type. - Fix a bad `CompilationError` invocation. - Fix a warning suppression that causes tests to fail if you run them with `-W error` on python 3.8.
This commit is contained in:
committed by
GitHub
parent
4a8953efa3
commit
5cdb948c05
@@ -33,21 +33,28 @@ ir::type *integer_promote(ir::type* a_ty, ir::type* b_ty){
|
||||
return a_rank > b_rank ? a_ty : b_ty;
|
||||
}
|
||||
|
||||
ir::type *computation_type(ir::type* a_ty, ir::type* b_ty){
|
||||
enum class DivOrMod { NO, YES };
|
||||
|
||||
ir::type *computation_type(ir::type* a_ty, ir::type* b_ty, DivOrMod div_or_mod) {
|
||||
context &ctx = a_ty->get_context();
|
||||
// 1) if one operand is double, the other is implicitly
|
||||
// converted to double
|
||||
if(a_ty->is_fp64_ty() || b_ty->is_fp64_ty())
|
||||
if (a_ty->is_fp64_ty() || b_ty->is_fp64_ty())
|
||||
return type::get_fp64_ty(ctx);
|
||||
// 2) if one operand is float, the other is implicitly
|
||||
// converted to float
|
||||
if(a_ty->is_fp32_ty() || b_ty->is_fp32_ty())
|
||||
if (a_ty->is_fp32_ty() || b_ty->is_fp32_ty())
|
||||
return type::get_fp32_ty(ctx);
|
||||
// 3 ) if one operand is half, the other is implicitly
|
||||
// converted to half
|
||||
if(a_ty->is_fp16_ty() || b_ty->is_fp16_ty())
|
||||
return type::get_fp16_ty(ctx);
|
||||
if(!a_ty->is_integer_ty() || !b_ty->is_integer_ty())
|
||||
// 3 ) if one operand is half, the other is implicitly converted to half
|
||||
// unless we're doing / or %, which do not exist natively in PTX for fp16.
|
||||
if (a_ty->is_fp16_ty() || b_ty->is_fp16_ty()) {
|
||||
if (div_or_mod == DivOrMod::YES) {
|
||||
return type::get_fp32_ty(ctx);
|
||||
} else {
|
||||
return type::get_fp16_ty(ctx);
|
||||
}
|
||||
}
|
||||
if (!a_ty->is_integer_ty() || !b_ty->is_integer_ty())
|
||||
throw_unreachable("augment_types");
|
||||
// 4 ) both operands are integer and undergo
|
||||
// integer promotion
|
||||
@@ -78,7 +85,7 @@ void check_ptr_type(ir::type* type_a, ir::type* type_b, bool allow_ptr_a){
|
||||
|
||||
void binary_op_type_checking(ir::value*& lhs, ir::value*& rhs, ir::builder* builder,
|
||||
bool allow_lhs_ptr = false, bool allow_rhs_ptr = false,
|
||||
bool arithmetic_check = true){
|
||||
bool arithmetic_check = true, DivOrMod div_or_mod = DivOrMod::NO) {
|
||||
// implicit broadcasting
|
||||
std::tie(lhs, rhs) = dispatch::broadcast(lhs, rhs, builder);
|
||||
// implicit typecasting
|
||||
@@ -86,8 +93,8 @@ void binary_op_type_checking(ir::value*& lhs, ir::value*& rhs, ir::builder* buil
|
||||
ir::type *rhs_sca_ty = rhs->get_type()->get_scalar_ty();
|
||||
check_ptr_type(lhs_sca_ty, rhs_sca_ty, allow_lhs_ptr);
|
||||
check_ptr_type(rhs_sca_ty, lhs_sca_ty, allow_rhs_ptr);
|
||||
if(arithmetic_check && !lhs_sca_ty->is_pointer_ty() && !rhs_sca_ty->is_pointer_ty()){
|
||||
ir::type *ret_sca_ty = computation_type(lhs_sca_ty, rhs_sca_ty);
|
||||
if (arithmetic_check && !lhs_sca_ty->is_pointer_ty() && !rhs_sca_ty->is_pointer_ty()) {
|
||||
ir::type *ret_sca_ty = computation_type(lhs_sca_ty, rhs_sca_ty, div_or_mod);
|
||||
lhs = dispatch::cast(lhs, ret_sca_ty, builder);
|
||||
rhs = dispatch::cast(rhs, ret_sca_ty, builder);
|
||||
}
|
||||
@@ -140,7 +147,7 @@ ir::value *dispatch::mul(ir::value *input, ir::value *other, ir::builder *builde
|
||||
}
|
||||
|
||||
ir::value *dispatch::truediv(ir::value *input, ir::value *other, ir::builder *builder) {
|
||||
binary_op_type_checking(input, other, builder, false, false, false);
|
||||
binary_op_type_checking(input, other, builder, false, false, true, DivOrMod::YES);
|
||||
ir::type *input_scalar_ty = input->get_type()->get_scalar_ty();
|
||||
ir::type *other_scalar_ty = other->get_type()->get_scalar_ty();
|
||||
// float / int
|
||||
@@ -168,7 +175,7 @@ ir::value *dispatch::truediv(ir::value *input, ir::value *other, ir::builder *bu
|
||||
}
|
||||
|
||||
ir::value *dispatch::floordiv(ir::value *input, ir::value *other, ir::builder *builder){
|
||||
binary_op_type_checking(input, other, builder, false, false, false);
|
||||
binary_op_type_checking(input, other, builder, false, false, true, DivOrMod::YES);
|
||||
ir::type *input_scalar_ty = input->get_type()->get_scalar_ty();
|
||||
ir::type *other_scalar_ty = other->get_type()->get_scalar_ty();
|
||||
if(input_scalar_ty->is_integer_ty() && other_scalar_ty->is_integer_ty()){
|
||||
@@ -181,7 +188,7 @@ ir::value *dispatch::floordiv(ir::value *input, ir::value *other, ir::builder *b
|
||||
}
|
||||
|
||||
ir::value *dispatch::mod(ir::value *input, ir::value *other, ir::builder *builder) {
|
||||
binary_op_type_checking(input, other, builder);
|
||||
binary_op_type_checking(input, other, builder, false, false, true, DivOrMod::YES);
|
||||
ir::type *scalar_ty = input->get_type()->get_scalar_ty();
|
||||
// float % int
|
||||
if (scalar_ty->is_floating_point_ty())
|
||||
@@ -193,51 +200,43 @@ ir::value *dispatch::mod(ir::value *input, ir::value *other, ir::builder *builde
|
||||
}
|
||||
|
||||
|
||||
void bitwise_op_type_checking(ir::value *&input, ir::value *&other, ir::builder *builder, bool force_lhs_type = false){
|
||||
void bitwise_op_type_checking(ir::value *&input, ir::value *&other, ir::builder *builder) {
|
||||
binary_op_type_checking(input, other, builder, false, false, false);
|
||||
ir::type *input_sca_ty = input->get_type()->get_scalar_ty();
|
||||
ir::type *other_sca_ty = other->get_type()->get_scalar_ty();
|
||||
if(!input_sca_ty->is_integer_ty() || !other_sca_ty->is_integer_ty())
|
||||
throw_incompatible_types(input_sca_ty, other_sca_ty);
|
||||
// for some reason pytorch assigns the result of binary op to have the type of the lhs...
|
||||
if(force_lhs_type){
|
||||
if(input_sca_ty->get_integer_bitwidth() != other_sca_ty->get_integer_bitwidth())
|
||||
other = dispatch::cast(other, input_sca_ty, builder);
|
||||
}
|
||||
else{
|
||||
if(input_sca_ty->get_integer_bitwidth() < other_sca_ty->get_integer_bitwidth())
|
||||
input = dispatch::cast(input, other_sca_ty, builder);
|
||||
else if(other_sca_ty->get_integer_bitwidth() < input_sca_ty->get_integer_bitwidth())
|
||||
other = dispatch::cast(other, input_sca_ty, builder);
|
||||
}
|
||||
|
||||
if(input_sca_ty->get_integer_bitwidth() < other_sca_ty->get_integer_bitwidth())
|
||||
input = dispatch::cast(input, other_sca_ty, builder);
|
||||
else if(other_sca_ty->get_integer_bitwidth() < input_sca_ty->get_integer_bitwidth())
|
||||
other = dispatch::cast(other, input_sca_ty, builder);
|
||||
}
|
||||
|
||||
ir::value *dispatch::and_(ir::value *input, ir::value *other, ir::builder *builder) {
|
||||
bitwise_op_type_checking(input, other, builder, true);
|
||||
bitwise_op_type_checking(input, other, builder);
|
||||
return builder->create_and(input, other);
|
||||
}
|
||||
|
||||
ir::value *dispatch::or_(ir::value *input, ir::value *other, ir::builder *builder) {
|
||||
bitwise_op_type_checking(input, other, builder, true);
|
||||
bitwise_op_type_checking(input, other, builder);
|
||||
return builder->create_or(input, other);
|
||||
}
|
||||
|
||||
|
||||
ir::value *dispatch::xor_(ir::value *input, ir::value *other, ir::builder *builder) {
|
||||
bitwise_op_type_checking(input, other, builder, true);
|
||||
bitwise_op_type_checking(input, other, builder);
|
||||
return builder->create_xor(input, other);
|
||||
}
|
||||
|
||||
|
||||
ir::value *dispatch::lshr(ir::value *input, ir::value *other, ir::builder *builder) {
|
||||
bitwise_op_type_checking(input, other, builder, false);
|
||||
bitwise_op_type_checking(input, other, builder);
|
||||
return builder->create_lshr(input, other);
|
||||
}
|
||||
|
||||
|
||||
ir::value *dispatch::shl(ir::value *input, ir::value *other, ir::builder *builder) {
|
||||
bitwise_op_type_checking(input, other, builder, false);
|
||||
bitwise_op_type_checking(input, other, builder);
|
||||
return builder->create_shl(input, other);
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user