[CODEGEN] Performance improvement on A100 (#125)
Improved codegen for the Ampere GPUs. * Make the layout pass recognize the multistage pipelined pattern. * Now the pipeline pass can automate the multistage pipelining transformation. * Remove extra barriers (from the prefetch pass & WAR) on Ampere. * Update the code generator (generator.cc) to make Triton generate n-buffered shared memory loads/stores.
This commit is contained in:
committed by
Philippe Tillet
parent
5a51f3e529
commit
d8d6b715c8
@@ -25,13 +25,12 @@ static void recursive_defs(ir::value *v, ir::basic_block *bb, std::vector<ir::in
|
||||
}
|
||||
|
||||
void prefetch::run(ir::module &mod) {
|
||||
// 1. collect dot that can be prefethced
|
||||
// 1. collect dots that can be prefethced
|
||||
std::vector<ir::dot_inst*> to_prefetch;
|
||||
ir::for_each_instruction(mod, [&](ir::instruction *i) {
|
||||
if (auto *dot = dynamic_cast<ir::dot_inst*>(i)) {
|
||||
// Now only do prefetching when dot is fp16 & volta/turing
|
||||
if (dot->get_operand(0)->get_type()->get_scalar_ty()->get_type_id() != ir::type::HalfTyID ||
|
||||
tgt_->as_nvidia()->sm() >= 80)
|
||||
// Now only do prefetching when dot is fp16
|
||||
if (dot->get_operand(0)->get_type()->get_scalar_ty()->get_type_id() != ir::type::HalfTyID)
|
||||
return;
|
||||
auto *a = dynamic_cast<ir::phi_node*>(dot->get_operand(0));
|
||||
auto *b = dynamic_cast<ir::phi_node*>(dot->get_operand(1));
|
||||
@@ -56,16 +55,31 @@ void prefetch::run(ir::module &mod) {
|
||||
|
||||
// 1. in the loop header (first iteration)
|
||||
builder.set_insert_point(loop_header->get_inst_list().back());
|
||||
builder.create_barrier();
|
||||
assert(a && b);
|
||||
builder.create_prefetch_s(a->get_incoming_value(0), /*inc*/ 0);
|
||||
builder.create_prefetch_s(b->get_incoming_value(0), /*inc*/ 0);
|
||||
|
||||
// 2. at the end of the loop body (next iteration)
|
||||
builder.set_insert_point(loop_body->get_inst_list().back());
|
||||
builder.create_barrier();
|
||||
builder.create_prefetch_s(a->get_incoming_value(1), /*inc*/ 1);
|
||||
builder.create_prefetch_s(b->get_incoming_value(1), /*inc*/ 1);
|
||||
|
||||
prefetched_vals_.insert(a->get_incoming_value(0));
|
||||
prefetched_vals_.insert(b->get_incoming_value(0));
|
||||
// nested phis
|
||||
ir::value* next_a = a->get_incoming_value(1);
|
||||
while (auto* next_a_phi = dynamic_cast<ir::phi_node*>(next_a)) {
|
||||
prefetched_vals_.insert(next_a_phi->get_incoming_value(0));
|
||||
next_a = next_a_phi->get_incoming_value(1);
|
||||
}
|
||||
prefetched_vals_.insert(next_a);
|
||||
|
||||
ir::value* next_b = b->get_incoming_value(1);
|
||||
while (auto* next_b_phi = dynamic_cast<ir::phi_node*>(next_b)) {
|
||||
prefetched_vals_.insert(next_b_phi->get_incoming_value(0));
|
||||
next_b = next_b_phi->get_incoming_value(1);
|
||||
}
|
||||
prefetched_vals_.insert(next_b);
|
||||
}
|
||||
|
||||
// move loads to the beginning of the loop
|
||||
|
Reference in New Issue
Block a user