From 08fcfbca47b80b733ac721d2ce6050b7f1f76a69 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Fri, 1 Mar 2019 14:36:17 -0500 Subject: [PATCH] [code generation] better predication --- examples/matrix.cpp | 3 +- include/triton/ast/ast.h | 2 + include/triton/codegen/selection.h | 3 + include/triton/ir/instructions.h | 4 ++ lib/ast/lowering.cpp | 12 ++-- lib/codegen/selection.cpp | 111 +++++++++++++++-------------- 6 files changed, 76 insertions(+), 59 deletions(-) diff --git a/examples/matrix.cpp b/examples/matrix.cpp index e023dc2ee..752f60bfe 100644 --- a/examples/matrix.cpp +++ b/examples/matrix.cpp @@ -209,7 +209,6 @@ int main() { llvm::Module llvm_module("matmul", llvm_context); - triton::ir::print(module, std::cout); // create passes triton::codegen::buffer_info_pass buffer_info; @@ -264,12 +263,14 @@ int main() { // run passes + triton::ir::print(module, std::cout); buffer_info.run(module); shared.run(module); liveness.run(module); allocation.run(); barriers.run(module); vectorize.run(module); + triton::ir::print(module, std::cout); selection.run(module, llvm_module); // llvm source diff --git a/include/triton/ast/ast.h b/include/triton/ast/ast.h index ab96bfa36..b9ae16ea0 100644 --- a/include/triton/ast/ast.h +++ b/include/triton/ast/ast.h @@ -299,6 +299,8 @@ public: : lvalue_((named_expression*)lvalue), op_(op), rvalue_((expression*)rvalue) { } ir::value* codegen(ir::module *mod) const; + const expression *lvalue() const { return lvalue_; } + const expression *rvalue() const { return rvalue_; } public: ASSIGN_OP_T op_; diff --git a/include/triton/codegen/selection.h b/include/triton/codegen/selection.h index b0ebc8c5e..c8632262e 100644 --- a/include/triton/codegen/selection.h +++ b/include/triton/codegen/selection.h @@ -100,6 +100,7 @@ private: class selection{ typedef std::map vmap_t; typedef std::map tmap_t; + typedef std::map, llvm::BasicBlock*> pmap_t; private: // utils @@ -131,6 +132,8 @@ public: private: vmap_t vmap_; tmap_t tmap_; + pmap_t pmap_; + pmap_t last_block_; allocation *alloc_; tune *params_; buffer_info_pass *buffer_info_; diff --git a/include/triton/ir/instructions.h b/include/triton/ir/instructions.h index 2328a4a8f..2d8e7d91d 100644 --- a/include/triton/ir/instructions.h +++ b/include/triton/ir/instructions.h @@ -346,6 +346,10 @@ public: static merge_inst* create(ir::value *mask_true, ir::value *value_true, ir::value *mask_false, ir::value *value_false, const std::string &name = "", instruction *next = nullptr); + ir::value *get_mask_true() { return get_operand(0); } + ir::value *get_value_true() { return get_operand(1); } + ir::value *get_mask_false() { return get_operand(2); } + ir::value *get_value_false() { return get_operand(3); } }; diff --git a/lib/ast/lowering.cpp b/lib/ast/lowering.cpp index 34cc36e2f..c9ad27a01 100644 --- a/lib/ast/lowering.cpp +++ b/lib/ast/lowering.cpp @@ -302,21 +302,25 @@ ir::value* compound_statement::codegen(ir::module* mod) const{ /* expression statement */ ir::value* expression_statement::codegen(ir::module *mod) const{ ir::builder &builder = mod->get_builder(); - ir::value *expr = expr_->codegen(mod); if(mask_) { ir::value *pred = mask_->codegen(mod); ir::mask_inst *mask = (ir::mask_inst*)builder.create_mask(pred); ir::value *true_value = expr_->codegen(mod); + assignment_expression *assignment = dynamic_cast(expr_); + assert(assignment); + ir::type *ty = true_value->get_type(); if(auto *itn = dynamic_cast(true_value)) itn->set_mask_pred(mask->get_result(0)); - if(expr->get_type()->is_void_ty()) - return expr; + if(ty->is_void_ty()) + return true_value; ir::merge_inst *merge = (ir::merge_inst*)builder.create_merge(mask->get_result(0), true_value, mask->get_result(1), ir::undef_value::get(ty)); + std::string name = ((named_expression*)assignment->lvalue())->id()->name(); + mod->set_value(name, merge); return merge; } - return expr; + return expr_->codegen(mod); } /* Iteration statement */ diff --git a/lib/codegen/selection.cpp b/lib/codegen/selection.cpp index 1c8269257..52f8011c9 100644 --- a/lib/codegen/selection.cpp +++ b/lib/codegen/selection.cpp @@ -517,41 +517,21 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> & BasicBlock *block = builder.GetInsertBlock(); Module *module = block->getModule(); LLVMContext &ctx = builder.getContext(); -// // helper to handle masks -// auto insert_masked = [&](indices_t idx, std::function insert_value) { -// BasicBlock *block = builder.GetInsertBlock(); -// Value *result; -// if(mask_pred){ -// Value *llvm_mask = tmap_.at(mask_pred)->get_value(idx); -// BasicBlock *then_bb = BasicBlock::Create(ctx, "", function); -// BasicBlock *done_bb = BasicBlock::Create(ctx, "", function); -// builder.CreateCondBr(llvm_mask, then_bb, done_bb); -// builder.SetInsertPoint(then_bb); -// result = insert_value(); -// builder.CreateBr(done_bb); -// builder.SetInsertPoint(done_bb); -// if(!ins->get_type()->is_void_ty()){ -// Type *ty = result->getType(); -// PHINode *phi = builder.CreatePHI(ty, 2); -// if(mask_else) -// phi->addIncoming(tmap_.at(mask_else)->get_value(idx), block); -// else -// phi->addIncoming(llvm::UndefValue::get(ty), block); -// phi->addIncoming(result, then_bb); -// return (Value*)phi; -// } -// } -// else -// result = insert_value(); -// return result; -// }; - - std::cout << ins->get_name() << " " << typeid(*ins).name() << std::endl; + Function *fn = block->getParent(); + ir::value *mask = ins->get_mask_pred(); + auto set_mask_insert_pt = [&](indices_t idx){ + if(mask){ + distributed_tile *mask_tile = (distributed_tile*)tmap_.at(ins->get_mask_pred()); + BasicBlock *block = pmap_.at({mask_tile, idx}); + builder.SetInsertPoint(block->getTerminator()); + } + }; // store if(auto *x = dynamic_cast(ins)) { distributed_tile* ptr = (distributed_tile*)tmap_.at(x->get_pointer_operand()); tile *value = tmap_.at(x->get_value_operand()); ptr->for_each([&](indices_t idx){ + set_mask_insert_pt(idx); builder.CreateStore(value->get_value(idx), ptr->get_value(idx)); }); } @@ -578,24 +558,46 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> & } // mask else if(dynamic_cast(ins)) { -// distributed_tile* pred = (distributed_tile*)ins->get_operand(0); -// BasicBlock *mask_done_bb = BasicBlock::Create(ctx, "mask_done"); -// pred->for_each([&](indices_t idx){ -// BasicBlock *mask_if_bb = BasicBlock::Create(ctx, "mask_if"); -// BasicBlock* mask_else_bb = BasicBlock::Create(ctx, "mask_else"); -// builder.CreateCondBr(pred->get_value(idx), mask_if_bb, mask_else_bb); -// builder.SetInsertPoint(mask_if_bb); -// builder.CreateBr(mask_done_bb); -// builder.SetInsertPoint(mask_else_bb); -// builder.CreateBr(mask_done_bb); -// }); -// builder.SetInsertPoint(mask_done_bb); + distributed_tile* pred = (distributed_tile*)tmap_.at(ins->get_operand(0)); + distributed_tile* mask_tile_true = (distributed_tile*)tmap_.at(ins->get_result(0)); + distributed_tile* mask_tile_false = (distributed_tile*)tmap_.at(ins->get_result(1)); + pred->for_each([&](indices_t idx){ + BasicBlock *mask_then_bb = BasicBlock::Create(ctx, "mask_then", fn); + BasicBlock* mask_else_bb = BasicBlock::Create(ctx, "mask_else", fn); + BasicBlock *mask_done_bb = BasicBlock::Create(ctx, "mask_done", fn); + builder.CreateCondBr(pred->get_value(idx), mask_then_bb, mask_else_bb); + builder.SetInsertPoint(mask_then_bb); + builder.CreateBr(mask_done_bb); + builder.SetInsertPoint(mask_else_bb); + builder.CreateBr(mask_done_bb); + builder.SetInsertPoint(mask_done_bb); + pmap_.insert({{mask_tile_true, idx}, mask_then_bb}); + pmap_.insert({{mask_tile_false, idx}, mask_else_bb}); + last_block_.insert({{mask_tile_true, idx}, mask_done_bb}); + last_block_.insert({{mask_tile_false, idx}, mask_done_bb}); + }); } // merge - else if(dynamic_cast(ins)) { -// result->for_each([&](indices_t idx){ -// std::cout << "merge" << std::endl; -// }); + else if(auto *merge = dynamic_cast(ins)) { + distributed_tile* mask_tile_true = (distributed_tile*)tmap_.at(merge->get_mask_true()); + distributed_tile *value_tile_true = (distributed_tile*)tmap_.at(merge->get_value_true()); + distributed_tile* mask_tile_false = (distributed_tile*)tmap_.at(merge->get_mask_false()); + distributed_tile *value_tile_false = (distributed_tile*)tmap_.at(merge->get_value_false()); + result->for_each([&](indices_t idx){ + BasicBlock *block_true = pmap_.at({mask_tile_true, idx}); + Value *value_true = value_tile_true->get_value(idx); + BasicBlock *block_false = pmap_.at({mask_tile_false, idx}); + Value *value_false = value_tile_false->get_value(idx); + BasicBlock *block_done = last_block_.at({mask_tile_true, idx}); + if(block_done->empty()) + builder.SetInsertPoint(block_done); + else + builder.SetInsertPoint(block_done->getTerminator()); + PHINode *phi = builder.CreatePHI(value_true->getType(), 2); + phi->addIncoming(value_true, block_true); + phi->addIncoming(value_false,block_false); + result->set_value(idx, phi); + }); } // reshape else if(dynamic_cast(ins)) { @@ -691,12 +693,13 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> & else return llvm_value(x, builder); }; + set_mask_insert_pt(idx); result->set_value(idx, llvm_inst(ins, value, builder)); }); } } - - + if(mask) + builder.SetInsertPoint(block); } void selection::lower_instruction(ir::instruction *src, IRBuilder<> &builder) { @@ -722,7 +725,6 @@ void selection::run(ir::module &src, Module &dst){ vmap_.clear(); LLVMContext &dst_ctx = dst.getContext(); IRBuilder<> dst_builder(dst_ctx); - std::map block_of; // iterate over functions for(ir::function *fn: src.get_function_list()) { @@ -771,12 +773,13 @@ void selection::run(ir::module &src, Module &dst){ BasicBlock *parent = (BasicBlock*)vmap_[block]; dst_builder.SetInsertPoint(parent); for(ir::instruction *i: block->get_inst_list()){ - if(dynamic_cast(i) && !parent->empty()){ - dst_builder.SetInsertPoint(&*parent->getFirstInsertionPt()); - } + BasicBlock *current = dst_builder.GetInsertBlock(); + bool phi_inserted = (dynamic_cast(i) || dynamic_cast(i)) && !current->empty(); + if(phi_inserted) + dst_builder.SetInsertPoint(&*current->getFirstInsertionPt()); lower_instruction(i, dst_builder); - if(dynamic_cast(i) && !parent->empty()) - dst_builder.SetInsertPoint(parent); + if(phi_inserted) + dst_builder.SetInsertPoint(current); last_block[block] = dst_builder.GetInsertBlock(); } }