[CODEGEN] Major performance improvements on A100 (#70)
Improved handling of asynchronous copy, scheduling and synchronization for A100. Now achieving CUTLASS-like performance on large square dense matrix multiplication tasks
This commit is contained in:
committed by
Philippe Tillet
parent
045ab5d62a
commit
5b83259592
@@ -15,114 +15,105 @@ namespace triton {
|
||||
namespace codegen{
|
||||
namespace transform{
|
||||
|
||||
bool membar::intersect(const interval_vec_t &X, interval_t x) {
|
||||
return std::any_of(X.begin(), X.end(), [&](const interval_t &y){
|
||||
bool left_intersect = y.first <= x.first && x.first < y.second;
|
||||
bool right_intersect = y.first <= x.second && x.second < y.second;
|
||||
return left_intersect || right_intersect;
|
||||
});
|
||||
}
|
||||
|
||||
bool membar::intersect(const interval_vec_t &X, const interval_vec_t &Y) {
|
||||
return std::any_of(Y.begin(), Y.end(), [&](const interval_t &y){
|
||||
return intersect(X, y);
|
||||
});
|
||||
}
|
||||
|
||||
void membar::add_reference(ir::value *v, interval_vec_t &res){
|
||||
auto *i = dynamic_cast<ir::instruction*>(v);
|
||||
if(!i)
|
||||
return;
|
||||
if(!i->get_type()->is_tile_ty())
|
||||
return;
|
||||
analysis::shared_layout* layout = layouts_->get(v)->to_shared();
|
||||
if(!layout)
|
||||
return;
|
||||
if(alloc_->has_offset(layout)){
|
||||
unsigned offset = alloc_->offset(layout);
|
||||
res.push_back(interval_t(offset, offset + layout->get_size()));
|
||||
int membar::group_of(ir::value* v, std::vector<ir::value*> &async_write) {
|
||||
if(ir::phi_node* phi = dynamic_cast<ir::phi_node*>(v)){
|
||||
analysis::shared_layout* layout = layouts_->get(v)->to_shared();
|
||||
analysis::double_buffer_info_t* info = layout->get_double_buffer();
|
||||
if(info)
|
||||
return group_of(info->first, async_write);
|
||||
std::vector<int> groups(phi->get_num_operands());
|
||||
std::transform(phi->op_begin(), phi->op_end(), groups.begin(), [&](ir::value* v){ return group_of(v, async_write);});
|
||||
return *std::max_element(groups.begin(), groups.end());
|
||||
}
|
||||
else{
|
||||
auto it = std::find(async_write.begin(), async_write.end(), v);
|
||||
return std::distance(async_write.begin(), it);
|
||||
}
|
||||
}
|
||||
|
||||
void membar::get_read_intervals(ir::instruction *i, interval_vec_t &res){
|
||||
for(ir::value *op: i->ops())
|
||||
add_reference(op, res);
|
||||
|
||||
membar::val_set_t membar::intersect_with(const val_set_t& as, const val_set_t& bs) {
|
||||
val_set_t ret;
|
||||
for(ir::value* a: as){
|
||||
if(!a->get_type()->is_tile_ty())
|
||||
continue;
|
||||
analysis::shared_layout* a_layout = layouts_->get(a)->to_shared();
|
||||
if(!a_layout)
|
||||
continue;
|
||||
int a_start = alloc_->offset(a_layout);
|
||||
int a_end = a_start + a_layout->get_size();
|
||||
for(ir::value* b: bs){
|
||||
if(!b->get_type()->is_tile_ty())
|
||||
continue;
|
||||
analysis::shared_layout* b_layout = layouts_->get(b)->to_shared();
|
||||
if(!b_layout)
|
||||
continue;
|
||||
int b_start = alloc_->offset(b_layout);
|
||||
int b_end = b_start + b_layout->get_size();
|
||||
if(a_start < b_end || b_start < a_end)
|
||||
ret.insert(b);
|
||||
}
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
void membar::get_written_intervals(ir::instruction *i, interval_vec_t &res){
|
||||
if(!dynamic_cast<ir::phi_node*>(i) && !dynamic_cast<ir::trans_inst*>(i))
|
||||
add_reference(i, res);
|
||||
}
|
||||
|
||||
void membar::insert_barrier(ir::instruction *instr, std::pair<bool, bool> type, ir::builder &builder) {
|
||||
if(auto *phi = dynamic_cast<ir::phi_node*>(instr)) {
|
||||
std::set<ir::value*> incoming;
|
||||
for(unsigned n = 0; n < phi->get_num_incoming(); n++){
|
||||
ir::instruction *inc_val = dynamic_cast<ir::instruction*>(phi->get_incoming_value(n));
|
||||
assert(inc_val);
|
||||
if(incoming.insert(inc_val).second){
|
||||
ir::basic_block *block = inc_val->get_parent();
|
||||
builder.set_insert_point(block->get_inst_list().back());
|
||||
if(type.first)
|
||||
builder.create_async_wait();
|
||||
if(type.second)
|
||||
builder.create_barrier();
|
||||
void membar::transfer(ir::basic_block *block,
|
||||
val_vec_t& async_write,
|
||||
val_set_t& sync_write,
|
||||
val_set_t& sync_read,
|
||||
std::set<ir::value*>& safe_war,
|
||||
bool& inserted, ir::builder& builder) {
|
||||
ir::basic_block::inst_list_t instructions = block->get_inst_list();
|
||||
for(ir::instruction *i: instructions){
|
||||
if(dynamic_cast<ir::phi_node*>(i))
|
||||
continue;
|
||||
if(std::find(async_write.begin(), async_write.end(), i) == async_write.end() &&
|
||||
dynamic_cast<ir::masked_load_async_inst*>(i)){
|
||||
async_write.push_back(i);
|
||||
}
|
||||
if(dynamic_cast<ir::copy_to_shared_inst*>(i))
|
||||
sync_write.insert(i);
|
||||
ir::barrier_inst* barrier = dynamic_cast<ir::barrier_inst*>(i);
|
||||
ir::async_wait_inst* async_wait = dynamic_cast<ir::async_wait_inst*>(i);
|
||||
// Get shared memory reads
|
||||
std::set<ir::value*> read;
|
||||
std::copy_if(i->op_begin(), i->op_end(), std::inserter(read, read.begin()),
|
||||
[&](ir::value* i){ return i->get_type()->is_tile_ty() && layouts_->get(i)->to_shared();});
|
||||
// RAW (async)
|
||||
val_set_t tmp;
|
||||
std::copy(async_write.begin(), async_write.end(), std::inserter(tmp, tmp.begin()));
|
||||
if(intersect_with(read, tmp).size()){
|
||||
std::vector<int> groups(read.size());
|
||||
std::transform(read.begin(), read.end(), groups.begin(), [&](ir::value* v){ return group_of(v, async_write);});
|
||||
int N = *std::max_element(groups.begin(), groups.end());
|
||||
if(N < async_write.size()){
|
||||
builder.set_insert_point(i);
|
||||
async_wait = (ir::async_wait_inst*)builder.create_async_wait(async_write.size() - 1 - N);
|
||||
barrier = (ir::barrier_inst*)builder.create_barrier();
|
||||
inserted = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
builder.set_insert_point(instr);
|
||||
builder.create_barrier();
|
||||
}
|
||||
}
|
||||
|
||||
membar::interval_vec_t membar::join(const std::vector<interval_vec_t>& intervals) {
|
||||
membar::interval_vec_t result;
|
||||
for(auto x: intervals)
|
||||
for(interval_t i: x)
|
||||
result.push_back(i);
|
||||
return result;
|
||||
}
|
||||
|
||||
std::pair<membar::interval_vec_t,
|
||||
membar::interval_vec_t> membar::transfer(ir::basic_block *block,
|
||||
const interval_vec_t &written_to,
|
||||
const interval_vec_t &read_from,
|
||||
std::map<ir::instruction*, std::pair<bool,bool>>& insert_loc,
|
||||
std::set<ir::value*>& safe_war,
|
||||
std::vector<ir::instruction*>& to_sync) {
|
||||
ir::basic_block::inst_list_t instructions = block->get_inst_list();
|
||||
interval_vec_t new_written_to = written_to;
|
||||
interval_vec_t new_read_from = read_from;
|
||||
|
||||
for(ir::instruction *i: instructions){
|
||||
interval_vec_t read, written;
|
||||
get_read_intervals(i, read);
|
||||
get_written_intervals(i, written);
|
||||
if(written.size())
|
||||
to_sync.push_back(i);
|
||||
bool read_after_write = intersect(new_written_to, read);
|
||||
bool write_after_read = intersect(new_read_from, written);
|
||||
// double buffering
|
||||
if(safe_war.find(i) != safe_war.end()){
|
||||
write_after_read = false;
|
||||
read_after_write = false;
|
||||
// RAW, WAR
|
||||
if(intersect_with(read, sync_write).size() || intersect_with({i}, sync_read).size()){
|
||||
builder.set_insert_point(i);
|
||||
barrier = (ir::barrier_inst*)builder.create_barrier();
|
||||
inserted = true;
|
||||
}
|
||||
// record hazards
|
||||
if(read_after_write || write_after_read) {
|
||||
auto is_load_async = [&](ir::instruction *i){ return dynamic_cast<ir::masked_load_async_inst*>(i);};
|
||||
auto is_copy_to_shared = [&](ir::instruction *i){ return dynamic_cast<ir::copy_to_shared_inst*>(i);};
|
||||
bool copy_async_wait = std::any_of(to_sync.begin(), to_sync.end(), is_load_async);
|
||||
bool barrier = std::any_of(to_sync.begin(), to_sync.end(), is_copy_to_shared);
|
||||
insert_loc.insert({i, {copy_async_wait, barrier}});
|
||||
new_written_to.clear();
|
||||
new_read_from.clear();
|
||||
to_sync.clear();
|
||||
// update state of asynchronous copies
|
||||
if(async_wait){
|
||||
int N = async_write.size() - async_wait->get_N();
|
||||
async_write.erase(async_write.begin(), async_write.begin() + N);
|
||||
}
|
||||
std::copy(written.begin(), written.end(), std::back_inserter(new_written_to));
|
||||
std::copy(read.begin(), read.end(), std::back_inserter(new_read_from));
|
||||
// all the copy_to_shared and read from shared are synchronized after barrier
|
||||
if(barrier){
|
||||
sync_write.clear();
|
||||
sync_read.clear();
|
||||
}
|
||||
sync_read.insert(read.begin(), read.end());
|
||||
|
||||
}
|
||||
return std::make_pair(new_written_to, new_read_from);
|
||||
}
|
||||
|
||||
void membar::run(ir::module &mod) {
|
||||
@@ -143,35 +134,33 @@ void membar::run(ir::module &mod) {
|
||||
|
||||
for(ir::function *fn: mod.get_function_list()){
|
||||
std::vector<ir::basic_block*> rpo = ir::cfg::reverse_post_order(fn);
|
||||
std::map<ir::basic_block*, interval_vec_t> written_to;
|
||||
std::map<ir::basic_block*, interval_vec_t> read_from;
|
||||
std::vector<ir::instruction*> to_sync;
|
||||
std::map<ir::instruction*, std::pair<bool,bool>> insert_locs;
|
||||
size_t n_inserted_im1 = 0;
|
||||
bool done = false;
|
||||
std::map<ir::basic_block*, val_vec_t> async_writes;
|
||||
std::map<ir::basic_block*, val_set_t> sync_writes;
|
||||
std::map<ir::basic_block*, val_set_t> sync_reads;
|
||||
std::list<ir::value *> pipelined;
|
||||
bool inserted;
|
||||
do{
|
||||
inserted = false;
|
||||
// find barrier location
|
||||
for(ir::basic_block *block: rpo){
|
||||
// written to
|
||||
std::vector<interval_vec_t> pred_written_to;
|
||||
for(ir::basic_block* pred: block->get_predecessors())
|
||||
pred_written_to.push_back(written_to[pred]);
|
||||
// read from
|
||||
std::vector<interval_vec_t> pred_read_from;
|
||||
for(ir::basic_block* pred: block->get_predecessors())
|
||||
pred_read_from.push_back(read_from[pred]);
|
||||
// apply transfer function
|
||||
auto result = transfer(block, join(pred_written_to), join(pred_read_from), insert_locs, safe_war, to_sync);
|
||||
written_to[block] = result.first;
|
||||
read_from[block] = result.second;
|
||||
// join inputs
|
||||
val_vec_t async_write;
|
||||
val_set_t sync_write;
|
||||
val_set_t sync_read;
|
||||
val_set_t tmp;
|
||||
for(ir::basic_block* pred: block->get_predecessors()){
|
||||
for(ir::value* v: async_writes[pred])
|
||||
if(tmp.insert(v).second)
|
||||
async_write.push_back(v);
|
||||
sync_write.insert(sync_writes[pred].begin(), sync_writes[pred].end());
|
||||
sync_read.insert(sync_reads[pred].begin(), sync_reads[pred].end());
|
||||
}
|
||||
transfer(block, async_write, sync_write, sync_read, safe_war, inserted, builder);
|
||||
async_writes[block] = async_write;
|
||||
sync_writes[block] = sync_write;
|
||||
sync_reads[block] = sync_read;
|
||||
}
|
||||
size_t n_inserted_i = insert_locs.size();
|
||||
done = (n_inserted_im1 == n_inserted_i);
|
||||
n_inserted_im1 = n_inserted_i;
|
||||
}while(!done);
|
||||
for(auto x: insert_locs){
|
||||
insert_barrier(x.first, x.second, builder);
|
||||
}
|
||||
}while(inserted);
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -1,7 +1,9 @@
|
||||
#include <algorithm>
|
||||
#include <iostream>
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/codegen/transform/peephole.h"
|
||||
#include "triton/codegen/analysis/layout.h"
|
||||
|
||||
namespace triton {
|
||||
namespace codegen{
|
||||
@@ -109,9 +111,18 @@ bool peephole::rewrite_load_to_shared(ir::instruction *value, ir::builder& build
|
||||
ir::value *ptr = ld->get_pointer_operand();
|
||||
ir::value *msk = ld->get_mask_operand();
|
||||
ir::value *val = ld->get_false_value_operand();
|
||||
ir::value* new_load = builder.create_masked_load_async(ptr, msk, val);
|
||||
copy_to_shared->replace_all_uses_with(new_load);
|
||||
return true;
|
||||
analysis::scanline_layout* layout = layouts_->get(ptr)->to_scanline();
|
||||
int nts = layout->nts(layout->get_order()[0]);
|
||||
int dtsize = value->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
|
||||
if(nts*dtsize >= 4){
|
||||
ir::value* new_load = builder.create_masked_load_async(ptr, msk, val);
|
||||
copy_to_shared->replace_all_uses_with(new_load);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
// analysis::scanline_layout* layout = layouts_->get(ptr)->to_scanline();
|
||||
// std::cout << layout->nts(layout->get_order(0)) << std::endl;
|
||||
// return true;
|
||||
|
||||
}
|
||||
|
||||
@@ -216,11 +227,11 @@ void peephole::run(ir::module &mod) {
|
||||
bool was_modified = false;
|
||||
was_modified = was_modified || rewrite_mult(i, builder);
|
||||
// was_modified = was_modified || rewrite_cts_cfs(i, builder);
|
||||
was_modified = was_modified || rewrite_trans_phi(i, builder);
|
||||
// was_modified = was_modified || rewrite_trans_phi(i, builder);
|
||||
was_modified = was_modified || rewrite_unit_red(i, builder);
|
||||
was_modified = was_modified || rewrite_gep_ptr_min_off_plus_off(i, builder);
|
||||
// if(tgt_->as_nvidia()->sm() >= 80)
|
||||
// was_modified = was_modified || rewrite_load_to_shared(i, builder);
|
||||
if(tgt_->as_nvidia()->sm() >= 80)
|
||||
was_modified = was_modified || rewrite_load_to_shared(i, builder);
|
||||
if(was_modified)
|
||||
seen.insert(i);
|
||||
}
|
||||
|
116
lib/codegen/transform/pipeline.cc
Normal file
116
lib/codegen/transform/pipeline.cc
Normal file
@@ -0,0 +1,116 @@
|
||||
#include <iostream>
|
||||
#include <algorithm>
|
||||
#include "triton/codegen/transform/pipeline.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/basic_block.h"
|
||||
#include "triton/ir/instructions.h"
|
||||
#include "triton/ir/utils.h"
|
||||
|
||||
namespace triton {
|
||||
namespace codegen{
|
||||
namespace transform{
|
||||
|
||||
|
||||
void recursive_deps(ir::value* v, ir::basic_block* block, std::vector<ir::instruction*>& ret){
|
||||
ir::instruction* i = dynamic_cast<ir::instruction*>(v);
|
||||
if(!i || i->get_parent() != block)
|
||||
return;
|
||||
if(i->get_id()==ir::INST_PHI)
|
||||
return;
|
||||
ret.push_back(i);
|
||||
for(ir::user* u: i->get_users())
|
||||
recursive_deps(u, block, ret);
|
||||
}
|
||||
|
||||
void pipeline::run(ir::module &mod) {
|
||||
// *Very* conservative heuristics for pre-fetching.
|
||||
// A load instruction can be pipelined if:
|
||||
// - the pointer is a phi node that references a value
|
||||
// in its basic block (i.e., pointer induction variable)
|
||||
// - the load has only a single use in a dot instruction
|
||||
// As more use cases become apparent, this pass will be improved
|
||||
std::vector<std::pair<ir::load_inst*, ir::phi_node*>> to_pipeline;
|
||||
ir::for_each_instruction(mod, [&](ir::instruction *i){
|
||||
if(auto* load = dynamic_cast<ir::load_inst*>(i)){
|
||||
ir::phi_node* ptr = dynamic_cast<ir::phi_node*>(load->get_pointer_operand());
|
||||
auto users = load->get_users();
|
||||
if(ptr && ptr->get_incoming_block(1) == ptr->get_parent()
|
||||
&& users.size() == 1 && dynamic_cast<ir::dot_inst*>(*users.begin()))
|
||||
to_pipeline.push_back({load, ptr});
|
||||
}});
|
||||
// do the pipelining
|
||||
std::vector<ir::phi_node*> new_loads;
|
||||
ir::builder &builder = mod.get_builder();
|
||||
for(auto info: to_pipeline){
|
||||
ir::load_inst* load = info.first;
|
||||
ir::phi_node* ptr = info.second;
|
||||
ir::basic_block* block = load->get_parent();
|
||||
ir::basic_block* header = block->get_predecessors()[0];
|
||||
auto* block_br = dynamic_cast<ir::cond_branch_inst*>(block->get_inst_list().back());
|
||||
auto* header_br = dynamic_cast<ir::cond_branch_inst*>(header->get_inst_list().back());
|
||||
assert(block_br);
|
||||
assert(header_br);
|
||||
ir::type* ty = load->get_type();
|
||||
// pre-fetch first iteration
|
||||
builder.set_insert_point(header->get_inst_list().back());
|
||||
ir::value* first_ptr = ptr->get_value_for_block(header);
|
||||
ir::value* first_mask = builder.create_splat(header_br->get_cond(), ty->get_tile_shapes());
|
||||
ir::value* false_value;
|
||||
if(auto* masked_load = dynamic_cast<ir::masked_load_inst*>(load)){
|
||||
first_mask = builder.create_and(first_mask, masked_load->get_mask_operand());
|
||||
false_value = masked_load->get_false_value_operand();
|
||||
}
|
||||
else
|
||||
false_value = builder.create_splat(ir::undef_value::get(ty->get_scalar_ty()), ty->get_tile_shapes());
|
||||
ir::value* first_load = builder.create_masked_load(first_ptr, first_mask, false_value);
|
||||
// pre-fetch next iteration
|
||||
builder.set_insert_point(block->get_inst_list().back());
|
||||
ir::value* next_ptr = ptr->get_value_for_block(block);
|
||||
ir::value* next_mask = builder.create_splat(block_br->get_cond(), ty->get_tile_shapes());
|
||||
if(auto* masked_load = dynamic_cast<ir::masked_load_inst*>(load))
|
||||
next_mask = builder.create_and(next_mask, masked_load->get_mask_operand());
|
||||
ir::value* next_load = builder.create_masked_load(next_ptr, next_mask, false_value);
|
||||
// phi node
|
||||
builder.set_insert_point(block->get_first_non_phi());
|
||||
ir::phi_node* new_load = builder.create_phi(ty, 2);
|
||||
new_load->add_incoming(first_load, header);
|
||||
new_load->add_incoming(next_load, block);
|
||||
load->replace_all_uses_with(new_load);
|
||||
new_loads.push_back(new_load);
|
||||
}
|
||||
|
||||
|
||||
// try to move dot_inst after loads
|
||||
// for better overlap of io and compute
|
||||
struct move_config_t{
|
||||
std::vector<ir::instruction*> insts;
|
||||
ir::load_inst* dst;
|
||||
};
|
||||
std::map<ir::basic_block*, move_config_t> to_move;
|
||||
|
||||
if(has_copy_async_){
|
||||
for(ir::function* fn: mod.get_function_list())
|
||||
for(ir::basic_block* bb: fn->blocks())
|
||||
for(ir::instruction* inst: bb->get_inst_list()){
|
||||
if(auto* i = dynamic_cast<ir::dot_inst*>(inst))
|
||||
recursive_deps(i, bb, to_move[bb].insts);
|
||||
if(auto* i = dynamic_cast<ir::load_inst*>(inst))
|
||||
to_move[bb].dst = i;
|
||||
}
|
||||
|
||||
for(auto& x: to_move){
|
||||
builder.set_insert_point_after(x.second.dst);
|
||||
for(ir::instruction* i: x.second.insts){
|
||||
x.first->erase(i);
|
||||
builder.insert(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
@@ -22,6 +22,8 @@ inline ir::instruction* reassociate::is_bin_add(ir::value *x) {
|
||||
inline bool is_cst(ir::value *x) {
|
||||
if(dynamic_cast<ir::constant*>(x))
|
||||
return true;
|
||||
if(dynamic_cast<ir::make_range*>(x))
|
||||
return true;
|
||||
if(auto *v = dynamic_cast<ir::retile_inst*>(x))
|
||||
return is_cst(v->get_operand(0));
|
||||
return false;
|
||||
|
Reference in New Issue
Block a user