[intermediate representation] added subdefinitions in types submodule
This commit is contained in:
@@ -1,18 +1,22 @@
|
||||
#ifndef TDL_INCLUDE_IR_CONTEXT_H
|
||||
#define TDL_INCLUDE_IR_CONTEXT_H
|
||||
|
||||
#include <memory>
|
||||
#include "ir/type.h"
|
||||
|
||||
namespace tdl{
|
||||
namespace ir{
|
||||
|
||||
class type;
|
||||
class context_impl;
|
||||
|
||||
/* Context */
|
||||
class context {
|
||||
public:
|
||||
type *get_void_ty();
|
||||
type *get_int1_ty();
|
||||
context();
|
||||
|
||||
private:
|
||||
public:
|
||||
std::shared_ptr<context_impl> p_impl;
|
||||
};
|
||||
|
||||
}
|
||||
|
@@ -36,7 +36,7 @@ private:
|
||||
|
||||
class phi_node: public instruction{
|
||||
private:
|
||||
phi_node(type *ty, unsigned num_reserved);
|
||||
phi_node(type *ty, unsigned num_reserved, const std::string &name, instruction *next);
|
||||
|
||||
public:
|
||||
void set_incoming_value(unsigned i, value *v);
|
||||
@@ -45,7 +45,7 @@ public:
|
||||
void add_incoming(value *v, basic_block *block);
|
||||
|
||||
// Factory methods
|
||||
static phi_node* create(type *ty, unsigned num_reserved);
|
||||
static phi_node* create(type *ty, unsigned num_reserved, const std::string &name = "", instruction *next = nullptr);
|
||||
|
||||
private:
|
||||
unsigned num_reserved_;
|
||||
@@ -235,6 +235,27 @@ private:
|
||||
type *res_elt_ty;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// retile_inst classes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class retile_inst: public instruction{
|
||||
|
||||
};
|
||||
|
||||
class reshape_inst: public instruction{
|
||||
|
||||
};
|
||||
|
||||
class splat_inst: public instruction{
|
||||
|
||||
};
|
||||
|
||||
class broadcast_inst: public instruction{
|
||||
|
||||
};
|
||||
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
|
@@ -8,62 +8,147 @@ namespace ir{
|
||||
|
||||
class context;
|
||||
class value;
|
||||
class integer_type;
|
||||
|
||||
/* Type */
|
||||
class type {
|
||||
public:
|
||||
enum id_t {
|
||||
// primitive types
|
||||
VoidTyID = 0, ///< 0: type with no size
|
||||
HalfTyID, ///< 1: 16-bit floating point type
|
||||
FloatTyID, ///< 2: 32-bit floating point type
|
||||
DoubleTyID, ///< 3: 64-bit floating point type
|
||||
LabelTyID, ///< 4: Labels
|
||||
MetadataTyID, ///< 5: Metadata
|
||||
TokenTyID, ///< 6: Token
|
||||
// derived types
|
||||
IntegerTyID, ///< 7: Arbitrary bit width integers
|
||||
FunctionTyID, ///< 8: Functions
|
||||
PointerTyID, ///< 9: Pointers
|
||||
TileTyID, ///< 10: Tile
|
||||
};
|
||||
|
||||
public:
|
||||
//constructors
|
||||
type(context &ctx, id_t id) : ctx_(ctx), id_(id) {}
|
||||
|
||||
//destructor
|
||||
virtual ~type(){}
|
||||
|
||||
// accessors
|
||||
context &get_context() const;
|
||||
context &get_context() const { return ctx_; }
|
||||
|
||||
// type attributes
|
||||
unsigned get_fp_mantissa_width() const;
|
||||
unsigned get_integer_bit_width() const;
|
||||
unsigned get_scalar_bitsize() const;
|
||||
const std::vector<unsigned> &get_tile_shapes() const;
|
||||
unsigned get_integer_bitwidth() const;
|
||||
type *get_scalar_ty() const;
|
||||
const std::vector<unsigned> &get_tile_shapes() const;
|
||||
type *get_tile_element_ty() const;
|
||||
unsigned get_pointer_address_space() const;
|
||||
|
||||
// type predicates
|
||||
// primitive predicates
|
||||
bool is_void_ty() const { return id_ == VoidTyID; }
|
||||
bool is_half_ty() const { return id_ == HalfTyID; }
|
||||
bool is_float_ty() const { return id_ == FloatTyID; }
|
||||
bool is_double_ty() const { return id_ == DoubleTyID; }
|
||||
bool is_label_ty() const { return id_ == LabelTyID;}
|
||||
bool is_metadata_ty() const { return id_ == MetadataTyID; }
|
||||
bool is_token_ty() const { return id_ == TokenTyID; }
|
||||
bool is_integer_ty() const { return id_ == IntegerTyID; }
|
||||
bool is_pointer_ty() const { return id_ == PointerTyID; }
|
||||
bool is_tile_ty() const { return id_ == TileTyID; }
|
||||
|
||||
// Composite predicates
|
||||
bool is_int_or_tileint_ty();
|
||||
bool is_integer_ty() const;
|
||||
bool is_integer_ty(unsigned width) const;
|
||||
bool is_pointer_ty() const;
|
||||
bool is_float_ty() const;
|
||||
bool is_double_ty() const;
|
||||
bool is_floating_point_ty() const;
|
||||
bool is_sized() const;
|
||||
bool is_tile_ty() const;
|
||||
bool is_sized() const ;
|
||||
|
||||
// Factory methods
|
||||
static type* get_void_ty(context &ctx);
|
||||
static type* get_float_ty(context &ctx);
|
||||
static type* get_double_ty(context &ctx);
|
||||
// primitive types
|
||||
static type *get_void_ty(context &ctx);
|
||||
static type *get_label_ty(context &ctx);
|
||||
// half
|
||||
static type *get_half_ty(context &ctx);
|
||||
static type *get_float_ty(context &ctx);
|
||||
static type *get_double_ty(context &ctx);
|
||||
// integer types
|
||||
static integer_type *get_int1_ty(context &ctx);
|
||||
static integer_type *get_int8_ty(context &ctx);
|
||||
static integer_type *get_int16_ty(context &ctx);
|
||||
static integer_type *get_int32_ty(context &ctx);
|
||||
static integer_type *get_int64_ty(context &ctx);
|
||||
static integer_type *get_int128_ty(context &ctx);
|
||||
|
||||
private:
|
||||
context &ctx_;
|
||||
id_t id_;
|
||||
|
||||
protected:
|
||||
std::vector<type*> contained_tys_;
|
||||
};
|
||||
|
||||
class integer_type: public type {
|
||||
friend class context_impl;
|
||||
|
||||
private:
|
||||
// constructors
|
||||
integer_type(context &ctx, unsigned bitwidth)
|
||||
: type(ctx, IntegerTyID), bitwidth_(bitwidth){ }
|
||||
|
||||
public:
|
||||
// accessors
|
||||
unsigned get_bitwidth() const { return bitwidth_; }
|
||||
|
||||
// factory methods
|
||||
static integer_type* get(context &ctx, unsigned width);
|
||||
|
||||
private:
|
||||
unsigned bitwidth_;
|
||||
};
|
||||
|
||||
class composite_type: public type{
|
||||
protected:
|
||||
using type::type;
|
||||
|
||||
public:
|
||||
bool index_valid(value *idx) const;
|
||||
type* get_type_at_index(value *idx) const;
|
||||
};
|
||||
|
||||
class tile_type: public type {
|
||||
class tile_type: public composite_type {
|
||||
private:
|
||||
tile_type(type *ty, const std::vector<unsigned> &shapes);
|
||||
static bool is_valid_elt_ty(type *ty);
|
||||
|
||||
public:
|
||||
// accessors
|
||||
const std::vector<unsigned>& get_shapes() const { return shapes_; }
|
||||
|
||||
// factory methods
|
||||
static tile_type* get(type *ty, const std::vector<unsigned> &shapes);
|
||||
static tile_type* get_same_shapes(type *ty, type *ref);
|
||||
|
||||
private:
|
||||
std::vector<unsigned> shapes_;
|
||||
};
|
||||
|
||||
class pointer_type: public type {
|
||||
private:
|
||||
pointer_type(type *ty, unsigned address_space);
|
||||
static bool is_valid_elt_ty(type *ty);
|
||||
|
||||
public:
|
||||
// accessors
|
||||
unsigned get_address_space() const { return address_space_; }
|
||||
type *get_element_ty() const { return contained_tys_[0]; }
|
||||
|
||||
// factory methods
|
||||
static pointer_type* get(type *ty, unsigned address_space);
|
||||
type *get_element_ty() const;
|
||||
|
||||
private:
|
||||
unsigned address_space_;
|
||||
};
|
||||
|
||||
class function_type: public type {
|
||||
|
@@ -23,7 +23,7 @@ public:
|
||||
void add_use(use *arg);
|
||||
// name
|
||||
void set_name(const std::string &name);
|
||||
type* get_type() { return ty_; }
|
||||
type* get_type() const { return ty_; }
|
||||
|
||||
private:
|
||||
type *ty_;
|
||||
|
@@ -25,10 +25,10 @@ ir::type* declaration_specifier::type(ir::module *mod) const {
|
||||
ir::context &ctx = mod->get_context();
|
||||
switch (spec_) {
|
||||
case VOID_T: return ir::type::get_void_ty(ctx);
|
||||
case INT8_T: return ir::integer_type::get(ctx, 8);
|
||||
case INT16_T: return ir::integer_type::get(ctx, 16);
|
||||
case INT32_T: return ir::integer_type::get(ctx, 32);
|
||||
case INT64_T: return ir::integer_type::get(ctx, 64);
|
||||
case INT8_T: return ir::type::get_int8_ty(ctx);
|
||||
case INT16_T: return ir::type::get_int16_ty(ctx);
|
||||
case INT32_T: return ir::type::get_int32_ty(ctx);
|
||||
case INT64_T: return ir::type::get_int64_ty(ctx);
|
||||
case FLOAT32_T: return ir::type::get_float_ty(ctx);
|
||||
case FLOAT64_T: return ir::type::get_double_ty(ctx);
|
||||
default: throw std::runtime_error("unreachable");
|
||||
@@ -227,7 +227,7 @@ ir::value *llvm_cast(ir::builder &builder, ir::value *src, ir::type *dst_ty){
|
||||
return builder.create_fp_trunc(src, dst_ty);
|
||||
|
||||
else if(src_ty->is_integer_ty() && dst_ty->is_integer_ty() &&
|
||||
src_ty->get_integer_bit_width())
|
||||
src_ty->get_integer_bitwidth())
|
||||
return builder.create_int_cast(src, dst_ty, dst_signed);
|
||||
|
||||
else
|
||||
@@ -259,8 +259,8 @@ inline void implicit_cast(ir::builder &builder, ir::value *&lhs, ir::value *&rhs
|
||||
else if(left_ty->is_integer_ty() && right_ty->is_integer_ty()){
|
||||
is_int = true;
|
||||
is_signed = false;
|
||||
if(left_ty->get_integer_bit_width() != right_ty->get_integer_bit_width()){
|
||||
ir::value *&to_convert = (left_ty->get_integer_bit_width() > right_ty->get_integer_bit_width())?rhs:lhs;
|
||||
if(left_ty->get_integer_bitwidth() != right_ty->get_integer_bitwidth()){
|
||||
ir::value *&to_convert = (left_ty->get_integer_bitwidth() > right_ty->get_integer_bitwidth())?rhs:lhs;
|
||||
ir::type *dst_ty = (to_convert==lhs)?right_ty:left_ty;
|
||||
to_convert = llvm_cast(builder, to_convert, dst_ty);
|
||||
}
|
||||
|
@@ -1,7 +1,29 @@
|
||||
#include "ir/context_impl.h"
|
||||
#include "ir/context.h"
|
||||
#include "ir/type.h"
|
||||
|
||||
namespace tdl{
|
||||
namespace ir{
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// context implementation
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
context_impl::context_impl(context &ctx)
|
||||
: void_ty(ctx, type::VoidTyID),
|
||||
label_ty(ctx, type::LabelTyID),
|
||||
half_ty(ctx, type::HalfTyID),
|
||||
float_ty(ctx, type::FloatTyID),
|
||||
double_ty(ctx, type::DoubleTyID),
|
||||
int1_ty(ctx, 1),
|
||||
int8_ty(ctx, 8),
|
||||
int16_ty(ctx, 16),
|
||||
int32_ty(ctx, 32),
|
||||
int64_ty(ctx, 64),
|
||||
int128_ty(ctx, 128)
|
||||
{
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
@@ -25,6 +25,9 @@ instruction::instruction(type *ty, unsigned num_ops, const std::string &name, in
|
||||
// phi_node classes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
phi_node::phi_node(type *ty, unsigned num_reserved, std::string const &name, instruction *next)
|
||||
: instruction(ty, num_reserved, name, next){ }
|
||||
|
||||
// Set incoming value
|
||||
void phi_node::set_incoming_value(unsigned i, value *v){
|
||||
assert(v && "PHI node got a null value!");
|
||||
@@ -51,8 +54,8 @@ void phi_node::add_incoming(value *v, basic_block *block){
|
||||
}
|
||||
|
||||
// Factory methods
|
||||
phi_node* phi_node::create(type *ty, unsigned num_reserved){
|
||||
return new phi_node(ty, num_reserved);
|
||||
phi_node* phi_node::create(type *ty, unsigned num_reserved, const std::string &name, instruction *next){
|
||||
return new phi_node(ty, num_reserved, name, next);
|
||||
}
|
||||
|
||||
|
||||
@@ -103,7 +106,7 @@ cmp_inst::cmp_inst(type *ty, cmp_inst::pred_t pred, value *lhs, value *rhs, cons
|
||||
}
|
||||
|
||||
type* cmp_inst::make_cmp_result_type(type *ty){
|
||||
type* int1_ty = ty->get_context().get_int1_ty();
|
||||
type* int1_ty = type::get_int1_ty(ty->get_context());
|
||||
if (tile_type* tile_ty = dynamic_cast<tile_type*>(ty))
|
||||
return tile_type::get_same_shapes(int1_ty, tile_ty);
|
||||
return int1_ty;
|
||||
@@ -173,8 +176,8 @@ cast_inst *cast_inst::create(op_t op, value *arg, type *ty, const std::string &n
|
||||
cast_inst *cast_inst::create_integer_cast(value *arg, type *ty, bool is_signed, const std::string &name, instruction *next){
|
||||
type *arg_ty = arg->get_type();
|
||||
assert(arg_ty->is_int_or_tileint_ty() && ty->is_int_or_tileint_ty() && "Invalid integer cast!");
|
||||
unsigned arg_bits = arg_ty->get_scalar_bitsize();
|
||||
unsigned dst_bits = ty->get_scalar_bitsize();
|
||||
unsigned arg_bits = arg_ty->get_integer_bitwidth();
|
||||
unsigned dst_bits = ty->get_integer_bitwidth();
|
||||
op_t op = (arg_bits == dst_bits ? ic::BitCast :
|
||||
(arg_bits > dst_bits ? ic::Trunc :
|
||||
(is_signed ? ic::SExt : ic::ZExt)));
|
||||
@@ -189,7 +192,7 @@ cast_inst *cast_inst::create_integer_cast(value *arg, type *ty, bool is_signed,
|
||||
// return_inst
|
||||
|
||||
return_inst::return_inst(context &ctx, value *ret_val, instruction *next)
|
||||
: terminator_inst(ctx.get_void_ty(), !!ret_val, "", next){
|
||||
: terminator_inst(type::get_void_ty(ctx), !!ret_val, "", next){
|
||||
if(ret_val)
|
||||
set_operand(0, ret_val);
|
||||
}
|
||||
@@ -202,12 +205,12 @@ return_inst *return_inst::create(context &ctx, value *ret_val, instruction *next
|
||||
// conditional/unconditional branch
|
||||
|
||||
branch_inst::branch_inst(basic_block *dst, instruction *next)
|
||||
: terminator_inst(dst->get_context().get_void_ty(), 1, "", next){
|
||||
: terminator_inst(type::get_void_ty(dst->get_context()), 1, "", next){
|
||||
set_operand(0, dst);
|
||||
}
|
||||
|
||||
branch_inst::branch_inst(basic_block *if_dst, basic_block *else_dst, value *cond, instruction *next)
|
||||
: terminator_inst(if_dst->get_context().get_void_ty(), 3, "", next){
|
||||
: terminator_inst(type::get_void_ty(if_dst->get_context()), 3, "", next){
|
||||
assert(cond->get_type()->is_integer_ty(1) && "May only branch on boolean predicates!");
|
||||
set_operand(0, if_dst);
|
||||
set_operand(1, else_dst);
|
||||
|
156
lib/ir/type.cpp
156
lib/ir/type.cpp
@@ -0,0 +1,156 @@
|
||||
#include <cassert>
|
||||
#include "ir/type.h"
|
||||
#include "ir/context.h"
|
||||
#include "ir/context_impl.h"
|
||||
#include "ir/value.h"
|
||||
|
||||
namespace tdl{
|
||||
namespace ir{
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// type class
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// attributes
|
||||
type *type::get_scalar_ty() const {
|
||||
if(is_tile_ty())
|
||||
return get_tile_element_ty();
|
||||
return const_cast<type*>(this);
|
||||
}
|
||||
|
||||
unsigned type::get_integer_bitwidth() const
|
||||
{ return ((integer_type*)(this))->get_bitwidth(); }
|
||||
|
||||
unsigned type::get_fp_mantissa_width() const {
|
||||
id_t id = get_scalar_ty()->id_;
|
||||
assert(is_floating_point_ty() && "Not a floating point type!");
|
||||
if (id == HalfTyID) return 11;
|
||||
if (id == FloatTyID) return 24;
|
||||
if (id == DoubleTyID) return 53;
|
||||
throw std::runtime_error("unreachable");
|
||||
}
|
||||
|
||||
type* type::get_tile_element_ty() const {
|
||||
assert(is_tile_ty());
|
||||
return contained_tys_[0];
|
||||
}
|
||||
|
||||
unsigned type::get_pointer_address_space() const {
|
||||
assert(is_pointer_ty());
|
||||
return ((pointer_type*)this)->get_address_space();
|
||||
}
|
||||
|
||||
const std::vector<unsigned> &type::get_tile_shapes() const {
|
||||
assert(is_tile_ty());
|
||||
return ((tile_type*)this)->get_shapes();
|
||||
}
|
||||
|
||||
|
||||
// composite predicates
|
||||
bool type::is_int_or_tileint_ty()
|
||||
{ return get_scalar_ty()->is_integer_ty(); }
|
||||
|
||||
bool type::is_integer_ty(unsigned width) const
|
||||
{ return is_integer_ty() && get_integer_bitwidth()== width; }
|
||||
|
||||
|
||||
bool type::is_floating_point_ty() const
|
||||
{ return is_half_ty() || is_float_ty() || is_double_ty(); }
|
||||
|
||||
bool type::is_sized() const {
|
||||
// primitive types are sized
|
||||
if(is_integer_ty() || is_floating_point_ty() ||
|
||||
is_pointer_ty()){
|
||||
return true;
|
||||
}
|
||||
// tile types are sizes
|
||||
if(is_tile_ty())
|
||||
return get_scalar_ty()->is_sized();
|
||||
return false;
|
||||
}
|
||||
|
||||
// primitive types
|
||||
type *type::get_void_ty(context &ctx) { return &ctx.p_impl->void_ty; }
|
||||
type *type::get_label_ty(context &ctx) { return &ctx.p_impl->label_ty; }
|
||||
// half
|
||||
type *type::get_half_ty(context &ctx) { return &ctx.p_impl->half_ty; }
|
||||
type *type::get_float_ty(context &ctx) { return &ctx.p_impl->float_ty; }
|
||||
type *type::get_double_ty(context &ctx) { return &ctx.p_impl->double_ty; }
|
||||
// integer types
|
||||
integer_type *type::get_int1_ty(context &ctx) { return &ctx.p_impl->int1_ty; }
|
||||
integer_type *type::get_int8_ty(context &ctx) { return &ctx.p_impl->int8_ty; }
|
||||
integer_type *type::get_int16_ty(context &ctx) { return &ctx.p_impl->int16_ty; }
|
||||
integer_type *type::get_int32_ty(context &ctx) { return &ctx.p_impl->int32_ty; }
|
||||
integer_type *type::get_int64_ty(context &ctx) { return &ctx.p_impl->int64_ty; }
|
||||
integer_type *type::get_int128_ty(context &ctx) { return &ctx.p_impl->int128_ty; }
|
||||
|
||||
|
||||
|
||||
pointer_type::pointer_type(type *ty, unsigned address_space)
|
||||
: type(ty->get_context(), PointerTyID), address_space_(address_space){
|
||||
contained_tys_.push_back(ty);
|
||||
}
|
||||
|
||||
bool pointer_type::is_valid_elt_ty(type *ty){
|
||||
return !ty->is_void_ty() && !ty->is_label_ty() &&
|
||||
!ty->is_metadata_ty() && !ty->is_token_ty();
|
||||
}
|
||||
|
||||
pointer_type* pointer_type::get(type *elt_ty, unsigned address_space){
|
||||
assert(elt_ty && "Can't get a pointer to <null> type!");
|
||||
assert(is_valid_elt_ty(elt_ty) && "Invalid type for pointer element!");
|
||||
// look-up
|
||||
context_impl *impl = elt_ty->get_context().p_impl.get();
|
||||
pointer_type *&entry = impl->ptr_tys[std::make_pair(elt_ty, address_space)];
|
||||
if(!entry)
|
||||
entry = new pointer_type(elt_ty, address_space);
|
||||
return entry;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// composite_type class
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
type* composite_type::get_type_at_index(value *) const{
|
||||
assert(is_tile_ty());
|
||||
return get_scalar_ty();
|
||||
}
|
||||
|
||||
bool composite_type::index_valid(value *idx) const{
|
||||
assert(is_tile_ty());
|
||||
return idx->get_type()->is_int_or_tileint_ty();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// tile_type class
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
tile_type::tile_type(type *ty, const std::vector<unsigned> &shapes)
|
||||
: composite_type(ty->get_context(), TileTyID), shapes_(shapes) {
|
||||
contained_tys_.push_back(ty);
|
||||
}
|
||||
|
||||
bool tile_type::is_valid_elt_ty(type *ty) {
|
||||
return ty->is_pointer_ty() || ty->is_floating_point_ty() || ty->is_integer_ty();
|
||||
}
|
||||
|
||||
tile_type* tile_type::get(type *elt_ty, const std::vector<unsigned> &shapes) {
|
||||
assert(elt_ty && "Can't get a tile of <null> type!");
|
||||
assert(shapes.size() && "Can't create a tile with empty shapes!");
|
||||
assert(is_valid_elt_ty(elt_ty) && "Invalid type for pointer element!");
|
||||
// look-up
|
||||
context_impl *impl = elt_ty->get_context().p_impl.get();
|
||||
tile_type *&entry = impl->tile_tys[std::make_pair(elt_ty, shapes)];
|
||||
if(!entry)
|
||||
entry = new tile_type(elt_ty, shapes);
|
||||
return entry;
|
||||
}
|
||||
|
||||
tile_type* tile_type::get_same_shapes(type *ty, type *ref){
|
||||
assert(ref->is_tile_ty());
|
||||
return get(ty, ref->get_tile_shapes());
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user