[intermediate representation] added subdefinitions in types submodule

This commit is contained in:
Philippe Tillet
2019-01-03 00:42:37 -05:00
parent 22a83ab526
commit b039498d15
8 changed files with 328 additions and 37 deletions

View File

@@ -1,18 +1,22 @@
#ifndef TDL_INCLUDE_IR_CONTEXT_H
#define TDL_INCLUDE_IR_CONTEXT_H
#include <memory>
#include "ir/type.h"
namespace tdl{
namespace ir{
class type;
class context_impl;
/* Context */
class context {
public:
type *get_void_ty();
type *get_int1_ty();
context();
private:
public:
std::shared_ptr<context_impl> p_impl;
};
}

View File

@@ -36,7 +36,7 @@ private:
class phi_node: public instruction{
private:
phi_node(type *ty, unsigned num_reserved);
phi_node(type *ty, unsigned num_reserved, const std::string &name, instruction *next);
public:
void set_incoming_value(unsigned i, value *v);
@@ -45,7 +45,7 @@ public:
void add_incoming(value *v, basic_block *block);
// Factory methods
static phi_node* create(type *ty, unsigned num_reserved);
static phi_node* create(type *ty, unsigned num_reserved, const std::string &name = "", instruction *next = nullptr);
private:
unsigned num_reserved_;
@@ -235,6 +235,27 @@ private:
type *res_elt_ty;
};
//===----------------------------------------------------------------------===//
// retile_inst classes
//===----------------------------------------------------------------------===//
class retile_inst: public instruction{
};
class reshape_inst: public instruction{
};
class splat_inst: public instruction{
};
class broadcast_inst: public instruction{
};
}
}

View File

@@ -8,62 +8,147 @@ namespace ir{
class context;
class value;
class integer_type;
/* Type */
class type {
public:
enum id_t {
// primitive types
VoidTyID = 0, ///< 0: type with no size
HalfTyID, ///< 1: 16-bit floating point type
FloatTyID, ///< 2: 32-bit floating point type
DoubleTyID, ///< 3: 64-bit floating point type
LabelTyID, ///< 4: Labels
MetadataTyID, ///< 5: Metadata
TokenTyID, ///< 6: Token
// derived types
IntegerTyID, ///< 7: Arbitrary bit width integers
FunctionTyID, ///< 8: Functions
PointerTyID, ///< 9: Pointers
TileTyID, ///< 10: Tile
};
public:
//constructors
type(context &ctx, id_t id) : ctx_(ctx), id_(id) {}
//destructor
virtual ~type(){}
// accessors
context &get_context() const;
context &get_context() const { return ctx_; }
// type attributes
unsigned get_fp_mantissa_width() const;
unsigned get_integer_bit_width() const;
unsigned get_scalar_bitsize() const;
const std::vector<unsigned> &get_tile_shapes() const;
unsigned get_integer_bitwidth() const;
type *get_scalar_ty() const;
const std::vector<unsigned> &get_tile_shapes() const;
type *get_tile_element_ty() const;
unsigned get_pointer_address_space() const;
// type predicates
// primitive predicates
bool is_void_ty() const { return id_ == VoidTyID; }
bool is_half_ty() const { return id_ == HalfTyID; }
bool is_float_ty() const { return id_ == FloatTyID; }
bool is_double_ty() const { return id_ == DoubleTyID; }
bool is_label_ty() const { return id_ == LabelTyID;}
bool is_metadata_ty() const { return id_ == MetadataTyID; }
bool is_token_ty() const { return id_ == TokenTyID; }
bool is_integer_ty() const { return id_ == IntegerTyID; }
bool is_pointer_ty() const { return id_ == PointerTyID; }
bool is_tile_ty() const { return id_ == TileTyID; }
// Composite predicates
bool is_int_or_tileint_ty();
bool is_integer_ty() const;
bool is_integer_ty(unsigned width) const;
bool is_pointer_ty() const;
bool is_float_ty() const;
bool is_double_ty() const;
bool is_floating_point_ty() const;
bool is_sized() const;
bool is_tile_ty() const;
bool is_sized() const ;
// Factory methods
static type* get_void_ty(context &ctx);
static type* get_float_ty(context &ctx);
static type* get_double_ty(context &ctx);
// primitive types
static type *get_void_ty(context &ctx);
static type *get_label_ty(context &ctx);
// half
static type *get_half_ty(context &ctx);
static type *get_float_ty(context &ctx);
static type *get_double_ty(context &ctx);
// integer types
static integer_type *get_int1_ty(context &ctx);
static integer_type *get_int8_ty(context &ctx);
static integer_type *get_int16_ty(context &ctx);
static integer_type *get_int32_ty(context &ctx);
static integer_type *get_int64_ty(context &ctx);
static integer_type *get_int128_ty(context &ctx);
private:
context &ctx_;
id_t id_;
protected:
std::vector<type*> contained_tys_;
};
class integer_type: public type {
friend class context_impl;
private:
// constructors
integer_type(context &ctx, unsigned bitwidth)
: type(ctx, IntegerTyID), bitwidth_(bitwidth){ }
public:
// accessors
unsigned get_bitwidth() const { return bitwidth_; }
// factory methods
static integer_type* get(context &ctx, unsigned width);
private:
unsigned bitwidth_;
};
class composite_type: public type{
protected:
using type::type;
public:
bool index_valid(value *idx) const;
type* get_type_at_index(value *idx) const;
};
class tile_type: public type {
class tile_type: public composite_type {
private:
tile_type(type *ty, const std::vector<unsigned> &shapes);
static bool is_valid_elt_ty(type *ty);
public:
// accessors
const std::vector<unsigned>& get_shapes() const { return shapes_; }
// factory methods
static tile_type* get(type *ty, const std::vector<unsigned> &shapes);
static tile_type* get_same_shapes(type *ty, type *ref);
private:
std::vector<unsigned> shapes_;
};
class pointer_type: public type {
private:
pointer_type(type *ty, unsigned address_space);
static bool is_valid_elt_ty(type *ty);
public:
// accessors
unsigned get_address_space() const { return address_space_; }
type *get_element_ty() const { return contained_tys_[0]; }
// factory methods
static pointer_type* get(type *ty, unsigned address_space);
type *get_element_ty() const;
private:
unsigned address_space_;
};
class function_type: public type {

View File

@@ -23,7 +23,7 @@ public:
void add_use(use *arg);
// name
void set_name(const std::string &name);
type* get_type() { return ty_; }
type* get_type() const { return ty_; }
private:
type *ty_;

View File

@@ -25,10 +25,10 @@ ir::type* declaration_specifier::type(ir::module *mod) const {
ir::context &ctx = mod->get_context();
switch (spec_) {
case VOID_T: return ir::type::get_void_ty(ctx);
case INT8_T: return ir::integer_type::get(ctx, 8);
case INT16_T: return ir::integer_type::get(ctx, 16);
case INT32_T: return ir::integer_type::get(ctx, 32);
case INT64_T: return ir::integer_type::get(ctx, 64);
case INT8_T: return ir::type::get_int8_ty(ctx);
case INT16_T: return ir::type::get_int16_ty(ctx);
case INT32_T: return ir::type::get_int32_ty(ctx);
case INT64_T: return ir::type::get_int64_ty(ctx);
case FLOAT32_T: return ir::type::get_float_ty(ctx);
case FLOAT64_T: return ir::type::get_double_ty(ctx);
default: throw std::runtime_error("unreachable");
@@ -227,7 +227,7 @@ ir::value *llvm_cast(ir::builder &builder, ir::value *src, ir::type *dst_ty){
return builder.create_fp_trunc(src, dst_ty);
else if(src_ty->is_integer_ty() && dst_ty->is_integer_ty() &&
src_ty->get_integer_bit_width())
src_ty->get_integer_bitwidth())
return builder.create_int_cast(src, dst_ty, dst_signed);
else
@@ -259,8 +259,8 @@ inline void implicit_cast(ir::builder &builder, ir::value *&lhs, ir::value *&rhs
else if(left_ty->is_integer_ty() && right_ty->is_integer_ty()){
is_int = true;
is_signed = false;
if(left_ty->get_integer_bit_width() != right_ty->get_integer_bit_width()){
ir::value *&to_convert = (left_ty->get_integer_bit_width() > right_ty->get_integer_bit_width())?rhs:lhs;
if(left_ty->get_integer_bitwidth() != right_ty->get_integer_bitwidth()){
ir::value *&to_convert = (left_ty->get_integer_bitwidth() > right_ty->get_integer_bitwidth())?rhs:lhs;
ir::type *dst_ty = (to_convert==lhs)?right_ty:left_ty;
to_convert = llvm_cast(builder, to_convert, dst_ty);
}

View File

@@ -1,7 +1,29 @@
#include "ir/context_impl.h"
#include "ir/context.h"
#include "ir/type.h"
namespace tdl{
namespace ir{
//===----------------------------------------------------------------------===//
// context implementation
//===----------------------------------------------------------------------===//
context_impl::context_impl(context &ctx)
: void_ty(ctx, type::VoidTyID),
label_ty(ctx, type::LabelTyID),
half_ty(ctx, type::HalfTyID),
float_ty(ctx, type::FloatTyID),
double_ty(ctx, type::DoubleTyID),
int1_ty(ctx, 1),
int8_ty(ctx, 8),
int16_ty(ctx, 16),
int32_ty(ctx, 32),
int64_ty(ctx, 64),
int128_ty(ctx, 128)
{
}
}
}

View File

@@ -25,6 +25,9 @@ instruction::instruction(type *ty, unsigned num_ops, const std::string &name, in
// phi_node classes
//===----------------------------------------------------------------------===//
phi_node::phi_node(type *ty, unsigned num_reserved, std::string const &name, instruction *next)
: instruction(ty, num_reserved, name, next){ }
// Set incoming value
void phi_node::set_incoming_value(unsigned i, value *v){
assert(v && "PHI node got a null value!");
@@ -51,8 +54,8 @@ void phi_node::add_incoming(value *v, basic_block *block){
}
// Factory methods
phi_node* phi_node::create(type *ty, unsigned num_reserved){
return new phi_node(ty, num_reserved);
phi_node* phi_node::create(type *ty, unsigned num_reserved, const std::string &name, instruction *next){
return new phi_node(ty, num_reserved, name, next);
}
@@ -103,7 +106,7 @@ cmp_inst::cmp_inst(type *ty, cmp_inst::pred_t pred, value *lhs, value *rhs, cons
}
type* cmp_inst::make_cmp_result_type(type *ty){
type* int1_ty = ty->get_context().get_int1_ty();
type* int1_ty = type::get_int1_ty(ty->get_context());
if (tile_type* tile_ty = dynamic_cast<tile_type*>(ty))
return tile_type::get_same_shapes(int1_ty, tile_ty);
return int1_ty;
@@ -173,8 +176,8 @@ cast_inst *cast_inst::create(op_t op, value *arg, type *ty, const std::string &n
cast_inst *cast_inst::create_integer_cast(value *arg, type *ty, bool is_signed, const std::string &name, instruction *next){
type *arg_ty = arg->get_type();
assert(arg_ty->is_int_or_tileint_ty() && ty->is_int_or_tileint_ty() && "Invalid integer cast!");
unsigned arg_bits = arg_ty->get_scalar_bitsize();
unsigned dst_bits = ty->get_scalar_bitsize();
unsigned arg_bits = arg_ty->get_integer_bitwidth();
unsigned dst_bits = ty->get_integer_bitwidth();
op_t op = (arg_bits == dst_bits ? ic::BitCast :
(arg_bits > dst_bits ? ic::Trunc :
(is_signed ? ic::SExt : ic::ZExt)));
@@ -189,7 +192,7 @@ cast_inst *cast_inst::create_integer_cast(value *arg, type *ty, bool is_signed,
// return_inst
return_inst::return_inst(context &ctx, value *ret_val, instruction *next)
: terminator_inst(ctx.get_void_ty(), !!ret_val, "", next){
: terminator_inst(type::get_void_ty(ctx), !!ret_val, "", next){
if(ret_val)
set_operand(0, ret_val);
}
@@ -202,12 +205,12 @@ return_inst *return_inst::create(context &ctx, value *ret_val, instruction *next
// conditional/unconditional branch
branch_inst::branch_inst(basic_block *dst, instruction *next)
: terminator_inst(dst->get_context().get_void_ty(), 1, "", next){
: terminator_inst(type::get_void_ty(dst->get_context()), 1, "", next){
set_operand(0, dst);
}
branch_inst::branch_inst(basic_block *if_dst, basic_block *else_dst, value *cond, instruction *next)
: terminator_inst(if_dst->get_context().get_void_ty(), 3, "", next){
: terminator_inst(type::get_void_ty(if_dst->get_context()), 3, "", next){
assert(cond->get_type()->is_integer_ty(1) && "May only branch on boolean predicates!");
set_operand(0, if_dst);
set_operand(1, else_dst);

View File

@@ -0,0 +1,156 @@
#include <cassert>
#include "ir/type.h"
#include "ir/context.h"
#include "ir/context_impl.h"
#include "ir/value.h"
namespace tdl{
namespace ir{
//===----------------------------------------------------------------------===//
// type class
//===----------------------------------------------------------------------===//
// attributes
type *type::get_scalar_ty() const {
if(is_tile_ty())
return get_tile_element_ty();
return const_cast<type*>(this);
}
unsigned type::get_integer_bitwidth() const
{ return ((integer_type*)(this))->get_bitwidth(); }
unsigned type::get_fp_mantissa_width() const {
id_t id = get_scalar_ty()->id_;
assert(is_floating_point_ty() && "Not a floating point type!");
if (id == HalfTyID) return 11;
if (id == FloatTyID) return 24;
if (id == DoubleTyID) return 53;
throw std::runtime_error("unreachable");
}
type* type::get_tile_element_ty() const {
assert(is_tile_ty());
return contained_tys_[0];
}
unsigned type::get_pointer_address_space() const {
assert(is_pointer_ty());
return ((pointer_type*)this)->get_address_space();
}
const std::vector<unsigned> &type::get_tile_shapes() const {
assert(is_tile_ty());
return ((tile_type*)this)->get_shapes();
}
// composite predicates
bool type::is_int_or_tileint_ty()
{ return get_scalar_ty()->is_integer_ty(); }
bool type::is_integer_ty(unsigned width) const
{ return is_integer_ty() && get_integer_bitwidth()== width; }
bool type::is_floating_point_ty() const
{ return is_half_ty() || is_float_ty() || is_double_ty(); }
bool type::is_sized() const {
// primitive types are sized
if(is_integer_ty() || is_floating_point_ty() ||
is_pointer_ty()){
return true;
}
// tile types are sizes
if(is_tile_ty())
return get_scalar_ty()->is_sized();
return false;
}
// primitive types
type *type::get_void_ty(context &ctx) { return &ctx.p_impl->void_ty; }
type *type::get_label_ty(context &ctx) { return &ctx.p_impl->label_ty; }
// half
type *type::get_half_ty(context &ctx) { return &ctx.p_impl->half_ty; }
type *type::get_float_ty(context &ctx) { return &ctx.p_impl->float_ty; }
type *type::get_double_ty(context &ctx) { return &ctx.p_impl->double_ty; }
// integer types
integer_type *type::get_int1_ty(context &ctx) { return &ctx.p_impl->int1_ty; }
integer_type *type::get_int8_ty(context &ctx) { return &ctx.p_impl->int8_ty; }
integer_type *type::get_int16_ty(context &ctx) { return &ctx.p_impl->int16_ty; }
integer_type *type::get_int32_ty(context &ctx) { return &ctx.p_impl->int32_ty; }
integer_type *type::get_int64_ty(context &ctx) { return &ctx.p_impl->int64_ty; }
integer_type *type::get_int128_ty(context &ctx) { return &ctx.p_impl->int128_ty; }
pointer_type::pointer_type(type *ty, unsigned address_space)
: type(ty->get_context(), PointerTyID), address_space_(address_space){
contained_tys_.push_back(ty);
}
bool pointer_type::is_valid_elt_ty(type *ty){
return !ty->is_void_ty() && !ty->is_label_ty() &&
!ty->is_metadata_ty() && !ty->is_token_ty();
}
pointer_type* pointer_type::get(type *elt_ty, unsigned address_space){
assert(elt_ty && "Can't get a pointer to <null> type!");
assert(is_valid_elt_ty(elt_ty) && "Invalid type for pointer element!");
// look-up
context_impl *impl = elt_ty->get_context().p_impl.get();
pointer_type *&entry = impl->ptr_tys[std::make_pair(elt_ty, address_space)];
if(!entry)
entry = new pointer_type(elt_ty, address_space);
return entry;
}
//===----------------------------------------------------------------------===//
// composite_type class
//===----------------------------------------------------------------------===//
type* composite_type::get_type_at_index(value *) const{
assert(is_tile_ty());
return get_scalar_ty();
}
bool composite_type::index_valid(value *idx) const{
assert(is_tile_ty());
return idx->get_type()->is_int_or_tileint_ty();
}
//===----------------------------------------------------------------------===//
// tile_type class
//===----------------------------------------------------------------------===//
tile_type::tile_type(type *ty, const std::vector<unsigned> &shapes)
: composite_type(ty->get_context(), TileTyID), shapes_(shapes) {
contained_tys_.push_back(ty);
}
bool tile_type::is_valid_elt_ty(type *ty) {
return ty->is_pointer_ty() || ty->is_floating_point_ty() || ty->is_integer_ty();
}
tile_type* tile_type::get(type *elt_ty, const std::vector<unsigned> &shapes) {
assert(elt_ty && "Can't get a tile of <null> type!");
assert(shapes.size() && "Can't create a tile with empty shapes!");
assert(is_valid_elt_ty(elt_ty) && "Invalid type for pointer element!");
// look-up
context_impl *impl = elt_ty->get_context().p_impl.get();
tile_type *&entry = impl->tile_tys[std::make_pair(elt_ty, shapes)];
if(!entry)
entry = new tile_type(elt_ty, shapes);
return entry;
}
tile_type* tile_type::get_same_shapes(type *ty, type *ref){
assert(ref->is_tile_ty());
return get(ty, ref->get_tile_shapes());
}
}
}