From 5fdb27d9aec401f8a55fac6a3e7a2f6e3da3f4b0 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Wed, 6 Feb 2019 23:34:45 -0500 Subject: [PATCH] [code generation] fixed bug in tile phi nodes --- examples/matrix.cpp | 3 +++ include/codegen/selection.h | 2 -- lib/codegen/selection.cpp | 39 ++++++++++++++++++++++--------------- 3 files changed, 26 insertions(+), 18 deletions(-) diff --git a/examples/matrix.cpp b/examples/matrix.cpp index 9599a8456..e5ea0b1ba 100644 --- a/examples/matrix.cpp +++ b/examples/matrix.cpp @@ -33,6 +33,9 @@ void test(fp32 *a, fp32 *b, fp32 *c, int32 M, int32 N, int32 K){\ fp32* pa[32, 8] = a + rx[:, newaxis] + rka[newaxis, :]*M;\ fp32* pb[32, 8] = b + ry[:, newaxis] + rkb[newaxis, :]*K;\ fp32* pc[32, 32] = c + rx[:, newaxis] + ry[newaxis, :]*M;\ + for(k = K; k >= 0; k = k - 8){\ + C = C + 1;\ + }\ *pc = C;\ }\ "; diff --git a/include/codegen/selection.h b/include/codegen/selection.h index 9a9e03135..8b39ee150 100644 --- a/include/codegen/selection.h +++ b/include/codegen/selection.h @@ -72,7 +72,6 @@ private: class selection{ typedef std::map vmap_t; - typedef std::map bmap_t; typedef std::map tmap_t; private: @@ -100,7 +99,6 @@ public: private: vmap_t vmap_; - bmap_t bmap_; tmap_t tmap_; allocation *alloc_; tune *params_; diff --git a/lib/codegen/selection.cpp b/lib/codegen/selection.cpp index 213ecd6d2..bba6e8e43 100644 --- a/lib/codegen/selection.cpp +++ b/lib/codegen/selection.cpp @@ -107,7 +107,7 @@ Constant *selection::llvm_constant(ir::constant *cst, LLVMContext &ctx) { /* convert ir::instruction to llvm::Instruction */ Instruction *selection::llvm_inst(ir::instruction *inst, std::function value, IRBuilder<> &builder) { LLVMContext & ctx = builder.getContext(); - auto block = [&](ir::basic_block *x) { return bmap_.at(x); }; + auto block = [&](ir::basic_block *x) { return (BasicBlock*)vmap_.at(x); }; auto type = [&](ir::type *x) { return llvm_type(x, ctx); }; if(auto* ii = dynamic_cast(inst)){ BasicBlock *true_dest = block(ii->get_true_dest()); @@ -163,7 +163,7 @@ Instruction *selection::llvm_inst(ir::instruction *inst, std::function &builder) { if(vmap_.find(v) != vmap_.end()) return vmap_.at(v); // create operands - if(auto *uu = dynamic_cast(v)) - for(ir::value* u: uu->ops()) - vmap_.insert({u, llvm_value(u, builder)}); if(auto *cc = dynamic_cast(v)) return llvm_constant(cc, ctx); // instruction @@ -445,7 +442,6 @@ void selection::lower_instruction(ir::instruction *src, IRBuilder<> &builder) { void selection::run(ir::module &src, Module &dst){ vmap_.clear(); - bmap_.clear(); LLVMContext &dst_ctx = dst.getContext(); IRBuilder<> dst_builder(dst_ctx); // iterate over functions @@ -459,14 +455,14 @@ void selection::run(ir::module &src, Module &dst){ // create blocks for(ir::basic_block *block: fn->blocks()) { BasicBlock *dst_block = BasicBlock::Create(dst_ctx, block->get_name(), dst_fn); - bmap_[block] = dst_block; + vmap_[block] = dst_block; } // create grids - dst_builder.SetInsertPoint(bmap_[fn->blocks()[0]]); + dst_builder.SetInsertPoint((BasicBlock*)vmap_[fn->blocks()[0]]); init_grids(fn, dst_builder); // iterate through block for(ir::basic_block *block: fn->blocks()) { - dst_builder.SetInsertPoint(bmap_[block]); + dst_builder.SetInsertPoint((BasicBlock*)vmap_[block]); for(ir::instruction *i: block->get_inst_list()) lower_instruction(i, dst_builder); } @@ -474,13 +470,24 @@ void selection::run(ir::module &src, Module &dst){ for(ir::basic_block *block: fn->blocks()) for(ir::instruction *inst: block->get_inst_list()) if(auto *phi = dynamic_cast(inst)){ - PHINode *dst_phi = (PHINode*)vmap_.at(phi); - for(unsigned i = 0; i < phi->get_num_incoming(); i++){ - ir::value *inc_val = phi->get_incoming_value(i); - ir::basic_block *inc_block = phi->get_incoming_block(i); - Value *llvm_inc_val = llvm_value(inc_val, dst_builder); - BasicBlock *llvm_block = bmap_[inc_block]; - dst_phi->addIncoming(llvm_inc_val, llvm_block); + for(unsigned n = 0; n < phi->get_num_incoming(); n++){ + ir::value *inc_val = phi->get_incoming_value(n); + ir::basic_block *inc_block = phi->get_incoming_block(n); + BasicBlock *llvm_inc_block = (BasicBlock*)vmap_[inc_block]; + if(phi->get_type()->is_tile_ty()) { + distributed_tile *phi_tile = (distributed_tile*)tmap_.at(phi); + distributed_tile *inc_tile = (distributed_tile*)tmap_.at(inc_val); + phi_tile->for_each([&](indices_t idx){ + PHINode *llvm_phi = (PHINode*)phi_tile->get_value(idx); + Value *llvm_inc_val = inc_tile->get_value(idx); + llvm_phi->addIncoming(llvm_inc_val, llvm_inc_block); + }); + } + else { + PHINode *llvm_phi = (PHINode*)vmap_.at(phi); + Value *llvm_inc_val = vmap_.at(inc_val); + llvm_phi->addIncoming(llvm_inc_val, llvm_inc_block); + } } } }