[LANG] Minor semantic changes (#121)
* Now using unordered instead of ordered float (fixes NaN issues) * Bool -> int32 now converts to 1 rather than -1 * Reduce extend arguments to 32-bits if possible
This commit is contained in:
committed by
Philippe Tillet
parent
0274429429
commit
80c86ecf4a
@@ -119,6 +119,12 @@ public:
|
||||
value *create_fcmpOGE(value *lhs, value *rhs);
|
||||
value *create_fcmpOEQ(value *lhs, value *rhs);
|
||||
value *create_fcmpONE(value *lhs, value *rhs);
|
||||
value *create_fcmpULT(value *lhs, value *rhs);
|
||||
value *create_fcmpUGT(value *lhs, value *rhs);
|
||||
value *create_fcmpULE(value *lhs, value *rhs);
|
||||
value *create_fcmpUGE(value *lhs, value *rhs);
|
||||
value *create_fcmpUEQ(value *lhs, value *rhs);
|
||||
value *create_fcmpUNE(value *lhs, value *rhs);
|
||||
// Logical
|
||||
value *create_and(value *lhs, value *rhs);
|
||||
value *create_xor(value *lhs, value *rhs);
|
||||
|
@@ -54,7 +54,6 @@ void add_passes_to_emit_bin(ir::module &ir, driver::device *dev, int num_warps,
|
||||
codegen::generator isel(&axes, &layouts, &align, &allocation, &swizzle, target.get(), num_warps);
|
||||
// run passes
|
||||
dce.run(ir);
|
||||
//ir::print(ir, std::cout);
|
||||
peephole.run(ir);
|
||||
dce.run(ir);
|
||||
pipeline.run(ir);
|
||||
|
@@ -1554,7 +1554,7 @@ Value* generator::shared_off(const std::vector<unsigned>& shapes, const std::vec
|
||||
void generator::visit_reduce1d_inst(ir::reduce_inst* x, std::function<Value*(Value*,Value*)> do_acc, Value *neutral) {
|
||||
std::map<indices_t, Value*> partial;
|
||||
ir::value *arg = x->get_operand(0);
|
||||
Type *ty = cvt(x->get_type()->get_scalar_ty());
|
||||
Type *ret_ty = cvt(x->get_type()->get_scalar_ty());
|
||||
Value *acc = nullptr;
|
||||
|
||||
// reduce within thread
|
||||
@@ -1563,13 +1563,13 @@ void generator::visit_reduce1d_inst(ir::reduce_inst* x, std::function<Value*(Val
|
||||
acc = !acc ? val : do_acc(acc, val);
|
||||
}
|
||||
// reduce within wrap
|
||||
InlineAsm *shfl = InlineAsm::get(FunctionType::get(ty, {ty, i32_ty}, false),
|
||||
InlineAsm *shfl = InlineAsm::get(FunctionType::get(ret_ty, {ret_ty, i32_ty}, false),
|
||||
"shfl.sync.bfly.b32 $0, $1, $2, 0x1f, 0xffffffff;", "=f,f,r", false);
|
||||
for(int i = 16; i > 0; i >>= 1)
|
||||
acc = do_acc(acc, call(shfl, {acc, i32(i)}));
|
||||
// pointers
|
||||
unsigned addr_space = shmem_->getType()->getPointerAddressSpace();
|
||||
Value *base = bit_cast(shmem_, ptr_ty(ty, addr_space));
|
||||
Value *base = bit_cast(shmem_, ptr_ty(ret_ty, addr_space));
|
||||
Value* thread = tgt_->get_local_id(mod_, *builder_, 0);
|
||||
Value* warp = udiv(thread, i32(32));
|
||||
Value* lane = urem(thread, i32(32));
|
||||
@@ -1688,10 +1688,10 @@ void generator::visit_reduce_inst(ir::reduce_inst* x) {
|
||||
// neutral element
|
||||
Value *neutral;
|
||||
switch(op) {
|
||||
case ir::reduce_inst::ADD: neutral = i32(0); break;
|
||||
case ir::reduce_inst::SUB: neutral = i32(0); break;
|
||||
case ir::reduce_inst::MAX: neutral = i32(INT32_MIN); break;
|
||||
case ir::reduce_inst::MIN: neutral = i32(INT32_MAX); break;
|
||||
case ir::reduce_inst::ADD: neutral = ConstantInt::get(ty, 0); break;
|
||||
case ir::reduce_inst::SUB: neutral = ConstantInt::get(ty, 0); break;
|
||||
case ir::reduce_inst::MAX: neutral = ConstantInt::get(ty, INT32_MIN); break;
|
||||
case ir::reduce_inst::MIN: neutral = ConstantInt::get(ty, INT32_MAX); break;
|
||||
case ir::reduce_inst::FADD: neutral = ConstantFP::get(ty, 0); break;
|
||||
case ir::reduce_inst::FSUB: neutral = ConstantFP::get(ty, 0); break;
|
||||
case ir::reduce_inst::FMAX: neutral = ConstantFP::get(ty, -INFINITY); break;
|
||||
|
@@ -261,6 +261,12 @@ DEFINE_FCMP_INSTR(OGT, cmp_pred_t::FCMP_OGT)
|
||||
DEFINE_FCMP_INSTR(OEQ, cmp_pred_t::FCMP_OEQ)
|
||||
DEFINE_FCMP_INSTR(ONE, cmp_pred_t::FCMP_ONE)
|
||||
|
||||
DEFINE_FCMP_INSTR(ULE, cmp_pred_t::FCMP_ULE)
|
||||
DEFINE_FCMP_INSTR(ULT, cmp_pred_t::FCMP_ULT)
|
||||
DEFINE_FCMP_INSTR(UGE, cmp_pred_t::FCMP_UGE)
|
||||
DEFINE_FCMP_INSTR(UGT, cmp_pred_t::FCMP_UGT)
|
||||
DEFINE_FCMP_INSTR(UEQ, cmp_pred_t::FCMP_UEQ)
|
||||
DEFINE_FCMP_INSTR(UNE, cmp_pred_t::FCMP_UNE)
|
||||
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@@ -275,7 +275,7 @@ ir::value *dispatch::greater_than(ir::value *input, ir::value *other, ir::builde
|
||||
ir::type *scalar_ty = input->get_type()->get_scalar_ty();
|
||||
// float > float
|
||||
if (scalar_ty->is_floating_point_ty())
|
||||
return builder->create_fcmpOGT(input, other);
|
||||
return builder->create_fcmpUGT(input, other);
|
||||
// int > int
|
||||
else if (scalar_ty->is_integer_ty())
|
||||
return builder->create_icmpSGT(input, other);
|
||||
@@ -287,7 +287,7 @@ ir::value *dispatch::greater_equal(ir::value *input, ir::value *other, ir::build
|
||||
ir::type *scalar_ty = input->get_type()->get_scalar_ty();
|
||||
// float >= float
|
||||
if (scalar_ty->is_floating_point_ty())
|
||||
return builder->create_fcmpOGE(input, other);
|
||||
return builder->create_fcmpUGE(input, other);
|
||||
// int >= int
|
||||
else if (scalar_ty->is_integer_ty())
|
||||
return builder->create_icmpSGE(input, other);
|
||||
@@ -299,7 +299,7 @@ ir::value *dispatch::less_than(ir::value *input, ir::value *other, ir::builder *
|
||||
ir::type *scalar_ty = input->get_type()->get_scalar_ty();
|
||||
// float < float
|
||||
if (scalar_ty->is_floating_point_ty())
|
||||
return builder->create_fcmpOLT(input, other);
|
||||
return builder->create_fcmpULT(input, other);
|
||||
// int < int
|
||||
else if (scalar_ty->is_integer_ty())
|
||||
return builder->create_icmpSLT(input, other);
|
||||
@@ -311,7 +311,7 @@ ir::value *dispatch::less_equal(ir::value *input, ir::value *other, ir::builder
|
||||
ir::type *scalar_ty = input->get_type()->get_scalar_ty();
|
||||
// float < float
|
||||
if (scalar_ty->is_floating_point_ty())
|
||||
return builder->create_fcmpOLE(input, other);
|
||||
return builder->create_fcmpULE(input, other);
|
||||
// int < int
|
||||
else if (scalar_ty->is_integer_ty())
|
||||
return builder->create_icmpSLE(input, other);
|
||||
@@ -323,7 +323,7 @@ ir::value *dispatch::equal(ir::value *input, ir::value *other, ir::builder *buil
|
||||
ir::type *scalar_ty = input->get_type()->get_scalar_ty();
|
||||
// float == float
|
||||
if (scalar_ty->is_floating_point_ty())
|
||||
return builder->create_fcmpOEQ(input, other);
|
||||
return builder->create_fcmpUEQ(input, other);
|
||||
// int == int
|
||||
else if (scalar_ty->is_integer_ty())
|
||||
return builder->create_icmpEQ(input, other);
|
||||
@@ -335,7 +335,7 @@ ir::value *dispatch::not_equal(ir::value *input, ir::value *other, ir::builder *
|
||||
ir::type *scalar_ty = input->get_type()->get_scalar_ty();
|
||||
// float == float
|
||||
if (scalar_ty->is_floating_point_ty())
|
||||
return builder->create_fcmpONE(input, other);
|
||||
return builder->create_fcmpUNE(input, other);
|
||||
// int == int
|
||||
else if (scalar_ty->is_integer_ty())
|
||||
return builder->create_icmpNE(input, other);
|
||||
@@ -454,7 +454,7 @@ ir::value *dispatch::cast(ir::value *input, ir::type *dst_ty, ir::builder *build
|
||||
// Int cast
|
||||
if (src_sca_ty->is_integer_ty() && dst_sca_ty->is_integer_ty() &&
|
||||
src_sca_ty->get_integer_bitwidth() != dst_sca_ty->get_integer_bitwidth())
|
||||
return builder->create_int_cast(input, dst_ty, true);
|
||||
return builder->create_int_cast(input, dst_ty, src_sca_ty != builder->get_int1_ty());
|
||||
// Float -> Int
|
||||
if (src_sca_ty->is_floating_point_ty() && dst_sca_ty->is_integer_ty()){
|
||||
if(dst_sca_ty->is_bool_ty())
|
||||
@@ -651,6 +651,11 @@ ir::value *dispatch::where(ir::value* condition, ir::value *x, ir::value *y, ir:
|
||||
ir::value *reduce_impl(ir::value *input, unsigned int axis, ir::builder *builder, const std::string &name,
|
||||
ir::reduce_inst::op_t FLOAT_OP, ir::reduce_inst::op_t INT_OP) {
|
||||
ir::type *scalar_ty = input->get_type()->get_scalar_ty();
|
||||
// input is extended to 32-bits if necessary
|
||||
// this increases numerical accuracy and can be done pretty much for free
|
||||
// on GPUs
|
||||
if(scalar_ty->is_integer_ty() && scalar_ty->get_integer_bitwidth() <= 32)
|
||||
input = dispatch::cast(input, type::get_int32_ty(scalar_ty->get_context()), builder);
|
||||
if (scalar_ty->is_floating_point_ty())
|
||||
return builder->create_reduce(input, FLOAT_OP, axis);
|
||||
else if (scalar_ty->is_integer_ty())
|
||||
|
Reference in New Issue
Block a user