[CODEGEN] Add a pass to prefetch operands of dot if applicable. (#105)

* update membar pass when data is double buffered

* Add instruction prefetch_s

* prefetch tests pass (except the 1 warp case)

* Fix the 1-warp bug

* Add back prefetch files

* Disable prefetch on a100

* Always add war barrier on sm>=80
This commit is contained in:
daadaada
2021-05-13 10:42:18 +08:00
committed by Philippe Tillet
parent 147675923e
commit 967e629c0c
14 changed files with 408 additions and 42 deletions

View File

@@ -1048,65 +1048,208 @@ void generator::visit_mma884(ir::dot_inst* C, ir::value *A, ir::value *B, ir::va
for(indices_t idx: idxs_.at(C))
acc.push_back(vals_[D][idx]);
// update accumulators
unsigned num_m = layout_c->rep(0) * shape_c[0] / layout_c->spt(0);
unsigned num_n = layout_c->rep(1) * shape_c[1] / layout_c->spt(1);
for(unsigned K = 0; K < NK; K += 4)
for(unsigned m = 0; m < num_m/2; m++)
for(unsigned n = 0; n < num_n/2; n++) {
if(has.find({m, K}) == has.end()){
Value* ptra = ptr_a[(is_a_row ? K/4 : m) % num_ptr_a];
// update accumulators
if (C->is_prefetched()) {
ir::phi_node* phiA = dynamic_cast<ir::phi_node*>(A);
ir::phi_node* phiB = dynamic_cast<ir::phi_node*>(B);
/// Cache lds value. If values are prefetched, create phi node
auto register_lds =
[&](decltype(has)& vals, int m, int K, int inc, Value* val0, Value *val1, ir::value *v) {
if (K == 0) {
ir::basic_block* inc_block = phiA->get_incoming_block(inc);
lazy_phi_incs_.push_back(std::make_tuple((PHINode*)vals[{m, K}].first, val0, inc_block));
lazy_phi_incs_.push_back(std::make_tuple((PHINode*)vals[{m, K}].second, val1, inc_block));
} else
vals[{m, K}] = {val0, val1};
};
auto load_a = [&](int m, int K, int inc) ->void {
int offidx = (is_a_row ? K/4 : m) % num_ptr_a;
Value* ptra;
if(K==0){
if(inc == 0) {
ptra = gep(shared_pre_ptr_[layout_a], off_a[offidx]);
}
else {
ptra = gep(shared_next_ptr_[layout_a], off_a[offidx]);
}
}
else
ptra = ptr_a[offidx];
int step_am = is_a_row ? m : m / (num_ptr_a)*(num_ptr_a);
int step_ak = is_a_row ? K / (num_ptr_a*vec_a)*(num_ptr_a*vec_a) : K;
Value* pa = gep(ptra, i32(step_am*stride_rep_m*stride_am + step_ak*stride_ak));
Value* ha = load(bit_cast(pa, ptr_ty(vec_ty(i32_ty, vec_a/2), 3)));
// record lds that needs to be moved
if (K == 0 && inc == 1)
prefetch_latch_to_bb_[phiA->get_incoming_value(1)].push_back(ha);
Value *ha00 = bit_cast(extract_elt(ha, i32(0)), f16x2_ty);
Value *ha01 = bit_cast(extract_elt(ha, i32(1)), f16x2_ty);
has[{m, K}] = {ha00, ha01};
register_lds(has, m, K, inc, ha00, ha01, A);
if(vec_a > 4){
Value *ha10 = bit_cast(extract_elt(ha, i32(2)), f16x2_ty);
Value *ha11 = bit_cast(extract_elt(ha, i32(3)), f16x2_ty);
if(is_a_row)
has[{m, K+4}] = {ha10, ha11};
register_lds(has, m, K+4, inc, ha10, ha11, A);
else
has[{m+1, K}] = {ha10, ha11};
register_lds(has, m+1, K, inc, ha10, ha11, A);
}
}
if(hbs.find({n, K}) == hbs.end()){
Value* ptrb = ptr_b[(is_b_row? n : K/4) % num_ptr_b];
};
auto load_b = [&](int n, int K, int inc){
int offidx = (is_b_row? n : K/4) % num_ptr_b;
Value* ptrb;
if(K==0){
if(inc == 0) {
ptrb = gep(shared_pre_ptr_[layout_b], off_b[offidx]);
}
else {
ptrb = gep(shared_next_ptr_[layout_b], off_b[offidx]);
}
}
else
ptrb = ptr_b[offidx];
int stepbn = is_b_row ? n / (num_ptr_b)*(num_ptr_b) : n;
int stepbk = is_b_row ? K : K / (num_ptr_b*vec_b)*(num_ptr_b*vec_b);
Value* pb = gep(ptrb, i32(stepbn*stride_rep_n*stride_bn + stepbk*stride_bk));
Value* hb = load(bit_cast(pb, ptr_ty(vec_ty(i32_ty, vec_b/2), 3)));
// record lds that needs to be moved
if (K == 0 && inc == 1)
prefetch_latch_to_bb_[phiB->get_incoming_value(1)].push_back(hb);
Value *hb00 = bit_cast(extract_elt(hb, i32(0)), f16x2_ty);
Value *hb01 = bit_cast(extract_elt(hb, i32(1)), f16x2_ty);
hbs[{n, K}] = {hb00, hb01};
register_lds(hbs, n, K, inc, hb00, hb01, B);
if(vec_b > 4){
Value *hb10 = bit_cast(extract_elt(hb, i32(2)), f16x2_ty);
Value *hb11 = bit_cast(extract_elt(hb, i32(3)), f16x2_ty);
if(is_b_row)
hbs[{n+1, K}] = {hb10, hb11};
register_lds(hbs, n+1, K, inc, hb10, hb11, B);
else
hbs[{n, K+4}] = {hb10, hb11};
register_lds(hbs, n, K+4, inc, hb10, hb11, B);
}
};
// create phis
builder_->SetInsertPoint(curr_bb->getFirstNonPHI());
for (unsigned m = 0; m < num_m/2; m += is_a_row?1:2) {
has[{m, 0}].first = phi(f16x2_ty, 2);
has[{m, 0}].second = phi(f16x2_ty, 2);
if (!is_a_row && vec_a>4) {
has[{m+1, 0}].first = phi(f16x2_ty, 2);
has[{m+1, 0}].second = phi(f16x2_ty, 2);
}
}
auto ha = has[{m, K}];
auto hb = hbs[{n, K}];
// arguments
std::vector<size_t> idx = {
(m*2 + 0) + (n*4 + 0)*num_m, (m*2 + 0) + (n*4 + 1)*num_m,
(m*2 + 1) + (n*4 + 0)*num_m, (m*2 + 1) + (n*4 + 1)*num_m,
(m*2 + 0) + (n*4 + 2)*num_m, (m*2 + 0) + (n*4 + 3)*num_m,
(m*2 + 1) + (n*4 + 2)*num_m, (m*2 + 1) + (n*4 + 3)*num_m
};
std::vector<Value*> args = {ha.first, ha.second, hb.first, hb.second};
for(unsigned i = 0; i < 8; i++)
args.push_back(acc[idx[i]]);
// execute mma
Value *nc = call(mma, args);
// unpack
for(unsigned i = 0; i < 8; i++)
acc[idx[i]] = extract_val(nc, {i});
for (unsigned n = 0; n < num_n/2; n += is_b_row?2:1) {
hbs[{n, 0}].first = phi(f16x2_ty, 2);
hbs[{n, 0}].second = phi(f16x2_ty, 2);
if (is_b_row && vec_b>4) {
hbs[{n+1, 0}].first = phi(f16x2_ty, 2);
hbs[{n+1, 0}].second = phi(f16x2_ty, 2);
}
}
builder_->SetInsertPoint(bbs_[phiA->get_incoming_block(0)]->getTerminator());
for (unsigned m = 0; m < num_m/2; m += is_a_row?1:2)
load_a(m, 0, 0);
for (unsigned n = 0; n < num_n/2; n += is_b_row?2:1)
load_b(n, 0, 0);
// update accumulators
builder_->SetInsertPoint(curr_bb);
for (unsigned K = 0; K < NK; K += 4) {
int NEXTK = (K + 4) % NK;
// prefetch A
for (unsigned m = 0; m < num_m/2; m+=is_a_row?1:2)
load_a(m, NEXTK, 1);
// prefetch B
for (unsigned n = 0; n < num_n/2; n+=is_b_row?2:1)
load_b(n, NEXTK, 1);
// tensor core ops
for(unsigned m = 0; m < num_m/2; m++)
for(unsigned n = 0; n < num_n/2; n++){
auto ha = has[{m, K}];
auto hb = hbs[{n, K}];
// arguments
std::vector<size_t> idx = {
(m*2 + 0) + (n*4 + 0)*num_m, (m*2 + 0) + (n*4 + 1)*num_m,
(m*2 + 1) + (n*4 + 0)*num_m, (m*2 + 1) + (n*4 + 1)*num_m,
(m*2 + 0) + (n*4 + 2)*num_m, (m*2 + 0) + (n*4 + 3)*num_m,
(m*2 + 1) + (n*4 + 2)*num_m, (m*2 + 1) + (n*4 + 3)*num_m
};
std::vector<Value*> args = {ha.first, ha.second, hb.first, hb.second};
for(unsigned i = 0; i < 8; i++)
args.push_back(acc[idx[i]]);
// execute mma
Value *nc = call(mma, args);
// unpack
for(unsigned i = 0; i < 8; i++)
acc[idx[i]] = extract_val(nc, {i});
}
}
} else { // not prefetched
for(unsigned K = 0; K < NK; K += 4)
for(unsigned m = 0; m < num_m/2; m++)
for(unsigned n = 0; n < num_n/2; n++) {
if(has.find({m, K}) == has.end()){
Value* ptra = ptr_a[(is_a_row ? K/4 : m) % num_ptr_a];
int step_am = is_a_row ? m : m / (num_ptr_a)*(num_ptr_a);
int step_ak = is_a_row ? K / (num_ptr_a*vec_a)*(num_ptr_a*vec_a) : K;
Value* pa = gep(ptra, i32(step_am*stride_rep_m*stride_am + step_ak*stride_ak));
Value* ha = load(bit_cast(pa, ptr_ty(vec_ty(i32_ty, vec_a/2), 3)));
Value *ha00 = bit_cast(extract_elt(ha, i32(0)), f16x2_ty);
Value *ha01 = bit_cast(extract_elt(ha, i32(1)), f16x2_ty);
has[{m, K}] = {ha00, ha01};
if(vec_a > 4){
Value *ha10 = bit_cast(extract_elt(ha, i32(2)), f16x2_ty);
Value *ha11 = bit_cast(extract_elt(ha, i32(3)), f16x2_ty);
if(is_a_row)
has[{m, K+4}] = {ha10, ha11};
else
has[{m+1, K}] = {ha10, ha11};
}
}
if(hbs.find({n, K}) == hbs.end()){
Value* ptrb = ptr_b[(is_b_row? n : K/4) % num_ptr_b];
int stepbn = is_b_row ? n / (num_ptr_b)*(num_ptr_b) : n;
int stepbk = is_b_row ? K : K / (num_ptr_b*vec_b)*(num_ptr_b*vec_b);
Value* pb = gep(ptrb, i32(stepbn*stride_rep_n*stride_bn + stepbk*stride_bk));
Value* hb = load(bit_cast(pb, ptr_ty(vec_ty(i32_ty, vec_b/2), 3)));
Value *hb00 = bit_cast(extract_elt(hb, i32(0)), f16x2_ty);
Value *hb01 = bit_cast(extract_elt(hb, i32(1)), f16x2_ty);
hbs[{n, K}] = {hb00, hb01};
if(vec_b > 4){
Value *hb10 = bit_cast(extract_elt(hb, i32(2)), f16x2_ty);
Value *hb11 = bit_cast(extract_elt(hb, i32(3)), f16x2_ty);
if(is_b_row)
hbs[{n+1, K}] = {hb10, hb11};
else
hbs[{n, K+4}] = {hb10, hb11};
}
}
auto ha = has[{m, K}];
auto hb = hbs[{n, K}];
// arguments
std::vector<size_t> idx = {
(m*2 + 0) + (n*4 + 0)*num_m, (m*2 + 0) + (n*4 + 1)*num_m,
(m*2 + 1) + (n*4 + 0)*num_m, (m*2 + 1) + (n*4 + 1)*num_m,
(m*2 + 0) + (n*4 + 2)*num_m, (m*2 + 0) + (n*4 + 3)*num_m,
(m*2 + 1) + (n*4 + 2)*num_m, (m*2 + 1) + (n*4 + 3)*num_m
};
std::vector<Value*> args = {ha.first, ha.second, hb.first, hb.second};
for(unsigned i = 0; i < 8; i++)
args.push_back(acc[idx[i]]);
// execute mma
Value *nc = call(mma, args);
// unpack
for(unsigned i = 0; i < 8; i++)
acc[idx[i]] = extract_val(nc, {i});
}
}
// write back accumulators
@@ -1827,6 +1970,40 @@ void generator::visit_barrier_inst(ir::barrier_inst*) {
add_barrier();
}
void generator::visit_prefetch_s_inst(ir::prefetch_s_inst *i) {
ir::value *v = i->get_operand(0);
int inc = i->get_inc();
if (inc == 0) {
// If dot has not been visitied, do nothing.
} else {
// If dot has been visitied, insert prefetched lds
assert(inc == 1);
assert(prefetch_latch_to_bb_.find(v) != prefetch_latch_to_bb_.end() &&
"dot hasn't be visited");
// sink lds & extract element
// move lds & all uses to current location
std::stack<Value*> work_stack;
for (Value *value : prefetch_latch_to_bb_[v])
work_stack.push(value);
std::vector<Instruction*> dead_instrs;
while (!work_stack.empty()) {
Value *m = work_stack.top();
work_stack.pop();
for (auto u : m->users())
work_stack.push(u);
assert(isa<Instruction>(m));
auto m_instr = static_cast<Instruction*>(m);
m_instr->removeFromParent();
m_instr->insertAfter(&*std::prev(builder_->GetInsertBlock()->end()));
assert(m_instr->getParent() == &*builder_->GetInsertBlock());
builder_->SetInsertPoint(m_instr->getParent());
}
}
}
void generator::visit_async_wait_inst(ir::async_wait_inst* i) {
std::string asm_str = "cp.async.wait_group " + std::to_string(i->get_N()) + ";";
InlineAsm *iasm = InlineAsm::get(FunctionType::get(void_ty, {}), asm_str, "", true);
@@ -2144,6 +2321,7 @@ void generator::visit_basic_block(ir::basic_block * block) {
for(ir::instruction *i: block->get_inst_list()){
visit_value(i);
}
// Update ir bb -> llvm bb mapping
bbs_[block] = builder_->GetInsertBlock();
}
@@ -2247,6 +2425,8 @@ void generator::finalize_function(ir::function *fn) {
for(ir::instruction *inst: block->get_inst_list())
if(auto *phi = dynamic_cast<ir::phi_node*>(inst))
finalize_phi_node(phi);
for(auto& x: lazy_phi_incs_)
std::get<0>(x)->addIncoming(std::get<1>(x), bbs_[std::get<2>(x)]);
}
void generator::finalize_phi_node(ir::phi_node *x) {