From b039498d1544d8433a32635b400d8a1c1e804024 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Thu, 3 Jan 2019 00:42:37 -0500 Subject: [PATCH] [intermediate representation] added subdefinitions in types submodule --- include/ir/context.h | 10 ++- include/ir/instructions.h | 25 +++++- include/ir/type.h | 117 ++++++++++++++++++++++++---- include/ir/value.h | 2 +- lib/codegen.cpp | 14 ++-- lib/ir/context.cpp | 22 ++++++ lib/ir/instructions.cpp | 19 +++-- lib/ir/type.cpp | 156 ++++++++++++++++++++++++++++++++++++++ 8 files changed, 328 insertions(+), 37 deletions(-) diff --git a/include/ir/context.h b/include/ir/context.h index 35907ede1..c7382a0cb 100644 --- a/include/ir/context.h +++ b/include/ir/context.h @@ -1,18 +1,22 @@ #ifndef TDL_INCLUDE_IR_CONTEXT_H #define TDL_INCLUDE_IR_CONTEXT_H +#include +#include "ir/type.h" + namespace tdl{ namespace ir{ class type; +class context_impl; /* Context */ class context { public: - type *get_void_ty(); - type *get_int1_ty(); + context(); -private: +public: + std::shared_ptr p_impl; }; } diff --git a/include/ir/instructions.h b/include/ir/instructions.h index aee0aa1d0..cb673d73e 100644 --- a/include/ir/instructions.h +++ b/include/ir/instructions.h @@ -36,7 +36,7 @@ private: class phi_node: public instruction{ private: - phi_node(type *ty, unsigned num_reserved); + phi_node(type *ty, unsigned num_reserved, const std::string &name, instruction *next); public: void set_incoming_value(unsigned i, value *v); @@ -45,7 +45,7 @@ public: void add_incoming(value *v, basic_block *block); // Factory methods - static phi_node* create(type *ty, unsigned num_reserved); + static phi_node* create(type *ty, unsigned num_reserved, const std::string &name = "", instruction *next = nullptr); private: unsigned num_reserved_; @@ -235,6 +235,27 @@ private: type *res_elt_ty; }; +//===----------------------------------------------------------------------===// +// retile_inst classes +//===----------------------------------------------------------------------===// + +class retile_inst: public instruction{ + +}; + +class reshape_inst: public instruction{ + +}; + +class splat_inst: public instruction{ + +}; + +class broadcast_inst: public instruction{ + +}; + + } } diff --git a/include/ir/type.h b/include/ir/type.h index 874bffcdd..43b6d1c34 100644 --- a/include/ir/type.h +++ b/include/ir/type.h @@ -8,62 +8,147 @@ namespace ir{ class context; class value; +class integer_type; /* Type */ class type { public: + enum id_t { + // primitive types + VoidTyID = 0, ///< 0: type with no size + HalfTyID, ///< 1: 16-bit floating point type + FloatTyID, ///< 2: 32-bit floating point type + DoubleTyID, ///< 3: 64-bit floating point type + LabelTyID, ///< 4: Labels + MetadataTyID, ///< 5: Metadata + TokenTyID, ///< 6: Token + // derived types + IntegerTyID, ///< 7: Arbitrary bit width integers + FunctionTyID, ///< 8: Functions + PointerTyID, ///< 9: Pointers + TileTyID, ///< 10: Tile + }; + +public: + //constructors + type(context &ctx, id_t id) : ctx_(ctx), id_(id) {} + + //destructor virtual ~type(){} // accessors - context &get_context() const; + context &get_context() const { return ctx_; } // type attributes unsigned get_fp_mantissa_width() const; - unsigned get_integer_bit_width() const; - unsigned get_scalar_bitsize() const; - const std::vector &get_tile_shapes() const; + unsigned get_integer_bitwidth() const; type *get_scalar_ty() const; + const std::vector &get_tile_shapes() const; + type *get_tile_element_ty() const; unsigned get_pointer_address_space() const; - // type predicates + // primitive predicates + bool is_void_ty() const { return id_ == VoidTyID; } + bool is_half_ty() const { return id_ == HalfTyID; } + bool is_float_ty() const { return id_ == FloatTyID; } + bool is_double_ty() const { return id_ == DoubleTyID; } + bool is_label_ty() const { return id_ == LabelTyID;} + bool is_metadata_ty() const { return id_ == MetadataTyID; } + bool is_token_ty() const { return id_ == TokenTyID; } + bool is_integer_ty() const { return id_ == IntegerTyID; } + bool is_pointer_ty() const { return id_ == PointerTyID; } + bool is_tile_ty() const { return id_ == TileTyID; } + + // Composite predicates bool is_int_or_tileint_ty(); - bool is_integer_ty() const; bool is_integer_ty(unsigned width) const; - bool is_pointer_ty() const; - bool is_float_ty() const; - bool is_double_ty() const; bool is_floating_point_ty() const; - bool is_sized() const; - bool is_tile_ty() const; + bool is_sized() const ; // Factory methods - static type* get_void_ty(context &ctx); - static type* get_float_ty(context &ctx); - static type* get_double_ty(context &ctx); + // primitive types + static type *get_void_ty(context &ctx); + static type *get_label_ty(context &ctx); + // half + static type *get_half_ty(context &ctx); + static type *get_float_ty(context &ctx); + static type *get_double_ty(context &ctx); + // integer types + static integer_type *get_int1_ty(context &ctx); + static integer_type *get_int8_ty(context &ctx); + static integer_type *get_int16_ty(context &ctx); + static integer_type *get_int32_ty(context &ctx); + static integer_type *get_int64_ty(context &ctx); + static integer_type *get_int128_ty(context &ctx); +private: + context &ctx_; + id_t id_; + +protected: + std::vector contained_tys_; }; class integer_type: public type { + friend class context_impl; + +private: + // constructors + integer_type(context &ctx, unsigned bitwidth) + : type(ctx, IntegerTyID), bitwidth_(bitwidth){ } + public: + // accessors + unsigned get_bitwidth() const { return bitwidth_; } + + // factory methods static integer_type* get(context &ctx, unsigned width); + +private: + unsigned bitwidth_; }; class composite_type: public type{ +protected: + using type::type; + public: bool index_valid(value *idx) const; type* get_type_at_index(value *idx) const; }; -class tile_type: public type { +class tile_type: public composite_type { +private: + tile_type(type *ty, const std::vector &shapes); + static bool is_valid_elt_ty(type *ty); + public: + // accessors + const std::vector& get_shapes() const { return shapes_; } + + // factory methods static tile_type* get(type *ty, const std::vector &shapes); static tile_type* get_same_shapes(type *ty, type *ref); + +private: + std::vector shapes_; }; class pointer_type: public type { +private: + pointer_type(type *ty, unsigned address_space); + static bool is_valid_elt_ty(type *ty); + public: + // accessors + unsigned get_address_space() const { return address_space_; } + type *get_element_ty() const { return contained_tys_[0]; } + + // factory methods static pointer_type* get(type *ty, unsigned address_space); - type *get_element_ty() const; + +private: + unsigned address_space_; }; class function_type: public type { diff --git a/include/ir/value.h b/include/ir/value.h index effa44014..1b26391f3 100644 --- a/include/ir/value.h +++ b/include/ir/value.h @@ -23,7 +23,7 @@ public: void add_use(use *arg); // name void set_name(const std::string &name); - type* get_type() { return ty_; } + type* get_type() const { return ty_; } private: type *ty_; diff --git a/lib/codegen.cpp b/lib/codegen.cpp index 1ef5df769..9f8ad8420 100644 --- a/lib/codegen.cpp +++ b/lib/codegen.cpp @@ -25,10 +25,10 @@ ir::type* declaration_specifier::type(ir::module *mod) const { ir::context &ctx = mod->get_context(); switch (spec_) { case VOID_T: return ir::type::get_void_ty(ctx); - case INT8_T: return ir::integer_type::get(ctx, 8); - case INT16_T: return ir::integer_type::get(ctx, 16); - case INT32_T: return ir::integer_type::get(ctx, 32); - case INT64_T: return ir::integer_type::get(ctx, 64); + case INT8_T: return ir::type::get_int8_ty(ctx); + case INT16_T: return ir::type::get_int16_ty(ctx); + case INT32_T: return ir::type::get_int32_ty(ctx); + case INT64_T: return ir::type::get_int64_ty(ctx); case FLOAT32_T: return ir::type::get_float_ty(ctx); case FLOAT64_T: return ir::type::get_double_ty(ctx); default: throw std::runtime_error("unreachable"); @@ -227,7 +227,7 @@ ir::value *llvm_cast(ir::builder &builder, ir::value *src, ir::type *dst_ty){ return builder.create_fp_trunc(src, dst_ty); else if(src_ty->is_integer_ty() && dst_ty->is_integer_ty() && - src_ty->get_integer_bit_width()) + src_ty->get_integer_bitwidth()) return builder.create_int_cast(src, dst_ty, dst_signed); else @@ -259,8 +259,8 @@ inline void implicit_cast(ir::builder &builder, ir::value *&lhs, ir::value *&rhs else if(left_ty->is_integer_ty() && right_ty->is_integer_ty()){ is_int = true; is_signed = false; - if(left_ty->get_integer_bit_width() != right_ty->get_integer_bit_width()){ - ir::value *&to_convert = (left_ty->get_integer_bit_width() > right_ty->get_integer_bit_width())?rhs:lhs; + if(left_ty->get_integer_bitwidth() != right_ty->get_integer_bitwidth()){ + ir::value *&to_convert = (left_ty->get_integer_bitwidth() > right_ty->get_integer_bitwidth())?rhs:lhs; ir::type *dst_ty = (to_convert==lhs)?right_ty:left_ty; to_convert = llvm_cast(builder, to_convert, dst_ty); } diff --git a/lib/ir/context.cpp b/lib/ir/context.cpp index 8357b0ab1..56b64b4a3 100644 --- a/lib/ir/context.cpp +++ b/lib/ir/context.cpp @@ -1,7 +1,29 @@ +#include "ir/context_impl.h" #include "ir/context.h" +#include "ir/type.h" namespace tdl{ namespace ir{ +//===----------------------------------------------------------------------===// +// context implementation +//===----------------------------------------------------------------------===// + +context_impl::context_impl(context &ctx) + : void_ty(ctx, type::VoidTyID), + label_ty(ctx, type::LabelTyID), + half_ty(ctx, type::HalfTyID), + float_ty(ctx, type::FloatTyID), + double_ty(ctx, type::DoubleTyID), + int1_ty(ctx, 1), + int8_ty(ctx, 8), + int16_ty(ctx, 16), + int32_ty(ctx, 32), + int64_ty(ctx, 64), + int128_ty(ctx, 128) +{ + +} + } } diff --git a/lib/ir/instructions.cpp b/lib/ir/instructions.cpp index b72341d72..a42dc0c4b 100644 --- a/lib/ir/instructions.cpp +++ b/lib/ir/instructions.cpp @@ -25,6 +25,9 @@ instruction::instruction(type *ty, unsigned num_ops, const std::string &name, in // phi_node classes //===----------------------------------------------------------------------===// +phi_node::phi_node(type *ty, unsigned num_reserved, std::string const &name, instruction *next) + : instruction(ty, num_reserved, name, next){ } + // Set incoming value void phi_node::set_incoming_value(unsigned i, value *v){ assert(v && "PHI node got a null value!"); @@ -51,8 +54,8 @@ void phi_node::add_incoming(value *v, basic_block *block){ } // Factory methods -phi_node* phi_node::create(type *ty, unsigned num_reserved){ - return new phi_node(ty, num_reserved); +phi_node* phi_node::create(type *ty, unsigned num_reserved, const std::string &name, instruction *next){ + return new phi_node(ty, num_reserved, name, next); } @@ -103,7 +106,7 @@ cmp_inst::cmp_inst(type *ty, cmp_inst::pred_t pred, value *lhs, value *rhs, cons } type* cmp_inst::make_cmp_result_type(type *ty){ - type* int1_ty = ty->get_context().get_int1_ty(); + type* int1_ty = type::get_int1_ty(ty->get_context()); if (tile_type* tile_ty = dynamic_cast(ty)) return tile_type::get_same_shapes(int1_ty, tile_ty); return int1_ty; @@ -173,8 +176,8 @@ cast_inst *cast_inst::create(op_t op, value *arg, type *ty, const std::string &n cast_inst *cast_inst::create_integer_cast(value *arg, type *ty, bool is_signed, const std::string &name, instruction *next){ type *arg_ty = arg->get_type(); assert(arg_ty->is_int_or_tileint_ty() && ty->is_int_or_tileint_ty() && "Invalid integer cast!"); - unsigned arg_bits = arg_ty->get_scalar_bitsize(); - unsigned dst_bits = ty->get_scalar_bitsize(); + unsigned arg_bits = arg_ty->get_integer_bitwidth(); + unsigned dst_bits = ty->get_integer_bitwidth(); op_t op = (arg_bits == dst_bits ? ic::BitCast : (arg_bits > dst_bits ? ic::Trunc : (is_signed ? ic::SExt : ic::ZExt))); @@ -189,7 +192,7 @@ cast_inst *cast_inst::create_integer_cast(value *arg, type *ty, bool is_signed, // return_inst return_inst::return_inst(context &ctx, value *ret_val, instruction *next) - : terminator_inst(ctx.get_void_ty(), !!ret_val, "", next){ + : terminator_inst(type::get_void_ty(ctx), !!ret_val, "", next){ if(ret_val) set_operand(0, ret_val); } @@ -202,12 +205,12 @@ return_inst *return_inst::create(context &ctx, value *ret_val, instruction *next // conditional/unconditional branch branch_inst::branch_inst(basic_block *dst, instruction *next) - : terminator_inst(dst->get_context().get_void_ty(), 1, "", next){ + : terminator_inst(type::get_void_ty(dst->get_context()), 1, "", next){ set_operand(0, dst); } branch_inst::branch_inst(basic_block *if_dst, basic_block *else_dst, value *cond, instruction *next) - : terminator_inst(if_dst->get_context().get_void_ty(), 3, "", next){ + : terminator_inst(type::get_void_ty(if_dst->get_context()), 3, "", next){ assert(cond->get_type()->is_integer_ty(1) && "May only branch on boolean predicates!"); set_operand(0, if_dst); set_operand(1, else_dst); diff --git a/lib/ir/type.cpp b/lib/ir/type.cpp index e69de29bb..bd49100d1 100644 --- a/lib/ir/type.cpp +++ b/lib/ir/type.cpp @@ -0,0 +1,156 @@ +#include +#include "ir/type.h" +#include "ir/context.h" +#include "ir/context_impl.h" +#include "ir/value.h" + +namespace tdl{ +namespace ir{ + +//===----------------------------------------------------------------------===// +// type class +//===----------------------------------------------------------------------===// + +// attributes +type *type::get_scalar_ty() const { + if(is_tile_ty()) + return get_tile_element_ty(); + return const_cast(this); +} + +unsigned type::get_integer_bitwidth() const +{ return ((integer_type*)(this))->get_bitwidth(); } + +unsigned type::get_fp_mantissa_width() const { + id_t id = get_scalar_ty()->id_; + assert(is_floating_point_ty() && "Not a floating point type!"); + if (id == HalfTyID) return 11; + if (id == FloatTyID) return 24; + if (id == DoubleTyID) return 53; + throw std::runtime_error("unreachable"); +} + +type* type::get_tile_element_ty() const { + assert(is_tile_ty()); + return contained_tys_[0]; +} + +unsigned type::get_pointer_address_space() const { + assert(is_pointer_ty()); + return ((pointer_type*)this)->get_address_space(); +} + +const std::vector &type::get_tile_shapes() const { + assert(is_tile_ty()); + return ((tile_type*)this)->get_shapes(); +} + + +// composite predicates +bool type::is_int_or_tileint_ty() +{ return get_scalar_ty()->is_integer_ty(); } + +bool type::is_integer_ty(unsigned width) const +{ return is_integer_ty() && get_integer_bitwidth()== width; } + + +bool type::is_floating_point_ty() const +{ return is_half_ty() || is_float_ty() || is_double_ty(); } + +bool type::is_sized() const { + // primitive types are sized + if(is_integer_ty() || is_floating_point_ty() || + is_pointer_ty()){ + return true; + } + // tile types are sizes + if(is_tile_ty()) + return get_scalar_ty()->is_sized(); + return false; +} + +// primitive types +type *type::get_void_ty(context &ctx) { return &ctx.p_impl->void_ty; } +type *type::get_label_ty(context &ctx) { return &ctx.p_impl->label_ty; } +// half +type *type::get_half_ty(context &ctx) { return &ctx.p_impl->half_ty; } +type *type::get_float_ty(context &ctx) { return &ctx.p_impl->float_ty; } +type *type::get_double_ty(context &ctx) { return &ctx.p_impl->double_ty; } +// integer types +integer_type *type::get_int1_ty(context &ctx) { return &ctx.p_impl->int1_ty; } +integer_type *type::get_int8_ty(context &ctx) { return &ctx.p_impl->int8_ty; } +integer_type *type::get_int16_ty(context &ctx) { return &ctx.p_impl->int16_ty; } +integer_type *type::get_int32_ty(context &ctx) { return &ctx.p_impl->int32_ty; } +integer_type *type::get_int64_ty(context &ctx) { return &ctx.p_impl->int64_ty; } +integer_type *type::get_int128_ty(context &ctx) { return &ctx.p_impl->int128_ty; } + + + +pointer_type::pointer_type(type *ty, unsigned address_space) + : type(ty->get_context(), PointerTyID), address_space_(address_space){ + contained_tys_.push_back(ty); +} + +bool pointer_type::is_valid_elt_ty(type *ty){ + return !ty->is_void_ty() && !ty->is_label_ty() && + !ty->is_metadata_ty() && !ty->is_token_ty(); +} + +pointer_type* pointer_type::get(type *elt_ty, unsigned address_space){ + assert(elt_ty && "Can't get a pointer to type!"); + assert(is_valid_elt_ty(elt_ty) && "Invalid type for pointer element!"); + // look-up + context_impl *impl = elt_ty->get_context().p_impl.get(); + pointer_type *&entry = impl->ptr_tys[std::make_pair(elt_ty, address_space)]; + if(!entry) + entry = new pointer_type(elt_ty, address_space); + return entry; +} + +//===----------------------------------------------------------------------===// +// composite_type class +//===----------------------------------------------------------------------===// + +type* composite_type::get_type_at_index(value *) const{ + assert(is_tile_ty()); + return get_scalar_ty(); +} + +bool composite_type::index_valid(value *idx) const{ + assert(is_tile_ty()); + return idx->get_type()->is_int_or_tileint_ty(); +} + +//===----------------------------------------------------------------------===// +// tile_type class +//===----------------------------------------------------------------------===// + +tile_type::tile_type(type *ty, const std::vector &shapes) + : composite_type(ty->get_context(), TileTyID), shapes_(shapes) { + contained_tys_.push_back(ty); +} + +bool tile_type::is_valid_elt_ty(type *ty) { + return ty->is_pointer_ty() || ty->is_floating_point_ty() || ty->is_integer_ty(); +} + +tile_type* tile_type::get(type *elt_ty, const std::vector &shapes) { + assert(elt_ty && "Can't get a tile of type!"); + assert(shapes.size() && "Can't create a tile with empty shapes!"); + assert(is_valid_elt_ty(elt_ty) && "Invalid type for pointer element!"); + // look-up + context_impl *impl = elt_ty->get_context().p_impl.get(); + tile_type *&entry = impl->tile_tys[std::make_pair(elt_ty, shapes)]; + if(!entry) + entry = new tile_type(elt_ty, shapes); + return entry; +} + +tile_type* tile_type::get_same_shapes(type *ty, type *ref){ + assert(ref->is_tile_ty()); + return get(ty, ref->get_tile_shapes()); +} + + +} +}