[CODEGEN] Performance improvement on A100 (#125)
Improved codegen for the Ampere GPUs. * Make the layout pass recognize the multistage pipelined pattern. * Now the pipeline pass can automate the multistage pipelining transformation. * Remove extra barriers (from the prefetch pass & WAR) on Ampere. * Update the code generator (generator.cc) to make Triton generate n-buffered shared memory loads/stores.
This commit is contained in:
committed by
Philippe Tillet
parent
5a51f3e529
commit
d8d6b715c8
@@ -7,6 +7,7 @@
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/utils.h"
|
||||
// #include "triton/ir/type.h"
|
||||
|
||||
namespace triton{
|
||||
namespace codegen{
|
||||
@@ -273,6 +274,81 @@ void shared_layout::extract_double_bufferable(ir::value *v, std::shared_ptr<doub
|
||||
res.reset(new double_buffer_info_t{value_1, value_0, phi});
|
||||
}
|
||||
|
||||
static bool is_smem(ir::value* v) {
|
||||
if (dynamic_cast<ir::copy_to_shared_inst*>(v) ||
|
||||
dynamic_cast<ir::masked_load_async_inst*>(v))
|
||||
return true;
|
||||
else
|
||||
return false;
|
||||
}
|
||||
|
||||
/// param:
|
||||
/// value_1: next_value
|
||||
static bool is_multistage_pipe_phi(ir::phi_node* phi, ir::basic_block* bb0, ir::basic_block* bb1,
|
||||
std::vector<ir::value*>& values_0, ir::value*& value_1) {
|
||||
ir::value* next = phi;
|
||||
while (auto cphi = dynamic_cast<ir::phi_node*>(next)) {
|
||||
// smem from previous bb & phi/smem from current bb
|
||||
ir::value* c0 = cphi->get_incoming_value(0);
|
||||
ir::value* c1 = cphi->get_incoming_value(1);
|
||||
ir::basic_block *cbb0 = cphi->get_incoming_block(0);
|
||||
ir::basic_block *cbb1 = cphi->get_incoming_block(1);
|
||||
|
||||
if (is_smem(c0)) {
|
||||
assert(cbb0 == bb0);
|
||||
values_0.push_back(c0);
|
||||
if (auto phi1 = dynamic_cast<ir::phi_node*>(c1)) {
|
||||
next = phi1;
|
||||
continue;
|
||||
} else {
|
||||
if (is_smem(c1)) {
|
||||
value_1 = c1;
|
||||
assert(cbb1 == bb1);
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
} else
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
void shared_layout::extract_N_bufferable(ir::value *v, std::shared_ptr<N_buffer_info_t> &res, int &prev_stages) {
|
||||
auto* phi = dynamic_cast<ir::phi_node*>(v);
|
||||
// if the phi node is nested
|
||||
if (!phi)
|
||||
return;
|
||||
|
||||
ir::basic_block *bb0 = phi->get_incoming_block(0);
|
||||
ir::basic_block *bb1 = phi->get_incoming_block(1);
|
||||
|
||||
std::vector<ir::value*> values_0;
|
||||
ir::value* value_1;
|
||||
|
||||
if (!is_multistage_pipe_phi(phi, bb0, bb1, values_0, value_1))
|
||||
return;
|
||||
|
||||
// double-buffer is a special case
|
||||
if (values_0.size() == 1)
|
||||
return;
|
||||
|
||||
// compute original values_0 input order
|
||||
std::map<ir::value*, int> order;
|
||||
int idx = 0;
|
||||
for (ir::instruction* instr : *bb0) {
|
||||
if (std::find(values_0.begin(), values_0.end(), instr) != values_0.end())
|
||||
order[static_cast<ir::value*>(instr)] = idx++;
|
||||
}
|
||||
assert(order.size() == values_0.size() && "order size incorrect");
|
||||
|
||||
int curr_stages = values_0.size() + 1;
|
||||
if (curr_stages > prev_stages) {
|
||||
res.reset(new N_buffer_info_t{values_0, value_1, phi, order});
|
||||
prev_stages = curr_stages;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
shared_layout::shared_layout(data_layout *arg,
|
||||
const std::vector<int>& axes,
|
||||
@@ -284,9 +360,15 @@ shared_layout::shared_layout(data_layout *arg,
|
||||
size_ = 0;
|
||||
arg_layout_ = arg;
|
||||
|
||||
// N-stage buffering
|
||||
int prev_stages = 0;
|
||||
for (ir::value *v : values)
|
||||
extract_N_bufferable(v, N_buffer_, prev_stages);
|
||||
|
||||
// double-buffering
|
||||
for(ir::value *v: values)
|
||||
extract_double_bufferable(v, double_buffer_);
|
||||
if (!N_buffer_)
|
||||
for(ir::value *v: values)
|
||||
extract_double_bufferable(v, double_buffer_);
|
||||
|
||||
// order
|
||||
std::vector<int> arg_order = arg ? arg->get_order() : std::vector<int>{0};
|
||||
@@ -311,8 +393,22 @@ shared_layout::shared_layout(data_layout *arg,
|
||||
size_ *= s;
|
||||
if(double_buffer_)
|
||||
size_ *= 2;
|
||||
if (N_buffer_) {
|
||||
size_ *= (N_buffer_->firsts.size() + 1);
|
||||
}
|
||||
}
|
||||
|
||||
int shared_layout::get_num_stages() const {
|
||||
if (double_buffer_)
|
||||
return 2;
|
||||
if (N_buffer_)
|
||||
return N_buffer_->firsts.size() + 1;
|
||||
return 1;
|
||||
}
|
||||
|
||||
size_t shared_layout::get_per_stage_elements() const {
|
||||
return get_per_stage_size()/(ty_->get_primitive_size_in_bits()/8);
|
||||
}
|
||||
|
||||
/* -------------------------------- *
|
||||
* ---- Layouts Inference Pass ---- *
|
||||
@@ -403,7 +499,6 @@ void layouts::run(ir::module &mod) {
|
||||
for(const auto& x: values_)
|
||||
create(x.first, x.second);
|
||||
|
||||
|
||||
// create temporaries
|
||||
size_t id = values_.size();
|
||||
ir::for_each_instruction(mod, [this, &id](ir::instruction* i) {
|
||||
|
Reference in New Issue
Block a user