[syntax tree] fixed bugs in control flow lowering

This commit is contained in:
Philippe Tillet
2019-02-17 21:35:03 -05:00
parent cf1a583dbf
commit f3094a512b
8 changed files with 201 additions and 35 deletions

View File

@@ -5,6 +5,7 @@
#include "ast/ast.h"
#include "ir/context.h"
#include "ir/module.h"
#include "ir/print.h"
#include "codegen/selection.h"
#include "codegen/tune.h"
#include "codegen/shared_copy.h"
@@ -54,13 +55,15 @@ void test(fp32 *a, fp32 *b, fp32 *c, int32 M, int32 N, int32 K, int32 bound){\
int1 checkc1[16] = ryc < N;\
int1 checkc[16, 16] = checkc0[:, newaxis] && checkc1[newaxis, :];\
for(k = K; k > 0; k = k - 8){\
int1 sanitya[16, 8] = (k > 16);\
int1 sanityb[16, 8] = (k > 16);\
int1 checka[16, 8] = (k > 8);\
int1 checkb[16, 8] = (k > 8);\
C = dot(a, b, C);\
pa = pa + 8*M;\
pb = pb + 8*K;\
@sanitya a = *pa;\
@sanityb b = *pb;\
@checka a = *pa;\
@checkb b = *pb;\
if(k <= 8){\
}\
}\
@checkc *pc = C;\
}\
@@ -167,6 +170,8 @@ int main() {
llvm::LLVMContext llvm_context;
llvm::Module llvm_module("test", llvm_context);
tdl::ir::print(module, std::cout);
// create passes
tdl::codegen::place_shared_copy shared;
tdl::codegen::buffer_info_pass buffer_info;

View File

@@ -287,17 +287,14 @@ public:
class assignment_expression: public expression{
public:
assignment_expression(node *lvalue, ASSIGN_OP_T op, node *rvalue)
: lhs_((named_expression*)lvalue), op_(op), rhs_((expression*)rvalue) { }
const expression *lhs() const { return lhs_; }
const expression *rhs() const { return rhs_; }
: lvalue_((named_expression*)lvalue), op_(op), rvalue_((expression*)rvalue) { }
ir::value* codegen(ir::module *mod) const;
public:
ASSIGN_OP_T op_;
const expression *lhs_;
const expression *rhs_;
const expression *lvalue_;
const expression *rvalue_;
};

View File

@@ -22,6 +22,8 @@ public:
value *else_value;
};
virtual std::string repr_impl() const = 0;
protected:
// constructors
instruction(type *ty, unsigned num_ops, const std::string &name = "", instruction *next = nullptr);
@@ -37,6 +39,8 @@ public:
const mask_info_t get_mask() const { return mask_; }
// helpers
bool has_tile_result_or_op();
// repr
std::string repr() const { return repr_impl(); }
private:
basic_block *parent_;
@@ -51,6 +55,7 @@ private:
class phi_node: public instruction{
private:
phi_node(type *ty, unsigned num_reserved, const std::string &name, instruction *next);
std::string repr_impl() const { return "phi"; }
public:
void set_incoming_value(unsigned i, value *v);
@@ -60,6 +65,9 @@ public:
unsigned get_num_incoming() { return get_num_operands(); }
void add_incoming(value *v, basic_block *block);
// Type
void set_type(type *ty) { ty_ = ty; }
// Factory methods
static phi_node* create(type *ty, unsigned num_reserved, const std::string &name = "", instruction *next = nullptr);
@@ -75,6 +83,10 @@ private:
class binary_operator: public instruction{
public:
typedef llvm::BinaryOperator::BinaryOps op_t;
using llop = llvm::BinaryOperator::BinaryOps;
private:
std::string repr_impl() const;
protected:
// Constructors
@@ -116,7 +128,10 @@ public:
class cmp_inst: public instruction{
public:
typedef llvm::CmpInst::Predicate pred_t;
using pcmp = llvm::CmpInst;
using llop = llvm::CmpInst;
private:
std::string repr_impl() const;
protected:
cmp_inst(type *ty, pred_t pred, value *lhs, value *rhs, const std::string &name, instruction *next);
@@ -164,6 +179,9 @@ protected:
class cast_inst: public unary_inst{
using ic = llvm::Instruction::CastOps;
private:
std::string repr_impl() const;
public:
typedef llvm::CastInst::CastOps op_t;
@@ -219,6 +237,8 @@ class terminator_inst: public instruction{
// return instruction
class return_inst: public terminator_inst{
private:
std::string repr_impl() const { return "ret"; }
return_inst(context &ctx, value *ret_val, instruction *next);
public:
@@ -234,6 +254,9 @@ public:
// base branch instruction
class branch_inst: public terminator_inst{
private:
std::string repr_impl() const { return "br"; }
protected:
using terminator_inst::terminator_inst;
@@ -246,8 +269,9 @@ public:
// conditional branch
class cond_branch_inst: public branch_inst {
cond_branch_inst(basic_block *if_dst, basic_block *else_dst, value *cond, instruction *next);
private:
friend class branch_inst;
cond_branch_inst(basic_block *if_dst, basic_block *else_dst, value *cond, instruction *next);
public:
basic_block *get_true_dest() { return (basic_block*)get_operand(0); }
@@ -257,6 +281,7 @@ public:
// unconditional branch
class uncond_branch_inst: public branch_inst {
private:
friend class branch_inst;
uncond_branch_inst(basic_block *dst, instruction *next);
@@ -269,6 +294,7 @@ public:
class getelementptr_inst: public instruction{
private:
std::string repr_impl() const { return "getelementptr"; }
getelementptr_inst(type *pointee_ty, value *ptr, const std::vector<value*> &idx, const std::string &name, instruction *next);
private:
@@ -297,6 +323,7 @@ private:
class load_inst: public unary_inst{
private:
std::string repr_impl() const { return "load"; }
load_inst(value *ptr, const std::string &name, instruction *next);
private:
@@ -312,6 +339,7 @@ public:
class store_inst: public instruction{
private:
std::string repr_impl() const { return "store"; }
store_inst(value *ptr, value *v, const std::string &name, instruction *next);
public:
@@ -330,36 +358,43 @@ public:
class retile_inst: public unary_inst {
protected:
retile_inst(value *arg, const std::vector<unsigned> &shapes, const std::string &name, instruction *next);
retile_inst(value *arg, const std::vector<unsigned> &shape_suffix, const std::string &name, instruction *next);
static std::string shape_suffix(ir::type* ty);
};
// reshape
class reshape_inst: public retile_inst {
private:
using retile_inst::retile_inst;
std::string repr_impl() const { return "reshape" + shape_suffix(get_type()); }
public:
static instruction* create(value *arg, const std::vector<unsigned> &shapes,
static instruction* create(value *arg, const std::vector<unsigned> &shape_suffix,
const std::string &name = "", instruction *next = nullptr);
};
// splat
class splat_inst: public retile_inst {
private:
using retile_inst::retile_inst;
std::string repr_impl() const { return "splat" + shape_suffix(get_type()); }
public:
static instruction* create(value *arg, const std::vector<unsigned> &shapes,
static instruction* create(value *arg, const std::vector<unsigned> &shape_suffix,
const std::string &name = "", instruction *next = nullptr);
};
// broadcast
class broadcast_inst: public retile_inst {
private:
using retile_inst::retile_inst;
std::string repr_impl() const { return "broadcast" + shape_suffix(get_type()); }
public:
static instruction* create(value *arg, const std::vector<unsigned> &shapes,
static instruction* create(value *arg, const std::vector<unsigned> &shape_suffix,
const std::string &name = "", instruction *next = nullptr);
};
@@ -374,7 +409,9 @@ protected:
};
class get_global_range_inst: public builtin_inst {
private:
get_global_range_inst(type *ty, unsigned axis, const std::string &name, instruction *next);
std::string repr_impl() const { return "get_global_range(" + std::to_string(axis_) + ")"; }
public:
static instruction* create(context &ctx, unsigned axis, unsigned size,
@@ -387,7 +424,9 @@ private:
};
class matmul_inst: public builtin_inst {
private:
matmul_inst(value *A, value *B, value *C, const std::string &name, instruction *next);
std::string repr_impl() const { return "dot"; }
public:
static instruction* create(value *A, value *B, value *C,
@@ -401,7 +440,9 @@ public:
//===----------------------------------------------------------------------===//
class copy_to_shared_inst: public unary_inst{
private:
using unary_inst::unary_inst;
std::string repr_impl() const { return "copy_to_shared"; }
public:
static copy_to_shared_inst* create(value *arg, const std::string &name = "",
@@ -411,6 +452,7 @@ public:
class barrier_inst: public instruction{
private:
barrier_inst(context &ctx, const std::string &name, instruction *next);
std::string repr_impl() const { return "barrier"; }
public:
static barrier_inst* create(context &ctx, const std::string &name = "",
@@ -418,7 +460,9 @@ public:
};
class vectorize_inst: public unary_inst{
private:
using unary_inst::unary_inst;
std::string repr_impl() const { return "vectorize"; }
public:
static vectorize_inst* create(value *arg, const std::string &name = "", instruction *next = nullptr);

View File

@@ -42,9 +42,13 @@ public:
// Setters
void set_value(const std::string& name, basic_block* block, value *x);
void set_value(const std::string& name, value* x);
void set_type(const std::string& name, basic_block* block, type* x);
void set_type(const std::string& name, type* x);
// Getters
value *get_value(const std::string& name, basic_block* block);
value *get_value(const std::string& name);
type *get_type(const std::string& name, basic_block* block);
type *get_type(const std::string& name);
// Seal block -- no more predecessors will be added
void seal_block(basic_block *block);
// Functions
@@ -58,6 +62,7 @@ private:
context &context_;
builder builder_;
std::map<val_key_t, value*> values_;
std::map<val_key_t, type*> types_;
std::set<basic_block*> sealed_blocks_;
std::map<basic_block*, std::map<std::string, phi_node*>> incomplete_phis_;
functions_list_t functions_;

View File

@@ -33,10 +33,10 @@ public:
type* get_type() const { return ty_; }
private:
type *ty_;
std::string name_;
protected:
type *ty_;
std::set<user*> users_;
};

View File

@@ -212,6 +212,7 @@ void function::bind_parameters(ir::module *mod, ir::function *fn) const{
if(id_i){
args[i]->set_name(id_i->name());
mod->set_value(id_i->name(), nullptr, args[i]);
mod->set_type(id_i->name(), nullptr, args[i]->get_type());
}
}
}
@@ -262,7 +263,8 @@ ir::value* expression_statement::codegen(ir::module *mod) const{
ir::value* iteration_statement::codegen(ir::module *mod) const{
ir::builder &builder = mod->get_builder();
ir::context &ctx = mod->get_context();
ir::function *fn = builder.get_insert_block()->get_parent();
ir::basic_block *current_bb = builder.get_insert_block();
ir::function *fn = current_bb->get_parent();
ir::basic_block *loop_bb = ir::basic_block::create(ctx, "loop", fn);
init_->codegen(mod);
builder.create_br(loop_bb);
@@ -270,8 +272,10 @@ ir::value* iteration_statement::codegen(ir::module *mod) const{
statements_->codegen(mod);
exec_->codegen(mod);
ir::value *cond = stop_->codegen(mod);
ir::basic_block *stop_bb = builder.get_insert_block();
ir::basic_block *next_bb = ir::basic_block::create(ctx, "postloop", fn);
builder.create_cond_br(cond, loop_bb, next_bb);
mod->seal_block(stop_bb);
mod->seal_block(loop_bb);
mod->seal_block(builder.get_insert_block());
mod->seal_block(next_bb);
@@ -296,8 +300,7 @@ ir::value* selection_statement::codegen(ir::module* mod) const{
// Then
builder.set_insert_point(then_bb);
then_value_->codegen(mod);
if(else_value_)
builder.create_br(endif_bb);
builder.create_br(endif_bb);
mod->seal_block(then_bb);
// Else
if(else_value_){
@@ -338,6 +341,7 @@ ir::value* initializer::codegen(ir::module * mod) const{
}
value->set_name(name);
mod->set_value(name, value);
mod->set_type(name, ty);
return value;
}
@@ -527,16 +531,16 @@ ir::value *conditional_expression::codegen(ir::module *mod) const{
/* Assignment expression */
ir::value *assignment_expression::codegen(ir::module *mod) const{
ir::value *rhs = rhs_->codegen(mod);
ir::value *rvalue = rvalue_->codegen(mod);
if(auto *x = dynamic_cast<const named_expression*>(lvalue_))
mod->set_value(x->id()->name(), rhs);
mod->set_value(x->id()->name(), rvalue);
else if(auto* x = dynamic_cast<const unary_operator*>(lvalue_)){
assert(x->get_op()==DEREF);
assert(x->lvalue());
ir::value *ptr = x->lvalue()->codegen(mod);
rhs = mod->get_builder().create_store(ptr, rhs);
rvalue = mod->get_builder().create_store(ptr, rvalue);
}
return rhs;
return rvalue;
}
/* Type name */

View File

@@ -73,6 +73,30 @@ phi_node* phi_node::create(type *ty, unsigned num_reserved, const std::string &n
// binary_operator classes
//===----------------------------------------------------------------------===//
std::string binary_operator::repr_impl() const {
switch(op_) {
case llop::Add : return "add";
case llop::FAdd : return "fadd";
case llop::Sub : return "sub";
case llop::FSub : return "fsub";
case llop::Mul : return "mul";
case llop::FMul : return "fmul";
case llop::UDiv : return "udiv";
case llop::SDiv : return "sdiv";
case llop::FDiv : return "fdiv";
case llop::URem : return "urem";
case llop::SRem : return "srem";
case llop::FRem : return "frem";
case llop::Shl : return "shl";
case llop::LShr : return "lshr";
case llop::AShr : return "ashr";
case llop::And : return "and";
case llop::Or : return "or";
case llop::Xor : return "xor";
default: throw std::runtime_error("unknown binary operator");
}
}
binary_operator::binary_operator(op_t op, value *lhs, value *rhs, type *ty, const std::string &name, instruction *next)
: instruction(ty, 2, name, next), op_(op){
set_operand(0, lhs);
@@ -108,6 +132,38 @@ binary_operator *binary_operator::create_not(value *arg, const std::string &name
//===----------------------------------------------------------------------===//
// cmp_inst
std::string cmp_inst::repr_impl() const {
switch (pred_) {
case llop::FCMP_FALSE : return "false";
case llop::FCMP_OEQ : return "fcmp_oeq";
case llop::FCMP_OGT : return "fcmp_ogt";
case llop::FCMP_OGE : return "fcmp_oge";
case llop::FCMP_OLT : return "fcmp_olt";
case llop::FCMP_OLE : return "fcmp_ole";
case llop::FCMP_ONE : return "fcmp_one";
case llop::FCMP_ORD : return "fcmp_ord";
case llop::FCMP_UNO : return "fcmp_uno";
case llop::FCMP_UEQ : return "fcmp_ueq";
case llop::FCMP_UGT : return "fcmp_ugt";
case llop::FCMP_UGE : return "fcmp_uge";
case llop::FCMP_ULT : return "fcmp_ult";
case llop::FCMP_ULE : return "fcmp_ule";
case llop::FCMP_UNE : return "fcmp_une";
case llop::FCMP_TRUE : return "true";
case llop::ICMP_EQ : return "icmp_eq";
case llop::ICMP_NE : return "icmp_ne";
case llop::ICMP_UGT : return "icmp_ugt";
case llop::ICMP_UGE : return "icmp_uge";
case llop::ICMP_ULT : return "icmp_ult";
case llop::ICMP_ULE : return "icmp_ule";
case llop::ICMP_SGT : return "icmp_sgt";
case llop::ICMP_SGE : return "icmp_sge";
case llop::ICMP_SLT : return "icmp_slt";
case llop::ICMP_SLE : return "icmp_sle";
default: throw std::runtime_error("unreachable");
}
}
cmp_inst::cmp_inst(type *ty, cmp_inst::pred_t pred, value *lhs, value *rhs, const std::string &name, instruction *next)
: instruction(ty, 2, name, next), pred_(pred) {
set_operand(0, lhs);
@@ -123,11 +179,11 @@ type* cmp_inst::make_cmp_result_type(type *ty){
bool cmp_inst::is_fp_predicate(pred_t pred) {
return pred >= pcmp::FIRST_FCMP_PREDICATE && pred <= pcmp::LAST_FCMP_PREDICATE;
return pred >= llop::FIRST_FCMP_PREDICATE && pred <= llop::LAST_FCMP_PREDICATE;
}
bool cmp_inst::is_int_predicate(pred_t pred) {
return pred >= pcmp::FIRST_ICMP_PREDICATE && pred <= pcmp::LAST_ICMP_PREDICATE;
return pred >= llop::FIRST_ICMP_PREDICATE && pred <= llop::LAST_ICMP_PREDICATE;
}
// icmp_inst
@@ -157,6 +213,24 @@ unary_inst::unary_inst(type *ty, value *v, const std::string &name, instruction
// cast_inst classes
//===----------------------------------------------------------------------===//
std::string cast_inst::repr_impl() const {
switch (op_){
case ic::Trunc: return "trunc";
case ic::ZExt: return "zext";
case ic::SExt: return "sext";
case ic::FPTrunc: return "fp_trunc";
case ic::FPExt: return "fp_ext";
case ic::UIToFP: return "ui_to_fp";
case ic::SIToFP: return "si_to_fp";
case ic::FPToUI: return "fp_to_ui";
case ic::FPToSI: return "fp_to_si";
case ic::PtrToInt: return "ptr_to_int";
case ic::IntToPtr: return "int_to_ptr";
case ic::BitCast: return "bitcast";
case ic::AddrSpaceCast: return "addr_space_cast";
default: throw std::runtime_error("unreachable");
}
}
// TODO
bool cast_inst::is_valid(op_t op, value *arg, type *ty) {
return true;
@@ -331,6 +405,18 @@ store_inst* store_inst::create(value *ptr, value *v, const std::string &name, in
// retile_inst classes
//===----------------------------------------------------------------------===//
std::string retile_inst::shape_suffix(ir::type* ty){
std::string res = "[";
const auto& shapes = ty->get_tile_shapes();
for(unsigned i = 0; i < shapes.size(); i++){
res += std::to_string(ty->get_tile_shapes()[i]);
if(i < shapes.size() - 1)
res += ", ";
}
res += "]";
return res;
}
retile_inst::retile_inst(value *arg, const std::vector<unsigned> &shapes,
const std::string &name, instruction *next)
: unary_inst(tile_type::get(arg->get_type()->get_scalar_ty(), shapes), arg, name, next) { }

View File

@@ -29,6 +29,14 @@ void module::set_value(const std::string& name, ir::value *value){
return set_value(name, builder_.get_insert_block(), value);
}
void module::set_type(const std::string& name, ir::basic_block *block, ir::type *type){
types_[val_key_t{name, block}] = type;
}
void module::set_type(const std::string& name, ir::type *type){
return set_type(name, builder_.get_insert_block(), type);
}
ir::phi_node* module::make_phi(ir::type *ty, unsigned num_values, ir::basic_block *block){
basic_block::iterator insert = block->get_first_non_phi();
if(insert != block->end()){
@@ -42,14 +50,14 @@ ir::phi_node* module::make_phi(ir::type *ty, unsigned num_values, ir::basic_bloc
ir::value *module::try_remove_trivial_phis(ir::phi_node *&phi){
// find non-self references
std::vector<ir::value*> non_self_ref;
std::copy_if(phi->ops().begin(), phi->ops().end(), std::back_inserter(non_self_ref),
[phi](ir::value* op){ return op != phi; });
std::set<ir::value*> non_self_ref;
std::copy_if(phi->ops().begin(), phi->ops().end(), std::inserter(non_self_ref, non_self_ref.begin()),
[phi](ir::value* op){ return op != phi && op; });
// non-trivial
if(non_self_ref.size() > 1)
if(non_self_ref.size() != 1)
return phi;
// unique value or self-reference
ir::value *same = non_self_ref[0];
ir::value *same = *non_self_ref.begin();
std::set<ir::user*> users = phi->get_users();
phi->replace_all_uses_with(same);
phi->erase_from_parent();
@@ -57,9 +65,12 @@ ir::value *module::try_remove_trivial_phis(ir::phi_node *&phi){
if(auto *uphi = dynamic_cast<ir::phi_node*>(u))
if(uphi != phi)
try_remove_trivial_phis(uphi);
if(auto *new_phi = dynamic_cast<ir::phi_node*>(same))
return try_remove_trivial_phis(new_phi);
return same;
}
ir::value *module::add_phi_operands(const std::string& name, ir::phi_node *&phi){
// already initialized
if(phi->get_num_operands())
@@ -75,9 +86,9 @@ ir::value *module::add_phi_operands(const std::string& name, ir::phi_node *&phi)
ir::value *module::get_value_recursive(const std::string& name, ir::basic_block *block) {
ir::value *result;
auto &preds = block->get_predecessors();
if(block)
if(sealed_blocks_.find(block) == sealed_blocks_.end()){
ir::value *pred = get_value(name, preds.front());
incomplete_phis_[block][name] = make_phi(pred->get_type(), 1, block);
incomplete_phis_[block][name] = make_phi(get_type(name, block), 1, block);
result = (ir::value*)incomplete_phis_[block][name];
}
else if(preds.size() <= 1){
@@ -85,8 +96,7 @@ ir::value *module::get_value_recursive(const std::string& name, ir::basic_block
result = get_value(name, has_pred?preds.front():nullptr);
}
else{
ir::value *pred = get_value(name, preds.front());
result = make_phi(pred->get_type(), 1, block);
result = make_phi(get_type(name, block), 1, block);
set_value(name, block, result);
result = add_phi_operands(name, (ir::phi_node*&)result);
}
@@ -112,6 +122,21 @@ ir::value *module::get_value(const std::string& name) {
return get_value(name, builder_.get_insert_block());
}
ir::type *module::get_type(const std::string &name, basic_block *block) {
val_key_t key(name, block);
if(types_.find(key) != types_.end())
return types_.at(key);
assert(block);
const auto& predecessors = block->get_predecessors();
if(predecessors.empty())
return get_type(name, nullptr);
return get_type(name, predecessors[0]);
}
ir::type *module::get_type(const std::string &name) {
return types_.at({name, builder_.get_insert_block()});
}
void module::seal_block(ir::basic_block *block){
for(auto &x: incomplete_phis_[block])
add_phi_operands(x.first, x.second);