[code generation] better predication
This commit is contained in:
@@ -209,7 +209,6 @@ int main() {
|
||||
llvm::Module llvm_module("matmul", llvm_context);
|
||||
|
||||
|
||||
triton::ir::print(module, std::cout);
|
||||
|
||||
// create passes
|
||||
triton::codegen::buffer_info_pass buffer_info;
|
||||
@@ -264,12 +263,14 @@ int main() {
|
||||
|
||||
|
||||
// run passes
|
||||
triton::ir::print(module, std::cout);
|
||||
buffer_info.run(module);
|
||||
shared.run(module);
|
||||
liveness.run(module);
|
||||
allocation.run();
|
||||
barriers.run(module);
|
||||
vectorize.run(module);
|
||||
triton::ir::print(module, std::cout);
|
||||
selection.run(module, llvm_module);
|
||||
|
||||
// llvm source
|
||||
|
@@ -299,6 +299,8 @@ public:
|
||||
: lvalue_((named_expression*)lvalue), op_(op), rvalue_((expression*)rvalue) { }
|
||||
|
||||
ir::value* codegen(ir::module *mod) const;
|
||||
const expression *lvalue() const { return lvalue_; }
|
||||
const expression *rvalue() const { return rvalue_; }
|
||||
|
||||
public:
|
||||
ASSIGN_OP_T op_;
|
||||
|
@@ -100,6 +100,7 @@ private:
|
||||
class selection{
|
||||
typedef std::map<ir::value *, llvm::Value *> vmap_t;
|
||||
typedef std::map<ir::value *, tile *> tmap_t;
|
||||
typedef std::map<std::pair<tile*, indices_t>, llvm::BasicBlock*> pmap_t;
|
||||
|
||||
private:
|
||||
// utils
|
||||
@@ -131,6 +132,8 @@ public:
|
||||
private:
|
||||
vmap_t vmap_;
|
||||
tmap_t tmap_;
|
||||
pmap_t pmap_;
|
||||
pmap_t last_block_;
|
||||
allocation *alloc_;
|
||||
tune *params_;
|
||||
buffer_info_pass *buffer_info_;
|
||||
|
@@ -346,6 +346,10 @@ public:
|
||||
static merge_inst* create(ir::value *mask_true, ir::value *value_true,
|
||||
ir::value *mask_false, ir::value *value_false,
|
||||
const std::string &name = "", instruction *next = nullptr);
|
||||
ir::value *get_mask_true() { return get_operand(0); }
|
||||
ir::value *get_value_true() { return get_operand(1); }
|
||||
ir::value *get_mask_false() { return get_operand(2); }
|
||||
ir::value *get_value_false() { return get_operand(3); }
|
||||
|
||||
};
|
||||
|
||||
|
@@ -302,21 +302,25 @@ ir::value* compound_statement::codegen(ir::module* mod) const{
|
||||
/* expression statement */
|
||||
ir::value* expression_statement::codegen(ir::module *mod) const{
|
||||
ir::builder &builder = mod->get_builder();
|
||||
ir::value *expr = expr_->codegen(mod);
|
||||
if(mask_) {
|
||||
ir::value *pred = mask_->codegen(mod);
|
||||
ir::mask_inst *mask = (ir::mask_inst*)builder.create_mask(pred);
|
||||
ir::value *true_value = expr_->codegen(mod);
|
||||
assignment_expression *assignment = dynamic_cast<assignment_expression*>(expr_);
|
||||
assert(assignment);
|
||||
|
||||
ir::type *ty = true_value->get_type();
|
||||
if(auto *itn = dynamic_cast<ir::instruction*>(true_value))
|
||||
itn->set_mask_pred(mask->get_result(0));
|
||||
if(expr->get_type()->is_void_ty())
|
||||
return expr;
|
||||
if(ty->is_void_ty())
|
||||
return true_value;
|
||||
ir::merge_inst *merge = (ir::merge_inst*)builder.create_merge(mask->get_result(0), true_value,
|
||||
mask->get_result(1), ir::undef_value::get(ty));
|
||||
std::string name = ((named_expression*)assignment->lvalue())->id()->name();
|
||||
mod->set_value(name, merge);
|
||||
return merge;
|
||||
}
|
||||
return expr;
|
||||
return expr_->codegen(mod);
|
||||
}
|
||||
|
||||
/* Iteration statement */
|
||||
|
@@ -517,41 +517,21 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
||||
BasicBlock *block = builder.GetInsertBlock();
|
||||
Module *module = block->getModule();
|
||||
LLVMContext &ctx = builder.getContext();
|
||||
// // helper to handle masks
|
||||
// auto insert_masked = [&](indices_t idx, std::function<Value*()> insert_value) {
|
||||
// BasicBlock *block = builder.GetInsertBlock();
|
||||
// Value *result;
|
||||
// if(mask_pred){
|
||||
// Value *llvm_mask = tmap_.at(mask_pred)->get_value(idx);
|
||||
// BasicBlock *then_bb = BasicBlock::Create(ctx, "", function);
|
||||
// BasicBlock *done_bb = BasicBlock::Create(ctx, "", function);
|
||||
// builder.CreateCondBr(llvm_mask, then_bb, done_bb);
|
||||
// builder.SetInsertPoint(then_bb);
|
||||
// result = insert_value();
|
||||
// builder.CreateBr(done_bb);
|
||||
// builder.SetInsertPoint(done_bb);
|
||||
// if(!ins->get_type()->is_void_ty()){
|
||||
// Type *ty = result->getType();
|
||||
// PHINode *phi = builder.CreatePHI(ty, 2);
|
||||
// if(mask_else)
|
||||
// phi->addIncoming(tmap_.at(mask_else)->get_value(idx), block);
|
||||
// else
|
||||
// phi->addIncoming(llvm::UndefValue::get(ty), block);
|
||||
// phi->addIncoming(result, then_bb);
|
||||
// return (Value*)phi;
|
||||
// }
|
||||
// }
|
||||
// else
|
||||
// result = insert_value();
|
||||
// return result;
|
||||
// };
|
||||
|
||||
std::cout << ins->get_name() << " " << typeid(*ins).name() << std::endl;
|
||||
Function *fn = block->getParent();
|
||||
ir::value *mask = ins->get_mask_pred();
|
||||
auto set_mask_insert_pt = [&](indices_t idx){
|
||||
if(mask){
|
||||
distributed_tile *mask_tile = (distributed_tile*)tmap_.at(ins->get_mask_pred());
|
||||
BasicBlock *block = pmap_.at({mask_tile, idx});
|
||||
builder.SetInsertPoint(block->getTerminator());
|
||||
}
|
||||
};
|
||||
// store
|
||||
if(auto *x = dynamic_cast<ir::store_inst*>(ins)) {
|
||||
distributed_tile* ptr = (distributed_tile*)tmap_.at(x->get_pointer_operand());
|
||||
tile *value = tmap_.at(x->get_value_operand());
|
||||
ptr->for_each([&](indices_t idx){
|
||||
set_mask_insert_pt(idx);
|
||||
builder.CreateStore(value->get_value(idx), ptr->get_value(idx));
|
||||
});
|
||||
}
|
||||
@@ -578,24 +558,46 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
||||
}
|
||||
// mask
|
||||
else if(dynamic_cast<ir::mask_inst*>(ins)) {
|
||||
// distributed_tile* pred = (distributed_tile*)ins->get_operand(0);
|
||||
// BasicBlock *mask_done_bb = BasicBlock::Create(ctx, "mask_done");
|
||||
// pred->for_each([&](indices_t idx){
|
||||
// BasicBlock *mask_if_bb = BasicBlock::Create(ctx, "mask_if");
|
||||
// BasicBlock* mask_else_bb = BasicBlock::Create(ctx, "mask_else");
|
||||
// builder.CreateCondBr(pred->get_value(idx), mask_if_bb, mask_else_bb);
|
||||
// builder.SetInsertPoint(mask_if_bb);
|
||||
// builder.CreateBr(mask_done_bb);
|
||||
// builder.SetInsertPoint(mask_else_bb);
|
||||
// builder.CreateBr(mask_done_bb);
|
||||
// });
|
||||
// builder.SetInsertPoint(mask_done_bb);
|
||||
distributed_tile* pred = (distributed_tile*)tmap_.at(ins->get_operand(0));
|
||||
distributed_tile* mask_tile_true = (distributed_tile*)tmap_.at(ins->get_result(0));
|
||||
distributed_tile* mask_tile_false = (distributed_tile*)tmap_.at(ins->get_result(1));
|
||||
pred->for_each([&](indices_t idx){
|
||||
BasicBlock *mask_then_bb = BasicBlock::Create(ctx, "mask_then", fn);
|
||||
BasicBlock* mask_else_bb = BasicBlock::Create(ctx, "mask_else", fn);
|
||||
BasicBlock *mask_done_bb = BasicBlock::Create(ctx, "mask_done", fn);
|
||||
builder.CreateCondBr(pred->get_value(idx), mask_then_bb, mask_else_bb);
|
||||
builder.SetInsertPoint(mask_then_bb);
|
||||
builder.CreateBr(mask_done_bb);
|
||||
builder.SetInsertPoint(mask_else_bb);
|
||||
builder.CreateBr(mask_done_bb);
|
||||
builder.SetInsertPoint(mask_done_bb);
|
||||
pmap_.insert({{mask_tile_true, idx}, mask_then_bb});
|
||||
pmap_.insert({{mask_tile_false, idx}, mask_else_bb});
|
||||
last_block_.insert({{mask_tile_true, idx}, mask_done_bb});
|
||||
last_block_.insert({{mask_tile_false, idx}, mask_done_bb});
|
||||
});
|
||||
}
|
||||
// merge
|
||||
else if(dynamic_cast<ir::merge_inst*>(ins)) {
|
||||
// result->for_each([&](indices_t idx){
|
||||
// std::cout << "merge" << std::endl;
|
||||
// });
|
||||
else if(auto *merge = dynamic_cast<ir::merge_inst*>(ins)) {
|
||||
distributed_tile* mask_tile_true = (distributed_tile*)tmap_.at(merge->get_mask_true());
|
||||
distributed_tile *value_tile_true = (distributed_tile*)tmap_.at(merge->get_value_true());
|
||||
distributed_tile* mask_tile_false = (distributed_tile*)tmap_.at(merge->get_mask_false());
|
||||
distributed_tile *value_tile_false = (distributed_tile*)tmap_.at(merge->get_value_false());
|
||||
result->for_each([&](indices_t idx){
|
||||
BasicBlock *block_true = pmap_.at({mask_tile_true, idx});
|
||||
Value *value_true = value_tile_true->get_value(idx);
|
||||
BasicBlock *block_false = pmap_.at({mask_tile_false, idx});
|
||||
Value *value_false = value_tile_false->get_value(idx);
|
||||
BasicBlock *block_done = last_block_.at({mask_tile_true, idx});
|
||||
if(block_done->empty())
|
||||
builder.SetInsertPoint(block_done);
|
||||
else
|
||||
builder.SetInsertPoint(block_done->getTerminator());
|
||||
PHINode *phi = builder.CreatePHI(value_true->getType(), 2);
|
||||
phi->addIncoming(value_true, block_true);
|
||||
phi->addIncoming(value_false,block_false);
|
||||
result->set_value(idx, phi);
|
||||
});
|
||||
}
|
||||
// reshape
|
||||
else if(dynamic_cast<ir::reshape_inst*>(ins)) {
|
||||
@@ -691,12 +693,13 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
||||
else
|
||||
return llvm_value(x, builder);
|
||||
};
|
||||
set_mask_insert_pt(idx);
|
||||
result->set_value(idx, llvm_inst(ins, value, builder));
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
if(mask)
|
||||
builder.SetInsertPoint(block);
|
||||
}
|
||||
|
||||
void selection::lower_instruction(ir::instruction *src, IRBuilder<> &builder) {
|
||||
@@ -722,7 +725,6 @@ void selection::run(ir::module &src, Module &dst){
|
||||
vmap_.clear();
|
||||
LLVMContext &dst_ctx = dst.getContext();
|
||||
IRBuilder<> dst_builder(dst_ctx);
|
||||
std::map<ir::value*, llvm::BasicBlock*> block_of;
|
||||
|
||||
// iterate over functions
|
||||
for(ir::function *fn: src.get_function_list()) {
|
||||
@@ -771,12 +773,13 @@ void selection::run(ir::module &src, Module &dst){
|
||||
BasicBlock *parent = (BasicBlock*)vmap_[block];
|
||||
dst_builder.SetInsertPoint(parent);
|
||||
for(ir::instruction *i: block->get_inst_list()){
|
||||
if(dynamic_cast<ir::phi_node*>(i) && !parent->empty()){
|
||||
dst_builder.SetInsertPoint(&*parent->getFirstInsertionPt());
|
||||
}
|
||||
BasicBlock *current = dst_builder.GetInsertBlock();
|
||||
bool phi_inserted = (dynamic_cast<ir::phi_node*>(i) || dynamic_cast<ir::merge_inst*>(i)) && !current->empty();
|
||||
if(phi_inserted)
|
||||
dst_builder.SetInsertPoint(&*current->getFirstInsertionPt());
|
||||
lower_instruction(i, dst_builder);
|
||||
if(dynamic_cast<ir::phi_node*>(i) && !parent->empty())
|
||||
dst_builder.SetInsertPoint(parent);
|
||||
if(phi_inserted)
|
||||
dst_builder.SetInsertPoint(current);
|
||||
last_block[block] = dst_builder.GetInsertBlock();
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user