[CODEGEN] Performance improvement on A100 (#125)

Improved codegen for the Ampere GPUs.

    * Make the layout pass recognize the multistage pipelined pattern.
    * Now the pipeline pass can automate the multistage pipelining transformation.
    * Remove extra barriers (from the prefetch pass & WAR) on Ampere.
    * Update the code generator (generator.cc) to make Triton generate n-buffered shared memory loads/stores.
This commit is contained in:
daadaada
2021-06-21 14:25:13 +08:00
committed by Philippe Tillet
parent 5a51f3e529
commit d8d6b715c8
21 changed files with 855 additions and 174 deletions

View File

@@ -4,6 +4,7 @@
#include "triton/codegen/analysis/layout.h"
#include "triton/codegen/analysis/allocation.h"
#include "triton/codegen/transform/membar.h"
#include "triton/codegen/transform/prefetch.h"
#include "triton/ir/module.h"
#include "triton/ir/function.h"
#include "triton/ir/basic_block.h"
@@ -20,9 +21,14 @@ namespace transform{
int membar::group_of(ir::value* v, std::vector<ir::value*> &async_write) {
if(ir::phi_node* phi = dynamic_cast<ir::phi_node*>(v)){
analysis::shared_layout* layout = layouts_->get(v)->to_shared();
analysis::double_buffer_info_t* info = layout->get_double_buffer();
if(info)
if (analysis::double_buffer_info_t* info = layout->get_double_buffer())
return group_of(info->first, async_write);
else if (analysis::N_buffer_info_t* info = layout->get_N_buffer()) {
if (v == info->phi)
return group_of(info->firsts[0], async_write);
else // prefetched value
return group_of(info->firsts[1], async_write);
}
std::vector<int> groups(phi->get_num_operands());
std::transform(phi->op_begin(), phi->op_end(), groups.begin(), [&](ir::value* v){ return group_of(v, async_write);});
return *std::max_element(groups.begin(), groups.end());
@@ -69,12 +75,31 @@ membar::val_set_t membar::intersect_with(const val_set_t& as, const val_set_t& b
return ret;
}
bool membar::check_safe_war(ir::instruction* i) {
bool is_i_shared_block = i->get_type()->is_block_ty() &&
layouts_->get(i)->to_shared();
bool is_i_double_buffered = is_i_shared_block &&
layouts_->get(i)->to_shared()->get_double_buffer();
bool is_i_n_buffered = is_i_shared_block &&
layouts_->get(i)->to_shared()->get_N_buffer();
if (is_i_double_buffered || is_i_n_buffered) {
// with async copy & prefetch_s disabled, WARs are not safe
if (dynamic_cast<ir::masked_load_async_inst*>(i) && !prefetch_->is_prefetched(i))
return false;
else
return true;
}
return false;
}
void membar::transfer(ir::basic_block *block,
val_vec_t& async_write,
val_set_t& sync_write,
val_set_t& sync_read,
std::set<ir::value*>& safe_war,
bool& inserted, ir::builder& builder) {
std::vector<ir::async_wait_inst*> async_waits;
ir::basic_block::inst_list_t instructions = block->get_inst_list();
for(ir::instruction *i: instructions){
if(dynamic_cast<ir::phi_node*>(i))
@@ -105,18 +130,14 @@ void membar::transfer(ir::basic_block *block,
async_wait = (ir::async_wait_inst*)builder.create_async_wait(async_write.size() - 1 - N);
barrier = (ir::barrier_inst*)builder.create_barrier();
inserted = true;
async_waits.push_back(async_wait);
}
}
// RAW, WAR
bool is_i_double_buffered = i->get_type()->is_block_ty() &&
layouts_->get(i)->to_shared() &&
layouts_->get(i)->to_shared()->get_double_buffer();
bool is_safe_war = check_safe_war(i);
// WAR barrier is not required when data is double-buffered
// TODO: how about other patterns, like WWAR?
if(!intersect_with(read, sync_write).empty() ||
(!intersect_with({i}, sync_read).empty() && !is_i_double_buffered) ||
// force WAR barrier on A100
(!intersect_with({i}, sync_read).empty() && tgt_->as_nvidia()->sm() >= 80)){
if(!intersect_with(read, sync_write).empty() ||
(!intersect_with({i}, sync_read).empty() && !is_safe_war)) {
builder.set_insert_point(i);
barrier = (ir::barrier_inst*)builder.create_barrier();
inserted = true;
@@ -132,7 +153,41 @@ void membar::transfer(ir::basic_block *block,
sync_read.clear();
}
sync_read.insert(read.begin(), read.end());
}
// coalesce barriers
// fixme: to support more general cases
if (async_waits.size() == 2) {
// (aw N; bar; prefetch; aw N-1; bar; prefetch; => aw N-1; bar; 2*prefetch;)
for (int idx=0; idx<async_waits.size()-1; ++idx) {
ir::async_wait_inst *first_async_wait = async_waits[idx];
std::vector<ir::instruction*> to_erase;
ir::basic_block::inst_list_t instructions = block->get_inst_list();
for(auto iter = instructions.begin(); iter != instructions.end(); ++iter){
ir::instruction *i = *iter;
if (static_cast<ir::instruction*>(first_async_wait) == i) {
// peak next 5 instructions
auto peak_iter = std::next(iter);
if (std::distance(peak_iter, instructions.end()) >= 5) {
auto first_bar = dynamic_cast<ir::barrier_inst*>(*peak_iter++);
auto first_pf = dynamic_cast<ir::prefetch_s_inst*>(*peak_iter++);
auto second_async_wait = dynamic_cast<ir::async_wait_inst*>(*peak_iter++);
auto second_bar = dynamic_cast<ir::barrier_inst*>(*peak_iter++);
auto second_pf = dynamic_cast<ir::prefetch_s_inst*>(*peak_iter);
if (first_bar && first_pf && second_async_wait && second_bar && second_pf) {
int first_n = first_async_wait->get_N();
int second_n = second_async_wait->get_N();
to_erase.push_back(second_async_wait);
to_erase.push_back(second_bar);
first_async_wait->set_N(second_n);
}
} else
break;
for (ir::instruction *i : to_erase)
block->erase(i);
}
}
}
}
}
@@ -144,7 +199,7 @@ void membar::run(ir::module &mod) {
std::set<ir::value*> safe_war;
for(const auto& x: layouts_->get_all()){
analysis::shared_layout* layout = x.second->to_shared();
if(!layout || !layout->get_double_buffer())
if(!layout || !layout->get_double_buffer() || !layout->get_N_buffer())
continue;
for(ir::value *v: layout->get_values())
if(v != layout->get_double_buffer()->phi){
@@ -153,7 +208,6 @@ void membar::run(ir::module &mod) {
}
for(ir::function *fn: mod.get_function_list()){
// TODO: (dyan) we need DominatorTree here.
std::vector<ir::basic_block*> rpo = ir::cfg::reverse_post_order(fn);
std::map<ir::basic_block*, val_vec_t> async_writes;
std::map<ir::basic_block*, val_set_t> sync_writes;

View File

@@ -23,6 +23,60 @@ void recursive_deps(ir::value* v, ir::basic_block* block, std::vector<ir::instru
recursive_deps(u, block, ret);
}
/// assume incoming block is 1
ir::value* rematerialize_vals(ir::builder& builder, ir::value* v,
std::map<ir::phi_node*, ir::value*>& prev_phi_vals) {
ir::instruction* i = dynamic_cast<ir::instruction*>(v);
if(!i)
return v;
if(ir::phi_node* phi = dynamic_cast<ir::phi_node*>(v)) {
if (prev_phi_vals.find(phi) == prev_phi_vals.end())
throw std::runtime_error("Don't have that phi node\n");
return prev_phi_vals.at(phi);
}
std::vector<ir::value*> new_ops;
for(ir::value* op: i->ops()){
new_ops.push_back(rematerialize_vals(builder, op, prev_phi_vals));
}
ir::instruction* ret = i->clone();
for(size_t k = 0; k < new_ops.size(); k++)
ret->set_operand(k, new_ops[k]);
builder.insert(ret);
return ret;
}
void get_induction_vars(ir::value* cond, std::set<ir::phi_node*>& phis) {
auto instr = dynamic_cast<ir::instruction*>(cond);
for (auto op : instr->ops()) {
if (auto phi_op = dynamic_cast<ir::phi_node*>(op)) {
phis.insert(phi_op);
return;
}
if (dynamic_cast<ir::instruction*>(op))
get_induction_vars(op, phis);
}
}
/// Returns phi_val if sees a phi node
ir::value* rematerialize_val(ir::builder& builder, ir::value* v, ir::value* phi_val) {
ir::instruction* i = dynamic_cast<ir::instruction*>(v);
if(!i)
return v;
if(ir::phi_node* phi = dynamic_cast<ir::phi_node*>(v))
return phi_val;
std::vector<ir::value*> new_ops;
for(ir::value* op: i->ops()){
new_ops.push_back(rematerialize_val(builder, op, phi_val));
}
ir::instruction* ret = i->clone();
for(size_t k = 0; k < new_ops.size(); k++)
ret->set_operand(k, new_ops[k]);
builder.insert(ret);
return ret;
}
ir::value* rematerialize(ir::builder& builder, ir::value* v, size_t phi_idx){
ir::instruction* i = dynamic_cast<ir::instruction*>(v);
if(!i)
@@ -41,6 +95,28 @@ ir::value* rematerialize(ir::builder& builder, ir::value* v, size_t phi_idx){
return ret;
}
/// moving the prev phi vals to the next iteration
void update_prev_phi_vals(ir::builder& builder, std::map<ir::phi_node*, ir::value*>& prev_phi_vals) {
for (auto& [phi, val] : prev_phi_vals) {
// TODO: handling nested phis
val = rematerialize_val(builder, phi->get_incoming_value(1), val);
}
}
void finalize_iv_vals(ir::builder& builder, std::map<ir::phi_node*, ir::value*>& load_ivs,
std::map<ir::phi_node*, ir::value*>& next_load_ivs) {
for (auto& [phi, val] : load_ivs) {
if (auto new_phi = dynamic_cast<ir::phi_node*>(val)) {
ir::value* next_k = rematerialize_vals(builder, phi->get_incoming_value(1), load_ivs);
assert(new_phi->get_num_operands() == 1 && "should be incomplete phi");
new_phi->add_incoming(next_k, phi->get_incoming_block(1));
// cache next_k (to be used by next_mask)
next_load_ivs[phi] = next_k;
} else
throw std::runtime_error("must be phi");
}
}
void pipeline::run(ir::module &mod) {
// *Very* conservative heuristics for pre-fetching.
// A load instruction can be pipelined if:
@@ -60,6 +136,8 @@ void pipeline::run(ir::module &mod) {
// do the pipelining
std::vector<ir::phi_node*> new_loads;
ir::builder &builder = mod.get_builder();
const int num_stages = num_stages_;
std::vector<std::pair<ir::phi_node*, std::vector<ir::value*>>> preheader_loads; // Used to reorder loads
for(auto info: to_pipeline){
ir::load_inst* load = info.first;
ir::phi_node* ptr = info.second;
@@ -70,40 +148,155 @@ void pipeline::run(ir::module &mod) {
assert(block_br);
assert(header_br);
ir::type* ty = load->get_type();
// pre-fetch first iteration
builder.set_insert_point(header->get_inst_list().back());
ir::value* first_ptr = ptr->get_value_for_block(header);
ir::value* first_mask = builder.create_splat(header_br->get_cond(), ty->get_block_shapes());
ir::value* false_value;
if(auto* masked_load = dynamic_cast<ir::masked_load_inst*>(load)){
ir::value* remat_mask = rematerialize(builder, masked_load->get_mask_operand(), 0);
ir::value* remat_false_value = rematerialize(builder, masked_load->get_false_value_operand(), 0);
first_mask = builder.create_and(first_mask, remat_mask);
false_value = remat_false_value;
// multi-stage pipe
if (has_copy_async_ && num_stages > 2) {
ir::value* header_cond = header_br->get_cond();
ir::value* block_cond = block_br->get_cond();
// 1. collect induction variables
std::set<ir::phi_node*> induction_vars;
get_induction_vars(block_cond, induction_vars);
std::vector<ir::value*> first_ptrs(num_stages-1);
std::vector<ir::value*> first_loads(num_stages-1);
std::vector<ir::value*> first_masks(num_stages-1);
std::vector<ir::value*> loop_conds(num_stages-1);
std::map<ir::phi_node*, ir::value*> prev_phi_vals;
// initialize prev_phi_vals
// note: we assume that ptr & other values only depend on ptr & iv (phis)
// TODO: can we just add all phis here?
prev_phi_vals[ptr] = ptr->get_value_for_block(header);
for (ir::phi_node* iv : induction_vars)
prev_phi_vals[iv] = iv->get_value_for_block(header);
prev_phi_vals[ptr] = ptr->get_value_for_block(header);
builder.set_insert_point(header->get_inst_list().back());
first_ptrs[0] = ptr->get_value_for_block(header);
loop_conds[0] = header_cond;
first_masks[0] = builder.create_splat(loop_conds[0], ty->get_block_shapes());
ir::value* false_value = nullptr;
if (auto* masked_load = dynamic_cast<ir::masked_load_inst*>(load)) {
ir::value* remat_mask =rematerialize_vals(builder, masked_load->get_mask_operand(), prev_phi_vals) ;
ir::value* remat_false_value =
rematerialize_vals(builder, masked_load->get_false_value_operand(), prev_phi_vals);
first_masks[0] = builder.create_and(first_masks[0], remat_mask);
false_value = remat_false_value;
} else
false_value = builder.create_splat(ir::undef_value::get(ty->get_scalar_ty()), ty->get_block_shapes());
first_loads[0] = builder.create_masked_load(first_ptrs[0], first_masks[0], false_value);
for (int stage = 1; stage < num_stages-1; ++stage) {
// mask is the loop condition of the previous iteration
loop_conds[stage] = rematerialize_vals(builder, block_cond, prev_phi_vals);
update_prev_phi_vals(builder, prev_phi_vals);
first_ptrs[stage] = rematerialize_vals(builder, ptr, prev_phi_vals);
first_masks[stage] = builder.create_splat(loop_conds[stage], ty->get_block_shapes());
if (auto* masked_load = dynamic_cast<ir::masked_load_inst*>(load)) {
ir::value* remat_mask = rematerialize_vals(builder, masked_load->get_mask_operand(), prev_phi_vals);
ir::value* remat_false_value =
rematerialize_vals(builder, masked_load->get_false_value_operand(), prev_phi_vals);
first_masks[stage] = builder.create_and(first_masks[stage], remat_mask);
false_value = remat_false_value;
}
first_loads[stage] = builder.create_masked_load(first_ptrs[stage], first_masks[stage], false_value);
}
// create new phis for induction variables
builder.set_insert_point(block->get_first_non_phi());
std::map<ir::phi_node*, ir::value*> load_ivs;
std::map<ir::phi_node*, ir::value*> next_load_ivs;
for (ir::phi_node* iv : induction_vars) {
ir::phi_node* pn = builder.create_phi(iv->get_type(), 2);
pn->add_incoming(prev_phi_vals[iv], header);
load_ivs[iv] = pn;
}
// add incoming for phis & update next_load_ivs
finalize_iv_vals(builder, load_ivs, next_load_ivs);
// pre-fetch next iteration
builder.set_insert_point(block->get_inst_list().back());
ir::value* next_ptr = ptr->get_value_for_block(block);
ir::value* next_mask = builder.create_splat(
rematerialize_vals(builder, block_cond, load_ivs), ty->get_block_shapes());
if (auto* masked_load = dynamic_cast<ir::masked_load_inst*>(load)) {
ir::value* remat_mask = rematerialize_vals(builder, masked_load->get_mask_operand(), next_load_ivs);
// TODO: false may depends on some other phi nodes
ir::value* remat_false_value =
rematerialize_vals(builder, masked_load->get_false_value_operand(), next_load_ivs);
next_mask = builder.create_and(next_mask, remat_mask);
false_value = remat_false_value;
}
ir::value* next_load = builder.create_masked_load(next_ptr, next_mask, false_value);
// phi node
ptr->set_incoming_value(0, first_ptrs.back());
builder.set_insert_point(block->get_first_non_phi());
// nested phis for load
std::vector<ir::phi_node*> new_load_phis(num_stages-1);
for (auto& pn : new_load_phis)
pn = builder.create_phi(ty, 2);
for (int i=0; i<num_stages-2; ++i) {
new_load_phis[i]->add_incoming(first_loads[i], header);
new_load_phis[i]->add_incoming(new_load_phis[i+1], block);
}
new_load_phis.back()->add_incoming(first_loads.back(), header);
new_load_phis.back()->add_incoming(next_load, block);
load->replace_all_uses_with(new_load_phis.front());
new_loads.push_back(new_load_phis.back());
// record first_loads to reorder them
preheader_loads.push_back({new_load_phis.front(), first_loads});
} else {
// pre-fetch first iteration
builder.set_insert_point(header->get_inst_list().back());
ir::value* first_ptr = ptr->get_value_for_block(header);
ir::value* first_mask = builder.create_splat(header_br->get_cond(), ty->get_block_shapes());
ir::value* false_value;
if(auto* masked_load = dynamic_cast<ir::masked_load_inst*>(load)){
ir::value* remat_mask = rematerialize(builder, masked_load->get_mask_operand(), 0);
ir::value* remat_false_value = rematerialize(builder, masked_load->get_false_value_operand(), 0);
first_mask = builder.create_and(first_mask, remat_mask);
false_value = remat_false_value;
}
else
false_value = builder.create_splat(ir::undef_value::get(ty->get_scalar_ty()), ty->get_block_shapes());
ir::value* first_load = builder.create_masked_load(first_ptr, first_mask, false_value);
// pre-fetch next iteration
builder.set_insert_point(block->get_inst_list().back());
ir::value* next_ptr = ptr->get_value_for_block(block);
ir::value* next_mask = builder.create_splat(block_br->get_cond(), ty->get_block_shapes());
if(auto* masked_load = dynamic_cast<ir::masked_load_inst*>(load)){
ir::value* remat_mask = rematerialize(builder, masked_load->get_mask_operand(), 1);
ir::value* remat_false_value = rematerialize(builder, masked_load->get_false_value_operand(), 1);
next_mask = builder.create_and(next_mask, remat_mask);
false_value = remat_false_value;
}
ir::value* next_load = builder.create_masked_load(next_ptr, next_mask, false_value);
// phi node
builder.set_insert_point(block->get_first_non_phi());
ir::phi_node* new_load = builder.create_phi(ty, 2);
new_load->add_incoming(first_load, header);
new_load->add_incoming(next_load, block);
load->replace_all_uses_with(new_load);
new_loads.push_back(new_load);
}
else
false_value = builder.create_splat(ir::undef_value::get(ty->get_scalar_ty()), ty->get_block_shapes());
ir::value* first_load = builder.create_masked_load(first_ptr, first_mask, false_value);
// pre-fetch next iteration
builder.set_insert_point(block->get_inst_list().back());
ir::value* next_ptr = ptr->get_value_for_block(block);
ir::value* next_mask = builder.create_splat(block_br->get_cond(), ty->get_block_shapes());
if(auto* masked_load = dynamic_cast<ir::masked_load_inst*>(load)){
ir::value* remat_mask = rematerialize(builder, masked_load->get_mask_operand(), 1);
ir::value* remat_false_value = rematerialize(builder, masked_load->get_false_value_operand(), 1);
next_mask = builder.create_and(next_mask, remat_mask);
false_value = remat_false_value;
}
ir::value* next_load = builder.create_masked_load(next_ptr, next_mask, false_value);
// phi node
builder.set_insert_point(block->get_first_non_phi());
ir::phi_node* new_load = builder.create_phi(ty, 2);
new_load->add_incoming(first_load, header);
new_load->add_incoming(next_load, block);
load->replace_all_uses_with(new_load);
new_loads.push_back(new_load);
}
// try to reorder prefetched value from a0, a1, a2, ..., b0, b1, b2, ... to
// a0, b0, a1, b1, ...
if (!preheader_loads.empty()) {
ir::basic_block* header = preheader_loads.begin()->first->get_incoming_block(0);
builder.set_insert_point(header->get_inst_list().back());
for (int i=1; i<num_stages-1; ++i) {
for (auto iter = preheader_loads.begin(); iter != preheader_loads.end(); ++iter) {
ir::instruction* original_load = static_cast<ir::instruction*>(iter->second.at(i));
ir::instruction* moved_load = original_load->clone();
builder.insert(moved_load);
original_load->replace_all_uses_with(moved_load);
}
}
}
// try to move dot_inst after loads
// for better overlap of io and compute

View File

@@ -25,13 +25,12 @@ static void recursive_defs(ir::value *v, ir::basic_block *bb, std::vector<ir::in
}
void prefetch::run(ir::module &mod) {
// 1. collect dot that can be prefethced
// 1. collect dots that can be prefethced
std::vector<ir::dot_inst*> to_prefetch;
ir::for_each_instruction(mod, [&](ir::instruction *i) {
if (auto *dot = dynamic_cast<ir::dot_inst*>(i)) {
// Now only do prefetching when dot is fp16 & volta/turing
if (dot->get_operand(0)->get_type()->get_scalar_ty()->get_type_id() != ir::type::HalfTyID ||
tgt_->as_nvidia()->sm() >= 80)
// Now only do prefetching when dot is fp16
if (dot->get_operand(0)->get_type()->get_scalar_ty()->get_type_id() != ir::type::HalfTyID)
return;
auto *a = dynamic_cast<ir::phi_node*>(dot->get_operand(0));
auto *b = dynamic_cast<ir::phi_node*>(dot->get_operand(1));
@@ -56,16 +55,31 @@ void prefetch::run(ir::module &mod) {
// 1. in the loop header (first iteration)
builder.set_insert_point(loop_header->get_inst_list().back());
builder.create_barrier();
assert(a && b);
builder.create_prefetch_s(a->get_incoming_value(0), /*inc*/ 0);
builder.create_prefetch_s(b->get_incoming_value(0), /*inc*/ 0);
// 2. at the end of the loop body (next iteration)
builder.set_insert_point(loop_body->get_inst_list().back());
builder.create_barrier();
builder.create_prefetch_s(a->get_incoming_value(1), /*inc*/ 1);
builder.create_prefetch_s(b->get_incoming_value(1), /*inc*/ 1);
prefetched_vals_.insert(a->get_incoming_value(0));
prefetched_vals_.insert(b->get_incoming_value(0));
// nested phis
ir::value* next_a = a->get_incoming_value(1);
while (auto* next_a_phi = dynamic_cast<ir::phi_node*>(next_a)) {
prefetched_vals_.insert(next_a_phi->get_incoming_value(0));
next_a = next_a_phi->get_incoming_value(1);
}
prefetched_vals_.insert(next_a);
ir::value* next_b = b->get_incoming_value(1);
while (auto* next_b_phi = dynamic_cast<ir::phi_node*>(next_b)) {
prefetched_vals_.insert(next_b_phi->get_incoming_value(0));
next_b = next_b_phi->get_incoming_value(1);
}
prefetched_vals_.insert(next_b);
}
// move loads to the beginning of the loop