diff --git a/include/triton/ir/builder.h b/include/triton/ir/builder.h index bfbbf6c66..cd4b1243d 100644 --- a/include/triton/ir/builder.h +++ b/include/triton/ir/builder.h @@ -119,6 +119,12 @@ public: value *create_fcmpOGE(value *lhs, value *rhs); value *create_fcmpOEQ(value *lhs, value *rhs); value *create_fcmpONE(value *lhs, value *rhs); + value *create_fcmpULT(value *lhs, value *rhs); + value *create_fcmpUGT(value *lhs, value *rhs); + value *create_fcmpULE(value *lhs, value *rhs); + value *create_fcmpUGE(value *lhs, value *rhs); + value *create_fcmpUEQ(value *lhs, value *rhs); + value *create_fcmpUNE(value *lhs, value *rhs); // Logical value *create_and(value *lhs, value *rhs); value *create_xor(value *lhs, value *rhs); diff --git a/lib/codegen/pass.cc b/lib/codegen/pass.cc index 66d3c238c..a39d927dc 100644 --- a/lib/codegen/pass.cc +++ b/lib/codegen/pass.cc @@ -54,7 +54,6 @@ void add_passes_to_emit_bin(ir::module &ir, driver::device *dev, int num_warps, codegen::generator isel(&axes, &layouts, &align, &allocation, &swizzle, target.get(), num_warps); // run passes dce.run(ir); - //ir::print(ir, std::cout); peephole.run(ir); dce.run(ir); pipeline.run(ir); diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index 349056015..0791ee6fc 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -1554,7 +1554,7 @@ Value* generator::shared_off(const std::vector& shapes, const std::vec void generator::visit_reduce1d_inst(ir::reduce_inst* x, std::function do_acc, Value *neutral) { std::map partial; ir::value *arg = x->get_operand(0); - Type *ty = cvt(x->get_type()->get_scalar_ty()); + Type *ret_ty = cvt(x->get_type()->get_scalar_ty()); Value *acc = nullptr; // reduce within thread @@ -1563,13 +1563,13 @@ void generator::visit_reduce1d_inst(ir::reduce_inst* x, std::function 0; i >>= 1) acc = do_acc(acc, call(shfl, {acc, i32(i)})); // pointers unsigned addr_space = shmem_->getType()->getPointerAddressSpace(); - Value *base = bit_cast(shmem_, ptr_ty(ty, addr_space)); + Value *base = bit_cast(shmem_, ptr_ty(ret_ty, addr_space)); Value* thread = tgt_->get_local_id(mod_, *builder_, 0); Value* warp = udiv(thread, i32(32)); Value* lane = urem(thread, i32(32)); @@ -1688,10 +1688,10 @@ void generator::visit_reduce_inst(ir::reduce_inst* x) { // neutral element Value *neutral; switch(op) { - case ir::reduce_inst::ADD: neutral = i32(0); break; - case ir::reduce_inst::SUB: neutral = i32(0); break; - case ir::reduce_inst::MAX: neutral = i32(INT32_MIN); break; - case ir::reduce_inst::MIN: neutral = i32(INT32_MAX); break; + case ir::reduce_inst::ADD: neutral = ConstantInt::get(ty, 0); break; + case ir::reduce_inst::SUB: neutral = ConstantInt::get(ty, 0); break; + case ir::reduce_inst::MAX: neutral = ConstantInt::get(ty, INT32_MIN); break; + case ir::reduce_inst::MIN: neutral = ConstantInt::get(ty, INT32_MAX); break; case ir::reduce_inst::FADD: neutral = ConstantFP::get(ty, 0); break; case ir::reduce_inst::FSUB: neutral = ConstantFP::get(ty, 0); break; case ir::reduce_inst::FMAX: neutral = ConstantFP::get(ty, -INFINITY); break; diff --git a/lib/ir/builder.cc b/lib/ir/builder.cc index ad0faeb8d..80914d145 100644 --- a/lib/ir/builder.cc +++ b/lib/ir/builder.cc @@ -261,6 +261,12 @@ DEFINE_FCMP_INSTR(OGT, cmp_pred_t::FCMP_OGT) DEFINE_FCMP_INSTR(OEQ, cmp_pred_t::FCMP_OEQ) DEFINE_FCMP_INSTR(ONE, cmp_pred_t::FCMP_ONE) +DEFINE_FCMP_INSTR(ULE, cmp_pred_t::FCMP_ULE) +DEFINE_FCMP_INSTR(ULT, cmp_pred_t::FCMP_ULT) +DEFINE_FCMP_INSTR(UGE, cmp_pred_t::FCMP_UGE) +DEFINE_FCMP_INSTR(UGT, cmp_pred_t::FCMP_UGT) +DEFINE_FCMP_INSTR(UEQ, cmp_pred_t::FCMP_UEQ) +DEFINE_FCMP_INSTR(UNE, cmp_pred_t::FCMP_UNE) //===----------------------------------------------------------------------===// diff --git a/lib/ir/dispatch.cc b/lib/ir/dispatch.cc index ee3be6bcc..83076e685 100644 --- a/lib/ir/dispatch.cc +++ b/lib/ir/dispatch.cc @@ -275,7 +275,7 @@ ir::value *dispatch::greater_than(ir::value *input, ir::value *other, ir::builde ir::type *scalar_ty = input->get_type()->get_scalar_ty(); // float > float if (scalar_ty->is_floating_point_ty()) - return builder->create_fcmpOGT(input, other); + return builder->create_fcmpUGT(input, other); // int > int else if (scalar_ty->is_integer_ty()) return builder->create_icmpSGT(input, other); @@ -287,7 +287,7 @@ ir::value *dispatch::greater_equal(ir::value *input, ir::value *other, ir::build ir::type *scalar_ty = input->get_type()->get_scalar_ty(); // float >= float if (scalar_ty->is_floating_point_ty()) - return builder->create_fcmpOGE(input, other); + return builder->create_fcmpUGE(input, other); // int >= int else if (scalar_ty->is_integer_ty()) return builder->create_icmpSGE(input, other); @@ -299,7 +299,7 @@ ir::value *dispatch::less_than(ir::value *input, ir::value *other, ir::builder * ir::type *scalar_ty = input->get_type()->get_scalar_ty(); // float < float if (scalar_ty->is_floating_point_ty()) - return builder->create_fcmpOLT(input, other); + return builder->create_fcmpULT(input, other); // int < int else if (scalar_ty->is_integer_ty()) return builder->create_icmpSLT(input, other); @@ -311,7 +311,7 @@ ir::value *dispatch::less_equal(ir::value *input, ir::value *other, ir::builder ir::type *scalar_ty = input->get_type()->get_scalar_ty(); // float < float if (scalar_ty->is_floating_point_ty()) - return builder->create_fcmpOLE(input, other); + return builder->create_fcmpULE(input, other); // int < int else if (scalar_ty->is_integer_ty()) return builder->create_icmpSLE(input, other); @@ -323,7 +323,7 @@ ir::value *dispatch::equal(ir::value *input, ir::value *other, ir::builder *buil ir::type *scalar_ty = input->get_type()->get_scalar_ty(); // float == float if (scalar_ty->is_floating_point_ty()) - return builder->create_fcmpOEQ(input, other); + return builder->create_fcmpUEQ(input, other); // int == int else if (scalar_ty->is_integer_ty()) return builder->create_icmpEQ(input, other); @@ -335,7 +335,7 @@ ir::value *dispatch::not_equal(ir::value *input, ir::value *other, ir::builder * ir::type *scalar_ty = input->get_type()->get_scalar_ty(); // float == float if (scalar_ty->is_floating_point_ty()) - return builder->create_fcmpONE(input, other); + return builder->create_fcmpUNE(input, other); // int == int else if (scalar_ty->is_integer_ty()) return builder->create_icmpNE(input, other); @@ -454,7 +454,7 @@ ir::value *dispatch::cast(ir::value *input, ir::type *dst_ty, ir::builder *build // Int cast if (src_sca_ty->is_integer_ty() && dst_sca_ty->is_integer_ty() && src_sca_ty->get_integer_bitwidth() != dst_sca_ty->get_integer_bitwidth()) - return builder->create_int_cast(input, dst_ty, true); + return builder->create_int_cast(input, dst_ty, src_sca_ty != builder->get_int1_ty()); // Float -> Int if (src_sca_ty->is_floating_point_ty() && dst_sca_ty->is_integer_ty()){ if(dst_sca_ty->is_bool_ty()) @@ -651,6 +651,11 @@ ir::value *dispatch::where(ir::value* condition, ir::value *x, ir::value *y, ir: ir::value *reduce_impl(ir::value *input, unsigned int axis, ir::builder *builder, const std::string &name, ir::reduce_inst::op_t FLOAT_OP, ir::reduce_inst::op_t INT_OP) { ir::type *scalar_ty = input->get_type()->get_scalar_ty(); + // input is extended to 32-bits if necessary + // this increases numerical accuracy and can be done pretty much for free + // on GPUs + if(scalar_ty->is_integer_ty() && scalar_ty->get_integer_bitwidth() <= 32) + input = dispatch::cast(input, type::get_int32_ty(scalar_ty->get_context()), builder); if (scalar_ty->is_floating_point_ty()) return builder->create_reduce(input, FLOAT_OP, axis); else if (scalar_ty->is_integer_ty())