From 5a51f3e5294f6604fdccf4cddb5d2c676e71582a Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Mon, 14 Jun 2021 12:22:48 -0400 Subject: [PATCH] [CODEGEN] Bugfix in membar pass (#124) Membar pass on top of master is buggy with asynchronous copy. For example, it doesn't wait for asynchronous copies to complete before recoalescing accumulator in GEMM, which leads to undefined behavior when the program doesn't enter the loop. This PR proposes --- include/triton/codegen/analysis/layout.h | 3 +- include/triton/codegen/transform/membar.h | 2 ++ include/triton/ir/utils.h | 1 + lib/codegen/pass.cc | 2 ++ lib/codegen/transform/membar.cc | 34 +++++++++++++++-------- lib/ir/utils.cc | 28 ++++++++++++++----- 6 files changed, 51 insertions(+), 19 deletions(-) diff --git a/include/triton/codegen/analysis/layout.h b/include/triton/codegen/analysis/layout.h index d11372a2e..14c760bf1 100644 --- a/include/triton/codegen/analysis/layout.h +++ b/include/triton/codegen/analysis/layout.h @@ -203,7 +203,8 @@ public: data_layout* get(size_t id) { return layouts_.at(id); } data_layout* get(ir::value *v) { return get(layout_of(v));} std::map &get_all() { return layouts_; } - size_t tmp(ir::instruction* i) { return tmp_.at((ir::value*)i);} + bool has_tmp(ir::value* i) { return tmp_.find(i) != tmp_.end(); } + int tmp(ir::value* i) { return tmp_.at(i);} // execution void run(ir::module &mod); diff --git a/include/triton/codegen/transform/membar.h b/include/triton/codegen/transform/membar.h index b5a11a46b..d35bd10ba 100644 --- a/include/triton/codegen/transform/membar.h +++ b/include/triton/codegen/transform/membar.h @@ -26,6 +26,7 @@ class allocation; class liveness; class layouts; class cts; +class shared_layout; } @@ -40,6 +41,7 @@ private: private: bool intersect(const val_set_t &X, const val_set_t &Y); int group_of(triton::ir::value *i, std::vector &async_write); + bool intersect_with(analysis::shared_layout* a_layout, analysis::shared_layout* b_layout); val_set_t intersect_with(const val_set_t& as, const val_set_t& bs); void transfer(ir::basic_block *block, val_vec_t &async_write, val_set_t &sync_write, val_set_t &sync_read, std::set &safe_war, bool &inserted, ir::builder &builder); diff --git a/include/triton/ir/utils.h b/include/triton/ir/utils.h index 3b9e2f5f3..893edd122 100644 --- a/include/triton/ir/utils.h +++ b/include/triton/ir/utils.h @@ -17,6 +17,7 @@ class value; class cfg { public: + static std::vector post_order(function* fn); static std::vector reverse_post_order(function* fn); }; diff --git a/lib/codegen/pass.cc b/lib/codegen/pass.cc index a39d927dc..d0d2f34ec 100644 --- a/lib/codegen/pass.cc +++ b/lib/codegen/pass.cc @@ -92,7 +92,9 @@ void add_passes_to_emit_bin(ir::module &ir, driver::device *dev, int num_warps, liveness.run(ir); allocation.run(ir); prefetch_s.run(ir); +// ir::print(ir, std::cout); barriers.run(ir); +// ir::print(ir, std::cout); // ir::print(ir, std::cout); isel.visit(ir, *llvm); mod = driver::module::create(dev, std::move(llvm)); diff --git a/lib/codegen/transform/membar.cc b/lib/codegen/transform/membar.cc index 95bb044b8..de2552c32 100644 --- a/lib/codegen/transform/membar.cc +++ b/lib/codegen/transform/membar.cc @@ -28,11 +28,24 @@ int membar::group_of(ir::value* v, std::vector &async_write) { return *std::max_element(groups.begin(), groups.end()); } else{ + if(layouts_->has_tmp(v)) + return async_write.size() - 1; auto it = std::find(async_write.begin(), async_write.end(), v); return std::distance(async_write.begin(), it); } } +inline bool membar::intersect_with(analysis::shared_layout* a_layout, analysis::shared_layout* b_layout) { + if(!a_layout || !b_layout) + return false; + int a_start = alloc_->offset(a_layout); + int a_end = a_start + a_layout->get_size(); + int b_start = alloc_->offset(b_layout); + int b_end = b_start + b_layout->get_size(); + if(a_start < b_end || b_start < a_end) + return true; + return false; +} membar::val_set_t membar::intersect_with(const val_set_t& as, const val_set_t& bs) { val_set_t ret; @@ -40,19 +53,16 @@ membar::val_set_t membar::intersect_with(const val_set_t& as, const val_set_t& b if(!a->get_type()->is_block_ty()) continue; analysis::shared_layout* a_layout = layouts_->get(a)->to_shared(); - if(!a_layout) - continue; - int a_start = alloc_->offset(a_layout); - int a_end = a_start + a_layout->get_size(); + analysis::shared_layout* a_tmp = layouts_->has_tmp(a) ? layouts_->get(layouts_->tmp(a))->to_shared() : nullptr; for(ir::value* b: bs){ if(!b->get_type()->is_block_ty()) continue; analysis::shared_layout* b_layout = layouts_->get(b)->to_shared(); - if(!b_layout) - continue; - int b_start = alloc_->offset(b_layout); - int b_end = b_start + b_layout->get_size(); - if(a_start < b_end || b_start < a_end) + analysis::shared_layout* b_tmp = layouts_->has_tmp(b) ? layouts_->get(layouts_->tmp(b))->to_shared() : nullptr; + if(intersect_with(a_layout, b_layout) || + intersect_with(a_layout, b_tmp) || + intersect_with(a_tmp, b_layout) || + intersect_with(a_tmp, b_tmp)) ret.insert(b); } } @@ -81,6 +91,8 @@ void membar::transfer(ir::basic_block *block, std::set read; std::copy_if(i->op_begin(), i->op_end(), std::inserter(read, read.begin()), [&](ir::value* i){ return i->get_type()->is_block_ty() && layouts_->get(i)->to_shared();}); + if(layouts_->has_tmp(i)) + read.insert(i); // RAW (async) val_set_t tmp; std::copy(async_write.begin(), async_write.end(), std::inserter(tmp, tmp.begin())); @@ -101,7 +113,7 @@ void membar::transfer(ir::basic_block *block, layouts_->get(i)->to_shared()->get_double_buffer(); // WAR barrier is not required when data is double-buffered // TODO: how about other patterns, like WWAR? - if(!intersect_with(read, sync_write).empty() || + if(!intersect_with(read, sync_write).empty() || (!intersect_with({i}, sync_read).empty() && !is_i_double_buffered) || // force WAR barrier on A100 (!intersect_with({i}, sync_read).empty() && tgt_->as_nvidia()->sm() >= 80)){ @@ -175,4 +187,4 @@ void membar::run(ir::module &mod) { } } -} \ No newline at end of file +} diff --git a/lib/ir/utils.cc b/lib/ir/utils.cc index 7baf5df14..cbfb4baf9 100644 --- a/lib/ir/utils.cc +++ b/lib/ir/utils.cc @@ -8,25 +8,39 @@ namespace triton{ namespace ir{ -std::vector cfg::reverse_post_order(function* fn) { +std::vector cfg::post_order(function* fn) { std::stack stack; std::set visited; std::vector result; // initialize stack for(ir::basic_block* block: fn->blocks()) - if(block->get_predecessors().empty()) + if(block->get_predecessors().empty()){ stack.push(block); + visited.insert(block); + } // DFS while(!stack.empty()) { basic_block* current = stack.top(); - stack.pop(); - result.push_back(current); - visited.insert(current); + bool tail = true; for(basic_block* succ: current->get_successors()) - if(visited.find(succ) == visited.end()) + if(visited.find(succ) == visited.end()){ stack.push(succ); + visited.insert(succ); + tail = false; + break; + } + if(tail){ + stack.pop(); + result.push_back(current); + } } - return std::move(result); + return result; +} + +std::vector cfg::reverse_post_order(function* fn) { + auto result = post_order(fn); + std::reverse(result.begin(), result.end()); + return result; } void for_each_instruction(module &mod, const std::function &do_work) {