diff --git a/include/ir/basic_block.h b/include/ir/basic_block.h index 240123796..3a4989f27 100644 --- a/include/ir/basic_block.h +++ b/include/ir/basic_block.h @@ -24,8 +24,9 @@ public: public: - // parent + // accessors function* get_parent() { return parent_; } + context& get_context() { return ctx_; } // get iterator to first instruction that is not a phi iterator get_first_non_phi(); @@ -51,10 +52,8 @@ public: inline const instruction &back() const { return *inst_list_.back(); } inline instruction &back() { return *inst_list_.back(); } - // get predecessors - const std::vector& get_predecessors() const; - - // add predecessor + // predecessors + const std::vector& get_predecessors() const { return preds_; } void add_predecessor(basic_block* pred); // factory functions diff --git a/include/ir/context.h b/include/ir/context.h index 8b80f7491..35907ede1 100644 --- a/include/ir/context.h +++ b/include/ir/context.h @@ -4,8 +4,15 @@ namespace tdl{ namespace ir{ +class type; + /* Context */ class context { +public: + type *get_void_ty(); + type *get_int1_ty(); + +private: }; } diff --git a/include/ir/instructions.h b/include/ir/instructions.h index 95c26de1d..aee0aa1d0 100644 --- a/include/ir/instructions.h +++ b/include/ir/instructions.h @@ -9,6 +9,7 @@ namespace tdl{ namespace ir{ class basic_block; +class context; //===----------------------------------------------------------------------===// // instruction classes @@ -17,7 +18,7 @@ class basic_block; class instruction: public user{ protected: // constructors - instruction(type *ty, unsigned num_ops, instruction *next = nullptr); + instruction(type *ty, unsigned num_ops, const std::string &name = "", instruction *next = nullptr); public: @@ -38,13 +39,17 @@ private: phi_node(type *ty, unsigned num_reserved); public: - void add_incoming(value *x, basic_block *bb); + void set_incoming_value(unsigned i, value *v); + void set_incoming_block(unsigned i, basic_block *block); + + void add_incoming(value *v, basic_block *block); // Factory methods static phi_node* create(type *ty, unsigned num_reserved); private: unsigned num_reserved_; + std::vector blocks_; }; //===----------------------------------------------------------------------===// @@ -84,11 +89,10 @@ public: typedef llvm::CmpInst::Predicate pred_t; using pcmp = llvm::CmpInst; -private: - type* make_cmp_result_type(type *ty); - protected: - cmp_inst(pred_t pred, value *lhs, value *rhs, const std::string &name, instruction *next); + cmp_inst(type *ty, pred_t pred, value *lhs, value *rhs, const std::string &name, instruction *next); + + static type* make_cmp_result_type(type *ty); static bool is_fp_predicate(pred_t pred); static bool is_int_predicate(pred_t pred); @@ -116,17 +120,29 @@ public: const std::string &name = "", instruction *next = nullptr); }; +//===----------------------------------------------------------------------===// +// unary_inst classes +//===----------------------------------------------------------------------===// + +class unary_inst: public instruction { +protected: + unary_inst(type *Ty, value *v, const std::string &name, instruction *next); +}; + + //===----------------------------------------------------------------------===// // cast_inst classes //===----------------------------------------------------------------------===// -class cast_inst: public instruction{ +class cast_inst: public unary_inst{ + using unary_inst::unary_inst; + using ic = llvm::Instruction::CastOps; + public: typedef llvm::CastInst::CastOps op_t; -protected: - // Constructors - cast_inst(op_t op, value *arg, type *ty, const std::string &name, instruction *next); +private: + bool is_valid(op_t op, value *arg, type *ty); public: // Factory methods @@ -135,33 +151,67 @@ public: static cast_inst *create_integer_cast(value *arg, type *ty, bool is_signed, const std::string &name = "", instruction *next = nullptr); - private: op_t op_; }; +#define TDL_IR_DECLARE_CAST_INST_SIMPLE(name) \ + class name : public cast_inst{ \ + friend class cast_inst; \ + using cast_inst::cast_inst; \ + }; + +TDL_IR_DECLARE_CAST_INST_SIMPLE(trunc_inst) +TDL_IR_DECLARE_CAST_INST_SIMPLE(z_ext_inst) +TDL_IR_DECLARE_CAST_INST_SIMPLE(s_ext_inst) +TDL_IR_DECLARE_CAST_INST_SIMPLE(fp_trunc_inst) +TDL_IR_DECLARE_CAST_INST_SIMPLE(fp_ext_inst) +TDL_IR_DECLARE_CAST_INST_SIMPLE(ui_to_fp_inst) +TDL_IR_DECLARE_CAST_INST_SIMPLE(si_to_fp_inst) +TDL_IR_DECLARE_CAST_INST_SIMPLE(fp_to_ui_inst) +TDL_IR_DECLARE_CAST_INST_SIMPLE(fp_to_si_inst) +TDL_IR_DECLARE_CAST_INST_SIMPLE(ptr_to_int_inst) +TDL_IR_DECLARE_CAST_INST_SIMPLE(int_to_ptr_inst) +TDL_IR_DECLARE_CAST_INST_SIMPLE(bit_cast_inst) +TDL_IR_DECLARE_CAST_INST_SIMPLE(addr_space_cast_inst) + //===----------------------------------------------------------------------===// // terminator_inst classes //===----------------------------------------------------------------------===// class terminator_inst: public instruction{ -public: + using instruction::instruction; }; -class return_inst: public instruction{ +// return instruction +class return_inst: public terminator_inst{ + return_inst(context &ctx, value *ret_val, instruction *next); + +public: + // accessors + value *get_return_value() + { return get_num_operands() ? get_operand(0) : nullptr; } + + unsigned get_num_successors() const { return 0; } + + // factory methods + static return_inst* create(context &ctx, value *ret_val = nullptr, instruction *next = nullptr); }; -//===----------------------------------------------------------------------===// -// branch_inst classes -//===----------------------------------------------------------------------===// +// conditional/unconditional branch instruction + +class branch_inst: public terminator_inst{ + branch_inst(basic_block *dst, instruction *next); + branch_inst(basic_block *if_dst, basic_block *else_dst, value *cond, instruction *next); -class branch_inst: public instruction{ public: + + // factory methods static branch_inst* create(basic_block *dest, - const std::string &name = "", instruction *next = nullptr); + instruction *next = nullptr); static branch_inst* create(value *cond, basic_block *if_dest, basic_block *else_dest, - const std::string &name = "", instruction *next = nullptr); + instruction *next = nullptr); }; //===----------------------------------------------------------------------===// @@ -169,9 +219,20 @@ public: //===----------------------------------------------------------------------===// class getelementptr_inst: public instruction{ + getelementptr_inst(type *pointee_ty, value *ptr, const std::vector &idx, const std::string &name, instruction *next); + +private: + static type *get_return_type(type *ty, value *ptr, const std::vector &idx); + static type *get_indexed_type_impl(type *ty, const std::vector &idx); + static type *get_indexed_type(type *ty, const std::vector &idx); + public: static getelementptr_inst* create(value *ptr, const std::vector &idx, const std::string &name = "", instruction *next = nullptr); + +private: + type *source_elt_ty; + type *res_elt_ty; }; diff --git a/include/ir/type.h b/include/ir/type.h index 6a50690ed..874bffcdd 100644 --- a/include/ir/type.h +++ b/include/ir/type.h @@ -7,24 +7,40 @@ namespace tdl{ namespace ir{ class context; +class value; /* Type */ class type { public: - bool is_integer_ty() const; - bool is_pointer_ty() const; - bool is_float_ty() const; - bool is_double_ty() const; - bool is_floating_point_ty() const; + virtual ~type(){} + + // accessors + context &get_context() const; // type attributes unsigned get_fp_mantissa_width() const; unsigned get_integer_bit_width() const; + unsigned get_scalar_bitsize() const; const std::vector &get_tile_shapes() const; + type *get_scalar_ty() const; + unsigned get_pointer_address_space() const; + + // type predicates + bool is_int_or_tileint_ty(); + bool is_integer_ty() const; + bool is_integer_ty(unsigned width) const; + bool is_pointer_ty() const; + bool is_float_ty() const; + bool is_double_ty() const; + bool is_floating_point_ty() const; + bool is_sized() const; + bool is_tile_ty() const; + // Factory methods static type* get_void_ty(context &ctx); static type* get_float_ty(context &ctx); static type* get_double_ty(context &ctx); + }; class integer_type: public type { @@ -32,14 +48,22 @@ public: static integer_type* get(context &ctx, unsigned width); }; +class composite_type: public type{ +public: + bool index_valid(value *idx) const; + type* get_type_at_index(value *idx) const; +}; + class tile_type: public type { public: static tile_type* get(type *ty, const std::vector &shapes); + static tile_type* get_same_shapes(type *ty, type *ref); }; class pointer_type: public type { public: static pointer_type* get(type *ty, unsigned address_space); + type *get_element_ty() const; }; class function_type: public type { diff --git a/include/ir/value.h b/include/ir/value.h index b7a017200..effa44014 100644 --- a/include/ir/value.h +++ b/include/ir/value.h @@ -55,6 +55,9 @@ private: //===----------------------------------------------------------------------===// class user: public value{ +protected: + void resize_ops(unsigned n) { ops_.resize(n); } + public: // Constructor user(type *ty, unsigned num_ops, const std::string &name = "") diff --git a/lib/ir/instructions.cpp b/lib/ir/instructions.cpp index 383085c10..b72341d72 100644 --- a/lib/ir/instructions.cpp +++ b/lib/ir/instructions.cpp @@ -1,6 +1,8 @@ +#include "ir/context.h" #include "ir/basic_block.h" #include "ir/instructions.h" #include "ir/constant.h" +#include "ir/type.h" namespace tdl{ namespace ir{ @@ -9,8 +11,8 @@ namespace ir{ // instruction classes //===----------------------------------------------------------------------===// -instruction::instruction(type *ty, unsigned num_ops, instruction *next) - : user(ty, num_ops) { +instruction::instruction(type *ty, unsigned num_ops, const std::string &name, instruction *next) + : user(ty, num_ops, name) { if(next){ basic_block *block = next->get_parent(); assert(block && "Next instruction is not in a basic block!"); @@ -23,9 +25,29 @@ instruction::instruction(type *ty, unsigned num_ops, instruction *next) // phi_node classes //===----------------------------------------------------------------------===// -// Add incoming -void phi_node::add_incoming(value *x, basic_block *bb){ +// Set incoming value +void phi_node::set_incoming_value(unsigned i, value *v){ + assert(v && "PHI node got a null value!"); + assert(get_type() == v->get_type() && + "All operands to PHI node must be the same type as the PHI node!"); + set_operand(i, v); +} +// Set incoming block +void phi_node::set_incoming_block(unsigned i, basic_block *block){ + assert(block && "PHI node got a null basic block!"); + blocks_[i] = block; +} + +// Add incoming +void phi_node::add_incoming(value *v, basic_block *block){ + if(get_num_operands()==num_reserved_){ + num_reserved_++; + resize_ops(num_reserved_); + blocks_.resize(num_reserved_); + } + set_incoming_value(get_num_operands() - 1, v); + set_incoming_block(get_num_operands() - 1, block); } // Factory methods @@ -39,7 +61,7 @@ phi_node* phi_node::create(type *ty, unsigned num_reserved){ //===----------------------------------------------------------------------===// binary_operator::binary_operator(op_t op, value *lhs, value *rhs, type *ty, const std::string &name, instruction *next) - : instruction(ty, 2, next), op_(op){ + : instruction(ty, 2, name, next), op_(op){ set_operand(0, lhs); set_operand(1, rhs); } @@ -72,6 +94,24 @@ binary_operator *binary_operator::create_not(value *arg, const std::string &name // cmp_inst classes //===----------------------------------------------------------------------===// +// cmp_inst + +cmp_inst::cmp_inst(type *ty, cmp_inst::pred_t pred, value *lhs, value *rhs, const std::string &name, instruction *next) + : instruction(ty, 2, name, next), pred_(pred) { + set_operand(0, lhs); + set_operand(1, rhs); +} + +type* cmp_inst::make_cmp_result_type(type *ty){ + type* int1_ty = ty->get_context().get_int1_ty(); + if (tile_type* tile_ty = dynamic_cast(ty)) + return tile_type::get_same_shapes(int1_ty, tile_ty); + return int1_ty; +} + + + + bool cmp_inst::is_fp_predicate(pred_t pred) { return pred >= pcmp::FIRST_FCMP_PREDICATE && pred <= pcmp::LAST_FCMP_PREDICATE; } @@ -84,15 +124,159 @@ bool cmp_inst::is_int_predicate(pred_t pred) { icmp_inst* icmp_inst::create(pred_t pred, value *lhs, value *rhs, const std::string &name, instruction *next){ assert(is_int_predicate(pred)); - return new icmp_inst(pred, lhs, rhs, name, next); + type *res_ty = make_cmp_result_type(lhs->get_type()); + return new icmp_inst(res_ty, pred, lhs, rhs, name, next); } // fcmp_inst fcmp_inst* fcmp_inst::create(pred_t pred, value *lhs, value *rhs, const std::string &name, instruction *next){ assert(is_fp_predicate(pred)); - return new fcmp_inst(pred, lhs, rhs, name, next); + type *res_ty = make_cmp_result_type(lhs->get_type()); + return new fcmp_inst(res_ty, pred, lhs, rhs, name, next); } +//===----------------------------------------------------------------------===// +// unary_inst classes +//===----------------------------------------------------------------------===// + +unary_inst::unary_inst(type *ty, value *v, const std::string &name, instruction *next) + : instruction(ty, 1, name, next) { + set_operand(0, v); +} + +//===----------------------------------------------------------------------===// +// cast_inst classes +//===----------------------------------------------------------------------===// + +cast_inst *cast_inst::create(op_t op, value *arg, type *ty, const std::string &name, instruction *next){ + assert(is_valid(op, arg, ty) && "Invalid cast!"); + // Construct and return the appropriate CastInst subclass + switch (op) { + case ic::Trunc: return new trunc_inst (ty, arg, name, next); + case ic::ZExt: return new z_ext_inst (ty, arg, name, next); + case ic::SExt: return new s_ext_inst (ty, arg, name, next); + case ic::FPTrunc: return new fp_trunc_inst (ty, arg, name, next); + case ic::FPExt: return new fp_ext_inst (ty, arg, name, next); + case ic::UIToFP: return new ui_to_fp_inst (ty, arg, name, next); + case ic::SIToFP: return new si_to_fp_inst (ty, arg, name, next); + case ic::FPToUI: return new fp_to_ui_inst (ty, arg, name, next); + case ic::FPToSI: return new fp_to_si_inst (ty, arg, name, next); + case ic::PtrToInt: return new ptr_to_int_inst (ty, arg, name, next); + case ic::IntToPtr: return new int_to_ptr_inst (ty, arg, name, next); + case ic::BitCast: return new bit_cast_inst (ty, arg, name, next); + case ic::AddrSpaceCast: return new addr_space_cast_inst (ty, arg, name, next); + default: throw std::runtime_error("unreachable"); + } +} + +cast_inst *cast_inst::create_integer_cast(value *arg, type *ty, bool is_signed, const std::string &name, instruction *next){ + type *arg_ty = arg->get_type(); + assert(arg_ty->is_int_or_tileint_ty() && ty->is_int_or_tileint_ty() && "Invalid integer cast!"); + unsigned arg_bits = arg_ty->get_scalar_bitsize(); + unsigned dst_bits = ty->get_scalar_bitsize(); + op_t op = (arg_bits == dst_bits ? ic::BitCast : + (arg_bits > dst_bits ? ic::Trunc : + (is_signed ? ic::SExt : ic::ZExt))); + return create(op, arg, ty, name, next); +} + +//===----------------------------------------------------------------------===// +// terminator_inst classes +//===----------------------------------------------------------------------===// + + +// return_inst + +return_inst::return_inst(context &ctx, value *ret_val, instruction *next) + : terminator_inst(ctx.get_void_ty(), !!ret_val, "", next){ + if(ret_val) + set_operand(0, ret_val); +} + +return_inst *return_inst::create(context &ctx, value *ret_val, instruction *next){ + return new return_inst(ctx, ret_val, next); +} + + +// conditional/unconditional branch + +branch_inst::branch_inst(basic_block *dst, instruction *next) + : terminator_inst(dst->get_context().get_void_ty(), 1, "", next){ + set_operand(0, dst); +} + +branch_inst::branch_inst(basic_block *if_dst, basic_block *else_dst, value *cond, instruction *next) + : terminator_inst(if_dst->get_context().get_void_ty(), 3, "", next){ + assert(cond->get_type()->is_integer_ty(1) && "May only branch on boolean predicates!"); + set_operand(0, if_dst); + set_operand(1, else_dst); + set_operand(2, cond); +} + +branch_inst* branch_inst::create(basic_block *dst, instruction *next) { + assert(dst && "Branch destination may not be null!"); + return new branch_inst(dst, next); +} + +branch_inst* branch_inst::create(value *cond, basic_block *if_dst, basic_block *else_dst, instruction *next) { + assert(cond->get_type()->is_integer_ty(1) && "May only branch on boolean predicates!"); + return new branch_inst(if_dst, else_dst, cond, next); +} + + +//===----------------------------------------------------------------------===// +// getelementptr_inst classes +//===----------------------------------------------------------------------===// + +getelementptr_inst::getelementptr_inst(type *pointee_ty, value *ptr, const std::vector &idx, const std::string &name, instruction *next) + : instruction(get_return_type(pointee_ty, ptr, idx), idx.size(), name, next), + source_elt_ty(pointee_ty), + res_elt_ty(get_indexed_type(pointee_ty, idx)){ + type *expected_ty = ((pointer_type*)(get_type()->get_scalar_ty()))->get_element_ty(); + assert(res_elt_ty == expected_ty); + set_operand(0, ptr); + for(size_t i = 0; i < idx.size(); i++) + set_operand(1 + i, idx[i]); +} + +type *getelementptr_inst::get_return_type(type *elt_ty, value *ptr, const std::vector &idx_list) { + // result pointer type + type *ptr_ty = pointer_type::get(get_indexed_type(elt_ty, idx_list), ptr->get_type()->get_pointer_address_space()); + // Tile GEP + if(ptr->get_type()->is_tile_ty()) + return tile_type::get_same_shapes(ptr_ty, ptr->get_type()); + for(value *idx : idx_list) + if (idx->get_type()->is_tile_ty()) + return tile_type::get_same_shapes(ptr_ty, idx->get_type()); + // Scalar GEP + return ptr_ty; +} + +type *getelementptr_inst::get_indexed_type_impl(type *ty, const std::vector &idx_list) { + if(idx_list.empty()) + return ty; + if(!ty->is_sized()) + return nullptr; + unsigned cur_idx = 1; + for(; cur_idx != idx_list.size(); cur_idx++){ + composite_type *cty = dynamic_cast(ty); + if(!cty || cty->is_pointer_ty()) + break; + value *idx = idx_list[cur_idx]; + if(!cty->index_valid(idx)) + break; + ty = cty->get_type_at_index(idx); + } + return (cur_idx == idx_list.size())? ty : nullptr; +} + +type *getelementptr_inst::get_indexed_type(type *ty, const std::vector &idx_list) { + type *result = get_indexed_type_impl(ty, idx_list); + assert(result && "invalid GEP type!"); + return result; +} + + } }