[CODEGEN] Performance improvement on A100 (#125)
Improved codegen for the Ampere GPUs. * Make the layout pass recognize the multistage pipelined pattern. * Now the pipeline pass can automate the multistage pipelining transformation. * Remove extra barriers (from the prefetch pass & WAR) on Ampere. * Update the code generator (generator.cc) to make Triton generate n-buffered shared memory loads/stores.
This commit is contained in:
committed by
Philippe Tillet
parent
5a51f3e529
commit
d8d6b715c8
@@ -23,6 +23,60 @@ void recursive_deps(ir::value* v, ir::basic_block* block, std::vector<ir::instru
|
||||
recursive_deps(u, block, ret);
|
||||
}
|
||||
|
||||
/// assume incoming block is 1
|
||||
ir::value* rematerialize_vals(ir::builder& builder, ir::value* v,
|
||||
std::map<ir::phi_node*, ir::value*>& prev_phi_vals) {
|
||||
ir::instruction* i = dynamic_cast<ir::instruction*>(v);
|
||||
if(!i)
|
||||
return v;
|
||||
if(ir::phi_node* phi = dynamic_cast<ir::phi_node*>(v)) {
|
||||
if (prev_phi_vals.find(phi) == prev_phi_vals.end())
|
||||
throw std::runtime_error("Don't have that phi node\n");
|
||||
return prev_phi_vals.at(phi);
|
||||
}
|
||||
|
||||
std::vector<ir::value*> new_ops;
|
||||
for(ir::value* op: i->ops()){
|
||||
new_ops.push_back(rematerialize_vals(builder, op, prev_phi_vals));
|
||||
}
|
||||
ir::instruction* ret = i->clone();
|
||||
for(size_t k = 0; k < new_ops.size(); k++)
|
||||
ret->set_operand(k, new_ops[k]);
|
||||
builder.insert(ret);
|
||||
return ret;
|
||||
}
|
||||
|
||||
void get_induction_vars(ir::value* cond, std::set<ir::phi_node*>& phis) {
|
||||
auto instr = dynamic_cast<ir::instruction*>(cond);
|
||||
for (auto op : instr->ops()) {
|
||||
if (auto phi_op = dynamic_cast<ir::phi_node*>(op)) {
|
||||
phis.insert(phi_op);
|
||||
return;
|
||||
}
|
||||
if (dynamic_cast<ir::instruction*>(op))
|
||||
get_induction_vars(op, phis);
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns phi_val if sees a phi node
|
||||
ir::value* rematerialize_val(ir::builder& builder, ir::value* v, ir::value* phi_val) {
|
||||
ir::instruction* i = dynamic_cast<ir::instruction*>(v);
|
||||
if(!i)
|
||||
return v;
|
||||
if(ir::phi_node* phi = dynamic_cast<ir::phi_node*>(v))
|
||||
return phi_val;
|
||||
|
||||
std::vector<ir::value*> new_ops;
|
||||
for(ir::value* op: i->ops()){
|
||||
new_ops.push_back(rematerialize_val(builder, op, phi_val));
|
||||
}
|
||||
ir::instruction* ret = i->clone();
|
||||
for(size_t k = 0; k < new_ops.size(); k++)
|
||||
ret->set_operand(k, new_ops[k]);
|
||||
builder.insert(ret);
|
||||
return ret;
|
||||
}
|
||||
|
||||
ir::value* rematerialize(ir::builder& builder, ir::value* v, size_t phi_idx){
|
||||
ir::instruction* i = dynamic_cast<ir::instruction*>(v);
|
||||
if(!i)
|
||||
@@ -41,6 +95,28 @@ ir::value* rematerialize(ir::builder& builder, ir::value* v, size_t phi_idx){
|
||||
return ret;
|
||||
}
|
||||
|
||||
/// moving the prev phi vals to the next iteration
|
||||
void update_prev_phi_vals(ir::builder& builder, std::map<ir::phi_node*, ir::value*>& prev_phi_vals) {
|
||||
for (auto& [phi, val] : prev_phi_vals) {
|
||||
// TODO: handling nested phis
|
||||
val = rematerialize_val(builder, phi->get_incoming_value(1), val);
|
||||
}
|
||||
}
|
||||
|
||||
void finalize_iv_vals(ir::builder& builder, std::map<ir::phi_node*, ir::value*>& load_ivs,
|
||||
std::map<ir::phi_node*, ir::value*>& next_load_ivs) {
|
||||
for (auto& [phi, val] : load_ivs) {
|
||||
if (auto new_phi = dynamic_cast<ir::phi_node*>(val)) {
|
||||
ir::value* next_k = rematerialize_vals(builder, phi->get_incoming_value(1), load_ivs);
|
||||
assert(new_phi->get_num_operands() == 1 && "should be incomplete phi");
|
||||
new_phi->add_incoming(next_k, phi->get_incoming_block(1));
|
||||
// cache next_k (to be used by next_mask)
|
||||
next_load_ivs[phi] = next_k;
|
||||
} else
|
||||
throw std::runtime_error("must be phi");
|
||||
}
|
||||
}
|
||||
|
||||
void pipeline::run(ir::module &mod) {
|
||||
// *Very* conservative heuristics for pre-fetching.
|
||||
// A load instruction can be pipelined if:
|
||||
@@ -60,6 +136,8 @@ void pipeline::run(ir::module &mod) {
|
||||
// do the pipelining
|
||||
std::vector<ir::phi_node*> new_loads;
|
||||
ir::builder &builder = mod.get_builder();
|
||||
const int num_stages = num_stages_;
|
||||
std::vector<std::pair<ir::phi_node*, std::vector<ir::value*>>> preheader_loads; // Used to reorder loads
|
||||
for(auto info: to_pipeline){
|
||||
ir::load_inst* load = info.first;
|
||||
ir::phi_node* ptr = info.second;
|
||||
@@ -70,40 +148,155 @@ void pipeline::run(ir::module &mod) {
|
||||
assert(block_br);
|
||||
assert(header_br);
|
||||
ir::type* ty = load->get_type();
|
||||
// pre-fetch first iteration
|
||||
builder.set_insert_point(header->get_inst_list().back());
|
||||
ir::value* first_ptr = ptr->get_value_for_block(header);
|
||||
ir::value* first_mask = builder.create_splat(header_br->get_cond(), ty->get_block_shapes());
|
||||
ir::value* false_value;
|
||||
if(auto* masked_load = dynamic_cast<ir::masked_load_inst*>(load)){
|
||||
ir::value* remat_mask = rematerialize(builder, masked_load->get_mask_operand(), 0);
|
||||
ir::value* remat_false_value = rematerialize(builder, masked_load->get_false_value_operand(), 0);
|
||||
first_mask = builder.create_and(first_mask, remat_mask);
|
||||
false_value = remat_false_value;
|
||||
// multi-stage pipe
|
||||
if (has_copy_async_ && num_stages > 2) {
|
||||
ir::value* header_cond = header_br->get_cond();
|
||||
ir::value* block_cond = block_br->get_cond();
|
||||
// 1. collect induction variables
|
||||
std::set<ir::phi_node*> induction_vars;
|
||||
get_induction_vars(block_cond, induction_vars);
|
||||
|
||||
std::vector<ir::value*> first_ptrs(num_stages-1);
|
||||
std::vector<ir::value*> first_loads(num_stages-1);
|
||||
std::vector<ir::value*> first_masks(num_stages-1);
|
||||
std::vector<ir::value*> loop_conds(num_stages-1);
|
||||
|
||||
std::map<ir::phi_node*, ir::value*> prev_phi_vals;
|
||||
// initialize prev_phi_vals
|
||||
// note: we assume that ptr & other values only depend on ptr & iv (phis)
|
||||
// TODO: can we just add all phis here?
|
||||
prev_phi_vals[ptr] = ptr->get_value_for_block(header);
|
||||
for (ir::phi_node* iv : induction_vars)
|
||||
prev_phi_vals[iv] = iv->get_value_for_block(header);
|
||||
prev_phi_vals[ptr] = ptr->get_value_for_block(header);
|
||||
|
||||
builder.set_insert_point(header->get_inst_list().back());
|
||||
first_ptrs[0] = ptr->get_value_for_block(header);
|
||||
loop_conds[0] = header_cond;
|
||||
first_masks[0] = builder.create_splat(loop_conds[0], ty->get_block_shapes());
|
||||
ir::value* false_value = nullptr;
|
||||
if (auto* masked_load = dynamic_cast<ir::masked_load_inst*>(load)) {
|
||||
ir::value* remat_mask =rematerialize_vals(builder, masked_load->get_mask_operand(), prev_phi_vals) ;
|
||||
ir::value* remat_false_value =
|
||||
rematerialize_vals(builder, masked_load->get_false_value_operand(), prev_phi_vals);
|
||||
first_masks[0] = builder.create_and(first_masks[0], remat_mask);
|
||||
false_value = remat_false_value;
|
||||
} else
|
||||
false_value = builder.create_splat(ir::undef_value::get(ty->get_scalar_ty()), ty->get_block_shapes());
|
||||
first_loads[0] = builder.create_masked_load(first_ptrs[0], first_masks[0], false_value);
|
||||
|
||||
for (int stage = 1; stage < num_stages-1; ++stage) {
|
||||
// mask is the loop condition of the previous iteration
|
||||
loop_conds[stage] = rematerialize_vals(builder, block_cond, prev_phi_vals);
|
||||
update_prev_phi_vals(builder, prev_phi_vals);
|
||||
first_ptrs[stage] = rematerialize_vals(builder, ptr, prev_phi_vals);
|
||||
first_masks[stage] = builder.create_splat(loop_conds[stage], ty->get_block_shapes());
|
||||
if (auto* masked_load = dynamic_cast<ir::masked_load_inst*>(load)) {
|
||||
ir::value* remat_mask = rematerialize_vals(builder, masked_load->get_mask_operand(), prev_phi_vals);
|
||||
ir::value* remat_false_value =
|
||||
rematerialize_vals(builder, masked_load->get_false_value_operand(), prev_phi_vals);
|
||||
first_masks[stage] = builder.create_and(first_masks[stage], remat_mask);
|
||||
false_value = remat_false_value;
|
||||
}
|
||||
first_loads[stage] = builder.create_masked_load(first_ptrs[stage], first_masks[stage], false_value);
|
||||
}
|
||||
|
||||
// create new phis for induction variables
|
||||
builder.set_insert_point(block->get_first_non_phi());
|
||||
std::map<ir::phi_node*, ir::value*> load_ivs;
|
||||
std::map<ir::phi_node*, ir::value*> next_load_ivs;
|
||||
for (ir::phi_node* iv : induction_vars) {
|
||||
ir::phi_node* pn = builder.create_phi(iv->get_type(), 2);
|
||||
pn->add_incoming(prev_phi_vals[iv], header);
|
||||
load_ivs[iv] = pn;
|
||||
}
|
||||
// add incoming for phis & update next_load_ivs
|
||||
finalize_iv_vals(builder, load_ivs, next_load_ivs);
|
||||
|
||||
// pre-fetch next iteration
|
||||
builder.set_insert_point(block->get_inst_list().back());
|
||||
ir::value* next_ptr = ptr->get_value_for_block(block);
|
||||
ir::value* next_mask = builder.create_splat(
|
||||
rematerialize_vals(builder, block_cond, load_ivs), ty->get_block_shapes());
|
||||
if (auto* masked_load = dynamic_cast<ir::masked_load_inst*>(load)) {
|
||||
ir::value* remat_mask = rematerialize_vals(builder, masked_load->get_mask_operand(), next_load_ivs);
|
||||
// TODO: false may depends on some other phi nodes
|
||||
ir::value* remat_false_value =
|
||||
rematerialize_vals(builder, masked_load->get_false_value_operand(), next_load_ivs);
|
||||
next_mask = builder.create_and(next_mask, remat_mask);
|
||||
false_value = remat_false_value;
|
||||
}
|
||||
ir::value* next_load = builder.create_masked_load(next_ptr, next_mask, false_value);
|
||||
|
||||
|
||||
// phi node
|
||||
ptr->set_incoming_value(0, first_ptrs.back());
|
||||
builder.set_insert_point(block->get_first_non_phi());
|
||||
// nested phis for load
|
||||
std::vector<ir::phi_node*> new_load_phis(num_stages-1);
|
||||
for (auto& pn : new_load_phis)
|
||||
pn = builder.create_phi(ty, 2);
|
||||
for (int i=0; i<num_stages-2; ++i) {
|
||||
new_load_phis[i]->add_incoming(first_loads[i], header);
|
||||
new_load_phis[i]->add_incoming(new_load_phis[i+1], block);
|
||||
}
|
||||
new_load_phis.back()->add_incoming(first_loads.back(), header);
|
||||
new_load_phis.back()->add_incoming(next_load, block);
|
||||
load->replace_all_uses_with(new_load_phis.front());
|
||||
new_loads.push_back(new_load_phis.back());
|
||||
|
||||
// record first_loads to reorder them
|
||||
preheader_loads.push_back({new_load_phis.front(), first_loads});
|
||||
} else {
|
||||
// pre-fetch first iteration
|
||||
builder.set_insert_point(header->get_inst_list().back());
|
||||
ir::value* first_ptr = ptr->get_value_for_block(header);
|
||||
ir::value* first_mask = builder.create_splat(header_br->get_cond(), ty->get_block_shapes());
|
||||
ir::value* false_value;
|
||||
if(auto* masked_load = dynamic_cast<ir::masked_load_inst*>(load)){
|
||||
ir::value* remat_mask = rematerialize(builder, masked_load->get_mask_operand(), 0);
|
||||
ir::value* remat_false_value = rematerialize(builder, masked_load->get_false_value_operand(), 0);
|
||||
first_mask = builder.create_and(first_mask, remat_mask);
|
||||
false_value = remat_false_value;
|
||||
}
|
||||
else
|
||||
false_value = builder.create_splat(ir::undef_value::get(ty->get_scalar_ty()), ty->get_block_shapes());
|
||||
ir::value* first_load = builder.create_masked_load(first_ptr, first_mask, false_value);
|
||||
// pre-fetch next iteration
|
||||
builder.set_insert_point(block->get_inst_list().back());
|
||||
ir::value* next_ptr = ptr->get_value_for_block(block);
|
||||
ir::value* next_mask = builder.create_splat(block_br->get_cond(), ty->get_block_shapes());
|
||||
if(auto* masked_load = dynamic_cast<ir::masked_load_inst*>(load)){
|
||||
ir::value* remat_mask = rematerialize(builder, masked_load->get_mask_operand(), 1);
|
||||
ir::value* remat_false_value = rematerialize(builder, masked_load->get_false_value_operand(), 1);
|
||||
next_mask = builder.create_and(next_mask, remat_mask);
|
||||
false_value = remat_false_value;
|
||||
}
|
||||
ir::value* next_load = builder.create_masked_load(next_ptr, next_mask, false_value);
|
||||
// phi node
|
||||
builder.set_insert_point(block->get_first_non_phi());
|
||||
ir::phi_node* new_load = builder.create_phi(ty, 2);
|
||||
new_load->add_incoming(first_load, header);
|
||||
new_load->add_incoming(next_load, block);
|
||||
load->replace_all_uses_with(new_load);
|
||||
new_loads.push_back(new_load);
|
||||
}
|
||||
else
|
||||
false_value = builder.create_splat(ir::undef_value::get(ty->get_scalar_ty()), ty->get_block_shapes());
|
||||
ir::value* first_load = builder.create_masked_load(first_ptr, first_mask, false_value);
|
||||
// pre-fetch next iteration
|
||||
builder.set_insert_point(block->get_inst_list().back());
|
||||
ir::value* next_ptr = ptr->get_value_for_block(block);
|
||||
ir::value* next_mask = builder.create_splat(block_br->get_cond(), ty->get_block_shapes());
|
||||
if(auto* masked_load = dynamic_cast<ir::masked_load_inst*>(load)){
|
||||
ir::value* remat_mask = rematerialize(builder, masked_load->get_mask_operand(), 1);
|
||||
ir::value* remat_false_value = rematerialize(builder, masked_load->get_false_value_operand(), 1);
|
||||
next_mask = builder.create_and(next_mask, remat_mask);
|
||||
false_value = remat_false_value;
|
||||
}
|
||||
ir::value* next_load = builder.create_masked_load(next_ptr, next_mask, false_value);
|
||||
// phi node
|
||||
builder.set_insert_point(block->get_first_non_phi());
|
||||
ir::phi_node* new_load = builder.create_phi(ty, 2);
|
||||
new_load->add_incoming(first_load, header);
|
||||
new_load->add_incoming(next_load, block);
|
||||
load->replace_all_uses_with(new_load);
|
||||
new_loads.push_back(new_load);
|
||||
}
|
||||
|
||||
// try to reorder prefetched value from a0, a1, a2, ..., b0, b1, b2, ... to
|
||||
// a0, b0, a1, b1, ...
|
||||
if (!preheader_loads.empty()) {
|
||||
ir::basic_block* header = preheader_loads.begin()->first->get_incoming_block(0);
|
||||
builder.set_insert_point(header->get_inst_list().back());
|
||||
for (int i=1; i<num_stages-1; ++i) {
|
||||
for (auto iter = preheader_loads.begin(); iter != preheader_loads.end(); ++iter) {
|
||||
ir::instruction* original_load = static_cast<ir::instruction*>(iter->second.at(i));
|
||||
ir::instruction* moved_load = original_load->clone();
|
||||
builder.insert(moved_load);
|
||||
original_load->replace_all_uses_with(moved_load);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// try to move dot_inst after loads
|
||||
// for better overlap of io and compute
|
||||
|
Reference in New Issue
Block a user