[code generation] better predication

This commit is contained in:
Philippe Tillet
2019-03-01 14:36:17 -05:00
parent 36acf22fd3
commit 08fcfbca47
6 changed files with 76 additions and 59 deletions

View File

@@ -209,7 +209,6 @@ int main() {
llvm::Module llvm_module("matmul", llvm_context);
triton::ir::print(module, std::cout);
// create passes
triton::codegen::buffer_info_pass buffer_info;
@@ -264,12 +263,14 @@ int main() {
// run passes
triton::ir::print(module, std::cout);
buffer_info.run(module);
shared.run(module);
liveness.run(module);
allocation.run();
barriers.run(module);
vectorize.run(module);
triton::ir::print(module, std::cout);
selection.run(module, llvm_module);
// llvm source

View File

@@ -299,6 +299,8 @@ public:
: lvalue_((named_expression*)lvalue), op_(op), rvalue_((expression*)rvalue) { }
ir::value* codegen(ir::module *mod) const;
const expression *lvalue() const { return lvalue_; }
const expression *rvalue() const { return rvalue_; }
public:
ASSIGN_OP_T op_;

View File

@@ -100,6 +100,7 @@ private:
class selection{
typedef std::map<ir::value *, llvm::Value *> vmap_t;
typedef std::map<ir::value *, tile *> tmap_t;
typedef std::map<std::pair<tile*, indices_t>, llvm::BasicBlock*> pmap_t;
private:
// utils
@@ -131,6 +132,8 @@ public:
private:
vmap_t vmap_;
tmap_t tmap_;
pmap_t pmap_;
pmap_t last_block_;
allocation *alloc_;
tune *params_;
buffer_info_pass *buffer_info_;

View File

@@ -346,6 +346,10 @@ public:
static merge_inst* create(ir::value *mask_true, ir::value *value_true,
ir::value *mask_false, ir::value *value_false,
const std::string &name = "", instruction *next = nullptr);
ir::value *get_mask_true() { return get_operand(0); }
ir::value *get_value_true() { return get_operand(1); }
ir::value *get_mask_false() { return get_operand(2); }
ir::value *get_value_false() { return get_operand(3); }
};

View File

@@ -302,21 +302,25 @@ ir::value* compound_statement::codegen(ir::module* mod) const{
/* expression statement */
ir::value* expression_statement::codegen(ir::module *mod) const{
ir::builder &builder = mod->get_builder();
ir::value *expr = expr_->codegen(mod);
if(mask_) {
ir::value *pred = mask_->codegen(mod);
ir::mask_inst *mask = (ir::mask_inst*)builder.create_mask(pred);
ir::value *true_value = expr_->codegen(mod);
assignment_expression *assignment = dynamic_cast<assignment_expression*>(expr_);
assert(assignment);
ir::type *ty = true_value->get_type();
if(auto *itn = dynamic_cast<ir::instruction*>(true_value))
itn->set_mask_pred(mask->get_result(0));
if(expr->get_type()->is_void_ty())
return expr;
if(ty->is_void_ty())
return true_value;
ir::merge_inst *merge = (ir::merge_inst*)builder.create_merge(mask->get_result(0), true_value,
mask->get_result(1), ir::undef_value::get(ty));
std::string name = ((named_expression*)assignment->lvalue())->id()->name();
mod->set_value(name, merge);
return merge;
}
return expr;
return expr_->codegen(mod);
}
/* Iteration statement */

View File

@@ -517,41 +517,21 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
BasicBlock *block = builder.GetInsertBlock();
Module *module = block->getModule();
LLVMContext &ctx = builder.getContext();
// // helper to handle masks
// auto insert_masked = [&](indices_t idx, std::function<Value*()> insert_value) {
// BasicBlock *block = builder.GetInsertBlock();
// Value *result;
// if(mask_pred){
// Value *llvm_mask = tmap_.at(mask_pred)->get_value(idx);
// BasicBlock *then_bb = BasicBlock::Create(ctx, "", function);
// BasicBlock *done_bb = BasicBlock::Create(ctx, "", function);
// builder.CreateCondBr(llvm_mask, then_bb, done_bb);
// builder.SetInsertPoint(then_bb);
// result = insert_value();
// builder.CreateBr(done_bb);
// builder.SetInsertPoint(done_bb);
// if(!ins->get_type()->is_void_ty()){
// Type *ty = result->getType();
// PHINode *phi = builder.CreatePHI(ty, 2);
// if(mask_else)
// phi->addIncoming(tmap_.at(mask_else)->get_value(idx), block);
// else
// phi->addIncoming(llvm::UndefValue::get(ty), block);
// phi->addIncoming(result, then_bb);
// return (Value*)phi;
// }
// }
// else
// result = insert_value();
// return result;
// };
std::cout << ins->get_name() << " " << typeid(*ins).name() << std::endl;
Function *fn = block->getParent();
ir::value *mask = ins->get_mask_pred();
auto set_mask_insert_pt = [&](indices_t idx){
if(mask){
distributed_tile *mask_tile = (distributed_tile*)tmap_.at(ins->get_mask_pred());
BasicBlock *block = pmap_.at({mask_tile, idx});
builder.SetInsertPoint(block->getTerminator());
}
};
// store
if(auto *x = dynamic_cast<ir::store_inst*>(ins)) {
distributed_tile* ptr = (distributed_tile*)tmap_.at(x->get_pointer_operand());
tile *value = tmap_.at(x->get_value_operand());
ptr->for_each([&](indices_t idx){
set_mask_insert_pt(idx);
builder.CreateStore(value->get_value(idx), ptr->get_value(idx));
});
}
@@ -578,24 +558,46 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
}
// mask
else if(dynamic_cast<ir::mask_inst*>(ins)) {
// distributed_tile* pred = (distributed_tile*)ins->get_operand(0);
// BasicBlock *mask_done_bb = BasicBlock::Create(ctx, "mask_done");
// pred->for_each([&](indices_t idx){
// BasicBlock *mask_if_bb = BasicBlock::Create(ctx, "mask_if");
// BasicBlock* mask_else_bb = BasicBlock::Create(ctx, "mask_else");
// builder.CreateCondBr(pred->get_value(idx), mask_if_bb, mask_else_bb);
// builder.SetInsertPoint(mask_if_bb);
// builder.CreateBr(mask_done_bb);
// builder.SetInsertPoint(mask_else_bb);
// builder.CreateBr(mask_done_bb);
// });
// builder.SetInsertPoint(mask_done_bb);
distributed_tile* pred = (distributed_tile*)tmap_.at(ins->get_operand(0));
distributed_tile* mask_tile_true = (distributed_tile*)tmap_.at(ins->get_result(0));
distributed_tile* mask_tile_false = (distributed_tile*)tmap_.at(ins->get_result(1));
pred->for_each([&](indices_t idx){
BasicBlock *mask_then_bb = BasicBlock::Create(ctx, "mask_then", fn);
BasicBlock* mask_else_bb = BasicBlock::Create(ctx, "mask_else", fn);
BasicBlock *mask_done_bb = BasicBlock::Create(ctx, "mask_done", fn);
builder.CreateCondBr(pred->get_value(idx), mask_then_bb, mask_else_bb);
builder.SetInsertPoint(mask_then_bb);
builder.CreateBr(mask_done_bb);
builder.SetInsertPoint(mask_else_bb);
builder.CreateBr(mask_done_bb);
builder.SetInsertPoint(mask_done_bb);
pmap_.insert({{mask_tile_true, idx}, mask_then_bb});
pmap_.insert({{mask_tile_false, idx}, mask_else_bb});
last_block_.insert({{mask_tile_true, idx}, mask_done_bb});
last_block_.insert({{mask_tile_false, idx}, mask_done_bb});
});
}
// merge
else if(dynamic_cast<ir::merge_inst*>(ins)) {
// result->for_each([&](indices_t idx){
// std::cout << "merge" << std::endl;
// });
else if(auto *merge = dynamic_cast<ir::merge_inst*>(ins)) {
distributed_tile* mask_tile_true = (distributed_tile*)tmap_.at(merge->get_mask_true());
distributed_tile *value_tile_true = (distributed_tile*)tmap_.at(merge->get_value_true());
distributed_tile* mask_tile_false = (distributed_tile*)tmap_.at(merge->get_mask_false());
distributed_tile *value_tile_false = (distributed_tile*)tmap_.at(merge->get_value_false());
result->for_each([&](indices_t idx){
BasicBlock *block_true = pmap_.at({mask_tile_true, idx});
Value *value_true = value_tile_true->get_value(idx);
BasicBlock *block_false = pmap_.at({mask_tile_false, idx});
Value *value_false = value_tile_false->get_value(idx);
BasicBlock *block_done = last_block_.at({mask_tile_true, idx});
if(block_done->empty())
builder.SetInsertPoint(block_done);
else
builder.SetInsertPoint(block_done->getTerminator());
PHINode *phi = builder.CreatePHI(value_true->getType(), 2);
phi->addIncoming(value_true, block_true);
phi->addIncoming(value_false,block_false);
result->set_value(idx, phi);
});
}
// reshape
else if(dynamic_cast<ir::reshape_inst*>(ins)) {
@@ -691,12 +693,13 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
else
return llvm_value(x, builder);
};
set_mask_insert_pt(idx);
result->set_value(idx, llvm_inst(ins, value, builder));
});
}
}
if(mask)
builder.SetInsertPoint(block);
}
void selection::lower_instruction(ir::instruction *src, IRBuilder<> &builder) {
@@ -722,7 +725,6 @@ void selection::run(ir::module &src, Module &dst){
vmap_.clear();
LLVMContext &dst_ctx = dst.getContext();
IRBuilder<> dst_builder(dst_ctx);
std::map<ir::value*, llvm::BasicBlock*> block_of;
// iterate over functions
for(ir::function *fn: src.get_function_list()) {
@@ -771,12 +773,13 @@ void selection::run(ir::module &src, Module &dst){
BasicBlock *parent = (BasicBlock*)vmap_[block];
dst_builder.SetInsertPoint(parent);
for(ir::instruction *i: block->get_inst_list()){
if(dynamic_cast<ir::phi_node*>(i) && !parent->empty()){
dst_builder.SetInsertPoint(&*parent->getFirstInsertionPt());
}
BasicBlock *current = dst_builder.GetInsertBlock();
bool phi_inserted = (dynamic_cast<ir::phi_node*>(i) || dynamic_cast<ir::merge_inst*>(i)) && !current->empty();
if(phi_inserted)
dst_builder.SetInsertPoint(&*current->getFirstInsertionPt());
lower_instruction(i, dst_builder);
if(dynamic_cast<ir::phi_node*>(i) && !parent->empty())
dst_builder.SetInsertPoint(parent);
if(phi_inserted)
dst_builder.SetInsertPoint(current);
last_block[block] = dst_builder.GetInsertBlock();
}
}