[CODEGEN] Add a pass to prefetch operands of dot if applicable. (#105)

* update membar pass when data is double buffered

* Add instruction prefetch_s

* prefetch tests pass (except the 1 warp case)

* Fix the 1-warp bug

* Add back prefetch files

* Disable prefetch on a100

* Always add war barrier on sm>=80
This commit is contained in:
daadaada
2021-05-13 10:42:18 +08:00
committed by Philippe Tillet
parent 147675923e
commit 967e629c0c
14 changed files with 408 additions and 42 deletions

View File

@@ -11,6 +11,7 @@
namespace llvm{
class Type;
class Value;
class PHINode;
class BasicBlock;
class Attribute;
class Instruction;
@@ -169,6 +170,7 @@ public:
void visit_copy_to_shared_inst(ir::copy_to_shared_inst*);
void visit_copy_from_shared_inst(ir::copy_from_shared_inst*);
void visit_barrier_inst(ir::barrier_inst*);
void visit_prefetch_s_inst(ir::prefetch_s_inst*);
void visit_async_wait_inst(ir::async_wait_inst*);
// void visit_make_range_dyn(ir::make_range_dyn*);
void visit_make_range(ir::make_range*);
@@ -209,16 +211,19 @@ private:
std::map<analysis::data_layout*, Value*> offset_b_k_;
std::map<analysis::data_layout*, Value*> offset_b_n_;
/// layout -> base ptr
std::map<analysis::data_layout*, Value*> shared_ptr_;
std::map<analysis::data_layout*, Value*> shared_pre_ptr_;
std::map<analysis::data_layout*, Value*> shared_next_ptr_;
/// offset for double-buffered layout
std::map<analysis::data_layout*, Value*> shared_off_;
/// Base shmem pointer of ir value
std::map<ir::value*, Value*> shmems_;
std::map<ir::value*, Value*> shoffs_;
std::map<ir::value*, std::vector<indices_t>> idxs_;
std::map<ir::value*, std::map<indices_t, Value*>> vals_;
/// triton bb -> llvm bb
std::map<ir::value*, BasicBlock *> bbs_;
std::map<ir::value*, std::vector<int>> ords_;
@@ -227,6 +232,11 @@ private:
multiplier mul;
geper gep;
/// PHI nodes
std::vector<std::tuple<llvm::PHINode*, Value*, ir::basic_block*>> lazy_phi_incs_;
/// Record prefetch instrs that needs to be moved
std::map<ir::value*, std::vector<Value*>> prefetch_latch_to_bb_;
};
}

View File

@@ -5,6 +5,7 @@
#include <map>
#include <list>
#include <set>
#include "triton/codegen/target.h"
namespace triton {
@@ -44,14 +45,16 @@ private:
std::set<triton::ir::value *> &safe_war, bool &inserted, ir::builder &builder);
public:
membar(analysis::liveness *liveness, analysis::layouts *layouts, analysis::allocation *alloc):
liveness_(liveness), layouts_(layouts), alloc_(alloc) {}
membar(analysis::liveness *liveness, analysis::layouts *layouts, analysis::allocation *alloc, target* tgt):
liveness_(liveness), layouts_(layouts), alloc_(alloc), tgt_(tgt) {}
void run(ir::module &mod);
private:
analysis::liveness *liveness_;
analysis::layouts *layouts_;
analysis::allocation *alloc_;
target* tgt_;
};

View File

@@ -0,0 +1,22 @@
#ifndef TRITON_INCLUDE_TRITON_CODEGEN_TRANSFORM_PREFETCH_H
#define TRITON_INCLUDE_TRITON_CODEGEN_TRANSFORM_PREFETCH_H
// forward dclaration
namespace triton::ir{
class module;
}
namespace triton::codegen {
class target;
}
namespace triton::codegen::transform {
class prefetch {
target* tgt_;
public:
prefetch(target *tgt) : tgt_(tgt) {}
void run(ir::module &module);
};
}
#endif

View File

@@ -151,6 +151,7 @@ public:
value *create_copy_from_shared(value *arg);
value *create_barrier(const std::string &name = "");
value *create_async_wait(int N);
value *create_prefetch_s(value *arg, int inc);
private:
context &ctx_;

View File

@@ -143,7 +143,8 @@ enum value_id_t: unsigned {
INST_ASYNC_WAIT,
INST_MAKE_RANGE_DYN,
INST_MAKE_RANGE_STA,
INST_MAKE_RANGE
INST_MAKE_RANGE,
INST_PREFETCH_S,
};

View File

@@ -666,6 +666,11 @@ private:
dot_inst(value *A, value *B, value *C, TransT AT, TransT BT, const std::string &name, instruction *next);
std::string repr_impl() const { return "dot"; }
bool is_prefetched_ = false;
public:
bool is_prefetched() const { return is_prefetched_; }
void set_prefetched(bool is_prefetched) { is_prefetched_ = is_prefetched; }
public:
static instruction *create(value *A, value *B, value *C, bool AT, bool BT, const std::string &name = "", instruction *next = nullptr);
static instruction* create_nn(value *A, value *B, value *C, const std::string &name = "", instruction *next = nullptr);
@@ -821,6 +826,23 @@ private:
int N_;
};
class prefetch_s_inst : public instruction {
std::string repr_impl() const { return "prefetch_s"; }
_TRITON_DEFINE_CLONE(prefetch_s_inst)
_TRITON_DEFINE_ACCEPT(prefetch_s_inst)
/// inc_: 0->first, 1->latch
int inc_ = 0;
public:
prefetch_s_inst(context &ctx, value *arg, int inc, const std::string &name, instruction *next)
: instruction(type::get_void_ty(ctx), INST_PREFETCH_S, 1, name, next), inc_(inc) {
set_operand(0, arg);
}
int get_inc() const { return inc_; }
static prefetch_s_inst *create(context &ctx, value *arg, int inc, const std::string &name = "",
instruction *next=nullptr);
};
//// On NVIDIA, implementation is such that
//// constant_range = nv_dynamic_program_idx + nv_static_program_idx
//// so as to enable re-association on nv_static_program_idx which is constant

View File

@@ -70,6 +70,7 @@ class barrier_inst;
class async_wait_inst;
class make_range_dyn;
class make_range;
class prefetch_s_inst;
class make_range_sta;
class undef_value;
@@ -146,6 +147,7 @@ public:
virtual void visit_async_wait_inst(async_wait_inst*) = 0;
// virtual void visit_make_range_dyn(make_range_dyn*) = 0;
virtual void visit_make_range(make_range*) = 0;
virtual void visit_prefetch_s_inst(prefetch_s_inst*) = 0;
virtual void visit_function(function*) = 0;

View File

@@ -12,6 +12,7 @@
#include "triton/codegen/transform/membar.h"
#include "triton/codegen/transform/peephole.h"
#include "triton/codegen/transform/pipeline.h"
#include "triton/codegen/transform/prefetch.h"
#include "triton/driver/device.h"
#include "triton/driver/kernel.h"
#include "triton/driver/module.h"
@@ -44,11 +45,12 @@ void add_passes_to_emit_bin(ir::module &ir, driver::device *dev, int num_warps,
codegen::analysis::liveness liveness(&layouts);
codegen::analysis::swizzle swizzle(&layouts, target.get());
codegen::analysis::allocation allocation(&liveness);
codegen::transform::membar barriers(&liveness, &layouts, &allocation);
codegen::transform::membar barriers(&liveness, &layouts, &allocation, target.get());
codegen::transform::dce dce;
codegen::transform::peephole peephole(target.get(), &layouts);
// codegen::transform::reassociate reassociate;
codegen::transform::coalesce coalesce(&align, &layouts);
codegen::transform::prefetch prefetch_s(target.get());
codegen::generator isel(&axes, &layouts, &align, &allocation, &swizzle, target.get(), num_warps);
// run passes
dce.run(ir);
@@ -90,8 +92,9 @@ void add_passes_to_emit_bin(ir::module &ir, driver::device *dev, int num_warps,
swizzle.run(ir);
liveness.run(ir);
allocation.run(ir);
barriers.run(ir);
// ir::print(ir, std::cout);
prefetch_s.run(ir);
barriers.run(ir);
// ir::print(ir, std::cout);
isel.visit(ir, *llvm);
mod = driver::module::create(dev, std::move(llvm));
ker = driver::kernel::create(&*mod, name.c_str());

View File

@@ -1048,65 +1048,208 @@ void generator::visit_mma884(ir::dot_inst* C, ir::value *A, ir::value *B, ir::va
for(indices_t idx: idxs_.at(C))
acc.push_back(vals_[D][idx]);
// update accumulators
unsigned num_m = layout_c->rep(0) * shape_c[0] / layout_c->spt(0);
unsigned num_n = layout_c->rep(1) * shape_c[1] / layout_c->spt(1);
for(unsigned K = 0; K < NK; K += 4)
for(unsigned m = 0; m < num_m/2; m++)
for(unsigned n = 0; n < num_n/2; n++) {
if(has.find({m, K}) == has.end()){
Value* ptra = ptr_a[(is_a_row ? K/4 : m) % num_ptr_a];
// update accumulators
if (C->is_prefetched()) {
ir::phi_node* phiA = dynamic_cast<ir::phi_node*>(A);
ir::phi_node* phiB = dynamic_cast<ir::phi_node*>(B);
/// Cache lds value. If values are prefetched, create phi node
auto register_lds =
[&](decltype(has)& vals, int m, int K, int inc, Value* val0, Value *val1, ir::value *v) {
if (K == 0) {
ir::basic_block* inc_block = phiA->get_incoming_block(inc);
lazy_phi_incs_.push_back(std::make_tuple((PHINode*)vals[{m, K}].first, val0, inc_block));
lazy_phi_incs_.push_back(std::make_tuple((PHINode*)vals[{m, K}].second, val1, inc_block));
} else
vals[{m, K}] = {val0, val1};
};
auto load_a = [&](int m, int K, int inc) ->void {
int offidx = (is_a_row ? K/4 : m) % num_ptr_a;
Value* ptra;
if(K==0){
if(inc == 0) {
ptra = gep(shared_pre_ptr_[layout_a], off_a[offidx]);
}
else {
ptra = gep(shared_next_ptr_[layout_a], off_a[offidx]);
}
}
else
ptra = ptr_a[offidx];
int step_am = is_a_row ? m : m / (num_ptr_a)*(num_ptr_a);
int step_ak = is_a_row ? K / (num_ptr_a*vec_a)*(num_ptr_a*vec_a) : K;
Value* pa = gep(ptra, i32(step_am*stride_rep_m*stride_am + step_ak*stride_ak));
Value* ha = load(bit_cast(pa, ptr_ty(vec_ty(i32_ty, vec_a/2), 3)));
// record lds that needs to be moved
if (K == 0 && inc == 1)
prefetch_latch_to_bb_[phiA->get_incoming_value(1)].push_back(ha);
Value *ha00 = bit_cast(extract_elt(ha, i32(0)), f16x2_ty);
Value *ha01 = bit_cast(extract_elt(ha, i32(1)), f16x2_ty);
has[{m, K}] = {ha00, ha01};
register_lds(has, m, K, inc, ha00, ha01, A);
if(vec_a > 4){
Value *ha10 = bit_cast(extract_elt(ha, i32(2)), f16x2_ty);
Value *ha11 = bit_cast(extract_elt(ha, i32(3)), f16x2_ty);
if(is_a_row)
has[{m, K+4}] = {ha10, ha11};
register_lds(has, m, K+4, inc, ha10, ha11, A);
else
has[{m+1, K}] = {ha10, ha11};
register_lds(has, m+1, K, inc, ha10, ha11, A);
}
}
if(hbs.find({n, K}) == hbs.end()){
Value* ptrb = ptr_b[(is_b_row? n : K/4) % num_ptr_b];
};
auto load_b = [&](int n, int K, int inc){
int offidx = (is_b_row? n : K/4) % num_ptr_b;
Value* ptrb;
if(K==0){
if(inc == 0) {
ptrb = gep(shared_pre_ptr_[layout_b], off_b[offidx]);
}
else {
ptrb = gep(shared_next_ptr_[layout_b], off_b[offidx]);
}
}
else
ptrb = ptr_b[offidx];
int stepbn = is_b_row ? n / (num_ptr_b)*(num_ptr_b) : n;
int stepbk = is_b_row ? K : K / (num_ptr_b*vec_b)*(num_ptr_b*vec_b);
Value* pb = gep(ptrb, i32(stepbn*stride_rep_n*stride_bn + stepbk*stride_bk));
Value* hb = load(bit_cast(pb, ptr_ty(vec_ty(i32_ty, vec_b/2), 3)));
// record lds that needs to be moved
if (K == 0 && inc == 1)
prefetch_latch_to_bb_[phiB->get_incoming_value(1)].push_back(hb);
Value *hb00 = bit_cast(extract_elt(hb, i32(0)), f16x2_ty);
Value *hb01 = bit_cast(extract_elt(hb, i32(1)), f16x2_ty);
hbs[{n, K}] = {hb00, hb01};
register_lds(hbs, n, K, inc, hb00, hb01, B);
if(vec_b > 4){
Value *hb10 = bit_cast(extract_elt(hb, i32(2)), f16x2_ty);
Value *hb11 = bit_cast(extract_elt(hb, i32(3)), f16x2_ty);
if(is_b_row)
hbs[{n+1, K}] = {hb10, hb11};
register_lds(hbs, n+1, K, inc, hb10, hb11, B);
else
hbs[{n, K+4}] = {hb10, hb11};
register_lds(hbs, n, K+4, inc, hb10, hb11, B);
}
};
// create phis
builder_->SetInsertPoint(curr_bb->getFirstNonPHI());
for (unsigned m = 0; m < num_m/2; m += is_a_row?1:2) {
has[{m, 0}].first = phi(f16x2_ty, 2);
has[{m, 0}].second = phi(f16x2_ty, 2);
if (!is_a_row && vec_a>4) {
has[{m+1, 0}].first = phi(f16x2_ty, 2);
has[{m+1, 0}].second = phi(f16x2_ty, 2);
}
}
auto ha = has[{m, K}];
auto hb = hbs[{n, K}];
// arguments
std::vector<size_t> idx = {
(m*2 + 0) + (n*4 + 0)*num_m, (m*2 + 0) + (n*4 + 1)*num_m,
(m*2 + 1) + (n*4 + 0)*num_m, (m*2 + 1) + (n*4 + 1)*num_m,
(m*2 + 0) + (n*4 + 2)*num_m, (m*2 + 0) + (n*4 + 3)*num_m,
(m*2 + 1) + (n*4 + 2)*num_m, (m*2 + 1) + (n*4 + 3)*num_m
};
std::vector<Value*> args = {ha.first, ha.second, hb.first, hb.second};
for(unsigned i = 0; i < 8; i++)
args.push_back(acc[idx[i]]);
// execute mma
Value *nc = call(mma, args);
// unpack
for(unsigned i = 0; i < 8; i++)
acc[idx[i]] = extract_val(nc, {i});
for (unsigned n = 0; n < num_n/2; n += is_b_row?2:1) {
hbs[{n, 0}].first = phi(f16x2_ty, 2);
hbs[{n, 0}].second = phi(f16x2_ty, 2);
if (is_b_row && vec_b>4) {
hbs[{n+1, 0}].first = phi(f16x2_ty, 2);
hbs[{n+1, 0}].second = phi(f16x2_ty, 2);
}
}
builder_->SetInsertPoint(bbs_[phiA->get_incoming_block(0)]->getTerminator());
for (unsigned m = 0; m < num_m/2; m += is_a_row?1:2)
load_a(m, 0, 0);
for (unsigned n = 0; n < num_n/2; n += is_b_row?2:1)
load_b(n, 0, 0);
// update accumulators
builder_->SetInsertPoint(curr_bb);
for (unsigned K = 0; K < NK; K += 4) {
int NEXTK = (K + 4) % NK;
// prefetch A
for (unsigned m = 0; m < num_m/2; m+=is_a_row?1:2)
load_a(m, NEXTK, 1);
// prefetch B
for (unsigned n = 0; n < num_n/2; n+=is_b_row?2:1)
load_b(n, NEXTK, 1);
// tensor core ops
for(unsigned m = 0; m < num_m/2; m++)
for(unsigned n = 0; n < num_n/2; n++){
auto ha = has[{m, K}];
auto hb = hbs[{n, K}];
// arguments
std::vector<size_t> idx = {
(m*2 + 0) + (n*4 + 0)*num_m, (m*2 + 0) + (n*4 + 1)*num_m,
(m*2 + 1) + (n*4 + 0)*num_m, (m*2 + 1) + (n*4 + 1)*num_m,
(m*2 + 0) + (n*4 + 2)*num_m, (m*2 + 0) + (n*4 + 3)*num_m,
(m*2 + 1) + (n*4 + 2)*num_m, (m*2 + 1) + (n*4 + 3)*num_m
};
std::vector<Value*> args = {ha.first, ha.second, hb.first, hb.second};
for(unsigned i = 0; i < 8; i++)
args.push_back(acc[idx[i]]);
// execute mma
Value *nc = call(mma, args);
// unpack
for(unsigned i = 0; i < 8; i++)
acc[idx[i]] = extract_val(nc, {i});
}
}
} else { // not prefetched
for(unsigned K = 0; K < NK; K += 4)
for(unsigned m = 0; m < num_m/2; m++)
for(unsigned n = 0; n < num_n/2; n++) {
if(has.find({m, K}) == has.end()){
Value* ptra = ptr_a[(is_a_row ? K/4 : m) % num_ptr_a];
int step_am = is_a_row ? m : m / (num_ptr_a)*(num_ptr_a);
int step_ak = is_a_row ? K / (num_ptr_a*vec_a)*(num_ptr_a*vec_a) : K;
Value* pa = gep(ptra, i32(step_am*stride_rep_m*stride_am + step_ak*stride_ak));
Value* ha = load(bit_cast(pa, ptr_ty(vec_ty(i32_ty, vec_a/2), 3)));
Value *ha00 = bit_cast(extract_elt(ha, i32(0)), f16x2_ty);
Value *ha01 = bit_cast(extract_elt(ha, i32(1)), f16x2_ty);
has[{m, K}] = {ha00, ha01};
if(vec_a > 4){
Value *ha10 = bit_cast(extract_elt(ha, i32(2)), f16x2_ty);
Value *ha11 = bit_cast(extract_elt(ha, i32(3)), f16x2_ty);
if(is_a_row)
has[{m, K+4}] = {ha10, ha11};
else
has[{m+1, K}] = {ha10, ha11};
}
}
if(hbs.find({n, K}) == hbs.end()){
Value* ptrb = ptr_b[(is_b_row? n : K/4) % num_ptr_b];
int stepbn = is_b_row ? n / (num_ptr_b)*(num_ptr_b) : n;
int stepbk = is_b_row ? K : K / (num_ptr_b*vec_b)*(num_ptr_b*vec_b);
Value* pb = gep(ptrb, i32(stepbn*stride_rep_n*stride_bn + stepbk*stride_bk));
Value* hb = load(bit_cast(pb, ptr_ty(vec_ty(i32_ty, vec_b/2), 3)));
Value *hb00 = bit_cast(extract_elt(hb, i32(0)), f16x2_ty);
Value *hb01 = bit_cast(extract_elt(hb, i32(1)), f16x2_ty);
hbs[{n, K}] = {hb00, hb01};
if(vec_b > 4){
Value *hb10 = bit_cast(extract_elt(hb, i32(2)), f16x2_ty);
Value *hb11 = bit_cast(extract_elt(hb, i32(3)), f16x2_ty);
if(is_b_row)
hbs[{n+1, K}] = {hb10, hb11};
else
hbs[{n, K+4}] = {hb10, hb11};
}
}
auto ha = has[{m, K}];
auto hb = hbs[{n, K}];
// arguments
std::vector<size_t> idx = {
(m*2 + 0) + (n*4 + 0)*num_m, (m*2 + 0) + (n*4 + 1)*num_m,
(m*2 + 1) + (n*4 + 0)*num_m, (m*2 + 1) + (n*4 + 1)*num_m,
(m*2 + 0) + (n*4 + 2)*num_m, (m*2 + 0) + (n*4 + 3)*num_m,
(m*2 + 1) + (n*4 + 2)*num_m, (m*2 + 1) + (n*4 + 3)*num_m
};
std::vector<Value*> args = {ha.first, ha.second, hb.first, hb.second};
for(unsigned i = 0; i < 8; i++)
args.push_back(acc[idx[i]]);
// execute mma
Value *nc = call(mma, args);
// unpack
for(unsigned i = 0; i < 8; i++)
acc[idx[i]] = extract_val(nc, {i});
}
}
// write back accumulators
@@ -1827,6 +1970,40 @@ void generator::visit_barrier_inst(ir::barrier_inst*) {
add_barrier();
}
void generator::visit_prefetch_s_inst(ir::prefetch_s_inst *i) {
ir::value *v = i->get_operand(0);
int inc = i->get_inc();
if (inc == 0) {
// If dot has not been visitied, do nothing.
} else {
// If dot has been visitied, insert prefetched lds
assert(inc == 1);
assert(prefetch_latch_to_bb_.find(v) != prefetch_latch_to_bb_.end() &&
"dot hasn't be visited");
// sink lds & extract element
// move lds & all uses to current location
std::stack<Value*> work_stack;
for (Value *value : prefetch_latch_to_bb_[v])
work_stack.push(value);
std::vector<Instruction*> dead_instrs;
while (!work_stack.empty()) {
Value *m = work_stack.top();
work_stack.pop();
for (auto u : m->users())
work_stack.push(u);
assert(isa<Instruction>(m));
auto m_instr = static_cast<Instruction*>(m);
m_instr->removeFromParent();
m_instr->insertAfter(&*std::prev(builder_->GetInsertBlock()->end()));
assert(m_instr->getParent() == &*builder_->GetInsertBlock());
builder_->SetInsertPoint(m_instr->getParent());
}
}
}
void generator::visit_async_wait_inst(ir::async_wait_inst* i) {
std::string asm_str = "cp.async.wait_group " + std::to_string(i->get_N()) + ";";
InlineAsm *iasm = InlineAsm::get(FunctionType::get(void_ty, {}), asm_str, "", true);
@@ -2144,6 +2321,7 @@ void generator::visit_basic_block(ir::basic_block * block) {
for(ir::instruction *i: block->get_inst_list()){
visit_value(i);
}
// Update ir bb -> llvm bb mapping
bbs_[block] = builder_->GetInsertBlock();
}
@@ -2247,6 +2425,8 @@ void generator::finalize_function(ir::function *fn) {
for(ir::instruction *inst: block->get_inst_list())
if(auto *phi = dynamic_cast<ir::phi_node*>(inst))
finalize_phi_node(phi);
for(auto& x: lazy_phi_incs_)
std::get<0>(x)->addIncoming(std::get<1>(x), bbs_[std::get<2>(x)]);
}
void generator::finalize_phi_node(ir::phi_node *x) {

View File

@@ -96,7 +96,15 @@ void membar::transfer(ir::basic_block *block,
}
}
// RAW, WAR
if(intersect_with(read, sync_write).size() || intersect_with({i}, sync_read).size()){
bool is_i_double_buffered = i->get_type()->is_block_ty() &&
layouts_->get(i)->to_shared() &&
layouts_->get(i)->to_shared()->get_double_buffer();
// WAR barrier is not required when data is double-buffered
// TODO: how about other patterns, like WWAR?
if(!intersect_with(read, sync_write).empty() ||
(!intersect_with({i}, sync_read).empty() && !is_i_double_buffered) ||
// force WAR barrier on A100
(!intersect_with({i}, sync_read).empty() && tgt_->as_nvidia()->sm() >= 80)){
builder.set_insert_point(i);
barrier = (ir::barrier_inst*)builder.create_barrier();
inserted = true;
@@ -133,6 +141,7 @@ void membar::run(ir::module &mod) {
}
for(ir::function *fn: mod.get_function_list()){
// TODO: (dyan) we need DominatorTree here.
std::vector<ir::basic_block*> rpo = ir::cfg::reverse_post_order(fn);
std::map<ir::basic_block*, val_vec_t> async_writes;
std::map<ir::basic_block*, val_set_t> sync_writes;
@@ -166,4 +175,4 @@ void membar::run(ir::module &mod) {
}
}
}
}

View File

@@ -0,0 +1,105 @@
#include "triton/codegen/transform/prefetch.h"
#include "triton/codegen/target.h"
#include "triton/ir/module.h"
#include "triton/ir/function.h"
#include "triton/ir/basic_block.h"
#include "triton/ir/instructions.h"
#include "triton/ir/utils.h"
#include "triton/ir/print.h"
#include <iostream>
#include <vector>
#include <algorithm>
namespace triton::codegen::transform {
/// find defs till phis
static void recursive_defs(ir::value *v, ir::basic_block *bb, std::vector<ir::instruction*> &ret) {
ir::instruction *i = dynamic_cast<ir::instruction*>(v);
if (!i || i->get_parent() != bb)
return;
if (i->get_id() == ir::INST_PHI)
return;
ret.push_back(i);
for (ir::value *op : i->ops())
recursive_defs(op, bb, ret);
}
void prefetch::run(ir::module &mod) {
// 1. collect dot that can be prefethced
std::vector<ir::dot_inst*> to_prefetch;
ir::for_each_instruction(mod, [&](ir::instruction *i) {
if (auto *dot = dynamic_cast<ir::dot_inst*>(i)) {
// Now only do prefetching when dot is fp16 & volta/turing
if (dot->get_operand(0)->get_type()->get_scalar_ty()->get_type_id() != ir::type::HalfTyID ||
tgt_->as_nvidia()->sm() >= 80)
return;
auto *a = dynamic_cast<ir::phi_node*>(dot->get_operand(0));
auto *b = dynamic_cast<ir::phi_node*>(dot->get_operand(1));
if (a && a->get_incoming_block(1) == a->get_parent() &&
b && b->get_incoming_block(1) == b->get_parent())
to_prefetch.push_back(dot);
}
});
assert(to_prefetch.size() <=1 && "Don't know what to do with multiple dots");
ir::builder &builder = mod.get_builder();
// 2. do the prefetching
for (ir::dot_inst* dot : to_prefetch) {
auto *a = dynamic_cast<ir::phi_node*>(dot->get_operand(0));
auto *b = dynamic_cast<ir::phi_node*>(dot->get_operand(1));
assert(a->get_incoming_block(0) == b->get_incoming_block(0));
ir::basic_block *loop_header = a->get_incoming_block(0);
ir::basic_block *loop_body = a->get_parent();
// mark as prefetched
dot->set_prefetched(true);
// 1. in the loop header (first iteration)
builder.set_insert_point(loop_header->get_inst_list().back());
builder.create_barrier();
assert(a && b);
builder.create_prefetch_s(a->get_incoming_value(0), /*inc*/ 0);
builder.create_prefetch_s(b->get_incoming_value(0), /*inc*/ 0);
// 2. at the end of the loop body (next iteration)
builder.set_insert_point(loop_body->get_inst_list().back());
builder.create_barrier();
builder.create_prefetch_s(a->get_incoming_value(1), /*inc*/ 1);
builder.create_prefetch_s(b->get_incoming_value(1), /*inc*/ 1);
}
// move loads to the beginning of the loop
if (tgt_->as_nvidia()->sm() < 80) {
for (ir::function *fn : mod.get_function_list())
for (ir::basic_block *bb : fn->blocks()) {
// only apply to loop body
if (bb->get_predecessors().size() != 2 || bb->get_predecessors()[1] != bb)
continue;
// record loads (& dependency) to move
std::vector<ir::instruction*> loads;
// record original inst order
std::map<ir::instruction*, size_t> idx_map;
size_t idx = 0;
for (ir::instruction *inst : bb->get_inst_list()) {
if (auto *i = dynamic_cast<ir::masked_load_inst*>(inst))
recursive_defs(i, bb, loads);
idx_map[inst] = idx;
idx++;
}
// remove duplicates & keep the original input order
std::sort(loads.begin(), loads.end());
loads.erase(std::unique(loads.begin(), loads.end()), loads.end());
std::sort(loads.begin(), loads.end(), [&idx_map](ir::instruction *a, ir::instruction *b) {
return idx_map[a] < idx_map[b];
});
builder.set_insert_point(bb->get_first_non_phi());
for (ir::instruction *i : loads) {
bb->erase(i);
builder.insert(i);
}
}
}
}
} // namespace triton::codegen::transform

View File

@@ -209,7 +209,8 @@ static std::map<int, int> vptx = {
{10020, 65},
{11000, 70},
{11010, 71},
{11020, 72}
{11020, 72},
{11030, 73},
};
std::string cu_module::compile_llvm_module(llvm::Module* module, driver::device* device) {

View File

@@ -381,6 +381,9 @@ value *builder::create_async_wait(int N) {
return insert(async_wait_inst::create(ctx_, N));
}
value *builder::create_prefetch_s(value *arg, int inc) {
return insert(prefetch_s_inst::create(ctx_, arg, inc));
}
}

View File

@@ -832,6 +832,10 @@ async_wait_inst* async_wait_inst::create(context &ctx, int N, const std::string
return new async_wait_inst(ctx, N, name, next);
}
// prefetch_s
prefetch_s_inst *prefetch_s_inst::create(context &ctx, value *arg, int inc, const std::string &name, instruction *next) {
return new prefetch_s_inst(ctx, arg, inc, name, next);
}
//// nv_dynamic_program_idx
//make_range_dyn::make_range_dyn(type *ty, const std::string &name, instruction *next)