[LANG] Minor semantic changes (#121)

* Now using unordered instead of ordered float (fixes NaN issues)
* Bool -> int32 now converts to 1 rather than -1
* Reduce extend arguments to 32-bits if possible
This commit is contained in:
Philippe Tillet
2021-06-01 21:13:21 -04:00
committed by Philippe Tillet
parent 0274429429
commit 80c86ecf4a
5 changed files with 31 additions and 15 deletions

View File

@@ -261,6 +261,12 @@ DEFINE_FCMP_INSTR(OGT, cmp_pred_t::FCMP_OGT)
DEFINE_FCMP_INSTR(OEQ, cmp_pred_t::FCMP_OEQ)
DEFINE_FCMP_INSTR(ONE, cmp_pred_t::FCMP_ONE)
DEFINE_FCMP_INSTR(ULE, cmp_pred_t::FCMP_ULE)
DEFINE_FCMP_INSTR(ULT, cmp_pred_t::FCMP_ULT)
DEFINE_FCMP_INSTR(UGE, cmp_pred_t::FCMP_UGE)
DEFINE_FCMP_INSTR(UGT, cmp_pred_t::FCMP_UGT)
DEFINE_FCMP_INSTR(UEQ, cmp_pred_t::FCMP_UEQ)
DEFINE_FCMP_INSTR(UNE, cmp_pred_t::FCMP_UNE)
//===----------------------------------------------------------------------===//

View File

@@ -275,7 +275,7 @@ ir::value *dispatch::greater_than(ir::value *input, ir::value *other, ir::builde
ir::type *scalar_ty = input->get_type()->get_scalar_ty();
// float > float
if (scalar_ty->is_floating_point_ty())
return builder->create_fcmpOGT(input, other);
return builder->create_fcmpUGT(input, other);
// int > int
else if (scalar_ty->is_integer_ty())
return builder->create_icmpSGT(input, other);
@@ -287,7 +287,7 @@ ir::value *dispatch::greater_equal(ir::value *input, ir::value *other, ir::build
ir::type *scalar_ty = input->get_type()->get_scalar_ty();
// float >= float
if (scalar_ty->is_floating_point_ty())
return builder->create_fcmpOGE(input, other);
return builder->create_fcmpUGE(input, other);
// int >= int
else if (scalar_ty->is_integer_ty())
return builder->create_icmpSGE(input, other);
@@ -299,7 +299,7 @@ ir::value *dispatch::less_than(ir::value *input, ir::value *other, ir::builder *
ir::type *scalar_ty = input->get_type()->get_scalar_ty();
// float < float
if (scalar_ty->is_floating_point_ty())
return builder->create_fcmpOLT(input, other);
return builder->create_fcmpULT(input, other);
// int < int
else if (scalar_ty->is_integer_ty())
return builder->create_icmpSLT(input, other);
@@ -311,7 +311,7 @@ ir::value *dispatch::less_equal(ir::value *input, ir::value *other, ir::builder
ir::type *scalar_ty = input->get_type()->get_scalar_ty();
// float < float
if (scalar_ty->is_floating_point_ty())
return builder->create_fcmpOLE(input, other);
return builder->create_fcmpULE(input, other);
// int < int
else if (scalar_ty->is_integer_ty())
return builder->create_icmpSLE(input, other);
@@ -323,7 +323,7 @@ ir::value *dispatch::equal(ir::value *input, ir::value *other, ir::builder *buil
ir::type *scalar_ty = input->get_type()->get_scalar_ty();
// float == float
if (scalar_ty->is_floating_point_ty())
return builder->create_fcmpOEQ(input, other);
return builder->create_fcmpUEQ(input, other);
// int == int
else if (scalar_ty->is_integer_ty())
return builder->create_icmpEQ(input, other);
@@ -335,7 +335,7 @@ ir::value *dispatch::not_equal(ir::value *input, ir::value *other, ir::builder *
ir::type *scalar_ty = input->get_type()->get_scalar_ty();
// float == float
if (scalar_ty->is_floating_point_ty())
return builder->create_fcmpONE(input, other);
return builder->create_fcmpUNE(input, other);
// int == int
else if (scalar_ty->is_integer_ty())
return builder->create_icmpNE(input, other);
@@ -454,7 +454,7 @@ ir::value *dispatch::cast(ir::value *input, ir::type *dst_ty, ir::builder *build
// Int cast
if (src_sca_ty->is_integer_ty() && dst_sca_ty->is_integer_ty() &&
src_sca_ty->get_integer_bitwidth() != dst_sca_ty->get_integer_bitwidth())
return builder->create_int_cast(input, dst_ty, true);
return builder->create_int_cast(input, dst_ty, src_sca_ty != builder->get_int1_ty());
// Float -> Int
if (src_sca_ty->is_floating_point_ty() && dst_sca_ty->is_integer_ty()){
if(dst_sca_ty->is_bool_ty())
@@ -651,6 +651,11 @@ ir::value *dispatch::where(ir::value* condition, ir::value *x, ir::value *y, ir:
ir::value *reduce_impl(ir::value *input, unsigned int axis, ir::builder *builder, const std::string &name,
ir::reduce_inst::op_t FLOAT_OP, ir::reduce_inst::op_t INT_OP) {
ir::type *scalar_ty = input->get_type()->get_scalar_ty();
// input is extended to 32-bits if necessary
// this increases numerical accuracy and can be done pretty much for free
// on GPUs
if(scalar_ty->is_integer_ty() && scalar_ty->get_integer_bitwidth() <= 32)
input = dispatch::cast(input, type::get_int32_ty(scalar_ty->get_context()), builder);
if (scalar_ty->is_floating_point_ty())
return builder->create_reduce(input, FLOAT_OP, axis);
else if (scalar_ty->is_integer_ty())