diff --git a/examples/matrix.cpp b/examples/matrix.cpp index 5c7aee602..df239e6ff 100644 --- a/examples/matrix.cpp +++ b/examples/matrix.cpp @@ -62,8 +62,8 @@ void test(fp32 *a, fp32 *b, fp32 *c, int32 M, int32 N, int32 K, int32 bound){\ pb = pb + 8*K;\ @checka a = *pa;\ @checkb b = *pb;\ - if(k <= 8){\ - @checka a = *pa;\ + if(k > 8){\ + continue;\ }\ }\ @checkc *pc = C;\ @@ -218,7 +218,6 @@ int main() { buffer_info.run(module); shared.run(module); liveness.run(module); - tdl::ir::print(module, std::cout); allocation.run(); barriers.run(module); vectorize.run(module); diff --git a/include/ast/ast.h b/include/ast/ast.h index 4a9889093..529c4b01b 100644 --- a/include/ast/ast.h +++ b/include/ast/ast.h @@ -374,6 +374,18 @@ private: const node *statements_; }; +// Jump + +class jump_statement: public statement{ +public: + using statement::statement; +}; + +class continue_statement: public jump_statement{ +public: + ir::value* codegen(ir::module *mod) const; +}; + class no_op: public statement { }; // Types diff --git a/include/ast/parser.y b/include/ast/parser.y index 905541d70..43c530e12 100644 --- a/include/ast/parser.y +++ b/include/ast/parser.y @@ -48,7 +48,7 @@ TYPE_T get_type_spec(node *op) { return ((token*)op)->type; } %token SUB_ASSIGN LEFT_ASSIGN RIGHT_ASSIGN AND_ASSIGN %token XOR_ASSIGN OR_ASSIGN TYPE_NAME %token VOID UINT1 UINT8 UINT16 UINT32 UINT64 INT1 INT8 INT16 INT32 INT64 FP32 FP64 -%token IF ELSE FOR +%token IF ELSE FOR CONTINUE %token NEWAXIS ELLIPSIS AT %token GET_GLOBAL_RANGE DOT @@ -266,6 +266,7 @@ statement | expression_statement { $$ = $1; } | selection_statement { $$ = $1; } | iteration_statement { $$ = $1; } + | jump_statement { $$ = $1; } ; compound_statement @@ -300,6 +301,9 @@ iteration_statement : FOR '(' expression_statement expression_statement expression ')' statement { $$ = new iteration_statement($3, $4, $5, $7); } ; +jump_statement + : CONTINUE ';' { $$ = new continue_statement(); } +; /* -------------------------- */ /* Declarator */ diff --git a/include/ast/scanner.l b/include/ast/scanner.l index 8e2d89f14..80da95dad 100644 --- a/include/ast/scanner.l +++ b/include/ast/scanner.l @@ -37,6 +37,7 @@ int comment(); "..." { count(); return(ELLIPSIS); } "get_global_range" { count(); return GET_GLOBAL_RANGE; } "dot" { count(); return DOT;} +"continue" { count(); return(CONTINUE); } {L}({L}|{D})* { count(); return(check_type()); } diff --git a/include/codegen/buffer_info.h b/include/codegen/buffer_info.h index 2cce9d829..0d22608c2 100644 --- a/include/codegen/buffer_info.h +++ b/include/codegen/buffer_info.h @@ -9,6 +9,7 @@ namespace tdl { namespace ir { class module; class value; + class phi_node; } namespace codegen{ @@ -19,8 +20,10 @@ public: // queries bool is_double(ir::value *x); bool is_shared(ir::value *x); + bool is_loop_latch(ir::phi_node *phi, ir::value *terminator); ir::value *get_reference(ir::value *x); + private: std::set shared_; std::set double_; diff --git a/include/ir/module.h b/include/ir/module.h index a5769c05f..347178fda 100644 --- a/include/ir/module.h +++ b/include/ir/module.h @@ -4,9 +4,17 @@ #include #include #include +#include #include "builder.h" namespace tdl{ + +namespace ast{ + +class iteration_statement; + +} + namespace ir{ class basic_block; @@ -27,10 +35,14 @@ class module { public: typedef std::map symbols_map_t; typedef std::vector functions_list_t; + struct current_iteration_info_t{ + ast::iteration_statement *statement; + basic_block *block; + }; private: phi_node *make_phi(type *ty, unsigned num_values, basic_block *block); - value *try_remove_trivial_phis(ir::phi_node *&phi, ir::value** pre_user); + value *try_remove_trivial_phis(ir::phi_node *&phi, value **pre_user); value *add_phi_operands(const std::string& name, phi_node *&phi); value *get_value_recursive(const std::string& name, basic_block *block); void push_function(function *fn) { functions_.push_back(fn); } @@ -44,11 +56,13 @@ public: void set_value(const std::string& name, value* x); void set_type(const std::string& name, basic_block* block, type* x); void set_type(const std::string& name, type* x); + void set_continue_fn(std::function fn); // Getters value *get_value(const std::string& name, basic_block* block); value *get_value(const std::string& name); type *get_type(const std::string& name, basic_block* block); type *get_type(const std::string& name); + std::function get_continue_fn(); // Seal block -- no more predecessors will be added void seal_block(basic_block *block); // Functions @@ -67,6 +81,8 @@ private: std::map> incomplete_phis_; functions_list_t functions_; symbols_map_t symbols_; + std::function continue_fn_; + std::map current_phi_; }; } diff --git a/lib/ast/lowering.cpp b/lib/ast/lowering.cpp index 231c83930..f11c40955 100644 --- a/lib/ast/lowering.cpp +++ b/lib/ast/lowering.cpp @@ -141,6 +141,11 @@ void node::implicit_broadcast(ir::module *mod, ir::value *&lhs, ir::value *&rhs) rhs = builder.create_broadcast(rhs, shapes); } +/* Helper */ +inline bool is_terminator(ir::value* x) { + return x && dynamic_cast(x); +} + /* Translation unit */ ir::value* translation_unit::codegen(ir::module *mod) const{ decls_->codegen(mod); @@ -242,8 +247,13 @@ ir::value* function_definition::codegen(ir::module *mod) const{ ir::value* compound_statement::codegen(ir::module* mod) const{ if(decls_) decls_->codegen(mod); - if(statements_) - statements_->codegen(mod); + if(statements_){ + for(statement *stmt: statements_->values()){ + ir::value *current = stmt->codegen(mod); + if(is_terminator(current)) + return current; + } + } return nullptr; } @@ -266,15 +276,18 @@ ir::value* iteration_statement::codegen(ir::module *mod) const{ ir::basic_block *current_bb = builder.get_insert_block(); ir::function *fn = current_bb->get_parent(); ir::basic_block *loop_bb = ir::basic_block::create(ctx, "loop", fn); + ir::basic_block *next_bb = ir::basic_block::create(ctx, "postloop", fn); + mod->set_continue_fn([&](){ + exec_->codegen(mod); + ir::value *cond = stop_->codegen(mod); + return builder.create_cond_br(cond, loop_bb, next_bb); + }); init_->codegen(mod); builder.create_br(loop_bb); builder.set_insert_point(loop_bb); - statements_->codegen(mod); - exec_->codegen(mod); - ir::value *cond = stop_->codegen(mod); + if(!is_terminator(statements_->codegen(mod))) + mod->get_continue_fn()(); ir::basic_block *stop_bb = builder.get_insert_block(); - ir::basic_block *next_bb = ir::basic_block::create(ctx, "postloop", fn); - builder.create_cond_br(cond, loop_bb, next_bb); mod->seal_block(stop_bb); mod->seal_block(loop_bb); mod->seal_block(builder.get_insert_block()); @@ -303,16 +316,22 @@ ir::value* selection_statement::codegen(ir::module* mod) const{ builder.create_cond_br(cond, then_bb, endif_bb); // Then builder.set_insert_point(then_bb); - then_value_->codegen(mod); - builder.create_br(endif_bb); + if(!is_terminator(then_value_->codegen(mod))) + builder.create_br(endif_bb); // Else if(else_value_){ builder.set_insert_point(else_bb); - else_value_->codegen(mod); - builder.create_br(endif_bb); + if(!is_terminator(else_value_->codegen(mod))) + builder.create_br(endif_bb); } // Endif builder.set_insert_point(endif_bb); + return nullptr; +} + +/* Continue statement */ +ir::value* continue_statement::codegen(ir::module *mod) const{ + return mod->get_continue_fn()(); } /* Declaration */ diff --git a/lib/codegen/allocation.cpp b/lib/codegen/allocation.cpp index 34ba1e59a..696b46cb9 100644 --- a/lib/codegen/allocation.cpp +++ b/lib/codegen/allocation.cpp @@ -106,9 +106,7 @@ void allocation::run(){ ir::phi_node *phi = (ir::phi_node*)x; for(unsigned i = 0; i < phi->get_num_incoming(); i++){ ir::value *inc_val = phi->get_incoming_value(i); - assert(offsets_.find(inc_val) == offsets_.end()); offsets_[inc_val] = offsets_[phi]; - std::cout << x->get_name() << " " << inc_val->get_name() << " " << inc_val << std::endl; } } } diff --git a/lib/codegen/buffer_info.cpp b/lib/codegen/buffer_info.cpp index 435f8ea8e..4d20fa6e9 100644 --- a/lib/codegen/buffer_info.cpp +++ b/lib/codegen/buffer_info.cpp @@ -11,6 +11,16 @@ namespace codegen{ // run pass on module +bool buffer_info_pass::is_loop_latch(ir::phi_node *phi, ir::value *terminator){ + if(auto *br = dynamic_cast(terminator)) + return br->get_true_dest() == phi->get_parent() + || br->get_false_dest() == phi->get_parent(); + else if(auto *br = dynamic_cast(terminator)) + return br->get_dest() == phi->get_parent(); + else + throw std::runtime_error("unreachable"); +} + void buffer_info_pass::run(ir::module &mod) { // Find which buffers are shared for(ir::function *fn: mod.get_function_list()) @@ -35,13 +45,7 @@ void buffer_info_pass::run(ir::module &mod) { for(unsigned n = 0; n < phi->get_num_incoming(); n++){ ir::basic_block *inc_block = phi->get_incoming_block(n); ir::value *terminator = inc_block->get_inst_list().back(); - if(auto *br = dynamic_cast(terminator)) - is_double = is_double || br->get_true_dest() == phi->get_parent() - || br->get_false_dest() == phi->get_parent(); - else if(auto *br = dynamic_cast(terminator)) - is_double = is_double || br->get_dest() == phi->get_parent(); - else - throw std::runtime_error("unreachable"); + is_double = is_double || is_loop_latch(phi, terminator); } // add to double-buffered if(is_double) @@ -49,7 +53,6 @@ void buffer_info_pass::run(ir::module &mod) { // set references of input for(unsigned n = 0; n < phi->get_num_incoming(); n++){ ir::value *inc_val = phi->get_incoming_value(n); - assert(refs_[inc_val] == nullptr); refs_[inc_val] = phi; } } diff --git a/lib/codegen/selection.cpp b/lib/codegen/selection.cpp index b7e3461d8..11bcf2738 100644 --- a/lib/codegen/selection.cpp +++ b/lib/codegen/selection.cpp @@ -299,7 +299,6 @@ std::vector delinearize(Value *trailing, std::vector &shapes, } void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id, Value *u_warp_id) { - std::cout << "name: " << v->get_name() << std::endl; const auto& shapes = v->get_type()->get_tile_shapes(); size_t dim = shapes.size(); std::vector contiguous(dim); @@ -406,8 +405,6 @@ void selection::create_tile(ir::value *v, IRBuilder<> &builder, unsigned id_pre = 0, id_loop = 1; if(phi->get_incoming_block(0) == phi->get_parent()) std::swap(id_pre, id_loop); - ir::value *pre_value = phi->get_incoming_value(id_pre); - ir::value *loop_value = phi->get_incoming_value(id_loop); if(parent->empty()) builder.SetInsertPoint(parent); else @@ -419,8 +416,13 @@ void selection::create_tile(ir::value *v, IRBuilder<> &builder, pre_ptr = builder.CreateBitCast(pre_ptr, ptr->getType()); Value *next_ptr = builder.CreateGEP(ptr, offset); tmap_.insert({phi, new shared_tile(ty, shapes, ptr, builder, offset)}); - tmap_.insert({pre_value, new shared_tile(ty, shapes, pre_ptr, builder)}); - tmap_.insert({loop_value, new shared_tile(ty, shapes, next_ptr, builder)}); + for(unsigned i = 0; i < phi->get_num_incoming(); i++) { + ir::basic_block* inc_block = phi->get_incoming_block(i); + ir::value* inc_value = phi->get_incoming_value(i); + ir::value* terminator = inc_block->get_inst_list().back(); + bool is_loop_latch = buffer_info_->is_loop_latch(phi, terminator); + tmap_.insert({inc_value, new shared_tile(ty, shapes, is_loop_latch?next_ptr:pre_ptr, builder)}); + } } else throw std::runtime_error("unknown shared memory tile"); @@ -479,7 +481,6 @@ void selection::init_grids(ir::function *fn, IRBuilder<> &builder, Value *sh_mem void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &builder) { - std::cout << "lowering " << ins->get_name() << std::endl; BasicBlock *block = builder.GetInsertBlock(); Module *module = block->getModule(); Function *function = block->getParent(); @@ -696,7 +697,6 @@ void selection::run(ir::module &src, Module &dst){ std::map last_block; // iterate through block for(ir::basic_block *block: fn->blocks()) { - std::cout << "block: " << block->get_name() << std::endl; BasicBlock *parent = (BasicBlock*)vmap_[block]; dst_builder.SetInsertPoint(parent); for(ir::instruction *i: block->get_inst_list()){ @@ -734,12 +734,10 @@ void selection::run(ir::module &src, Module &dst){ } } else { - std::cout << "phi: " << phi->get_name() << std::endl; for(unsigned n = 0; n < phi->get_num_incoming(); n++){ ir::value *inc_val = phi->get_incoming_value(n); ir::basic_block *inc_block = phi->get_incoming_block(n); BasicBlock *llvm_inc_block = last_block.at(inc_block); - std::cout << "incoming block: " << inc_block->get_name() << " " << llvm_inc_block->getName().str() << std::endl; if(phi->get_type()->is_tile_ty()) { distributed_tile *phi_tile = (distributed_tile*)tmap_.at(phi); distributed_tile *inc_tile = (distributed_tile*)tmap_.at(inc_val); diff --git a/lib/codegen/shared_copy.cpp b/lib/codegen/shared_copy.cpp index f759003bd..60c31199f 100644 --- a/lib/codegen/shared_copy.cpp +++ b/lib/codegen/shared_copy.cpp @@ -13,7 +13,6 @@ namespace codegen{ void place_shared_copy::add_copy(ir::value *x, ir::builder &builder) { if(auto *i = dynamic_cast(x)){ ir::basic_block* block = i->get_parent(); - std::cout << "adding copy: " << x->get_name() << " " << block->get_name() << std::endl; auto it = std::find(block->begin(), block->end(), i); builder.set_insert_point(++it); } diff --git a/lib/codegen/tune.cpp b/lib/codegen/tune.cpp index 8919f171b..9d4a08f2e 100644 --- a/lib/codegen/tune.cpp +++ b/lib/codegen/tune.cpp @@ -104,7 +104,6 @@ std::vector tune::get_params(ir::module &mod) { for(ir::instruction *i : block->get_inst_list()) for(auto &x: params_[i]) if(seen.insert(x.second).second && *x.second == 0){ - std::cout << typeid(*i).name() << std::endl; result.push_back(x.second); } return result; diff --git a/lib/ir/module.cpp b/lib/ir/module.cpp index fe6f0d48f..a8a11ff1c 100644 --- a/lib/ir/module.cpp +++ b/lib/ir/module.cpp @@ -37,6 +37,14 @@ void module::set_type(const std::string& name, ir::type *type){ return set_type(name, builder_.get_insert_block(), type); } +void module::set_continue_fn(std::function fn) { + continue_fn_ = fn; +} + +std::function module::get_continue_fn() { + return continue_fn_; +} + ir::phi_node* module::make_phi(ir::type *ty, unsigned num_values, ir::basic_block *block){ basic_block::iterator insert = block->get_first_non_phi(); if(insert != block->end()){ @@ -61,8 +69,6 @@ ir::value *module::try_remove_trivial_phis(ir::phi_node *&phi, ir::value** pre_u std::set users = phi->get_users(); phi->replace_all_uses_with(same); phi->erase_from_parent(); - if(pre_user) - *pre_user = same; for(ir::user* u: users) if(auto *uphi = dynamic_cast(u)) if(uphi != phi) @@ -80,11 +86,10 @@ ir::value *module::add_phi_operands(const std::string& name, ir::phi_node *&phi) ir::value *value = get_value(name, pred); phi->add_incoming(value, pred); } - return try_remove_trivial_phis(phi, nullptr); + return phi; } ir::value *module::get_value_recursive(const std::string& name, ir::basic_block *block) { - std::cout << "getting value " << name << std::endl; ir::value *result; auto &preds = block->get_predecessors(); if(block) @@ -101,6 +106,8 @@ ir::value *module::get_value_recursive(const std::string& name, ir::basic_block set_value(name, block, result); result = add_phi_operands(name, (ir::phi_node*&)result); } + if(auto *phi = dynamic_cast(result)) + result = try_remove_trivial_phis(phi, nullptr); set_value(name, block, result); return result; } @@ -138,9 +145,12 @@ ir::type *module::get_type(const std::string &name) { return types_.at({name, builder_.get_insert_block()}); } + void module::seal_block(ir::basic_block *block){ - for(auto &x: incomplete_phis_[block]) + for(auto &x: incomplete_phis_[block]){ add_phi_operands(x.first, x.second); + try_remove_trivial_phis(x.second, nullptr); + } sealed_blocks_.insert(block); incomplete_phis_[block].clear(); }