[CODEGEN] Bugfix in membar pass (#124)
Membar pass on top of master is buggy with asynchronous copy. For example, it doesn't wait for asynchronous copies to complete before recoalescing accumulator in GEMM, which leads to undefined behavior when the program doesn't enter the loop. This PR proposes
This commit is contained in:
committed by
Philippe Tillet
parent
b7b05a560e
commit
5a51f3e529
@@ -203,7 +203,8 @@ public:
|
||||
data_layout* get(size_t id) { return layouts_.at(id); }
|
||||
data_layout* get(ir::value *v) { return get(layout_of(v));}
|
||||
std::map<size_t, data_layout*> &get_all() { return layouts_; }
|
||||
size_t tmp(ir::instruction* i) { return tmp_.at((ir::value*)i);}
|
||||
bool has_tmp(ir::value* i) { return tmp_.find(i) != tmp_.end(); }
|
||||
int tmp(ir::value* i) { return tmp_.at(i);}
|
||||
|
||||
// execution
|
||||
void run(ir::module &mod);
|
||||
|
@@ -26,6 +26,7 @@ class allocation;
|
||||
class liveness;
|
||||
class layouts;
|
||||
class cts;
|
||||
class shared_layout;
|
||||
|
||||
}
|
||||
|
||||
@@ -40,6 +41,7 @@ private:
|
||||
private:
|
||||
bool intersect(const val_set_t &X, const val_set_t &Y);
|
||||
int group_of(triton::ir::value *i, std::vector<triton::ir::value *> &async_write);
|
||||
bool intersect_with(analysis::shared_layout* a_layout, analysis::shared_layout* b_layout);
|
||||
val_set_t intersect_with(const val_set_t& as, const val_set_t& bs);
|
||||
void transfer(ir::basic_block *block, val_vec_t &async_write, val_set_t &sync_write, val_set_t &sync_read,
|
||||
std::set<triton::ir::value *> &safe_war, bool &inserted, ir::builder &builder);
|
||||
|
@@ -17,6 +17,7 @@ class value;
|
||||
|
||||
class cfg {
|
||||
public:
|
||||
static std::vector<basic_block *> post_order(function* fn);
|
||||
static std::vector<basic_block *> reverse_post_order(function* fn);
|
||||
};
|
||||
|
||||
|
@@ -92,7 +92,9 @@ void add_passes_to_emit_bin(ir::module &ir, driver::device *dev, int num_warps,
|
||||
liveness.run(ir);
|
||||
allocation.run(ir);
|
||||
prefetch_s.run(ir);
|
||||
// ir::print(ir, std::cout);
|
||||
barriers.run(ir);
|
||||
// ir::print(ir, std::cout);
|
||||
// ir::print(ir, std::cout);
|
||||
isel.visit(ir, *llvm);
|
||||
mod = driver::module::create(dev, std::move(llvm));
|
||||
|
@@ -28,11 +28,24 @@ int membar::group_of(ir::value* v, std::vector<ir::value*> &async_write) {
|
||||
return *std::max_element(groups.begin(), groups.end());
|
||||
}
|
||||
else{
|
||||
if(layouts_->has_tmp(v))
|
||||
return async_write.size() - 1;
|
||||
auto it = std::find(async_write.begin(), async_write.end(), v);
|
||||
return std::distance(async_write.begin(), it);
|
||||
}
|
||||
}
|
||||
|
||||
inline bool membar::intersect_with(analysis::shared_layout* a_layout, analysis::shared_layout* b_layout) {
|
||||
if(!a_layout || !b_layout)
|
||||
return false;
|
||||
int a_start = alloc_->offset(a_layout);
|
||||
int a_end = a_start + a_layout->get_size();
|
||||
int b_start = alloc_->offset(b_layout);
|
||||
int b_end = b_start + b_layout->get_size();
|
||||
if(a_start < b_end || b_start < a_end)
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
membar::val_set_t membar::intersect_with(const val_set_t& as, const val_set_t& bs) {
|
||||
val_set_t ret;
|
||||
@@ -40,19 +53,16 @@ membar::val_set_t membar::intersect_with(const val_set_t& as, const val_set_t& b
|
||||
if(!a->get_type()->is_block_ty())
|
||||
continue;
|
||||
analysis::shared_layout* a_layout = layouts_->get(a)->to_shared();
|
||||
if(!a_layout)
|
||||
continue;
|
||||
int a_start = alloc_->offset(a_layout);
|
||||
int a_end = a_start + a_layout->get_size();
|
||||
analysis::shared_layout* a_tmp = layouts_->has_tmp(a) ? layouts_->get(layouts_->tmp(a))->to_shared() : nullptr;
|
||||
for(ir::value* b: bs){
|
||||
if(!b->get_type()->is_block_ty())
|
||||
continue;
|
||||
analysis::shared_layout* b_layout = layouts_->get(b)->to_shared();
|
||||
if(!b_layout)
|
||||
continue;
|
||||
int b_start = alloc_->offset(b_layout);
|
||||
int b_end = b_start + b_layout->get_size();
|
||||
if(a_start < b_end || b_start < a_end)
|
||||
analysis::shared_layout* b_tmp = layouts_->has_tmp(b) ? layouts_->get(layouts_->tmp(b))->to_shared() : nullptr;
|
||||
if(intersect_with(a_layout, b_layout) ||
|
||||
intersect_with(a_layout, b_tmp) ||
|
||||
intersect_with(a_tmp, b_layout) ||
|
||||
intersect_with(a_tmp, b_tmp))
|
||||
ret.insert(b);
|
||||
}
|
||||
}
|
||||
@@ -81,6 +91,8 @@ void membar::transfer(ir::basic_block *block,
|
||||
std::set<ir::value*> read;
|
||||
std::copy_if(i->op_begin(), i->op_end(), std::inserter(read, read.begin()),
|
||||
[&](ir::value* i){ return i->get_type()->is_block_ty() && layouts_->get(i)->to_shared();});
|
||||
if(layouts_->has_tmp(i))
|
||||
read.insert(i);
|
||||
// RAW (async)
|
||||
val_set_t tmp;
|
||||
std::copy(async_write.begin(), async_write.end(), std::inserter(tmp, tmp.begin()));
|
||||
|
@@ -8,25 +8,39 @@
|
||||
namespace triton{
|
||||
namespace ir{
|
||||
|
||||
std::vector<basic_block*> cfg::reverse_post_order(function* fn) {
|
||||
std::vector<basic_block*> cfg::post_order(function* fn) {
|
||||
std::stack<basic_block*> stack;
|
||||
std::set<basic_block*> visited;
|
||||
std::vector<basic_block*> result;
|
||||
// initialize stack
|
||||
for(ir::basic_block* block: fn->blocks())
|
||||
if(block->get_predecessors().empty())
|
||||
if(block->get_predecessors().empty()){
|
||||
stack.push(block);
|
||||
visited.insert(block);
|
||||
}
|
||||
// DFS
|
||||
while(!stack.empty()) {
|
||||
basic_block* current = stack.top();
|
||||
bool tail = true;
|
||||
for(basic_block* succ: current->get_successors())
|
||||
if(visited.find(succ) == visited.end()){
|
||||
stack.push(succ);
|
||||
visited.insert(succ);
|
||||
tail = false;
|
||||
break;
|
||||
}
|
||||
if(tail){
|
||||
stack.pop();
|
||||
result.push_back(current);
|
||||
visited.insert(current);
|
||||
for(basic_block* succ: current->get_successors())
|
||||
if(visited.find(succ) == visited.end())
|
||||
stack.push(succ);
|
||||
}
|
||||
return std::move(result);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
std::vector<basic_block*> cfg::reverse_post_order(function* fn) {
|
||||
auto result = post_order(fn);
|
||||
std::reverse(result.begin(), result.end());
|
||||
return result;
|
||||
}
|
||||
|
||||
void for_each_instruction(module &mod, const std::function<void (instruction *)> &do_work) {
|
||||
|
Reference in New Issue
Block a user