diff --git a/include/triton/codegen/analysis/layout.h b/include/triton/codegen/analysis/layout.h index d11372a2e..14c760bf1 100644 --- a/include/triton/codegen/analysis/layout.h +++ b/include/triton/codegen/analysis/layout.h @@ -203,7 +203,8 @@ public: data_layout* get(size_t id) { return layouts_.at(id); } data_layout* get(ir::value *v) { return get(layout_of(v));} std::map &get_all() { return layouts_; } - size_t tmp(ir::instruction* i) { return tmp_.at((ir::value*)i);} + bool has_tmp(ir::value* i) { return tmp_.find(i) != tmp_.end(); } + int tmp(ir::value* i) { return tmp_.at(i);} // execution void run(ir::module &mod); diff --git a/include/triton/codegen/transform/membar.h b/include/triton/codegen/transform/membar.h index b5a11a46b..d35bd10ba 100644 --- a/include/triton/codegen/transform/membar.h +++ b/include/triton/codegen/transform/membar.h @@ -26,6 +26,7 @@ class allocation; class liveness; class layouts; class cts; +class shared_layout; } @@ -40,6 +41,7 @@ private: private: bool intersect(const val_set_t &X, const val_set_t &Y); int group_of(triton::ir::value *i, std::vector &async_write); + bool intersect_with(analysis::shared_layout* a_layout, analysis::shared_layout* b_layout); val_set_t intersect_with(const val_set_t& as, const val_set_t& bs); void transfer(ir::basic_block *block, val_vec_t &async_write, val_set_t &sync_write, val_set_t &sync_read, std::set &safe_war, bool &inserted, ir::builder &builder); diff --git a/include/triton/ir/utils.h b/include/triton/ir/utils.h index 3b9e2f5f3..893edd122 100644 --- a/include/triton/ir/utils.h +++ b/include/triton/ir/utils.h @@ -17,6 +17,7 @@ class value; class cfg { public: + static std::vector post_order(function* fn); static std::vector reverse_post_order(function* fn); }; diff --git a/lib/codegen/pass.cc b/lib/codegen/pass.cc index a39d927dc..d0d2f34ec 100644 --- a/lib/codegen/pass.cc +++ b/lib/codegen/pass.cc @@ -92,7 +92,9 @@ void add_passes_to_emit_bin(ir::module &ir, driver::device *dev, int num_warps, liveness.run(ir); allocation.run(ir); prefetch_s.run(ir); +// ir::print(ir, std::cout); barriers.run(ir); +// ir::print(ir, std::cout); // ir::print(ir, std::cout); isel.visit(ir, *llvm); mod = driver::module::create(dev, std::move(llvm)); diff --git a/lib/codegen/transform/membar.cc b/lib/codegen/transform/membar.cc index 95bb044b8..de2552c32 100644 --- a/lib/codegen/transform/membar.cc +++ b/lib/codegen/transform/membar.cc @@ -28,11 +28,24 @@ int membar::group_of(ir::value* v, std::vector &async_write) { return *std::max_element(groups.begin(), groups.end()); } else{ + if(layouts_->has_tmp(v)) + return async_write.size() - 1; auto it = std::find(async_write.begin(), async_write.end(), v); return std::distance(async_write.begin(), it); } } +inline bool membar::intersect_with(analysis::shared_layout* a_layout, analysis::shared_layout* b_layout) { + if(!a_layout || !b_layout) + return false; + int a_start = alloc_->offset(a_layout); + int a_end = a_start + a_layout->get_size(); + int b_start = alloc_->offset(b_layout); + int b_end = b_start + b_layout->get_size(); + if(a_start < b_end || b_start < a_end) + return true; + return false; +} membar::val_set_t membar::intersect_with(const val_set_t& as, const val_set_t& bs) { val_set_t ret; @@ -40,19 +53,16 @@ membar::val_set_t membar::intersect_with(const val_set_t& as, const val_set_t& b if(!a->get_type()->is_block_ty()) continue; analysis::shared_layout* a_layout = layouts_->get(a)->to_shared(); - if(!a_layout) - continue; - int a_start = alloc_->offset(a_layout); - int a_end = a_start + a_layout->get_size(); + analysis::shared_layout* a_tmp = layouts_->has_tmp(a) ? layouts_->get(layouts_->tmp(a))->to_shared() : nullptr; for(ir::value* b: bs){ if(!b->get_type()->is_block_ty()) continue; analysis::shared_layout* b_layout = layouts_->get(b)->to_shared(); - if(!b_layout) - continue; - int b_start = alloc_->offset(b_layout); - int b_end = b_start + b_layout->get_size(); - if(a_start < b_end || b_start < a_end) + analysis::shared_layout* b_tmp = layouts_->has_tmp(b) ? layouts_->get(layouts_->tmp(b))->to_shared() : nullptr; + if(intersect_with(a_layout, b_layout) || + intersect_with(a_layout, b_tmp) || + intersect_with(a_tmp, b_layout) || + intersect_with(a_tmp, b_tmp)) ret.insert(b); } } @@ -81,6 +91,8 @@ void membar::transfer(ir::basic_block *block, std::set read; std::copy_if(i->op_begin(), i->op_end(), std::inserter(read, read.begin()), [&](ir::value* i){ return i->get_type()->is_block_ty() && layouts_->get(i)->to_shared();}); + if(layouts_->has_tmp(i)) + read.insert(i); // RAW (async) val_set_t tmp; std::copy(async_write.begin(), async_write.end(), std::inserter(tmp, tmp.begin())); @@ -101,7 +113,7 @@ void membar::transfer(ir::basic_block *block, layouts_->get(i)->to_shared()->get_double_buffer(); // WAR barrier is not required when data is double-buffered // TODO: how about other patterns, like WWAR? - if(!intersect_with(read, sync_write).empty() || + if(!intersect_with(read, sync_write).empty() || (!intersect_with({i}, sync_read).empty() && !is_i_double_buffered) || // force WAR barrier on A100 (!intersect_with({i}, sync_read).empty() && tgt_->as_nvidia()->sm() >= 80)){ @@ -175,4 +187,4 @@ void membar::run(ir::module &mod) { } } -} \ No newline at end of file +} diff --git a/lib/ir/utils.cc b/lib/ir/utils.cc index 7baf5df14..cbfb4baf9 100644 --- a/lib/ir/utils.cc +++ b/lib/ir/utils.cc @@ -8,25 +8,39 @@ namespace triton{ namespace ir{ -std::vector cfg::reverse_post_order(function* fn) { +std::vector cfg::post_order(function* fn) { std::stack stack; std::set visited; std::vector result; // initialize stack for(ir::basic_block* block: fn->blocks()) - if(block->get_predecessors().empty()) + if(block->get_predecessors().empty()){ stack.push(block); + visited.insert(block); + } // DFS while(!stack.empty()) { basic_block* current = stack.top(); - stack.pop(); - result.push_back(current); - visited.insert(current); + bool tail = true; for(basic_block* succ: current->get_successors()) - if(visited.find(succ) == visited.end()) + if(visited.find(succ) == visited.end()){ stack.push(succ); + visited.insert(succ); + tail = false; + break; + } + if(tail){ + stack.pop(); + result.push_back(current); + } } - return std::move(result); + return result; +} + +std::vector cfg::reverse_post_order(function* fn) { + auto result = post_order(fn); + std::reverse(result.begin(), result.end()); + return result; } void for_each_instruction(module &mod, const std::function &do_work) {