[CODEGEN] Major performance improvements on A100 (#70)

Improved handling of asynchronous copy, scheduling and synchronization for A100. Now achieving CUTLASS-like performance on large square dense matrix multiplication tasks
This commit is contained in:
Philippe Tillet
2021-02-21 15:19:39 -08:00
committed by Philippe Tillet
parent 045ab5d62a
commit 5b83259592
31 changed files with 1331 additions and 1115 deletions

View File

@@ -312,7 +312,6 @@ std::vector<unsigned> align::populate_max_contiguous_gep(ir::getelementptr_inst*
if(rhs_cst_info[d].num_cst)
rvalue = lhs_max_contiguous[d];
result[d] = std::max(lvalue, rvalue);
// std::cout << "max contiguous: " << x->get_name() << " " << d << " " << result[d] << std::endl;
}
return add_to_cache(x, result, max_contiguous_);
}
@@ -527,8 +526,7 @@ void align::run(ir::module &mod) {
ir::for_each_value(mod, [this](ir::value* v) { populate(v); } );
// ir::for_each_value(mod, [this](ir::value* v) {
// if(dynamic_cast<ir::cast_inst*>(v) || dynamic_cast<ir::getelementptr_inst*>(v))
// std::cout << "ALIGN: " << v->get_name() << " " << starting_multiple_.at(v)[0] << " " << max_contiguous_.at(v)[0]
// << " " << starting_multiple_.at(v)[1] << " " << max_contiguous_.at(v)[1] << std::endl;
// std::cout << "ALIGN: " << v->get_name() << " " << max_contiguous_.at(v)[0] << " " << max_contiguous_.at(v)[1] << std::endl;
// });
}

View File

@@ -118,15 +118,6 @@ data_layout::data_layout(id_t id,
// std::cout << max_contiguous[0] << " " << max_contiguous[1] << std::endl;
// std::cout << order_[0] << " " << order_[1] << std::endl;
}
if(is_recoalesce){
if(ptr.size() > 0){
// std::cout << "recoalesce: " << order_[0] << " " << order_[1] << " " << ptr.size() << std::endl;
// std::cout << max_contiguous[0] << " " << max_contiguous[1] << std::endl;
// if(order_[0] == 0)
// exit(1);
}
}
// std::cout << "---" << std::endl;
}
int data_layout::find_axis(int to_find) const {
@@ -213,14 +204,16 @@ scanline_layout::scanline_layout(size_t num_warps,
ir::value *ptr = nullptr;
for(ir::value *v: values)
for(ir::user *usr: v->get_users())
if(auto *st = dynamic_cast<ir::io_inst*>(usr))
ptr = st->get_pointer_operand();
if(auto *io = dynamic_cast<ir::io_inst*>(usr)){
if(!ptr || ptr->get_type()->get_tile_rank() < io->get_pointer_operand()->get_type()->get_tile_rank())
ptr = io->get_pointer_operand();
}
unsigned i = order_[0];
int contiguous = 1;
if(ptr){
int nbits = ptr->get_type()->get_pointer_element_ty()->get_scalar_ty()->get_primitive_size_in_bits();
contiguous = std::min<int>(align->contiguous(ptr)[i], 128 / nbits);
contiguous = std::min<int>(align->get(ptr, i), 128 / nbits);
}
nts_[i] = clamp(size / num_threads, 1, std::min<int>(contiguous, shape_[i]));

View File

@@ -1416,59 +1416,80 @@ void generator::visit_recoalesce_inst(ir::recoalesce_inst* rc) {
}
void generator::visit_masked_load_async_inst(ir::masked_load_async_inst* x){
unsigned vector = 1;
ir::value *ptrs = x->get_pointer_operand();
ir::value *msks = x->get_mask_operand();
unsigned in_vec = 1;
ir::value *arg = x->get_pointer_operand();
analysis::shared_layout* out_layout = layouts_->get(x)->to_shared();
analysis::scanline_layout* in_layout = layouts_->get(ptrs)->to_scanline();
analysis::scanline_layout* in_layout = layouts_->get(arg)->to_scanline();
auto out_order = out_layout->get_order();
auto in_order = in_layout->get_order();
// tiles
if(out_order == in_order)
vector = in_layout->nts(in_order[0]);
in_vec = in_layout->nts(in_order[0]);
int out_vec = swizzle_->get_vec(out_layout);
int min_vec = std::min<int>(out_vec, in_vec);
int s = std::max<int>(out_vec / in_vec, 1);
//
int dtsize = x->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
int num_per_phase = std::max<int>(128 / (in_layout->mts(in_order[0])*vector*dtsize), 1);
Value *max_phase = i32(8 / num_per_phase);
int per_phase = swizzle_->get_per_phase(out_layout);
int max_phase = swizzle_->get_max_phase(out_layout);
//
int in_ld = in_layout->get_shape()[in_order[0]] / in_layout->mts(in_order[0]);
int n_shared_1 = std::max<int>(per_phase*max_phase / in_layout->mts(in_order[1]), 1);
int n_shared_0 = std::max<int>(in_vec / out_vec, 1);
auto shapes = x->get_type()->get_tile_shapes();
//
int per_thread_ld = in_layout->get_shape()[in_order[0]] / in_layout->mts(in_order[0]);
int n_shared = std::max<int>(8 / in_layout->mts(in_order[1]), 1);
std::vector<Value*> shared;
for(size_t i = 0; i < n_shared; i++){
indices_t idx = idxs_.at(ptrs).at(i*per_thread_ld);
// phase
Value* phase = udiv(idx[in_order[1]], i32(num_per_phase));
phase = urem(phase, max_phase);
// off
Value* off_0 = idx[in_order[0]];
off_0 = udiv(off_0, i32(vector));
off_0 = xor_(off_0, phase);
off_0 = mul(off_0 , i32(vector));
Value* off_1 = mul(idx[in_order[1]], i32(shapes[in_order[0]]));
Value* off = add(off_0, off_1);
//
shared.push_back(gep(shmems_[x], {off}));
}
//
for(size_t i = 0; i < idxs_.at(ptrs).size(); i += vector){
auto idx = idxs_[ptrs][i];
BasicBlock* CurrBB = builder_->GetInsertBlock();
BasicBlock* FirstBB = &CurrBB->getParent()->getEntryBlock();
std::map<std::pair<int, int>, Value*> tmp;
std::vector<std::pair<Value*, int>> shared;
for(int i = 0; i < idxs_.at(arg).size(); i++){
unsigned id = i / min_vec;
// input ptr info
GetElementPtrInst *in_gep = dyn_cast<GetElementPtrInst>(vals_[ptrs][idx]);
Value *in_base = in_gep->getPointerOperand();
size_t in_off = dyn_cast<ConstantInt>(in_gep->idx_begin())->getValue().getSExtValue()*2*vector;
Value* out_base = shared[(i / per_thread_ld) % n_shared];
int out_off_0 = (i / per_thread_ld) / n_shared * n_shared * in_layout->mts(in_order[1]);
int out_off_1 = i % per_thread_ld;
int out_off = (out_off_0*shapes[in_order[0]] + out_off_1)*2;
// asm
FunctionType *ty = FunctionType::get(void_ty, {out_base->getType(), in_base->getType()}, false);
std::string mod = (vector*2 == 16) ? ".cg" : ".ca";
std::string asm_str = "@$0 cp.async" + mod + ".shared.global [$1 + " + std::to_string(out_off) + "], [$2 + " + std::to_string(in_off) + "], " + std::to_string(vector*2) + ";";
InlineAsm *iasm = InlineAsm::get(ty, asm_str, "b,r,l", true);
call(iasm, {vals_[msks][idx], out_base, in_base});
int id_0 = id % (in_ld/min_vec);
int id_1 = id / (in_ld/min_vec);
int off_0 = id_0 / n_shared_0 * n_shared_0 * in_layout->mts(in_order[0]);
int off_1 = id_1 / n_shared_1 * n_shared_1 * in_layout->mts(in_order[1]);
int off = (off_1*shapes[in_order[0]] + off_0);
std::pair<int, int> key = {id_1 % n_shared_1, id_0 % n_shared_0};
if(tmp.find(key) == tmp.end()){
if(CurrBB != FirstBB)
builder_->SetInsertPoint(FirstBB->getTerminator());
indices_t idx = idxs_.at(arg).at(key.first*in_ld);
Value* phase = udiv(idx[in_order[1]], i32(per_phase));
phase = urem(phase, i32(max_phase));
Value* off_1 = mul(idx[in_order[1]], i32(shapes[in_order[0]]));
Value* off_0 = add(idx[in_order[0]], i32(key.second*out_vec));
off_0 = udiv(off_0, i32(min_vec));
off_0 = add(mul(xor_(udiv(off_0, i32(s)), phase),i32(s)), urem(off_0, i32(s)));
off_0 = mul(off_0 , i32(min_vec));
Value* off = add(off_0, off_1);
if(CurrBB != FirstBB)
builder_->SetInsertPoint(CurrBB);
tmp[key] = gep(shmems_[x], {off});
}
shared.push_back({tmp[key], off});
}
for(size_t i = 0; i < idxs_.at(arg).size(); i += in_vec){
auto idx = idxs_[arg][i];
// input ptr info
GetElementPtrInst *in_gep = dyn_cast<GetElementPtrInst>(vals_[arg][idx]);
Value *in_base = in_gep->getPointerOperand();
ConstantInt* cst = dyn_cast<ConstantInt>(in_gep->idx_begin());
size_t in_off = cst ? cst->getValue().getSExtValue()*2*in_vec : 0;
in_base = cst ? in_base : in_gep;
// output ptr info
Value* out_base = shared[i].first;
int out_off = shared[i].second*2;
// asm
FunctionType *ty = FunctionType::get(void_ty, {builder_->getInt1Ty(), out_base->getType(), in_base->getType()}, false);
std::string mod = (in_vec*2 == 16) ? ".cg" : ".ca";
std::string asm_str = "@$0 cp.async" + mod + ".shared.global [$1 + " + std::to_string(out_off) + "], [$2 + " + std::to_string(in_off) + "], " + std::to_string(in_vec*2) + ";";
InlineAsm *iasm = InlineAsm::get(ty, asm_str, "b,r,l", true);
call(iasm, {vals_[x->get_mask_operand()][idx], out_base, in_base});
}
std::string asm_str = "cp.async.commit_group;";
InlineAsm *iasm = InlineAsm::get(FunctionType::get(void_ty, {}), asm_str, "", true);
call(iasm);
}
void generator::visit_copy_to_shared_inst(ir::copy_to_shared_inst* cts) {
@@ -1496,7 +1517,7 @@ void generator::visit_copy_to_shared_inst(ir::copy_to_shared_inst* cts) {
BasicBlock* FirstBB = &CurrBB->getParent()->getEntryBlock();
auto shapes = cts->get_type()->get_tile_shapes();
// default implementation
// store to shared
Value *current = nullptr;
std::map<std::pair<int, int>, Value*> ptrs;
for(int i = 0; i < idxs_.at(arg).size(); i++){
@@ -1549,11 +1570,10 @@ void generator::visit_barrier_inst(ir::barrier_inst*) {
add_barrier();
}
void generator::visit_async_wait_inst(ir::async_wait_inst*) {
std::string asm_str = "cp.async.wait_all;";
void generator::visit_async_wait_inst(ir::async_wait_inst* i) {
std::string asm_str = "cp.async.wait_group " + std::to_string(i->get_N()) + ";";
InlineAsm *iasm = InlineAsm::get(FunctionType::get(void_ty, {}), asm_str, "", true);
call(iasm);
add_barrier();
}
void generator::visit_make_range_dyn(ir::make_range_dyn* x) {
@@ -1993,10 +2013,10 @@ void generator::visit(ir::module &src, llvm::Module &dst) {
if(unsigned alloc_size = alloc_->allocated_size()){
Type *int_8_ty = Type::getInt8Ty(*ctx_);
Type *int_32_ty = Type::getInt32Ty(*ctx_);
ArrayType *array_ty = ArrayType::get(int_32_ty, alloc_size/4);
ArrayType *array_ty = ArrayType::get(int_32_ty, 0);
Type *ptr_ty = ptr_ty(int_8_ty, 3);
GlobalVariable *sh_mem_array =
new GlobalVariable(*mod_, array_ty, false, GlobalVariable::ExternalWeakLinkage,
new GlobalVariable(*mod_, array_ty, false, GlobalVariable::ExternalLinkage,
nullptr, "__shared_ptr", nullptr, GlobalVariable::NotThreadLocal, 3);
shmem_ = bit_cast(sh_mem_array, ptr_ty);
}

View File

@@ -15,114 +15,105 @@ namespace triton {
namespace codegen{
namespace transform{
bool membar::intersect(const interval_vec_t &X, interval_t x) {
return std::any_of(X.begin(), X.end(), [&](const interval_t &y){
bool left_intersect = y.first <= x.first && x.first < y.second;
bool right_intersect = y.first <= x.second && x.second < y.second;
return left_intersect || right_intersect;
});
}
bool membar::intersect(const interval_vec_t &X, const interval_vec_t &Y) {
return std::any_of(Y.begin(), Y.end(), [&](const interval_t &y){
return intersect(X, y);
});
}
void membar::add_reference(ir::value *v, interval_vec_t &res){
auto *i = dynamic_cast<ir::instruction*>(v);
if(!i)
return;
if(!i->get_type()->is_tile_ty())
return;
analysis::shared_layout* layout = layouts_->get(v)->to_shared();
if(!layout)
return;
if(alloc_->has_offset(layout)){
unsigned offset = alloc_->offset(layout);
res.push_back(interval_t(offset, offset + layout->get_size()));
int membar::group_of(ir::value* v, std::vector<ir::value*> &async_write) {
if(ir::phi_node* phi = dynamic_cast<ir::phi_node*>(v)){
analysis::shared_layout* layout = layouts_->get(v)->to_shared();
analysis::double_buffer_info_t* info = layout->get_double_buffer();
if(info)
return group_of(info->first, async_write);
std::vector<int> groups(phi->get_num_operands());
std::transform(phi->op_begin(), phi->op_end(), groups.begin(), [&](ir::value* v){ return group_of(v, async_write);});
return *std::max_element(groups.begin(), groups.end());
}
else{
auto it = std::find(async_write.begin(), async_write.end(), v);
return std::distance(async_write.begin(), it);
}
}
void membar::get_read_intervals(ir::instruction *i, interval_vec_t &res){
for(ir::value *op: i->ops())
add_reference(op, res);
membar::val_set_t membar::intersect_with(const val_set_t& as, const val_set_t& bs) {
val_set_t ret;
for(ir::value* a: as){
if(!a->get_type()->is_tile_ty())
continue;
analysis::shared_layout* a_layout = layouts_->get(a)->to_shared();
if(!a_layout)
continue;
int a_start = alloc_->offset(a_layout);
int a_end = a_start + a_layout->get_size();
for(ir::value* b: bs){
if(!b->get_type()->is_tile_ty())
continue;
analysis::shared_layout* b_layout = layouts_->get(b)->to_shared();
if(!b_layout)
continue;
int b_start = alloc_->offset(b_layout);
int b_end = b_start + b_layout->get_size();
if(a_start < b_end || b_start < a_end)
ret.insert(b);
}
}
return ret;
}
void membar::get_written_intervals(ir::instruction *i, interval_vec_t &res){
if(!dynamic_cast<ir::phi_node*>(i) && !dynamic_cast<ir::trans_inst*>(i))
add_reference(i, res);
}
void membar::insert_barrier(ir::instruction *instr, std::pair<bool, bool> type, ir::builder &builder) {
if(auto *phi = dynamic_cast<ir::phi_node*>(instr)) {
std::set<ir::value*> incoming;
for(unsigned n = 0; n < phi->get_num_incoming(); n++){
ir::instruction *inc_val = dynamic_cast<ir::instruction*>(phi->get_incoming_value(n));
assert(inc_val);
if(incoming.insert(inc_val).second){
ir::basic_block *block = inc_val->get_parent();
builder.set_insert_point(block->get_inst_list().back());
if(type.first)
builder.create_async_wait();
if(type.second)
builder.create_barrier();
void membar::transfer(ir::basic_block *block,
val_vec_t& async_write,
val_set_t& sync_write,
val_set_t& sync_read,
std::set<ir::value*>& safe_war,
bool& inserted, ir::builder& builder) {
ir::basic_block::inst_list_t instructions = block->get_inst_list();
for(ir::instruction *i: instructions){
if(dynamic_cast<ir::phi_node*>(i))
continue;
if(std::find(async_write.begin(), async_write.end(), i) == async_write.end() &&
dynamic_cast<ir::masked_load_async_inst*>(i)){
async_write.push_back(i);
}
if(dynamic_cast<ir::copy_to_shared_inst*>(i))
sync_write.insert(i);
ir::barrier_inst* barrier = dynamic_cast<ir::barrier_inst*>(i);
ir::async_wait_inst* async_wait = dynamic_cast<ir::async_wait_inst*>(i);
// Get shared memory reads
std::set<ir::value*> read;
std::copy_if(i->op_begin(), i->op_end(), std::inserter(read, read.begin()),
[&](ir::value* i){ return i->get_type()->is_tile_ty() && layouts_->get(i)->to_shared();});
// RAW (async)
val_set_t tmp;
std::copy(async_write.begin(), async_write.end(), std::inserter(tmp, tmp.begin()));
if(intersect_with(read, tmp).size()){
std::vector<int> groups(read.size());
std::transform(read.begin(), read.end(), groups.begin(), [&](ir::value* v){ return group_of(v, async_write);});
int N = *std::max_element(groups.begin(), groups.end());
if(N < async_write.size()){
builder.set_insert_point(i);
async_wait = (ir::async_wait_inst*)builder.create_async_wait(async_write.size() - 1 - N);
barrier = (ir::barrier_inst*)builder.create_barrier();
inserted = true;
}
}
}
else {
builder.set_insert_point(instr);
builder.create_barrier();
}
}
membar::interval_vec_t membar::join(const std::vector<interval_vec_t>& intervals) {
membar::interval_vec_t result;
for(auto x: intervals)
for(interval_t i: x)
result.push_back(i);
return result;
}
std::pair<membar::interval_vec_t,
membar::interval_vec_t> membar::transfer(ir::basic_block *block,
const interval_vec_t &written_to,
const interval_vec_t &read_from,
std::map<ir::instruction*, std::pair<bool,bool>>& insert_loc,
std::set<ir::value*>& safe_war,
std::vector<ir::instruction*>& to_sync) {
ir::basic_block::inst_list_t instructions = block->get_inst_list();
interval_vec_t new_written_to = written_to;
interval_vec_t new_read_from = read_from;
for(ir::instruction *i: instructions){
interval_vec_t read, written;
get_read_intervals(i, read);
get_written_intervals(i, written);
if(written.size())
to_sync.push_back(i);
bool read_after_write = intersect(new_written_to, read);
bool write_after_read = intersect(new_read_from, written);
// double buffering
if(safe_war.find(i) != safe_war.end()){
write_after_read = false;
read_after_write = false;
// RAW, WAR
if(intersect_with(read, sync_write).size() || intersect_with({i}, sync_read).size()){
builder.set_insert_point(i);
barrier = (ir::barrier_inst*)builder.create_barrier();
inserted = true;
}
// record hazards
if(read_after_write || write_after_read) {
auto is_load_async = [&](ir::instruction *i){ return dynamic_cast<ir::masked_load_async_inst*>(i);};
auto is_copy_to_shared = [&](ir::instruction *i){ return dynamic_cast<ir::copy_to_shared_inst*>(i);};
bool copy_async_wait = std::any_of(to_sync.begin(), to_sync.end(), is_load_async);
bool barrier = std::any_of(to_sync.begin(), to_sync.end(), is_copy_to_shared);
insert_loc.insert({i, {copy_async_wait, barrier}});
new_written_to.clear();
new_read_from.clear();
to_sync.clear();
// update state of asynchronous copies
if(async_wait){
int N = async_write.size() - async_wait->get_N();
async_write.erase(async_write.begin(), async_write.begin() + N);
}
std::copy(written.begin(), written.end(), std::back_inserter(new_written_to));
std::copy(read.begin(), read.end(), std::back_inserter(new_read_from));
// all the copy_to_shared and read from shared are synchronized after barrier
if(barrier){
sync_write.clear();
sync_read.clear();
}
sync_read.insert(read.begin(), read.end());
}
return std::make_pair(new_written_to, new_read_from);
}
void membar::run(ir::module &mod) {
@@ -143,35 +134,33 @@ void membar::run(ir::module &mod) {
for(ir::function *fn: mod.get_function_list()){
std::vector<ir::basic_block*> rpo = ir::cfg::reverse_post_order(fn);
std::map<ir::basic_block*, interval_vec_t> written_to;
std::map<ir::basic_block*, interval_vec_t> read_from;
std::vector<ir::instruction*> to_sync;
std::map<ir::instruction*, std::pair<bool,bool>> insert_locs;
size_t n_inserted_im1 = 0;
bool done = false;
std::map<ir::basic_block*, val_vec_t> async_writes;
std::map<ir::basic_block*, val_set_t> sync_writes;
std::map<ir::basic_block*, val_set_t> sync_reads;
std::list<ir::value *> pipelined;
bool inserted;
do{
inserted = false;
// find barrier location
for(ir::basic_block *block: rpo){
// written to
std::vector<interval_vec_t> pred_written_to;
for(ir::basic_block* pred: block->get_predecessors())
pred_written_to.push_back(written_to[pred]);
// read from
std::vector<interval_vec_t> pred_read_from;
for(ir::basic_block* pred: block->get_predecessors())
pred_read_from.push_back(read_from[pred]);
// apply transfer function
auto result = transfer(block, join(pred_written_to), join(pred_read_from), insert_locs, safe_war, to_sync);
written_to[block] = result.first;
read_from[block] = result.second;
// join inputs
val_vec_t async_write;
val_set_t sync_write;
val_set_t sync_read;
val_set_t tmp;
for(ir::basic_block* pred: block->get_predecessors()){
for(ir::value* v: async_writes[pred])
if(tmp.insert(v).second)
async_write.push_back(v);
sync_write.insert(sync_writes[pred].begin(), sync_writes[pred].end());
sync_read.insert(sync_reads[pred].begin(), sync_reads[pred].end());
}
transfer(block, async_write, sync_write, sync_read, safe_war, inserted, builder);
async_writes[block] = async_write;
sync_writes[block] = sync_write;
sync_reads[block] = sync_read;
}
size_t n_inserted_i = insert_locs.size();
done = (n_inserted_im1 == n_inserted_i);
n_inserted_im1 = n_inserted_i;
}while(!done);
for(auto x: insert_locs){
insert_barrier(x.first, x.second, builder);
}
}while(inserted);
}
}

View File

@@ -1,7 +1,9 @@
#include <algorithm>
#include <iostream>
#include "triton/ir/module.h"
#include "triton/ir/function.h"
#include "triton/codegen/transform/peephole.h"
#include "triton/codegen/analysis/layout.h"
namespace triton {
namespace codegen{
@@ -109,9 +111,18 @@ bool peephole::rewrite_load_to_shared(ir::instruction *value, ir::builder& build
ir::value *ptr = ld->get_pointer_operand();
ir::value *msk = ld->get_mask_operand();
ir::value *val = ld->get_false_value_operand();
ir::value* new_load = builder.create_masked_load_async(ptr, msk, val);
copy_to_shared->replace_all_uses_with(new_load);
return true;
analysis::scanline_layout* layout = layouts_->get(ptr)->to_scanline();
int nts = layout->nts(layout->get_order()[0]);
int dtsize = value->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
if(nts*dtsize >= 4){
ir::value* new_load = builder.create_masked_load_async(ptr, msk, val);
copy_to_shared->replace_all_uses_with(new_load);
return true;
}
return false;
// analysis::scanline_layout* layout = layouts_->get(ptr)->to_scanline();
// std::cout << layout->nts(layout->get_order(0)) << std::endl;
// return true;
}
@@ -216,11 +227,11 @@ void peephole::run(ir::module &mod) {
bool was_modified = false;
was_modified = was_modified || rewrite_mult(i, builder);
// was_modified = was_modified || rewrite_cts_cfs(i, builder);
was_modified = was_modified || rewrite_trans_phi(i, builder);
// was_modified = was_modified || rewrite_trans_phi(i, builder);
was_modified = was_modified || rewrite_unit_red(i, builder);
was_modified = was_modified || rewrite_gep_ptr_min_off_plus_off(i, builder);
// if(tgt_->as_nvidia()->sm() >= 80)
// was_modified = was_modified || rewrite_load_to_shared(i, builder);
if(tgt_->as_nvidia()->sm() >= 80)
was_modified = was_modified || rewrite_load_to_shared(i, builder);
if(was_modified)
seen.insert(i);
}

View File

@@ -0,0 +1,116 @@
#include <iostream>
#include <algorithm>
#include "triton/codegen/transform/pipeline.h"
#include "triton/ir/module.h"
#include "triton/ir/function.h"
#include "triton/ir/basic_block.h"
#include "triton/ir/instructions.h"
#include "triton/ir/utils.h"
namespace triton {
namespace codegen{
namespace transform{
void recursive_deps(ir::value* v, ir::basic_block* block, std::vector<ir::instruction*>& ret){
ir::instruction* i = dynamic_cast<ir::instruction*>(v);
if(!i || i->get_parent() != block)
return;
if(i->get_id()==ir::INST_PHI)
return;
ret.push_back(i);
for(ir::user* u: i->get_users())
recursive_deps(u, block, ret);
}
void pipeline::run(ir::module &mod) {
// *Very* conservative heuristics for pre-fetching.
// A load instruction can be pipelined if:
// - the pointer is a phi node that references a value
// in its basic block (i.e., pointer induction variable)
// - the load has only a single use in a dot instruction
// As more use cases become apparent, this pass will be improved
std::vector<std::pair<ir::load_inst*, ir::phi_node*>> to_pipeline;
ir::for_each_instruction(mod, [&](ir::instruction *i){
if(auto* load = dynamic_cast<ir::load_inst*>(i)){
ir::phi_node* ptr = dynamic_cast<ir::phi_node*>(load->get_pointer_operand());
auto users = load->get_users();
if(ptr && ptr->get_incoming_block(1) == ptr->get_parent()
&& users.size() == 1 && dynamic_cast<ir::dot_inst*>(*users.begin()))
to_pipeline.push_back({load, ptr});
}});
// do the pipelining
std::vector<ir::phi_node*> new_loads;
ir::builder &builder = mod.get_builder();
for(auto info: to_pipeline){
ir::load_inst* load = info.first;
ir::phi_node* ptr = info.second;
ir::basic_block* block = load->get_parent();
ir::basic_block* header = block->get_predecessors()[0];
auto* block_br = dynamic_cast<ir::cond_branch_inst*>(block->get_inst_list().back());
auto* header_br = dynamic_cast<ir::cond_branch_inst*>(header->get_inst_list().back());
assert(block_br);
assert(header_br);
ir::type* ty = load->get_type();
// pre-fetch first iteration
builder.set_insert_point(header->get_inst_list().back());
ir::value* first_ptr = ptr->get_value_for_block(header);
ir::value* first_mask = builder.create_splat(header_br->get_cond(), ty->get_tile_shapes());
ir::value* false_value;
if(auto* masked_load = dynamic_cast<ir::masked_load_inst*>(load)){
first_mask = builder.create_and(first_mask, masked_load->get_mask_operand());
false_value = masked_load->get_false_value_operand();
}
else
false_value = builder.create_splat(ir::undef_value::get(ty->get_scalar_ty()), ty->get_tile_shapes());
ir::value* first_load = builder.create_masked_load(first_ptr, first_mask, false_value);
// pre-fetch next iteration
builder.set_insert_point(block->get_inst_list().back());
ir::value* next_ptr = ptr->get_value_for_block(block);
ir::value* next_mask = builder.create_splat(block_br->get_cond(), ty->get_tile_shapes());
if(auto* masked_load = dynamic_cast<ir::masked_load_inst*>(load))
next_mask = builder.create_and(next_mask, masked_load->get_mask_operand());
ir::value* next_load = builder.create_masked_load(next_ptr, next_mask, false_value);
// phi node
builder.set_insert_point(block->get_first_non_phi());
ir::phi_node* new_load = builder.create_phi(ty, 2);
new_load->add_incoming(first_load, header);
new_load->add_incoming(next_load, block);
load->replace_all_uses_with(new_load);
new_loads.push_back(new_load);
}
// try to move dot_inst after loads
// for better overlap of io and compute
struct move_config_t{
std::vector<ir::instruction*> insts;
ir::load_inst* dst;
};
std::map<ir::basic_block*, move_config_t> to_move;
if(has_copy_async_){
for(ir::function* fn: mod.get_function_list())
for(ir::basic_block* bb: fn->blocks())
for(ir::instruction* inst: bb->get_inst_list()){
if(auto* i = dynamic_cast<ir::dot_inst*>(inst))
recursive_deps(i, bb, to_move[bb].insts);
if(auto* i = dynamic_cast<ir::load_inst*>(inst))
to_move[bb].dst = i;
}
for(auto& x: to_move){
builder.set_insert_point_after(x.second.dst);
for(ir::instruction* i: x.second.insts){
x.first->erase(i);
builder.insert(i);
}
}
}
}
}
}
}

View File

@@ -22,6 +22,8 @@ inline ir::instruction* reassociate::is_bin_add(ir::value *x) {
inline bool is_cst(ir::value *x) {
if(dynamic_cast<ir::constant*>(x))
return true;
if(dynamic_cast<ir::make_range*>(x))
return true;
if(auto *v = dynamic_cast<ir::retile_inst*>(x))
return is_cst(v->get_operand(0));
return false;