[lang] fixup in cast type
This commit is contained in:
@@ -237,26 +237,26 @@ private:
|
|||||||
op_t op_;
|
op_t op_;
|
||||||
};
|
};
|
||||||
|
|
||||||
#define TDL_IR_DECLARE_CAST_INST_SIMPLE(name, op) \
|
#define TRITON_IR_DECLARE_CAST_INST_SIMPL(name, op) \
|
||||||
class name : public cast_inst{ \
|
class name : public cast_inst{ \
|
||||||
friend class cast_inst; \
|
friend class cast_inst; \
|
||||||
name(type *ty, value *v, const std::string &name, instruction *next) \
|
name(type *ty, value *v, const std::string &name, instruction *next) \
|
||||||
: cast_inst(ty, v, name, next, op){ } \
|
: cast_inst(ty, v, name, next, op){ } \
|
||||||
};
|
};
|
||||||
|
|
||||||
TDL_IR_DECLARE_CAST_INST_SIMPLE(trunc_inst, llvm::Instruction::CastOps::Trunc)
|
TRITON_IR_DECLARE_CAST_INST_SIMPL(trunc_inst, llvm::Instruction::CastOps::Trunc)
|
||||||
TDL_IR_DECLARE_CAST_INST_SIMPLE(z_ext_inst, llvm::Instruction::CastOps::ZExt)
|
TRITON_IR_DECLARE_CAST_INST_SIMPL(z_ext_inst, llvm::Instruction::CastOps::ZExt)
|
||||||
TDL_IR_DECLARE_CAST_INST_SIMPLE(s_ext_inst, llvm::Instruction::CastOps::SExt)
|
TRITON_IR_DECLARE_CAST_INST_SIMPL(s_ext_inst, llvm::Instruction::CastOps::SExt)
|
||||||
TDL_IR_DECLARE_CAST_INST_SIMPLE(fp_trunc_inst, llvm::Instruction::CastOps::FPTrunc)
|
TRITON_IR_DECLARE_CAST_INST_SIMPL(fp_trunc_inst, llvm::Instruction::CastOps::FPTrunc)
|
||||||
TDL_IR_DECLARE_CAST_INST_SIMPLE(fp_ext_inst, llvm::Instruction::CastOps::FPExt)
|
TRITON_IR_DECLARE_CAST_INST_SIMPL(fp_ext_inst, llvm::Instruction::CastOps::FPExt)
|
||||||
TDL_IR_DECLARE_CAST_INST_SIMPLE(ui_to_fp_inst, llvm::Instruction::CastOps::UIToFP)
|
TRITON_IR_DECLARE_CAST_INST_SIMPL(ui_to_fp_inst, llvm::Instruction::CastOps::UIToFP)
|
||||||
TDL_IR_DECLARE_CAST_INST_SIMPLE(si_to_fp_inst, llvm::Instruction::CastOps::SIToFP)
|
TRITON_IR_DECLARE_CAST_INST_SIMPL(si_to_fp_inst, llvm::Instruction::CastOps::SIToFP)
|
||||||
TDL_IR_DECLARE_CAST_INST_SIMPLE(fp_to_ui_inst, llvm::Instruction::CastOps::FPToUI)
|
TRITON_IR_DECLARE_CAST_INST_SIMPL(fp_to_ui_inst, llvm::Instruction::CastOps::FPToUI)
|
||||||
TDL_IR_DECLARE_CAST_INST_SIMPLE(fp_to_si_inst, llvm::Instruction::CastOps::FPToSI)
|
TRITON_IR_DECLARE_CAST_INST_SIMPL(fp_to_si_inst, llvm::Instruction::CastOps::FPToSI)
|
||||||
TDL_IR_DECLARE_CAST_INST_SIMPLE(ptr_to_int_inst, llvm::Instruction::CastOps::PtrToInt)
|
TRITON_IR_DECLARE_CAST_INST_SIMPL(ptr_to_int_inst, llvm::Instruction::CastOps::PtrToInt)
|
||||||
TDL_IR_DECLARE_CAST_INST_SIMPLE(int_to_ptr_inst, llvm::Instruction::CastOps::IntToPtr)
|
TRITON_IR_DECLARE_CAST_INST_SIMPL(int_to_ptr_inst, llvm::Instruction::CastOps::IntToPtr)
|
||||||
TDL_IR_DECLARE_CAST_INST_SIMPLE(bit_cast_inst, llvm::Instruction::CastOps::BitCast)
|
TRITON_IR_DECLARE_CAST_INST_SIMPL(bit_cast_inst, llvm::Instruction::CastOps::BitCast)
|
||||||
TDL_IR_DECLARE_CAST_INST_SIMPLE(addr_space_cast_inst, llvm::Instruction::CastOps::AddrSpaceCast)
|
TRITON_IR_DECLARE_CAST_INST_SIMPL(addr_space_cast_inst, llvm::Instruction::CastOps::AddrSpaceCast)
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// terminator_inst classes
|
// terminator_inst classes
|
||||||
|
@@ -65,7 +65,7 @@ public:
|
|||||||
target_(target) { }
|
target_(target) { }
|
||||||
|
|
||||||
void target_independent(ir::module &module) {
|
void target_independent(ir::module &module) {
|
||||||
ir::print(module, std::cout);
|
// ir::print(module, std::cout);
|
||||||
optimize_dot.run(module);
|
optimize_dot.run(module);
|
||||||
optimize_trans.run(module);
|
optimize_trans.run(module);
|
||||||
}
|
}
|
||||||
|
@@ -58,10 +58,8 @@ void tune::init_c_graph(ir::instruction *v) {
|
|||||||
shapes = store->get_pointer_operand()->get_type()->get_tile_shapes();
|
shapes = store->get_pointer_operand()->get_type()->get_tile_shapes();
|
||||||
else if(auto *downcast = dynamic_cast<ir::downcast_inst*>(v))
|
else if(auto *downcast = dynamic_cast<ir::downcast_inst*>(v))
|
||||||
return;
|
return;
|
||||||
else{
|
else
|
||||||
std::cout << v->get_name() << std::endl;
|
|
||||||
shapes = v->get_type()->get_tile_shapes();
|
shapes = v->get_type()->get_tile_shapes();
|
||||||
}
|
|
||||||
// Reshape
|
// Reshape
|
||||||
if(dynamic_cast<ir::reshape_inst*>(v)){
|
if(dynamic_cast<ir::reshape_inst*>(v)){
|
||||||
ir::value *op = v->get_operand(0);
|
ir::value *op = v->get_operand(0);
|
||||||
|
@@ -246,20 +246,22 @@ ir::value *conditional_expression::codegen(ir::module *mod) const{
|
|||||||
it_true_begin++;
|
it_true_begin++;
|
||||||
auto it_true_end = instructions.end();
|
auto it_true_end = instructions.end();
|
||||||
for(auto it = it_true_begin; it != it_true_end; it++)
|
for(auto it = it_true_begin; it != it_true_end; it++)
|
||||||
|
// if(!dynamic_cast<ir::retile_inst*>(*it))
|
||||||
(*it)->set_mask_pred(true_mask);
|
(*it)->set_mask_pred(true_mask);
|
||||||
/* false value */
|
/* false value */
|
||||||
ir::value *false_mask = mask->get_result(1);
|
ir::value *false_mask = mask->get_result(1);
|
||||||
auto it_false_begin = instructions.end();
|
auto it_false_begin = instructions.end();
|
||||||
it_false_begin--;
|
it_false_begin--;
|
||||||
ir::value *false_value = false_value_->codegen(mod);
|
ir::value *false_value = false_value_->codegen(mod);
|
||||||
it_false_begin++;
|
|
||||||
implicit_broadcast(mod, pred, false_value);
|
implicit_broadcast(mod, pred, false_value);
|
||||||
auto it_false_end = instructions.end();
|
|
||||||
for(auto it = it_false_begin; it != it_false_end; it++)
|
|
||||||
(*it)->set_mask_pred(false_mask);
|
|
||||||
/* cast */
|
|
||||||
bool is_float, is_ptr, is_int, is_signed;
|
bool is_float, is_ptr, is_int, is_signed;
|
||||||
implicit_cast(builder, true_value, false_value, is_float, is_ptr, is_int, is_signed);
|
implicit_cast(builder, true_value, false_value, is_float, is_ptr, is_int, is_signed);
|
||||||
|
it_false_begin++;
|
||||||
|
auto it_false_end = instructions.end();
|
||||||
|
for(auto it = it_false_begin; it != it_false_end; it++)
|
||||||
|
// if(!dynamic_cast<ir::retile_inst*>(*it))
|
||||||
|
(*it)->set_mask_pred(false_mask);
|
||||||
|
/* psi */
|
||||||
ir::value *result = builder.create_merge(true_mask, true_value, false_mask, false_value);
|
ir::value *result = builder.create_merge(true_mask, true_value, false_mask, false_value);
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
@@ -11,6 +11,8 @@ namespace lang{
|
|||||||
ir::value *node::explicit_cast(ir::builder &builder, ir::value *src, ir::type *dst_ty){
|
ir::value *node::explicit_cast(ir::builder &builder, ir::value *src, ir::type *dst_ty){
|
||||||
ir::type *src_scalar_ty = src->get_type()->get_scalar_ty();
|
ir::type *src_scalar_ty = src->get_type()->get_scalar_ty();
|
||||||
ir::type *dst_scalar_ty = dst_ty->get_scalar_ty();
|
ir::type *dst_scalar_ty = dst_ty->get_scalar_ty();
|
||||||
|
if(src->get_type()->is_tile_ty())
|
||||||
|
dst_ty = ir::tile_type::get_same_shapes(dst_scalar_ty, src->get_type());
|
||||||
bool src_signed = false;
|
bool src_signed = false;
|
||||||
bool dst_signed = false;
|
bool dst_signed = false;
|
||||||
if(src_scalar_ty == dst_scalar_ty)
|
if(src_scalar_ty == dst_scalar_ty)
|
||||||
|
Reference in New Issue
Block a user