diff --git a/include/triton/ir/instructions.h b/include/triton/ir/instructions.h index a08e89a85..ee22a5b25 100644 --- a/include/triton/ir/instructions.h +++ b/include/triton/ir/instructions.h @@ -237,26 +237,26 @@ private: op_t op_; }; -#define TDL_IR_DECLARE_CAST_INST_SIMPLE(name, op) \ +#define TRITON_IR_DECLARE_CAST_INST_SIMPL(name, op) \ class name : public cast_inst{ \ friend class cast_inst; \ name(type *ty, value *v, const std::string &name, instruction *next) \ : cast_inst(ty, v, name, next, op){ } \ }; -TDL_IR_DECLARE_CAST_INST_SIMPLE(trunc_inst, llvm::Instruction::CastOps::Trunc) -TDL_IR_DECLARE_CAST_INST_SIMPLE(z_ext_inst, llvm::Instruction::CastOps::ZExt) -TDL_IR_DECLARE_CAST_INST_SIMPLE(s_ext_inst, llvm::Instruction::CastOps::SExt) -TDL_IR_DECLARE_CAST_INST_SIMPLE(fp_trunc_inst, llvm::Instruction::CastOps::FPTrunc) -TDL_IR_DECLARE_CAST_INST_SIMPLE(fp_ext_inst, llvm::Instruction::CastOps::FPExt) -TDL_IR_DECLARE_CAST_INST_SIMPLE(ui_to_fp_inst, llvm::Instruction::CastOps::UIToFP) -TDL_IR_DECLARE_CAST_INST_SIMPLE(si_to_fp_inst, llvm::Instruction::CastOps::SIToFP) -TDL_IR_DECLARE_CAST_INST_SIMPLE(fp_to_ui_inst, llvm::Instruction::CastOps::FPToUI) -TDL_IR_DECLARE_CAST_INST_SIMPLE(fp_to_si_inst, llvm::Instruction::CastOps::FPToSI) -TDL_IR_DECLARE_CAST_INST_SIMPLE(ptr_to_int_inst, llvm::Instruction::CastOps::PtrToInt) -TDL_IR_DECLARE_CAST_INST_SIMPLE(int_to_ptr_inst, llvm::Instruction::CastOps::IntToPtr) -TDL_IR_DECLARE_CAST_INST_SIMPLE(bit_cast_inst, llvm::Instruction::CastOps::BitCast) -TDL_IR_DECLARE_CAST_INST_SIMPLE(addr_space_cast_inst, llvm::Instruction::CastOps::AddrSpaceCast) +TRITON_IR_DECLARE_CAST_INST_SIMPL(trunc_inst, llvm::Instruction::CastOps::Trunc) +TRITON_IR_DECLARE_CAST_INST_SIMPL(z_ext_inst, llvm::Instruction::CastOps::ZExt) +TRITON_IR_DECLARE_CAST_INST_SIMPL(s_ext_inst, llvm::Instruction::CastOps::SExt) +TRITON_IR_DECLARE_CAST_INST_SIMPL(fp_trunc_inst, llvm::Instruction::CastOps::FPTrunc) +TRITON_IR_DECLARE_CAST_INST_SIMPL(fp_ext_inst, llvm::Instruction::CastOps::FPExt) +TRITON_IR_DECLARE_CAST_INST_SIMPL(ui_to_fp_inst, llvm::Instruction::CastOps::UIToFP) +TRITON_IR_DECLARE_CAST_INST_SIMPL(si_to_fp_inst, llvm::Instruction::CastOps::SIToFP) +TRITON_IR_DECLARE_CAST_INST_SIMPL(fp_to_ui_inst, llvm::Instruction::CastOps::FPToUI) +TRITON_IR_DECLARE_CAST_INST_SIMPL(fp_to_si_inst, llvm::Instruction::CastOps::FPToSI) +TRITON_IR_DECLARE_CAST_INST_SIMPL(ptr_to_int_inst, llvm::Instruction::CastOps::PtrToInt) +TRITON_IR_DECLARE_CAST_INST_SIMPL(int_to_ptr_inst, llvm::Instruction::CastOps::IntToPtr) +TRITON_IR_DECLARE_CAST_INST_SIMPL(bit_cast_inst, llvm::Instruction::CastOps::BitCast) +TRITON_IR_DECLARE_CAST_INST_SIMPL(addr_space_cast_inst, llvm::Instruction::CastOps::AddrSpaceCast) //===----------------------------------------------------------------------===// // terminator_inst classes diff --git a/include/triton/runtime/jit.h b/include/triton/runtime/jit.h index b74ae7c83..6bc377c95 100644 --- a/include/triton/runtime/jit.h +++ b/include/triton/runtime/jit.h @@ -65,7 +65,7 @@ public: target_(target) { } void target_independent(ir::module &module) { - ir::print(module, std::cout); +// ir::print(module, std::cout); optimize_dot.run(module); optimize_trans.run(module); } diff --git a/lib/codegen/tune.cpp b/lib/codegen/tune.cpp index 3821ecdb2..9f8d88492 100644 --- a/lib/codegen/tune.cpp +++ b/lib/codegen/tune.cpp @@ -58,10 +58,8 @@ void tune::init_c_graph(ir::instruction *v) { shapes = store->get_pointer_operand()->get_type()->get_tile_shapes(); else if(auto *downcast = dynamic_cast(v)) return; - else{ - std::cout << v->get_name() << std::endl; + else shapes = v->get_type()->get_tile_shapes(); - } // Reshape if(dynamic_cast(v)){ ir::value *op = v->get_operand(0); diff --git a/lib/lang/expression.cpp b/lib/lang/expression.cpp index 388815164..731d50c24 100644 --- a/lib/lang/expression.cpp +++ b/lib/lang/expression.cpp @@ -246,20 +246,22 @@ ir::value *conditional_expression::codegen(ir::module *mod) const{ it_true_begin++; auto it_true_end = instructions.end(); for(auto it = it_true_begin; it != it_true_end; it++) +// if(!dynamic_cast(*it)) (*it)->set_mask_pred(true_mask); /* false value */ ir::value *false_mask = mask->get_result(1); auto it_false_begin = instructions.end(); it_false_begin--; ir::value *false_value = false_value_->codegen(mod); - it_false_begin++; implicit_broadcast(mod, pred, false_value); - auto it_false_end = instructions.end(); - for(auto it = it_false_begin; it != it_false_end; it++) - (*it)->set_mask_pred(false_mask); - /* cast */ bool is_float, is_ptr, is_int, is_signed; implicit_cast(builder, true_value, false_value, is_float, is_ptr, is_int, is_signed); + it_false_begin++; + auto it_false_end = instructions.end(); + for(auto it = it_false_begin; it != it_false_end; it++) +// if(!dynamic_cast(*it)) + (*it)->set_mask_pred(false_mask); + /* psi */ ir::value *result = builder.create_merge(true_mask, true_value, false_mask, false_value); return result; } diff --git a/lib/lang/node.cpp b/lib/lang/node.cpp index 418a86fca..5c48657bf 100644 --- a/lib/lang/node.cpp +++ b/lib/lang/node.cpp @@ -11,6 +11,8 @@ namespace lang{ ir::value *node::explicit_cast(ir::builder &builder, ir::value *src, ir::type *dst_ty){ ir::type *src_scalar_ty = src->get_type()->get_scalar_ty(); ir::type *dst_scalar_ty = dst_ty->get_scalar_ty(); + if(src->get_type()->is_tile_ty()) + dst_ty = ir::tile_type::get_same_shapes(dst_scalar_ty, src->get_type()); bool src_signed = false; bool dst_signed = false; if(src_scalar_ty == dst_scalar_ty)