[code generation] more bugfixes in control flow

This commit is contained in:
Philippe Tillet
2019-02-20 22:55:20 -05:00
parent 90ec0ae2c0
commit 5618a15dc1
13 changed files with 103 additions and 42 deletions

View File

@@ -62,8 +62,8 @@ void test(fp32 *a, fp32 *b, fp32 *c, int32 M, int32 N, int32 K, int32 bound){\
pb = pb + 8*K;\
@checka a = *pa;\
@checkb b = *pb;\
if(k <= 8){\
@checka a = *pa;\
if(k > 8){\
continue;\
}\
}\
@checkc *pc = C;\
@@ -218,7 +218,6 @@ int main() {
buffer_info.run(module);
shared.run(module);
liveness.run(module);
tdl::ir::print(module, std::cout);
allocation.run();
barriers.run(module);
vectorize.run(module);

View File

@@ -374,6 +374,18 @@ private:
const node *statements_;
};
// Jump
class jump_statement: public statement{
public:
using statement::statement;
};
class continue_statement: public jump_statement{
public:
ir::value* codegen(ir::module *mod) const;
};
class no_op: public statement { };
// Types

View File

@@ -48,7 +48,7 @@ TYPE_T get_type_spec(node *op) { return ((token*)op)->type; }
%token SUB_ASSIGN LEFT_ASSIGN RIGHT_ASSIGN AND_ASSIGN
%token XOR_ASSIGN OR_ASSIGN TYPE_NAME
%token VOID UINT1 UINT8 UINT16 UINT32 UINT64 INT1 INT8 INT16 INT32 INT64 FP32 FP64
%token IF ELSE FOR
%token IF ELSE FOR CONTINUE
%token NEWAXIS ELLIPSIS AT
%token GET_GLOBAL_RANGE DOT
@@ -266,6 +266,7 @@ statement
| expression_statement { $$ = $1; }
| selection_statement { $$ = $1; }
| iteration_statement { $$ = $1; }
| jump_statement { $$ = $1; }
;
compound_statement
@@ -300,6 +301,9 @@ iteration_statement
: FOR '(' expression_statement expression_statement expression ')' statement { $$ = new iteration_statement($3, $4, $5, $7); }
;
jump_statement
: CONTINUE ';' { $$ = new continue_statement(); }
;
/* -------------------------- */
/* Declarator */

View File

@@ -37,6 +37,7 @@ int comment();
"..." { count(); return(ELLIPSIS); }
"get_global_range" { count(); return GET_GLOBAL_RANGE; }
"dot" { count(); return DOT;}
"continue" { count(); return(CONTINUE); }
{L}({L}|{D})* { count(); return(check_type()); }

View File

@@ -9,6 +9,7 @@ namespace tdl {
namespace ir {
class module;
class value;
class phi_node;
}
namespace codegen{
@@ -19,8 +20,10 @@ public:
// queries
bool is_double(ir::value *x);
bool is_shared(ir::value *x);
bool is_loop_latch(ir::phi_node *phi, ir::value *terminator);
ir::value *get_reference(ir::value *x);
private:
std::set<ir::value*> shared_;
std::set<ir::value*> double_;

View File

@@ -4,9 +4,17 @@
#include <map>
#include <set>
#include <string>
#include <functional>
#include "builder.h"
namespace tdl{
namespace ast{
class iteration_statement;
}
namespace ir{
class basic_block;
@@ -27,10 +35,14 @@ class module {
public:
typedef std::map<std::string, global_value*> symbols_map_t;
typedef std::vector<function*> functions_list_t;
struct current_iteration_info_t{
ast::iteration_statement *statement;
basic_block *block;
};
private:
phi_node *make_phi(type *ty, unsigned num_values, basic_block *block);
value *try_remove_trivial_phis(ir::phi_node *&phi, ir::value** pre_user);
value *try_remove_trivial_phis(ir::phi_node *&phi, value **pre_user);
value *add_phi_operands(const std::string& name, phi_node *&phi);
value *get_value_recursive(const std::string& name, basic_block *block);
void push_function(function *fn) { functions_.push_back(fn); }
@@ -44,11 +56,13 @@ public:
void set_value(const std::string& name, value* x);
void set_type(const std::string& name, basic_block* block, type* x);
void set_type(const std::string& name, type* x);
void set_continue_fn(std::function<ir::value*()> fn);
// Getters
value *get_value(const std::string& name, basic_block* block);
value *get_value(const std::string& name);
type *get_type(const std::string& name, basic_block* block);
type *get_type(const std::string& name);
std::function<ir::value*()> get_continue_fn();
// Seal block -- no more predecessors will be added
void seal_block(basic_block *block);
// Functions
@@ -67,6 +81,8 @@ private:
std::map<basic_block*, std::map<std::string, phi_node*>> incomplete_phis_;
functions_list_t functions_;
symbols_map_t symbols_;
std::function<ir::value*()> continue_fn_;
std::map<value*, value**> current_phi_;
};
}

View File

@@ -141,6 +141,11 @@ void node::implicit_broadcast(ir::module *mod, ir::value *&lhs, ir::value *&rhs)
rhs = builder.create_broadcast(rhs, shapes);
}
/* Helper */
inline bool is_terminator(ir::value* x) {
return x && dynamic_cast<ir::terminator_inst*>(x);
}
/* Translation unit */
ir::value* translation_unit::codegen(ir::module *mod) const{
decls_->codegen(mod);
@@ -242,8 +247,13 @@ ir::value* function_definition::codegen(ir::module *mod) const{
ir::value* compound_statement::codegen(ir::module* mod) const{
if(decls_)
decls_->codegen(mod);
if(statements_)
statements_->codegen(mod);
if(statements_){
for(statement *stmt: statements_->values()){
ir::value *current = stmt->codegen(mod);
if(is_terminator(current))
return current;
}
}
return nullptr;
}
@@ -266,15 +276,18 @@ ir::value* iteration_statement::codegen(ir::module *mod) const{
ir::basic_block *current_bb = builder.get_insert_block();
ir::function *fn = current_bb->get_parent();
ir::basic_block *loop_bb = ir::basic_block::create(ctx, "loop", fn);
ir::basic_block *next_bb = ir::basic_block::create(ctx, "postloop", fn);
mod->set_continue_fn([&](){
exec_->codegen(mod);
ir::value *cond = stop_->codegen(mod);
return builder.create_cond_br(cond, loop_bb, next_bb);
});
init_->codegen(mod);
builder.create_br(loop_bb);
builder.set_insert_point(loop_bb);
statements_->codegen(mod);
exec_->codegen(mod);
ir::value *cond = stop_->codegen(mod);
if(!is_terminator(statements_->codegen(mod)))
mod->get_continue_fn()();
ir::basic_block *stop_bb = builder.get_insert_block();
ir::basic_block *next_bb = ir::basic_block::create(ctx, "postloop", fn);
builder.create_cond_br(cond, loop_bb, next_bb);
mod->seal_block(stop_bb);
mod->seal_block(loop_bb);
mod->seal_block(builder.get_insert_block());
@@ -303,16 +316,22 @@ ir::value* selection_statement::codegen(ir::module* mod) const{
builder.create_cond_br(cond, then_bb, endif_bb);
// Then
builder.set_insert_point(then_bb);
then_value_->codegen(mod);
builder.create_br(endif_bb);
if(!is_terminator(then_value_->codegen(mod)))
builder.create_br(endif_bb);
// Else
if(else_value_){
builder.set_insert_point(else_bb);
else_value_->codegen(mod);
builder.create_br(endif_bb);
if(!is_terminator(else_value_->codegen(mod)))
builder.create_br(endif_bb);
}
// Endif
builder.set_insert_point(endif_bb);
return nullptr;
}
/* Continue statement */
ir::value* continue_statement::codegen(ir::module *mod) const{
return mod->get_continue_fn()();
}
/* Declaration */

View File

@@ -106,9 +106,7 @@ void allocation::run(){
ir::phi_node *phi = (ir::phi_node*)x;
for(unsigned i = 0; i < phi->get_num_incoming(); i++){
ir::value *inc_val = phi->get_incoming_value(i);
assert(offsets_.find(inc_val) == offsets_.end());
offsets_[inc_val] = offsets_[phi];
std::cout << x->get_name() << " " << inc_val->get_name() << " " << inc_val << std::endl;
}
}
}

View File

@@ -11,6 +11,16 @@ namespace codegen{
// run pass on module
bool buffer_info_pass::is_loop_latch(ir::phi_node *phi, ir::value *terminator){
if(auto *br = dynamic_cast<ir::cond_branch_inst*>(terminator))
return br->get_true_dest() == phi->get_parent()
|| br->get_false_dest() == phi->get_parent();
else if(auto *br = dynamic_cast<ir::uncond_branch_inst*>(terminator))
return br->get_dest() == phi->get_parent();
else
throw std::runtime_error("unreachable");
}
void buffer_info_pass::run(ir::module &mod) {
// Find which buffers are shared
for(ir::function *fn: mod.get_function_list())
@@ -35,13 +45,7 @@ void buffer_info_pass::run(ir::module &mod) {
for(unsigned n = 0; n < phi->get_num_incoming(); n++){
ir::basic_block *inc_block = phi->get_incoming_block(n);
ir::value *terminator = inc_block->get_inst_list().back();
if(auto *br = dynamic_cast<ir::cond_branch_inst*>(terminator))
is_double = is_double || br->get_true_dest() == phi->get_parent()
|| br->get_false_dest() == phi->get_parent();
else if(auto *br = dynamic_cast<ir::uncond_branch_inst*>(terminator))
is_double = is_double || br->get_dest() == phi->get_parent();
else
throw std::runtime_error("unreachable");
is_double = is_double || is_loop_latch(phi, terminator);
}
// add to double-buffered
if(is_double)
@@ -49,7 +53,6 @@ void buffer_info_pass::run(ir::module &mod) {
// set references of input
for(unsigned n = 0; n < phi->get_num_incoming(); n++){
ir::value *inc_val = phi->get_incoming_value(n);
assert(refs_[inc_val] == nullptr);
refs_[inc_val] = phi;
}
}

View File

@@ -299,7 +299,6 @@ std::vector<Value*> delinearize(Value *trailing, std::vector<unsigned> &shapes,
}
void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id, Value *u_warp_id) {
std::cout << "name: " << v->get_name() << std::endl;
const auto& shapes = v->get_type()->get_tile_shapes();
size_t dim = shapes.size();
std::vector<unsigned> contiguous(dim);
@@ -406,8 +405,6 @@ void selection::create_tile(ir::value *v, IRBuilder<> &builder,
unsigned id_pre = 0, id_loop = 1;
if(phi->get_incoming_block(0) == phi->get_parent())
std::swap(id_pre, id_loop);
ir::value *pre_value = phi->get_incoming_value(id_pre);
ir::value *loop_value = phi->get_incoming_value(id_loop);
if(parent->empty())
builder.SetInsertPoint(parent);
else
@@ -419,8 +416,13 @@ void selection::create_tile(ir::value *v, IRBuilder<> &builder,
pre_ptr = builder.CreateBitCast(pre_ptr, ptr->getType());
Value *next_ptr = builder.CreateGEP(ptr, offset);
tmap_.insert({phi, new shared_tile(ty, shapes, ptr, builder, offset)});
tmap_.insert({pre_value, new shared_tile(ty, shapes, pre_ptr, builder)});
tmap_.insert({loop_value, new shared_tile(ty, shapes, next_ptr, builder)});
for(unsigned i = 0; i < phi->get_num_incoming(); i++) {
ir::basic_block* inc_block = phi->get_incoming_block(i);
ir::value* inc_value = phi->get_incoming_value(i);
ir::value* terminator = inc_block->get_inst_list().back();
bool is_loop_latch = buffer_info_->is_loop_latch(phi, terminator);
tmap_.insert({inc_value, new shared_tile(ty, shapes, is_loop_latch?next_ptr:pre_ptr, builder)});
}
}
else
throw std::runtime_error("unknown shared memory tile");
@@ -479,7 +481,6 @@ void selection::init_grids(ir::function *fn, IRBuilder<> &builder, Value *sh_mem
void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &builder) {
std::cout << "lowering " << ins->get_name() << std::endl;
BasicBlock *block = builder.GetInsertBlock();
Module *module = block->getModule();
Function *function = block->getParent();
@@ -696,7 +697,6 @@ void selection::run(ir::module &src, Module &dst){
std::map<ir::basic_block*, BasicBlock*> last_block;
// iterate through block
for(ir::basic_block *block: fn->blocks()) {
std::cout << "block: " << block->get_name() << std::endl;
BasicBlock *parent = (BasicBlock*)vmap_[block];
dst_builder.SetInsertPoint(parent);
for(ir::instruction *i: block->get_inst_list()){
@@ -734,12 +734,10 @@ void selection::run(ir::module &src, Module &dst){
}
}
else {
std::cout << "phi: " << phi->get_name() << std::endl;
for(unsigned n = 0; n < phi->get_num_incoming(); n++){
ir::value *inc_val = phi->get_incoming_value(n);
ir::basic_block *inc_block = phi->get_incoming_block(n);
BasicBlock *llvm_inc_block = last_block.at(inc_block);
std::cout << "incoming block: " << inc_block->get_name() << " " << llvm_inc_block->getName().str() << std::endl;
if(phi->get_type()->is_tile_ty()) {
distributed_tile *phi_tile = (distributed_tile*)tmap_.at(phi);
distributed_tile *inc_tile = (distributed_tile*)tmap_.at(inc_val);

View File

@@ -13,7 +13,6 @@ namespace codegen{
void place_shared_copy::add_copy(ir::value *x, ir::builder &builder) {
if(auto *i = dynamic_cast<ir::instruction*>(x)){
ir::basic_block* block = i->get_parent();
std::cout << "adding copy: " << x->get_name() << " " << block->get_name() << std::endl;
auto it = std::find(block->begin(), block->end(), i);
builder.set_insert_point(++it);
}

View File

@@ -104,7 +104,6 @@ std::vector<unsigned*> tune::get_params(ir::module &mod) {
for(ir::instruction *i : block->get_inst_list())
for(auto &x: params_[i])
if(seen.insert(x.second).second && *x.second == 0){
std::cout << typeid(*i).name() << std::endl;
result.push_back(x.second);
}
return result;

View File

@@ -37,6 +37,14 @@ void module::set_type(const std::string& name, ir::type *type){
return set_type(name, builder_.get_insert_block(), type);
}
void module::set_continue_fn(std::function<ir::value*()> fn) {
continue_fn_ = fn;
}
std::function<ir::value*()> module::get_continue_fn() {
return continue_fn_;
}
ir::phi_node* module::make_phi(ir::type *ty, unsigned num_values, ir::basic_block *block){
basic_block::iterator insert = block->get_first_non_phi();
if(insert != block->end()){
@@ -61,8 +69,6 @@ ir::value *module::try_remove_trivial_phis(ir::phi_node *&phi, ir::value** pre_u
std::set<ir::user*> users = phi->get_users();
phi->replace_all_uses_with(same);
phi->erase_from_parent();
if(pre_user)
*pre_user = same;
for(ir::user* u: users)
if(auto *uphi = dynamic_cast<ir::phi_node*>(u))
if(uphi != phi)
@@ -80,11 +86,10 @@ ir::value *module::add_phi_operands(const std::string& name, ir::phi_node *&phi)
ir::value *value = get_value(name, pred);
phi->add_incoming(value, pred);
}
return try_remove_trivial_phis(phi, nullptr);
return phi;
}
ir::value *module::get_value_recursive(const std::string& name, ir::basic_block *block) {
std::cout << "getting value " << name << std::endl;
ir::value *result;
auto &preds = block->get_predecessors();
if(block)
@@ -101,6 +106,8 @@ ir::value *module::get_value_recursive(const std::string& name, ir::basic_block
set_value(name, block, result);
result = add_phi_operands(name, (ir::phi_node*&)result);
}
if(auto *phi = dynamic_cast<ir::phi_node*>(result))
result = try_remove_trivial_phis(phi, nullptr);
set_value(name, block, result);
return result;
}
@@ -138,9 +145,12 @@ ir::type *module::get_type(const std::string &name) {
return types_.at({name, builder_.get_insert_block()});
}
void module::seal_block(ir::basic_block *block){
for(auto &x: incomplete_phis_[block])
for(auto &x: incomplete_phis_[block]){
add_phi_operands(x.first, x.second);
try_remove_trivial_phis(x.second, nullptr);
}
sealed_blocks_.insert(block);
incomplete_phis_[block].clear();
}