[code generation] better predication

This commit is contained in:
Philippe Tillet
2019-03-01 14:36:17 -05:00
parent 36acf22fd3
commit 08fcfbca47
6 changed files with 76 additions and 59 deletions

View File

@@ -209,7 +209,6 @@ int main() {
llvm::Module llvm_module("matmul", llvm_context); llvm::Module llvm_module("matmul", llvm_context);
triton::ir::print(module, std::cout);
// create passes // create passes
triton::codegen::buffer_info_pass buffer_info; triton::codegen::buffer_info_pass buffer_info;
@@ -264,12 +263,14 @@ int main() {
// run passes // run passes
triton::ir::print(module, std::cout);
buffer_info.run(module); buffer_info.run(module);
shared.run(module); shared.run(module);
liveness.run(module); liveness.run(module);
allocation.run(); allocation.run();
barriers.run(module); barriers.run(module);
vectorize.run(module); vectorize.run(module);
triton::ir::print(module, std::cout);
selection.run(module, llvm_module); selection.run(module, llvm_module);
// llvm source // llvm source

View File

@@ -299,6 +299,8 @@ public:
: lvalue_((named_expression*)lvalue), op_(op), rvalue_((expression*)rvalue) { } : lvalue_((named_expression*)lvalue), op_(op), rvalue_((expression*)rvalue) { }
ir::value* codegen(ir::module *mod) const; ir::value* codegen(ir::module *mod) const;
const expression *lvalue() const { return lvalue_; }
const expression *rvalue() const { return rvalue_; }
public: public:
ASSIGN_OP_T op_; ASSIGN_OP_T op_;

View File

@@ -100,6 +100,7 @@ private:
class selection{ class selection{
typedef std::map<ir::value *, llvm::Value *> vmap_t; typedef std::map<ir::value *, llvm::Value *> vmap_t;
typedef std::map<ir::value *, tile *> tmap_t; typedef std::map<ir::value *, tile *> tmap_t;
typedef std::map<std::pair<tile*, indices_t>, llvm::BasicBlock*> pmap_t;
private: private:
// utils // utils
@@ -131,6 +132,8 @@ public:
private: private:
vmap_t vmap_; vmap_t vmap_;
tmap_t tmap_; tmap_t tmap_;
pmap_t pmap_;
pmap_t last_block_;
allocation *alloc_; allocation *alloc_;
tune *params_; tune *params_;
buffer_info_pass *buffer_info_; buffer_info_pass *buffer_info_;

View File

@@ -346,6 +346,10 @@ public:
static merge_inst* create(ir::value *mask_true, ir::value *value_true, static merge_inst* create(ir::value *mask_true, ir::value *value_true,
ir::value *mask_false, ir::value *value_false, ir::value *mask_false, ir::value *value_false,
const std::string &name = "", instruction *next = nullptr); const std::string &name = "", instruction *next = nullptr);
ir::value *get_mask_true() { return get_operand(0); }
ir::value *get_value_true() { return get_operand(1); }
ir::value *get_mask_false() { return get_operand(2); }
ir::value *get_value_false() { return get_operand(3); }
}; };

View File

@@ -302,21 +302,25 @@ ir::value* compound_statement::codegen(ir::module* mod) const{
/* expression statement */ /* expression statement */
ir::value* expression_statement::codegen(ir::module *mod) const{ ir::value* expression_statement::codegen(ir::module *mod) const{
ir::builder &builder = mod->get_builder(); ir::builder &builder = mod->get_builder();
ir::value *expr = expr_->codegen(mod);
if(mask_) { if(mask_) {
ir::value *pred = mask_->codegen(mod); ir::value *pred = mask_->codegen(mod);
ir::mask_inst *mask = (ir::mask_inst*)builder.create_mask(pred); ir::mask_inst *mask = (ir::mask_inst*)builder.create_mask(pred);
ir::value *true_value = expr_->codegen(mod); ir::value *true_value = expr_->codegen(mod);
assignment_expression *assignment = dynamic_cast<assignment_expression*>(expr_);
assert(assignment);
ir::type *ty = true_value->get_type(); ir::type *ty = true_value->get_type();
if(auto *itn = dynamic_cast<ir::instruction*>(true_value)) if(auto *itn = dynamic_cast<ir::instruction*>(true_value))
itn->set_mask_pred(mask->get_result(0)); itn->set_mask_pred(mask->get_result(0));
if(expr->get_type()->is_void_ty()) if(ty->is_void_ty())
return expr; return true_value;
ir::merge_inst *merge = (ir::merge_inst*)builder.create_merge(mask->get_result(0), true_value, ir::merge_inst *merge = (ir::merge_inst*)builder.create_merge(mask->get_result(0), true_value,
mask->get_result(1), ir::undef_value::get(ty)); mask->get_result(1), ir::undef_value::get(ty));
std::string name = ((named_expression*)assignment->lvalue())->id()->name();
mod->set_value(name, merge);
return merge; return merge;
} }
return expr; return expr_->codegen(mod);
} }
/* Iteration statement */ /* Iteration statement */

View File

@@ -517,41 +517,21 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
BasicBlock *block = builder.GetInsertBlock(); BasicBlock *block = builder.GetInsertBlock();
Module *module = block->getModule(); Module *module = block->getModule();
LLVMContext &ctx = builder.getContext(); LLVMContext &ctx = builder.getContext();
// // helper to handle masks Function *fn = block->getParent();
// auto insert_masked = [&](indices_t idx, std::function<Value*()> insert_value) { ir::value *mask = ins->get_mask_pred();
// BasicBlock *block = builder.GetInsertBlock(); auto set_mask_insert_pt = [&](indices_t idx){
// Value *result; if(mask){
// if(mask_pred){ distributed_tile *mask_tile = (distributed_tile*)tmap_.at(ins->get_mask_pred());
// Value *llvm_mask = tmap_.at(mask_pred)->get_value(idx); BasicBlock *block = pmap_.at({mask_tile, idx});
// BasicBlock *then_bb = BasicBlock::Create(ctx, "", function); builder.SetInsertPoint(block->getTerminator());
// BasicBlock *done_bb = BasicBlock::Create(ctx, "", function); }
// builder.CreateCondBr(llvm_mask, then_bb, done_bb); };
// builder.SetInsertPoint(then_bb);
// result = insert_value();
// builder.CreateBr(done_bb);
// builder.SetInsertPoint(done_bb);
// if(!ins->get_type()->is_void_ty()){
// Type *ty = result->getType();
// PHINode *phi = builder.CreatePHI(ty, 2);
// if(mask_else)
// phi->addIncoming(tmap_.at(mask_else)->get_value(idx), block);
// else
// phi->addIncoming(llvm::UndefValue::get(ty), block);
// phi->addIncoming(result, then_bb);
// return (Value*)phi;
// }
// }
// else
// result = insert_value();
// return result;
// };
std::cout << ins->get_name() << " " << typeid(*ins).name() << std::endl;
// store // store
if(auto *x = dynamic_cast<ir::store_inst*>(ins)) { if(auto *x = dynamic_cast<ir::store_inst*>(ins)) {
distributed_tile* ptr = (distributed_tile*)tmap_.at(x->get_pointer_operand()); distributed_tile* ptr = (distributed_tile*)tmap_.at(x->get_pointer_operand());
tile *value = tmap_.at(x->get_value_operand()); tile *value = tmap_.at(x->get_value_operand());
ptr->for_each([&](indices_t idx){ ptr->for_each([&](indices_t idx){
set_mask_insert_pt(idx);
builder.CreateStore(value->get_value(idx), ptr->get_value(idx)); builder.CreateStore(value->get_value(idx), ptr->get_value(idx));
}); });
} }
@@ -578,24 +558,46 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
} }
// mask // mask
else if(dynamic_cast<ir::mask_inst*>(ins)) { else if(dynamic_cast<ir::mask_inst*>(ins)) {
// distributed_tile* pred = (distributed_tile*)ins->get_operand(0); distributed_tile* pred = (distributed_tile*)tmap_.at(ins->get_operand(0));
// BasicBlock *mask_done_bb = BasicBlock::Create(ctx, "mask_done"); distributed_tile* mask_tile_true = (distributed_tile*)tmap_.at(ins->get_result(0));
// pred->for_each([&](indices_t idx){ distributed_tile* mask_tile_false = (distributed_tile*)tmap_.at(ins->get_result(1));
// BasicBlock *mask_if_bb = BasicBlock::Create(ctx, "mask_if"); pred->for_each([&](indices_t idx){
// BasicBlock* mask_else_bb = BasicBlock::Create(ctx, "mask_else"); BasicBlock *mask_then_bb = BasicBlock::Create(ctx, "mask_then", fn);
// builder.CreateCondBr(pred->get_value(idx), mask_if_bb, mask_else_bb); BasicBlock* mask_else_bb = BasicBlock::Create(ctx, "mask_else", fn);
// builder.SetInsertPoint(mask_if_bb); BasicBlock *mask_done_bb = BasicBlock::Create(ctx, "mask_done", fn);
// builder.CreateBr(mask_done_bb); builder.CreateCondBr(pred->get_value(idx), mask_then_bb, mask_else_bb);
// builder.SetInsertPoint(mask_else_bb); builder.SetInsertPoint(mask_then_bb);
// builder.CreateBr(mask_done_bb); builder.CreateBr(mask_done_bb);
// }); builder.SetInsertPoint(mask_else_bb);
// builder.SetInsertPoint(mask_done_bb); builder.CreateBr(mask_done_bb);
builder.SetInsertPoint(mask_done_bb);
pmap_.insert({{mask_tile_true, idx}, mask_then_bb});
pmap_.insert({{mask_tile_false, idx}, mask_else_bb});
last_block_.insert({{mask_tile_true, idx}, mask_done_bb});
last_block_.insert({{mask_tile_false, idx}, mask_done_bb});
});
} }
// merge // merge
else if(dynamic_cast<ir::merge_inst*>(ins)) { else if(auto *merge = dynamic_cast<ir::merge_inst*>(ins)) {
// result->for_each([&](indices_t idx){ distributed_tile* mask_tile_true = (distributed_tile*)tmap_.at(merge->get_mask_true());
// std::cout << "merge" << std::endl; distributed_tile *value_tile_true = (distributed_tile*)tmap_.at(merge->get_value_true());
// }); distributed_tile* mask_tile_false = (distributed_tile*)tmap_.at(merge->get_mask_false());
distributed_tile *value_tile_false = (distributed_tile*)tmap_.at(merge->get_value_false());
result->for_each([&](indices_t idx){
BasicBlock *block_true = pmap_.at({mask_tile_true, idx});
Value *value_true = value_tile_true->get_value(idx);
BasicBlock *block_false = pmap_.at({mask_tile_false, idx});
Value *value_false = value_tile_false->get_value(idx);
BasicBlock *block_done = last_block_.at({mask_tile_true, idx});
if(block_done->empty())
builder.SetInsertPoint(block_done);
else
builder.SetInsertPoint(block_done->getTerminator());
PHINode *phi = builder.CreatePHI(value_true->getType(), 2);
phi->addIncoming(value_true, block_true);
phi->addIncoming(value_false,block_false);
result->set_value(idx, phi);
});
} }
// reshape // reshape
else if(dynamic_cast<ir::reshape_inst*>(ins)) { else if(dynamic_cast<ir::reshape_inst*>(ins)) {
@@ -691,12 +693,13 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
else else
return llvm_value(x, builder); return llvm_value(x, builder);
}; };
set_mask_insert_pt(idx);
result->set_value(idx, llvm_inst(ins, value, builder)); result->set_value(idx, llvm_inst(ins, value, builder));
}); });
} }
} }
if(mask)
builder.SetInsertPoint(block);
} }
void selection::lower_instruction(ir::instruction *src, IRBuilder<> &builder) { void selection::lower_instruction(ir::instruction *src, IRBuilder<> &builder) {
@@ -722,7 +725,6 @@ void selection::run(ir::module &src, Module &dst){
vmap_.clear(); vmap_.clear();
LLVMContext &dst_ctx = dst.getContext(); LLVMContext &dst_ctx = dst.getContext();
IRBuilder<> dst_builder(dst_ctx); IRBuilder<> dst_builder(dst_ctx);
std::map<ir::value*, llvm::BasicBlock*> block_of;
// iterate over functions // iterate over functions
for(ir::function *fn: src.get_function_list()) { for(ir::function *fn: src.get_function_list()) {
@@ -771,12 +773,13 @@ void selection::run(ir::module &src, Module &dst){
BasicBlock *parent = (BasicBlock*)vmap_[block]; BasicBlock *parent = (BasicBlock*)vmap_[block];
dst_builder.SetInsertPoint(parent); dst_builder.SetInsertPoint(parent);
for(ir::instruction *i: block->get_inst_list()){ for(ir::instruction *i: block->get_inst_list()){
if(dynamic_cast<ir::phi_node*>(i) && !parent->empty()){ BasicBlock *current = dst_builder.GetInsertBlock();
dst_builder.SetInsertPoint(&*parent->getFirstInsertionPt()); bool phi_inserted = (dynamic_cast<ir::phi_node*>(i) || dynamic_cast<ir::merge_inst*>(i)) && !current->empty();
} if(phi_inserted)
dst_builder.SetInsertPoint(&*current->getFirstInsertionPt());
lower_instruction(i, dst_builder); lower_instruction(i, dst_builder);
if(dynamic_cast<ir::phi_node*>(i) && !parent->empty()) if(phi_inserted)
dst_builder.SetInsertPoint(parent); dst_builder.SetInsertPoint(current);
last_block[block] = dst_builder.GetInsertBlock(); last_block[block] = dst_builder.GetInsertBlock();
} }
} }