[CODEGEN] Performance improvement on A100 (#125)
Improved codegen for the Ampere GPUs. * Make the layout pass recognize the multistage pipelined pattern. * Now the pipeline pass can automate the multistage pipelining transformation. * Remove extra barriers (from the prefetch pass & WAR) on Ampere. * Update the code generator (generator.cc) to make Triton generate n-buffered shared memory loads/stores.
This commit is contained in:
committed by
Philippe Tillet
parent
5a51f3e529
commit
d8d6b715c8
@@ -4,6 +4,7 @@
|
||||
#include "triton/codegen/analysis/layout.h"
|
||||
#include "triton/codegen/analysis/allocation.h"
|
||||
#include "triton/codegen/transform/membar.h"
|
||||
#include "triton/codegen/transform/prefetch.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/basic_block.h"
|
||||
@@ -20,9 +21,14 @@ namespace transform{
|
||||
int membar::group_of(ir::value* v, std::vector<ir::value*> &async_write) {
|
||||
if(ir::phi_node* phi = dynamic_cast<ir::phi_node*>(v)){
|
||||
analysis::shared_layout* layout = layouts_->get(v)->to_shared();
|
||||
analysis::double_buffer_info_t* info = layout->get_double_buffer();
|
||||
if(info)
|
||||
if (analysis::double_buffer_info_t* info = layout->get_double_buffer())
|
||||
return group_of(info->first, async_write);
|
||||
else if (analysis::N_buffer_info_t* info = layout->get_N_buffer()) {
|
||||
if (v == info->phi)
|
||||
return group_of(info->firsts[0], async_write);
|
||||
else // prefetched value
|
||||
return group_of(info->firsts[1], async_write);
|
||||
}
|
||||
std::vector<int> groups(phi->get_num_operands());
|
||||
std::transform(phi->op_begin(), phi->op_end(), groups.begin(), [&](ir::value* v){ return group_of(v, async_write);});
|
||||
return *std::max_element(groups.begin(), groups.end());
|
||||
@@ -69,12 +75,31 @@ membar::val_set_t membar::intersect_with(const val_set_t& as, const val_set_t& b
|
||||
return ret;
|
||||
}
|
||||
|
||||
bool membar::check_safe_war(ir::instruction* i) {
|
||||
bool is_i_shared_block = i->get_type()->is_block_ty() &&
|
||||
layouts_->get(i)->to_shared();
|
||||
bool is_i_double_buffered = is_i_shared_block &&
|
||||
layouts_->get(i)->to_shared()->get_double_buffer();
|
||||
bool is_i_n_buffered = is_i_shared_block &&
|
||||
layouts_->get(i)->to_shared()->get_N_buffer();
|
||||
|
||||
if (is_i_double_buffered || is_i_n_buffered) {
|
||||
// with async copy & prefetch_s disabled, WARs are not safe
|
||||
if (dynamic_cast<ir::masked_load_async_inst*>(i) && !prefetch_->is_prefetched(i))
|
||||
return false;
|
||||
else
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void membar::transfer(ir::basic_block *block,
|
||||
val_vec_t& async_write,
|
||||
val_set_t& sync_write,
|
||||
val_set_t& sync_read,
|
||||
std::set<ir::value*>& safe_war,
|
||||
bool& inserted, ir::builder& builder) {
|
||||
std::vector<ir::async_wait_inst*> async_waits;
|
||||
ir::basic_block::inst_list_t instructions = block->get_inst_list();
|
||||
for(ir::instruction *i: instructions){
|
||||
if(dynamic_cast<ir::phi_node*>(i))
|
||||
@@ -105,18 +130,14 @@ void membar::transfer(ir::basic_block *block,
|
||||
async_wait = (ir::async_wait_inst*)builder.create_async_wait(async_write.size() - 1 - N);
|
||||
barrier = (ir::barrier_inst*)builder.create_barrier();
|
||||
inserted = true;
|
||||
async_waits.push_back(async_wait);
|
||||
}
|
||||
}
|
||||
// RAW, WAR
|
||||
bool is_i_double_buffered = i->get_type()->is_block_ty() &&
|
||||
layouts_->get(i)->to_shared() &&
|
||||
layouts_->get(i)->to_shared()->get_double_buffer();
|
||||
bool is_safe_war = check_safe_war(i);
|
||||
// WAR barrier is not required when data is double-buffered
|
||||
// TODO: how about other patterns, like WWAR?
|
||||
if(!intersect_with(read, sync_write).empty() ||
|
||||
(!intersect_with({i}, sync_read).empty() && !is_i_double_buffered) ||
|
||||
// force WAR barrier on A100
|
||||
(!intersect_with({i}, sync_read).empty() && tgt_->as_nvidia()->sm() >= 80)){
|
||||
if(!intersect_with(read, sync_write).empty() ||
|
||||
(!intersect_with({i}, sync_read).empty() && !is_safe_war)) {
|
||||
builder.set_insert_point(i);
|
||||
barrier = (ir::barrier_inst*)builder.create_barrier();
|
||||
inserted = true;
|
||||
@@ -132,7 +153,41 @@ void membar::transfer(ir::basic_block *block,
|
||||
sync_read.clear();
|
||||
}
|
||||
sync_read.insert(read.begin(), read.end());
|
||||
}
|
||||
|
||||
// coalesce barriers
|
||||
// fixme: to support more general cases
|
||||
if (async_waits.size() == 2) {
|
||||
// (aw N; bar; prefetch; aw N-1; bar; prefetch; => aw N-1; bar; 2*prefetch;)
|
||||
for (int idx=0; idx<async_waits.size()-1; ++idx) {
|
||||
ir::async_wait_inst *first_async_wait = async_waits[idx];
|
||||
std::vector<ir::instruction*> to_erase;
|
||||
ir::basic_block::inst_list_t instructions = block->get_inst_list();
|
||||
for(auto iter = instructions.begin(); iter != instructions.end(); ++iter){
|
||||
ir::instruction *i = *iter;
|
||||
if (static_cast<ir::instruction*>(first_async_wait) == i) {
|
||||
// peak next 5 instructions
|
||||
auto peak_iter = std::next(iter);
|
||||
if (std::distance(peak_iter, instructions.end()) >= 5) {
|
||||
auto first_bar = dynamic_cast<ir::barrier_inst*>(*peak_iter++);
|
||||
auto first_pf = dynamic_cast<ir::prefetch_s_inst*>(*peak_iter++);
|
||||
auto second_async_wait = dynamic_cast<ir::async_wait_inst*>(*peak_iter++);
|
||||
auto second_bar = dynamic_cast<ir::barrier_inst*>(*peak_iter++);
|
||||
auto second_pf = dynamic_cast<ir::prefetch_s_inst*>(*peak_iter);
|
||||
if (first_bar && first_pf && second_async_wait && second_bar && second_pf) {
|
||||
int first_n = first_async_wait->get_N();
|
||||
int second_n = second_async_wait->get_N();
|
||||
to_erase.push_back(second_async_wait);
|
||||
to_erase.push_back(second_bar);
|
||||
first_async_wait->set_N(second_n);
|
||||
}
|
||||
} else
|
||||
break;
|
||||
for (ir::instruction *i : to_erase)
|
||||
block->erase(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -144,7 +199,7 @@ void membar::run(ir::module &mod) {
|
||||
std::set<ir::value*> safe_war;
|
||||
for(const auto& x: layouts_->get_all()){
|
||||
analysis::shared_layout* layout = x.second->to_shared();
|
||||
if(!layout || !layout->get_double_buffer())
|
||||
if(!layout || !layout->get_double_buffer() || !layout->get_N_buffer())
|
||||
continue;
|
||||
for(ir::value *v: layout->get_values())
|
||||
if(v != layout->get_double_buffer()->phi){
|
||||
@@ -153,7 +208,6 @@ void membar::run(ir::module &mod) {
|
||||
}
|
||||
|
||||
for(ir::function *fn: mod.get_function_list()){
|
||||
// TODO: (dyan) we need DominatorTree here.
|
||||
std::vector<ir::basic_block*> rpo = ir::cfg::reverse_post_order(fn);
|
||||
std::map<ir::basic_block*, val_vec_t> async_writes;
|
||||
std::map<ir::basic_block*, val_set_t> sync_writes;
|
||||
|
Reference in New Issue
Block a user