[code generation] more bugfixes in control flow
This commit is contained in:
@@ -62,8 +62,8 @@ void test(fp32 *a, fp32 *b, fp32 *c, int32 M, int32 N, int32 K, int32 bound){\
|
||||
pb = pb + 8*K;\
|
||||
@checka a = *pa;\
|
||||
@checkb b = *pb;\
|
||||
if(k <= 8){\
|
||||
@checka a = *pa;\
|
||||
if(k > 8){\
|
||||
continue;\
|
||||
}\
|
||||
}\
|
||||
@checkc *pc = C;\
|
||||
@@ -218,7 +218,6 @@ int main() {
|
||||
buffer_info.run(module);
|
||||
shared.run(module);
|
||||
liveness.run(module);
|
||||
tdl::ir::print(module, std::cout);
|
||||
allocation.run();
|
||||
barriers.run(module);
|
||||
vectorize.run(module);
|
||||
|
@@ -374,6 +374,18 @@ private:
|
||||
const node *statements_;
|
||||
};
|
||||
|
||||
// Jump
|
||||
|
||||
class jump_statement: public statement{
|
||||
public:
|
||||
using statement::statement;
|
||||
};
|
||||
|
||||
class continue_statement: public jump_statement{
|
||||
public:
|
||||
ir::value* codegen(ir::module *mod) const;
|
||||
};
|
||||
|
||||
class no_op: public statement { };
|
||||
|
||||
// Types
|
||||
|
@@ -48,7 +48,7 @@ TYPE_T get_type_spec(node *op) { return ((token*)op)->type; }
|
||||
%token SUB_ASSIGN LEFT_ASSIGN RIGHT_ASSIGN AND_ASSIGN
|
||||
%token XOR_ASSIGN OR_ASSIGN TYPE_NAME
|
||||
%token VOID UINT1 UINT8 UINT16 UINT32 UINT64 INT1 INT8 INT16 INT32 INT64 FP32 FP64
|
||||
%token IF ELSE FOR
|
||||
%token IF ELSE FOR CONTINUE
|
||||
%token NEWAXIS ELLIPSIS AT
|
||||
%token GET_GLOBAL_RANGE DOT
|
||||
|
||||
@@ -266,6 +266,7 @@ statement
|
||||
| expression_statement { $$ = $1; }
|
||||
| selection_statement { $$ = $1; }
|
||||
| iteration_statement { $$ = $1; }
|
||||
| jump_statement { $$ = $1; }
|
||||
;
|
||||
|
||||
compound_statement
|
||||
@@ -300,6 +301,9 @@ iteration_statement
|
||||
: FOR '(' expression_statement expression_statement expression ')' statement { $$ = new iteration_statement($3, $4, $5, $7); }
|
||||
;
|
||||
|
||||
jump_statement
|
||||
: CONTINUE ';' { $$ = new continue_statement(); }
|
||||
;
|
||||
|
||||
/* -------------------------- */
|
||||
/* Declarator */
|
||||
|
@@ -37,6 +37,7 @@ int comment();
|
||||
"..." { count(); return(ELLIPSIS); }
|
||||
"get_global_range" { count(); return GET_GLOBAL_RANGE; }
|
||||
"dot" { count(); return DOT;}
|
||||
"continue" { count(); return(CONTINUE); }
|
||||
|
||||
{L}({L}|{D})* { count(); return(check_type()); }
|
||||
|
||||
|
@@ -9,6 +9,7 @@ namespace tdl {
|
||||
namespace ir {
|
||||
class module;
|
||||
class value;
|
||||
class phi_node;
|
||||
}
|
||||
|
||||
namespace codegen{
|
||||
@@ -19,8 +20,10 @@ public:
|
||||
// queries
|
||||
bool is_double(ir::value *x);
|
||||
bool is_shared(ir::value *x);
|
||||
bool is_loop_latch(ir::phi_node *phi, ir::value *terminator);
|
||||
ir::value *get_reference(ir::value *x);
|
||||
|
||||
|
||||
private:
|
||||
std::set<ir::value*> shared_;
|
||||
std::set<ir::value*> double_;
|
||||
|
@@ -4,9 +4,17 @@
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <functional>
|
||||
#include "builder.h"
|
||||
|
||||
namespace tdl{
|
||||
|
||||
namespace ast{
|
||||
|
||||
class iteration_statement;
|
||||
|
||||
}
|
||||
|
||||
namespace ir{
|
||||
|
||||
class basic_block;
|
||||
@@ -27,10 +35,14 @@ class module {
|
||||
public:
|
||||
typedef std::map<std::string, global_value*> symbols_map_t;
|
||||
typedef std::vector<function*> functions_list_t;
|
||||
struct current_iteration_info_t{
|
||||
ast::iteration_statement *statement;
|
||||
basic_block *block;
|
||||
};
|
||||
|
||||
private:
|
||||
phi_node *make_phi(type *ty, unsigned num_values, basic_block *block);
|
||||
value *try_remove_trivial_phis(ir::phi_node *&phi, ir::value** pre_user);
|
||||
value *try_remove_trivial_phis(ir::phi_node *&phi, value **pre_user);
|
||||
value *add_phi_operands(const std::string& name, phi_node *&phi);
|
||||
value *get_value_recursive(const std::string& name, basic_block *block);
|
||||
void push_function(function *fn) { functions_.push_back(fn); }
|
||||
@@ -44,11 +56,13 @@ public:
|
||||
void set_value(const std::string& name, value* x);
|
||||
void set_type(const std::string& name, basic_block* block, type* x);
|
||||
void set_type(const std::string& name, type* x);
|
||||
void set_continue_fn(std::function<ir::value*()> fn);
|
||||
// Getters
|
||||
value *get_value(const std::string& name, basic_block* block);
|
||||
value *get_value(const std::string& name);
|
||||
type *get_type(const std::string& name, basic_block* block);
|
||||
type *get_type(const std::string& name);
|
||||
std::function<ir::value*()> get_continue_fn();
|
||||
// Seal block -- no more predecessors will be added
|
||||
void seal_block(basic_block *block);
|
||||
// Functions
|
||||
@@ -67,6 +81,8 @@ private:
|
||||
std::map<basic_block*, std::map<std::string, phi_node*>> incomplete_phis_;
|
||||
functions_list_t functions_;
|
||||
symbols_map_t symbols_;
|
||||
std::function<ir::value*()> continue_fn_;
|
||||
std::map<value*, value**> current_phi_;
|
||||
};
|
||||
|
||||
}
|
||||
|
@@ -141,6 +141,11 @@ void node::implicit_broadcast(ir::module *mod, ir::value *&lhs, ir::value *&rhs)
|
||||
rhs = builder.create_broadcast(rhs, shapes);
|
||||
}
|
||||
|
||||
/* Helper */
|
||||
inline bool is_terminator(ir::value* x) {
|
||||
return x && dynamic_cast<ir::terminator_inst*>(x);
|
||||
}
|
||||
|
||||
/* Translation unit */
|
||||
ir::value* translation_unit::codegen(ir::module *mod) const{
|
||||
decls_->codegen(mod);
|
||||
@@ -242,8 +247,13 @@ ir::value* function_definition::codegen(ir::module *mod) const{
|
||||
ir::value* compound_statement::codegen(ir::module* mod) const{
|
||||
if(decls_)
|
||||
decls_->codegen(mod);
|
||||
if(statements_)
|
||||
statements_->codegen(mod);
|
||||
if(statements_){
|
||||
for(statement *stmt: statements_->values()){
|
||||
ir::value *current = stmt->codegen(mod);
|
||||
if(is_terminator(current))
|
||||
return current;
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
@@ -266,15 +276,18 @@ ir::value* iteration_statement::codegen(ir::module *mod) const{
|
||||
ir::basic_block *current_bb = builder.get_insert_block();
|
||||
ir::function *fn = current_bb->get_parent();
|
||||
ir::basic_block *loop_bb = ir::basic_block::create(ctx, "loop", fn);
|
||||
ir::basic_block *next_bb = ir::basic_block::create(ctx, "postloop", fn);
|
||||
mod->set_continue_fn([&](){
|
||||
exec_->codegen(mod);
|
||||
ir::value *cond = stop_->codegen(mod);
|
||||
return builder.create_cond_br(cond, loop_bb, next_bb);
|
||||
});
|
||||
init_->codegen(mod);
|
||||
builder.create_br(loop_bb);
|
||||
builder.set_insert_point(loop_bb);
|
||||
statements_->codegen(mod);
|
||||
exec_->codegen(mod);
|
||||
ir::value *cond = stop_->codegen(mod);
|
||||
if(!is_terminator(statements_->codegen(mod)))
|
||||
mod->get_continue_fn()();
|
||||
ir::basic_block *stop_bb = builder.get_insert_block();
|
||||
ir::basic_block *next_bb = ir::basic_block::create(ctx, "postloop", fn);
|
||||
builder.create_cond_br(cond, loop_bb, next_bb);
|
||||
mod->seal_block(stop_bb);
|
||||
mod->seal_block(loop_bb);
|
||||
mod->seal_block(builder.get_insert_block());
|
||||
@@ -303,16 +316,22 @@ ir::value* selection_statement::codegen(ir::module* mod) const{
|
||||
builder.create_cond_br(cond, then_bb, endif_bb);
|
||||
// Then
|
||||
builder.set_insert_point(then_bb);
|
||||
then_value_->codegen(mod);
|
||||
builder.create_br(endif_bb);
|
||||
if(!is_terminator(then_value_->codegen(mod)))
|
||||
builder.create_br(endif_bb);
|
||||
// Else
|
||||
if(else_value_){
|
||||
builder.set_insert_point(else_bb);
|
||||
else_value_->codegen(mod);
|
||||
builder.create_br(endif_bb);
|
||||
if(!is_terminator(else_value_->codegen(mod)))
|
||||
builder.create_br(endif_bb);
|
||||
}
|
||||
// Endif
|
||||
builder.set_insert_point(endif_bb);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
/* Continue statement */
|
||||
ir::value* continue_statement::codegen(ir::module *mod) const{
|
||||
return mod->get_continue_fn()();
|
||||
}
|
||||
|
||||
/* Declaration */
|
||||
|
@@ -106,9 +106,7 @@ void allocation::run(){
|
||||
ir::phi_node *phi = (ir::phi_node*)x;
|
||||
for(unsigned i = 0; i < phi->get_num_incoming(); i++){
|
||||
ir::value *inc_val = phi->get_incoming_value(i);
|
||||
assert(offsets_.find(inc_val) == offsets_.end());
|
||||
offsets_[inc_val] = offsets_[phi];
|
||||
std::cout << x->get_name() << " " << inc_val->get_name() << " " << inc_val << std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@@ -11,6 +11,16 @@ namespace codegen{
|
||||
|
||||
|
||||
// run pass on module
|
||||
bool buffer_info_pass::is_loop_latch(ir::phi_node *phi, ir::value *terminator){
|
||||
if(auto *br = dynamic_cast<ir::cond_branch_inst*>(terminator))
|
||||
return br->get_true_dest() == phi->get_parent()
|
||||
|| br->get_false_dest() == phi->get_parent();
|
||||
else if(auto *br = dynamic_cast<ir::uncond_branch_inst*>(terminator))
|
||||
return br->get_dest() == phi->get_parent();
|
||||
else
|
||||
throw std::runtime_error("unreachable");
|
||||
}
|
||||
|
||||
void buffer_info_pass::run(ir::module &mod) {
|
||||
// Find which buffers are shared
|
||||
for(ir::function *fn: mod.get_function_list())
|
||||
@@ -35,13 +45,7 @@ void buffer_info_pass::run(ir::module &mod) {
|
||||
for(unsigned n = 0; n < phi->get_num_incoming(); n++){
|
||||
ir::basic_block *inc_block = phi->get_incoming_block(n);
|
||||
ir::value *terminator = inc_block->get_inst_list().back();
|
||||
if(auto *br = dynamic_cast<ir::cond_branch_inst*>(terminator))
|
||||
is_double = is_double || br->get_true_dest() == phi->get_parent()
|
||||
|| br->get_false_dest() == phi->get_parent();
|
||||
else if(auto *br = dynamic_cast<ir::uncond_branch_inst*>(terminator))
|
||||
is_double = is_double || br->get_dest() == phi->get_parent();
|
||||
else
|
||||
throw std::runtime_error("unreachable");
|
||||
is_double = is_double || is_loop_latch(phi, terminator);
|
||||
}
|
||||
// add to double-buffered
|
||||
if(is_double)
|
||||
@@ -49,7 +53,6 @@ void buffer_info_pass::run(ir::module &mod) {
|
||||
// set references of input
|
||||
for(unsigned n = 0; n < phi->get_num_incoming(); n++){
|
||||
ir::value *inc_val = phi->get_incoming_value(n);
|
||||
assert(refs_[inc_val] == nullptr);
|
||||
refs_[inc_val] = phi;
|
||||
}
|
||||
}
|
||||
|
@@ -299,7 +299,6 @@ std::vector<Value*> delinearize(Value *trailing, std::vector<unsigned> &shapes,
|
||||
}
|
||||
|
||||
void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id, Value *u_warp_id) {
|
||||
std::cout << "name: " << v->get_name() << std::endl;
|
||||
const auto& shapes = v->get_type()->get_tile_shapes();
|
||||
size_t dim = shapes.size();
|
||||
std::vector<unsigned> contiguous(dim);
|
||||
@@ -406,8 +405,6 @@ void selection::create_tile(ir::value *v, IRBuilder<> &builder,
|
||||
unsigned id_pre = 0, id_loop = 1;
|
||||
if(phi->get_incoming_block(0) == phi->get_parent())
|
||||
std::swap(id_pre, id_loop);
|
||||
ir::value *pre_value = phi->get_incoming_value(id_pre);
|
||||
ir::value *loop_value = phi->get_incoming_value(id_loop);
|
||||
if(parent->empty())
|
||||
builder.SetInsertPoint(parent);
|
||||
else
|
||||
@@ -419,8 +416,13 @@ void selection::create_tile(ir::value *v, IRBuilder<> &builder,
|
||||
pre_ptr = builder.CreateBitCast(pre_ptr, ptr->getType());
|
||||
Value *next_ptr = builder.CreateGEP(ptr, offset);
|
||||
tmap_.insert({phi, new shared_tile(ty, shapes, ptr, builder, offset)});
|
||||
tmap_.insert({pre_value, new shared_tile(ty, shapes, pre_ptr, builder)});
|
||||
tmap_.insert({loop_value, new shared_tile(ty, shapes, next_ptr, builder)});
|
||||
for(unsigned i = 0; i < phi->get_num_incoming(); i++) {
|
||||
ir::basic_block* inc_block = phi->get_incoming_block(i);
|
||||
ir::value* inc_value = phi->get_incoming_value(i);
|
||||
ir::value* terminator = inc_block->get_inst_list().back();
|
||||
bool is_loop_latch = buffer_info_->is_loop_latch(phi, terminator);
|
||||
tmap_.insert({inc_value, new shared_tile(ty, shapes, is_loop_latch?next_ptr:pre_ptr, builder)});
|
||||
}
|
||||
}
|
||||
else
|
||||
throw std::runtime_error("unknown shared memory tile");
|
||||
@@ -479,7 +481,6 @@ void selection::init_grids(ir::function *fn, IRBuilder<> &builder, Value *sh_mem
|
||||
|
||||
|
||||
void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &builder) {
|
||||
std::cout << "lowering " << ins->get_name() << std::endl;
|
||||
BasicBlock *block = builder.GetInsertBlock();
|
||||
Module *module = block->getModule();
|
||||
Function *function = block->getParent();
|
||||
@@ -696,7 +697,6 @@ void selection::run(ir::module &src, Module &dst){
|
||||
std::map<ir::basic_block*, BasicBlock*> last_block;
|
||||
// iterate through block
|
||||
for(ir::basic_block *block: fn->blocks()) {
|
||||
std::cout << "block: " << block->get_name() << std::endl;
|
||||
BasicBlock *parent = (BasicBlock*)vmap_[block];
|
||||
dst_builder.SetInsertPoint(parent);
|
||||
for(ir::instruction *i: block->get_inst_list()){
|
||||
@@ -734,12 +734,10 @@ void selection::run(ir::module &src, Module &dst){
|
||||
}
|
||||
}
|
||||
else {
|
||||
std::cout << "phi: " << phi->get_name() << std::endl;
|
||||
for(unsigned n = 0; n < phi->get_num_incoming(); n++){
|
||||
ir::value *inc_val = phi->get_incoming_value(n);
|
||||
ir::basic_block *inc_block = phi->get_incoming_block(n);
|
||||
BasicBlock *llvm_inc_block = last_block.at(inc_block);
|
||||
std::cout << "incoming block: " << inc_block->get_name() << " " << llvm_inc_block->getName().str() << std::endl;
|
||||
if(phi->get_type()->is_tile_ty()) {
|
||||
distributed_tile *phi_tile = (distributed_tile*)tmap_.at(phi);
|
||||
distributed_tile *inc_tile = (distributed_tile*)tmap_.at(inc_val);
|
||||
|
@@ -13,7 +13,6 @@ namespace codegen{
|
||||
void place_shared_copy::add_copy(ir::value *x, ir::builder &builder) {
|
||||
if(auto *i = dynamic_cast<ir::instruction*>(x)){
|
||||
ir::basic_block* block = i->get_parent();
|
||||
std::cout << "adding copy: " << x->get_name() << " " << block->get_name() << std::endl;
|
||||
auto it = std::find(block->begin(), block->end(), i);
|
||||
builder.set_insert_point(++it);
|
||||
}
|
||||
|
@@ -104,7 +104,6 @@ std::vector<unsigned*> tune::get_params(ir::module &mod) {
|
||||
for(ir::instruction *i : block->get_inst_list())
|
||||
for(auto &x: params_[i])
|
||||
if(seen.insert(x.second).second && *x.second == 0){
|
||||
std::cout << typeid(*i).name() << std::endl;
|
||||
result.push_back(x.second);
|
||||
}
|
||||
return result;
|
||||
|
@@ -37,6 +37,14 @@ void module::set_type(const std::string& name, ir::type *type){
|
||||
return set_type(name, builder_.get_insert_block(), type);
|
||||
}
|
||||
|
||||
void module::set_continue_fn(std::function<ir::value*()> fn) {
|
||||
continue_fn_ = fn;
|
||||
}
|
||||
|
||||
std::function<ir::value*()> module::get_continue_fn() {
|
||||
return continue_fn_;
|
||||
}
|
||||
|
||||
ir::phi_node* module::make_phi(ir::type *ty, unsigned num_values, ir::basic_block *block){
|
||||
basic_block::iterator insert = block->get_first_non_phi();
|
||||
if(insert != block->end()){
|
||||
@@ -61,8 +69,6 @@ ir::value *module::try_remove_trivial_phis(ir::phi_node *&phi, ir::value** pre_u
|
||||
std::set<ir::user*> users = phi->get_users();
|
||||
phi->replace_all_uses_with(same);
|
||||
phi->erase_from_parent();
|
||||
if(pre_user)
|
||||
*pre_user = same;
|
||||
for(ir::user* u: users)
|
||||
if(auto *uphi = dynamic_cast<ir::phi_node*>(u))
|
||||
if(uphi != phi)
|
||||
@@ -80,11 +86,10 @@ ir::value *module::add_phi_operands(const std::string& name, ir::phi_node *&phi)
|
||||
ir::value *value = get_value(name, pred);
|
||||
phi->add_incoming(value, pred);
|
||||
}
|
||||
return try_remove_trivial_phis(phi, nullptr);
|
||||
return phi;
|
||||
}
|
||||
|
||||
ir::value *module::get_value_recursive(const std::string& name, ir::basic_block *block) {
|
||||
std::cout << "getting value " << name << std::endl;
|
||||
ir::value *result;
|
||||
auto &preds = block->get_predecessors();
|
||||
if(block)
|
||||
@@ -101,6 +106,8 @@ ir::value *module::get_value_recursive(const std::string& name, ir::basic_block
|
||||
set_value(name, block, result);
|
||||
result = add_phi_operands(name, (ir::phi_node*&)result);
|
||||
}
|
||||
if(auto *phi = dynamic_cast<ir::phi_node*>(result))
|
||||
result = try_remove_trivial_phis(phi, nullptr);
|
||||
set_value(name, block, result);
|
||||
return result;
|
||||
}
|
||||
@@ -138,9 +145,12 @@ ir::type *module::get_type(const std::string &name) {
|
||||
return types_.at({name, builder_.get_insert_block()});
|
||||
}
|
||||
|
||||
|
||||
void module::seal_block(ir::basic_block *block){
|
||||
for(auto &x: incomplete_phis_[block])
|
||||
for(auto &x: incomplete_phis_[block]){
|
||||
add_phi_operands(x.first, x.second);
|
||||
try_remove_trivial_phis(x.second, nullptr);
|
||||
}
|
||||
sealed_blocks_.insert(block);
|
||||
incomplete_phis_[block].clear();
|
||||
}
|
||||
|
Reference in New Issue
Block a user