[intermediate representation] added some instruction definitions

This commit is contained in:
Philippe Tillet
2019-01-02 19:29:59 -05:00
parent 0378b9eb43
commit 22a83ab526
6 changed files with 314 additions and 36 deletions

View File

@@ -24,8 +24,9 @@ public:
public:
// parent
// accessors
function* get_parent() { return parent_; }
context& get_context() { return ctx_; }
// get iterator to first instruction that is not a phi
iterator get_first_non_phi();
@@ -51,10 +52,8 @@ public:
inline const instruction &back() const { return *inst_list_.back(); }
inline instruction &back() { return *inst_list_.back(); }
// get predecessors
const std::vector<basic_block*>& get_predecessors() const;
// add predecessor
// predecessors
const std::vector<basic_block*>& get_predecessors() const { return preds_; }
void add_predecessor(basic_block* pred);
// factory functions

View File

@@ -4,8 +4,15 @@
namespace tdl{
namespace ir{
class type;
/* Context */
class context {
public:
type *get_void_ty();
type *get_int1_ty();
private:
};
}

View File

@@ -9,6 +9,7 @@ namespace tdl{
namespace ir{
class basic_block;
class context;
//===----------------------------------------------------------------------===//
// instruction classes
@@ -17,7 +18,7 @@ class basic_block;
class instruction: public user{
protected:
// constructors
instruction(type *ty, unsigned num_ops, instruction *next = nullptr);
instruction(type *ty, unsigned num_ops, const std::string &name = "", instruction *next = nullptr);
public:
@@ -38,13 +39,17 @@ private:
phi_node(type *ty, unsigned num_reserved);
public:
void add_incoming(value *x, basic_block *bb);
void set_incoming_value(unsigned i, value *v);
void set_incoming_block(unsigned i, basic_block *block);
void add_incoming(value *v, basic_block *block);
// Factory methods
static phi_node* create(type *ty, unsigned num_reserved);
private:
unsigned num_reserved_;
std::vector<basic_block*> blocks_;
};
//===----------------------------------------------------------------------===//
@@ -84,11 +89,10 @@ public:
typedef llvm::CmpInst::Predicate pred_t;
using pcmp = llvm::CmpInst;
private:
type* make_cmp_result_type(type *ty);
protected:
cmp_inst(pred_t pred, value *lhs, value *rhs, const std::string &name, instruction *next);
cmp_inst(type *ty, pred_t pred, value *lhs, value *rhs, const std::string &name, instruction *next);
static type* make_cmp_result_type(type *ty);
static bool is_fp_predicate(pred_t pred);
static bool is_int_predicate(pred_t pred);
@@ -116,17 +120,29 @@ public:
const std::string &name = "", instruction *next = nullptr);
};
//===----------------------------------------------------------------------===//
// unary_inst classes
//===----------------------------------------------------------------------===//
class unary_inst: public instruction {
protected:
unary_inst(type *Ty, value *v, const std::string &name, instruction *next);
};
//===----------------------------------------------------------------------===//
// cast_inst classes
//===----------------------------------------------------------------------===//
class cast_inst: public instruction{
class cast_inst: public unary_inst{
using unary_inst::unary_inst;
using ic = llvm::Instruction::CastOps;
public:
typedef llvm::CastInst::CastOps op_t;
protected:
// Constructors
cast_inst(op_t op, value *arg, type *ty, const std::string &name, instruction *next);
private:
bool is_valid(op_t op, value *arg, type *ty);
public:
// Factory methods
@@ -135,33 +151,67 @@ public:
static cast_inst *create_integer_cast(value *arg, type *ty, bool is_signed,
const std::string &name = "", instruction *next = nullptr);
private:
op_t op_;
};
#define TDL_IR_DECLARE_CAST_INST_SIMPLE(name) \
class name : public cast_inst{ \
friend class cast_inst; \
using cast_inst::cast_inst; \
};
TDL_IR_DECLARE_CAST_INST_SIMPLE(trunc_inst)
TDL_IR_DECLARE_CAST_INST_SIMPLE(z_ext_inst)
TDL_IR_DECLARE_CAST_INST_SIMPLE(s_ext_inst)
TDL_IR_DECLARE_CAST_INST_SIMPLE(fp_trunc_inst)
TDL_IR_DECLARE_CAST_INST_SIMPLE(fp_ext_inst)
TDL_IR_DECLARE_CAST_INST_SIMPLE(ui_to_fp_inst)
TDL_IR_DECLARE_CAST_INST_SIMPLE(si_to_fp_inst)
TDL_IR_DECLARE_CAST_INST_SIMPLE(fp_to_ui_inst)
TDL_IR_DECLARE_CAST_INST_SIMPLE(fp_to_si_inst)
TDL_IR_DECLARE_CAST_INST_SIMPLE(ptr_to_int_inst)
TDL_IR_DECLARE_CAST_INST_SIMPLE(int_to_ptr_inst)
TDL_IR_DECLARE_CAST_INST_SIMPLE(bit_cast_inst)
TDL_IR_DECLARE_CAST_INST_SIMPLE(addr_space_cast_inst)
//===----------------------------------------------------------------------===//
// terminator_inst classes
//===----------------------------------------------------------------------===//
class terminator_inst: public instruction{
public:
using instruction::instruction;
};
class return_inst: public instruction{
// return instruction
class return_inst: public terminator_inst{
return_inst(context &ctx, value *ret_val, instruction *next);
public:
// accessors
value *get_return_value()
{ return get_num_operands() ? get_operand(0) : nullptr; }
unsigned get_num_successors() const { return 0; }
// factory methods
static return_inst* create(context &ctx, value *ret_val = nullptr, instruction *next = nullptr);
};
//===----------------------------------------------------------------------===//
// branch_inst classes
//===----------------------------------------------------------------------===//
// conditional/unconditional branch instruction
class branch_inst: public terminator_inst{
branch_inst(basic_block *dst, instruction *next);
branch_inst(basic_block *if_dst, basic_block *else_dst, value *cond, instruction *next);
class branch_inst: public instruction{
public:
// factory methods
static branch_inst* create(basic_block *dest,
const std::string &name = "", instruction *next = nullptr);
instruction *next = nullptr);
static branch_inst* create(value *cond, basic_block *if_dest, basic_block *else_dest,
const std::string &name = "", instruction *next = nullptr);
instruction *next = nullptr);
};
//===----------------------------------------------------------------------===//
@@ -169,9 +219,20 @@ public:
//===----------------------------------------------------------------------===//
class getelementptr_inst: public instruction{
getelementptr_inst(type *pointee_ty, value *ptr, const std::vector<value*> &idx, const std::string &name, instruction *next);
private:
static type *get_return_type(type *ty, value *ptr, const std::vector<value*> &idx);
static type *get_indexed_type_impl(type *ty, const std::vector<value *> &idx);
static type *get_indexed_type(type *ty, const std::vector<value*> &idx);
public:
static getelementptr_inst* create(value *ptr, const std::vector<value*> &idx,
const std::string &name = "", instruction *next = nullptr);
private:
type *source_elt_ty;
type *res_elt_ty;
};

View File

@@ -7,24 +7,40 @@ namespace tdl{
namespace ir{
class context;
class value;
/* Type */
class type {
public:
bool is_integer_ty() const;
bool is_pointer_ty() const;
bool is_float_ty() const;
bool is_double_ty() const;
bool is_floating_point_ty() const;
virtual ~type(){}
// accessors
context &get_context() const;
// type attributes
unsigned get_fp_mantissa_width() const;
unsigned get_integer_bit_width() const;
unsigned get_scalar_bitsize() const;
const std::vector<unsigned> &get_tile_shapes() const;
type *get_scalar_ty() const;
unsigned get_pointer_address_space() const;
// type predicates
bool is_int_or_tileint_ty();
bool is_integer_ty() const;
bool is_integer_ty(unsigned width) const;
bool is_pointer_ty() const;
bool is_float_ty() const;
bool is_double_ty() const;
bool is_floating_point_ty() const;
bool is_sized() const;
bool is_tile_ty() const;
// Factory methods
static type* get_void_ty(context &ctx);
static type* get_float_ty(context &ctx);
static type* get_double_ty(context &ctx);
};
class integer_type: public type {
@@ -32,14 +48,22 @@ public:
static integer_type* get(context &ctx, unsigned width);
};
class composite_type: public type{
public:
bool index_valid(value *idx) const;
type* get_type_at_index(value *idx) const;
};
class tile_type: public type {
public:
static tile_type* get(type *ty, const std::vector<unsigned> &shapes);
static tile_type* get_same_shapes(type *ty, type *ref);
};
class pointer_type: public type {
public:
static pointer_type* get(type *ty, unsigned address_space);
type *get_element_ty() const;
};
class function_type: public type {

View File

@@ -55,6 +55,9 @@ private:
//===----------------------------------------------------------------------===//
class user: public value{
protected:
void resize_ops(unsigned n) { ops_.resize(n); }
public:
// Constructor
user(type *ty, unsigned num_ops, const std::string &name = "")

View File

@@ -1,6 +1,8 @@
#include "ir/context.h"
#include "ir/basic_block.h"
#include "ir/instructions.h"
#include "ir/constant.h"
#include "ir/type.h"
namespace tdl{
namespace ir{
@@ -9,8 +11,8 @@ namespace ir{
// instruction classes
//===----------------------------------------------------------------------===//
instruction::instruction(type *ty, unsigned num_ops, instruction *next)
: user(ty, num_ops) {
instruction::instruction(type *ty, unsigned num_ops, const std::string &name, instruction *next)
: user(ty, num_ops, name) {
if(next){
basic_block *block = next->get_parent();
assert(block && "Next instruction is not in a basic block!");
@@ -23,9 +25,29 @@ instruction::instruction(type *ty, unsigned num_ops, instruction *next)
// phi_node classes
//===----------------------------------------------------------------------===//
// Add incoming
void phi_node::add_incoming(value *x, basic_block *bb){
// Set incoming value
void phi_node::set_incoming_value(unsigned i, value *v){
assert(v && "PHI node got a null value!");
assert(get_type() == v->get_type() &&
"All operands to PHI node must be the same type as the PHI node!");
set_operand(i, v);
}
// Set incoming block
void phi_node::set_incoming_block(unsigned i, basic_block *block){
assert(block && "PHI node got a null basic block!");
blocks_[i] = block;
}
// Add incoming
void phi_node::add_incoming(value *v, basic_block *block){
if(get_num_operands()==num_reserved_){
num_reserved_++;
resize_ops(num_reserved_);
blocks_.resize(num_reserved_);
}
set_incoming_value(get_num_operands() - 1, v);
set_incoming_block(get_num_operands() - 1, block);
}
// Factory methods
@@ -39,7 +61,7 @@ phi_node* phi_node::create(type *ty, unsigned num_reserved){
//===----------------------------------------------------------------------===//
binary_operator::binary_operator(op_t op, value *lhs, value *rhs, type *ty, const std::string &name, instruction *next)
: instruction(ty, 2, next), op_(op){
: instruction(ty, 2, name, next), op_(op){
set_operand(0, lhs);
set_operand(1, rhs);
}
@@ -72,6 +94,24 @@ binary_operator *binary_operator::create_not(value *arg, const std::string &name
// cmp_inst classes
//===----------------------------------------------------------------------===//
// cmp_inst
cmp_inst::cmp_inst(type *ty, cmp_inst::pred_t pred, value *lhs, value *rhs, const std::string &name, instruction *next)
: instruction(ty, 2, name, next), pred_(pred) {
set_operand(0, lhs);
set_operand(1, rhs);
}
type* cmp_inst::make_cmp_result_type(type *ty){
type* int1_ty = ty->get_context().get_int1_ty();
if (tile_type* tile_ty = dynamic_cast<tile_type*>(ty))
return tile_type::get_same_shapes(int1_ty, tile_ty);
return int1_ty;
}
bool cmp_inst::is_fp_predicate(pred_t pred) {
return pred >= pcmp::FIRST_FCMP_PREDICATE && pred <= pcmp::LAST_FCMP_PREDICATE;
}
@@ -84,15 +124,159 @@ bool cmp_inst::is_int_predicate(pred_t pred) {
icmp_inst* icmp_inst::create(pred_t pred, value *lhs, value *rhs, const std::string &name, instruction *next){
assert(is_int_predicate(pred));
return new icmp_inst(pred, lhs, rhs, name, next);
type *res_ty = make_cmp_result_type(lhs->get_type());
return new icmp_inst(res_ty, pred, lhs, rhs, name, next);
}
// fcmp_inst
fcmp_inst* fcmp_inst::create(pred_t pred, value *lhs, value *rhs, const std::string &name, instruction *next){
assert(is_fp_predicate(pred));
return new fcmp_inst(pred, lhs, rhs, name, next);
type *res_ty = make_cmp_result_type(lhs->get_type());
return new fcmp_inst(res_ty, pred, lhs, rhs, name, next);
}
//===----------------------------------------------------------------------===//
// unary_inst classes
//===----------------------------------------------------------------------===//
unary_inst::unary_inst(type *ty, value *v, const std::string &name, instruction *next)
: instruction(ty, 1, name, next) {
set_operand(0, v);
}
//===----------------------------------------------------------------------===//
// cast_inst classes
//===----------------------------------------------------------------------===//
cast_inst *cast_inst::create(op_t op, value *arg, type *ty, const std::string &name, instruction *next){
assert(is_valid(op, arg, ty) && "Invalid cast!");
// Construct and return the appropriate CastInst subclass
switch (op) {
case ic::Trunc: return new trunc_inst (ty, arg, name, next);
case ic::ZExt: return new z_ext_inst (ty, arg, name, next);
case ic::SExt: return new s_ext_inst (ty, arg, name, next);
case ic::FPTrunc: return new fp_trunc_inst (ty, arg, name, next);
case ic::FPExt: return new fp_ext_inst (ty, arg, name, next);
case ic::UIToFP: return new ui_to_fp_inst (ty, arg, name, next);
case ic::SIToFP: return new si_to_fp_inst (ty, arg, name, next);
case ic::FPToUI: return new fp_to_ui_inst (ty, arg, name, next);
case ic::FPToSI: return new fp_to_si_inst (ty, arg, name, next);
case ic::PtrToInt: return new ptr_to_int_inst (ty, arg, name, next);
case ic::IntToPtr: return new int_to_ptr_inst (ty, arg, name, next);
case ic::BitCast: return new bit_cast_inst (ty, arg, name, next);
case ic::AddrSpaceCast: return new addr_space_cast_inst (ty, arg, name, next);
default: throw std::runtime_error("unreachable");
}
}
cast_inst *cast_inst::create_integer_cast(value *arg, type *ty, bool is_signed, const std::string &name, instruction *next){
type *arg_ty = arg->get_type();
assert(arg_ty->is_int_or_tileint_ty() && ty->is_int_or_tileint_ty() && "Invalid integer cast!");
unsigned arg_bits = arg_ty->get_scalar_bitsize();
unsigned dst_bits = ty->get_scalar_bitsize();
op_t op = (arg_bits == dst_bits ? ic::BitCast :
(arg_bits > dst_bits ? ic::Trunc :
(is_signed ? ic::SExt : ic::ZExt)));
return create(op, arg, ty, name, next);
}
//===----------------------------------------------------------------------===//
// terminator_inst classes
//===----------------------------------------------------------------------===//
// return_inst
return_inst::return_inst(context &ctx, value *ret_val, instruction *next)
: terminator_inst(ctx.get_void_ty(), !!ret_val, "", next){
if(ret_val)
set_operand(0, ret_val);
}
return_inst *return_inst::create(context &ctx, value *ret_val, instruction *next){
return new return_inst(ctx, ret_val, next);
}
// conditional/unconditional branch
branch_inst::branch_inst(basic_block *dst, instruction *next)
: terminator_inst(dst->get_context().get_void_ty(), 1, "", next){
set_operand(0, dst);
}
branch_inst::branch_inst(basic_block *if_dst, basic_block *else_dst, value *cond, instruction *next)
: terminator_inst(if_dst->get_context().get_void_ty(), 3, "", next){
assert(cond->get_type()->is_integer_ty(1) && "May only branch on boolean predicates!");
set_operand(0, if_dst);
set_operand(1, else_dst);
set_operand(2, cond);
}
branch_inst* branch_inst::create(basic_block *dst, instruction *next) {
assert(dst && "Branch destination may not be null!");
return new branch_inst(dst, next);
}
branch_inst* branch_inst::create(value *cond, basic_block *if_dst, basic_block *else_dst, instruction *next) {
assert(cond->get_type()->is_integer_ty(1) && "May only branch on boolean predicates!");
return new branch_inst(if_dst, else_dst, cond, next);
}
//===----------------------------------------------------------------------===//
// getelementptr_inst classes
//===----------------------------------------------------------------------===//
getelementptr_inst::getelementptr_inst(type *pointee_ty, value *ptr, const std::vector<value *> &idx, const std::string &name, instruction *next)
: instruction(get_return_type(pointee_ty, ptr, idx), idx.size(), name, next),
source_elt_ty(pointee_ty),
res_elt_ty(get_indexed_type(pointee_ty, idx)){
type *expected_ty = ((pointer_type*)(get_type()->get_scalar_ty()))->get_element_ty();
assert(res_elt_ty == expected_ty);
set_operand(0, ptr);
for(size_t i = 0; i < idx.size(); i++)
set_operand(1 + i, idx[i]);
}
type *getelementptr_inst::get_return_type(type *elt_ty, value *ptr, const std::vector<value *> &idx_list) {
// result pointer type
type *ptr_ty = pointer_type::get(get_indexed_type(elt_ty, idx_list), ptr->get_type()->get_pointer_address_space());
// Tile GEP
if(ptr->get_type()->is_tile_ty())
return tile_type::get_same_shapes(ptr_ty, ptr->get_type());
for(value *idx : idx_list)
if (idx->get_type()->is_tile_ty())
return tile_type::get_same_shapes(ptr_ty, idx->get_type());
// Scalar GEP
return ptr_ty;
}
type *getelementptr_inst::get_indexed_type_impl(type *ty, const std::vector<value *> &idx_list) {
if(idx_list.empty())
return ty;
if(!ty->is_sized())
return nullptr;
unsigned cur_idx = 1;
for(; cur_idx != idx_list.size(); cur_idx++){
composite_type *cty = dynamic_cast<composite_type*>(ty);
if(!cty || cty->is_pointer_ty())
break;
value *idx = idx_list[cur_idx];
if(!cty->index_valid(idx))
break;
ty = cty->get_type_at_index(idx);
}
return (cur_idx == idx_list.size())? ty : nullptr;
}
type *getelementptr_inst::get_indexed_type(type *ty, const std::vector<value *> &idx_list) {
type *result = get_indexed_type_impl(ty, idx_list);
assert(result && "invalid GEP type!");
return result;
}
}
}