[code generation] fixed bug in tile phi nodes

This commit is contained in:
Philippe Tillet
2019-02-06 23:34:45 -05:00
parent 53aca3fa89
commit 5fdb27d9ae
3 changed files with 26 additions and 18 deletions

View File

@@ -33,6 +33,9 @@ void test(fp32 *a, fp32 *b, fp32 *c, int32 M, int32 N, int32 K){\
fp32* pa[32, 8] = a + rx[:, newaxis] + rka[newaxis, :]*M;\ fp32* pa[32, 8] = a + rx[:, newaxis] + rka[newaxis, :]*M;\
fp32* pb[32, 8] = b + ry[:, newaxis] + rkb[newaxis, :]*K;\ fp32* pb[32, 8] = b + ry[:, newaxis] + rkb[newaxis, :]*K;\
fp32* pc[32, 32] = c + rx[:, newaxis] + ry[newaxis, :]*M;\ fp32* pc[32, 32] = c + rx[:, newaxis] + ry[newaxis, :]*M;\
for(k = K; k >= 0; k = k - 8){\
C = C + 1;\
}\
*pc = C;\ *pc = C;\
}\ }\
"; ";

View File

@@ -72,7 +72,6 @@ private:
class selection{ class selection{
typedef std::map<ir::value *, llvm::Value *> vmap_t; typedef std::map<ir::value *, llvm::Value *> vmap_t;
typedef std::map<ir::basic_block *, llvm::BasicBlock *> bmap_t;
typedef std::map<ir::value *, tile *> tmap_t; typedef std::map<ir::value *, tile *> tmap_t;
private: private:
@@ -100,7 +99,6 @@ public:
private: private:
vmap_t vmap_; vmap_t vmap_;
bmap_t bmap_;
tmap_t tmap_; tmap_t tmap_;
allocation *alloc_; allocation *alloc_;
tune *params_; tune *params_;

View File

@@ -107,7 +107,7 @@ Constant *selection::llvm_constant(ir::constant *cst, LLVMContext &ctx) {
/* convert ir::instruction to llvm::Instruction */ /* convert ir::instruction to llvm::Instruction */
Instruction *selection::llvm_inst(ir::instruction *inst, std::function<Value*(ir::value*)> value, IRBuilder<> &builder) { Instruction *selection::llvm_inst(ir::instruction *inst, std::function<Value*(ir::value*)> value, IRBuilder<> &builder) {
LLVMContext & ctx = builder.getContext(); LLVMContext & ctx = builder.getContext();
auto block = [&](ir::basic_block *x) { return bmap_.at(x); }; auto block = [&](ir::basic_block *x) { return (BasicBlock*)vmap_.at(x); };
auto type = [&](ir::type *x) { return llvm_type(x, ctx); }; auto type = [&](ir::type *x) { return llvm_type(x, ctx); };
if(auto* ii = dynamic_cast<ir::cond_branch_inst*>(inst)){ if(auto* ii = dynamic_cast<ir::cond_branch_inst*>(inst)){
BasicBlock *true_dest = block(ii->get_true_dest()); BasicBlock *true_dest = block(ii->get_true_dest());
@@ -163,7 +163,7 @@ Instruction *selection::llvm_inst(ir::instruction *inst, std::function<Value*(ir
return builder.Insert(new LoadInst(ptr)); return builder.Insert(new LoadInst(ptr));
} }
// unknown instruction // unknown instruction
throw std::runtime_error("unknown conversion from ir::type to Type"); throw std::runtime_error("unknown conversion from ir::instruction to Instruction");
} }
/* convert ir::value to llvm::Value */ /* convert ir::value to llvm::Value */
@@ -173,9 +173,6 @@ Value* selection::llvm_value(ir::value *v, IRBuilder<> &builder) {
if(vmap_.find(v) != vmap_.end()) if(vmap_.find(v) != vmap_.end())
return vmap_.at(v); return vmap_.at(v);
// create operands // create operands
if(auto *uu = dynamic_cast<ir::user*>(v))
for(ir::value* u: uu->ops())
vmap_.insert({u, llvm_value(u, builder)});
if(auto *cc = dynamic_cast<ir::constant*>(v)) if(auto *cc = dynamic_cast<ir::constant*>(v))
return llvm_constant(cc, ctx); return llvm_constant(cc, ctx);
// instruction // instruction
@@ -445,7 +442,6 @@ void selection::lower_instruction(ir::instruction *src, IRBuilder<> &builder) {
void selection::run(ir::module &src, Module &dst){ void selection::run(ir::module &src, Module &dst){
vmap_.clear(); vmap_.clear();
bmap_.clear();
LLVMContext &dst_ctx = dst.getContext(); LLVMContext &dst_ctx = dst.getContext();
IRBuilder<> dst_builder(dst_ctx); IRBuilder<> dst_builder(dst_ctx);
// iterate over functions // iterate over functions
@@ -459,14 +455,14 @@ void selection::run(ir::module &src, Module &dst){
// create blocks // create blocks
for(ir::basic_block *block: fn->blocks()) { for(ir::basic_block *block: fn->blocks()) {
BasicBlock *dst_block = BasicBlock::Create(dst_ctx, block->get_name(), dst_fn); BasicBlock *dst_block = BasicBlock::Create(dst_ctx, block->get_name(), dst_fn);
bmap_[block] = dst_block; vmap_[block] = dst_block;
} }
// create grids // create grids
dst_builder.SetInsertPoint(bmap_[fn->blocks()[0]]); dst_builder.SetInsertPoint((BasicBlock*)vmap_[fn->blocks()[0]]);
init_grids(fn, dst_builder); init_grids(fn, dst_builder);
// iterate through block // iterate through block
for(ir::basic_block *block: fn->blocks()) { for(ir::basic_block *block: fn->blocks()) {
dst_builder.SetInsertPoint(bmap_[block]); dst_builder.SetInsertPoint((BasicBlock*)vmap_[block]);
for(ir::instruction *i: block->get_inst_list()) for(ir::instruction *i: block->get_inst_list())
lower_instruction(i, dst_builder); lower_instruction(i, dst_builder);
} }
@@ -474,13 +470,24 @@ void selection::run(ir::module &src, Module &dst){
for(ir::basic_block *block: fn->blocks()) for(ir::basic_block *block: fn->blocks())
for(ir::instruction *inst: block->get_inst_list()) for(ir::instruction *inst: block->get_inst_list())
if(auto *phi = dynamic_cast<ir::phi_node*>(inst)){ if(auto *phi = dynamic_cast<ir::phi_node*>(inst)){
PHINode *dst_phi = (PHINode*)vmap_.at(phi); for(unsigned n = 0; n < phi->get_num_incoming(); n++){
for(unsigned i = 0; i < phi->get_num_incoming(); i++){ ir::value *inc_val = phi->get_incoming_value(n);
ir::value *inc_val = phi->get_incoming_value(i); ir::basic_block *inc_block = phi->get_incoming_block(n);
ir::basic_block *inc_block = phi->get_incoming_block(i); BasicBlock *llvm_inc_block = (BasicBlock*)vmap_[inc_block];
Value *llvm_inc_val = llvm_value(inc_val, dst_builder); if(phi->get_type()->is_tile_ty()) {
BasicBlock *llvm_block = bmap_[inc_block]; distributed_tile *phi_tile = (distributed_tile*)tmap_.at(phi);
dst_phi->addIncoming(llvm_inc_val, llvm_block); distributed_tile *inc_tile = (distributed_tile*)tmap_.at(inc_val);
phi_tile->for_each([&](indices_t idx){
PHINode *llvm_phi = (PHINode*)phi_tile->get_value(idx);
Value *llvm_inc_val = inc_tile->get_value(idx);
llvm_phi->addIncoming(llvm_inc_val, llvm_inc_block);
});
}
else {
PHINode *llvm_phi = (PHINode*)vmap_.at(phi);
Value *llvm_inc_val = vmap_.at(inc_val);
llvm_phi->addIncoming(llvm_inc_val, llvm_inc_block);
}
} }
} }
} }