[CODEGEN] Add a pass to prefetch operands of dot if applicable. (#105)
* update membar pass when data is double buffered * Add instruction prefetch_s * prefetch tests pass (except the 1 warp case) * Fix the 1-warp bug * Add back prefetch files * Disable prefetch on a100 * Always add war barrier on sm>=80
This commit is contained in:
committed by
Philippe Tillet
parent
147675923e
commit
967e629c0c
@@ -1048,65 +1048,208 @@ void generator::visit_mma884(ir::dot_inst* C, ir::value *A, ir::value *B, ir::va
|
||||
for(indices_t idx: idxs_.at(C))
|
||||
acc.push_back(vals_[D][idx]);
|
||||
|
||||
// update accumulators
|
||||
unsigned num_m = layout_c->rep(0) * shape_c[0] / layout_c->spt(0);
|
||||
unsigned num_n = layout_c->rep(1) * shape_c[1] / layout_c->spt(1);
|
||||
for(unsigned K = 0; K < NK; K += 4)
|
||||
for(unsigned m = 0; m < num_m/2; m++)
|
||||
for(unsigned n = 0; n < num_n/2; n++) {
|
||||
if(has.find({m, K}) == has.end()){
|
||||
Value* ptra = ptr_a[(is_a_row ? K/4 : m) % num_ptr_a];
|
||||
|
||||
// update accumulators
|
||||
if (C->is_prefetched()) {
|
||||
ir::phi_node* phiA = dynamic_cast<ir::phi_node*>(A);
|
||||
ir::phi_node* phiB = dynamic_cast<ir::phi_node*>(B);
|
||||
|
||||
/// Cache lds value. If values are prefetched, create phi node
|
||||
auto register_lds =
|
||||
[&](decltype(has)& vals, int m, int K, int inc, Value* val0, Value *val1, ir::value *v) {
|
||||
if (K == 0) {
|
||||
ir::basic_block* inc_block = phiA->get_incoming_block(inc);
|
||||
lazy_phi_incs_.push_back(std::make_tuple((PHINode*)vals[{m, K}].first, val0, inc_block));
|
||||
lazy_phi_incs_.push_back(std::make_tuple((PHINode*)vals[{m, K}].second, val1, inc_block));
|
||||
} else
|
||||
vals[{m, K}] = {val0, val1};
|
||||
};
|
||||
|
||||
auto load_a = [&](int m, int K, int inc) ->void {
|
||||
int offidx = (is_a_row ? K/4 : m) % num_ptr_a;
|
||||
Value* ptra;
|
||||
if(K==0){
|
||||
if(inc == 0) {
|
||||
ptra = gep(shared_pre_ptr_[layout_a], off_a[offidx]);
|
||||
}
|
||||
else {
|
||||
ptra = gep(shared_next_ptr_[layout_a], off_a[offidx]);
|
||||
}
|
||||
}
|
||||
else
|
||||
ptra = ptr_a[offidx];
|
||||
int step_am = is_a_row ? m : m / (num_ptr_a)*(num_ptr_a);
|
||||
int step_ak = is_a_row ? K / (num_ptr_a*vec_a)*(num_ptr_a*vec_a) : K;
|
||||
Value* pa = gep(ptra, i32(step_am*stride_rep_m*stride_am + step_ak*stride_ak));
|
||||
Value* ha = load(bit_cast(pa, ptr_ty(vec_ty(i32_ty, vec_a/2), 3)));
|
||||
// record lds that needs to be moved
|
||||
if (K == 0 && inc == 1)
|
||||
prefetch_latch_to_bb_[phiA->get_incoming_value(1)].push_back(ha);
|
||||
Value *ha00 = bit_cast(extract_elt(ha, i32(0)), f16x2_ty);
|
||||
Value *ha01 = bit_cast(extract_elt(ha, i32(1)), f16x2_ty);
|
||||
has[{m, K}] = {ha00, ha01};
|
||||
register_lds(has, m, K, inc, ha00, ha01, A);
|
||||
if(vec_a > 4){
|
||||
Value *ha10 = bit_cast(extract_elt(ha, i32(2)), f16x2_ty);
|
||||
Value *ha11 = bit_cast(extract_elt(ha, i32(3)), f16x2_ty);
|
||||
if(is_a_row)
|
||||
has[{m, K+4}] = {ha10, ha11};
|
||||
register_lds(has, m, K+4, inc, ha10, ha11, A);
|
||||
else
|
||||
has[{m+1, K}] = {ha10, ha11};
|
||||
register_lds(has, m+1, K, inc, ha10, ha11, A);
|
||||
}
|
||||
}
|
||||
if(hbs.find({n, K}) == hbs.end()){
|
||||
Value* ptrb = ptr_b[(is_b_row? n : K/4) % num_ptr_b];
|
||||
};
|
||||
|
||||
auto load_b = [&](int n, int K, int inc){
|
||||
int offidx = (is_b_row? n : K/4) % num_ptr_b;
|
||||
Value* ptrb;
|
||||
if(K==0){
|
||||
if(inc == 0) {
|
||||
ptrb = gep(shared_pre_ptr_[layout_b], off_b[offidx]);
|
||||
}
|
||||
else {
|
||||
ptrb = gep(shared_next_ptr_[layout_b], off_b[offidx]);
|
||||
}
|
||||
}
|
||||
else
|
||||
ptrb = ptr_b[offidx];
|
||||
|
||||
int stepbn = is_b_row ? n / (num_ptr_b)*(num_ptr_b) : n;
|
||||
int stepbk = is_b_row ? K : K / (num_ptr_b*vec_b)*(num_ptr_b*vec_b);
|
||||
Value* pb = gep(ptrb, i32(stepbn*stride_rep_n*stride_bn + stepbk*stride_bk));
|
||||
Value* hb = load(bit_cast(pb, ptr_ty(vec_ty(i32_ty, vec_b/2), 3)));
|
||||
// record lds that needs to be moved
|
||||
if (K == 0 && inc == 1)
|
||||
prefetch_latch_to_bb_[phiB->get_incoming_value(1)].push_back(hb);
|
||||
Value *hb00 = bit_cast(extract_elt(hb, i32(0)), f16x2_ty);
|
||||
Value *hb01 = bit_cast(extract_elt(hb, i32(1)), f16x2_ty);
|
||||
hbs[{n, K}] = {hb00, hb01};
|
||||
register_lds(hbs, n, K, inc, hb00, hb01, B);
|
||||
if(vec_b > 4){
|
||||
Value *hb10 = bit_cast(extract_elt(hb, i32(2)), f16x2_ty);
|
||||
Value *hb11 = bit_cast(extract_elt(hb, i32(3)), f16x2_ty);
|
||||
if(is_b_row)
|
||||
hbs[{n+1, K}] = {hb10, hb11};
|
||||
register_lds(hbs, n+1, K, inc, hb10, hb11, B);
|
||||
else
|
||||
hbs[{n, K+4}] = {hb10, hb11};
|
||||
register_lds(hbs, n, K+4, inc, hb10, hb11, B);
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
// create phis
|
||||
builder_->SetInsertPoint(curr_bb->getFirstNonPHI());
|
||||
for (unsigned m = 0; m < num_m/2; m += is_a_row?1:2) {
|
||||
has[{m, 0}].first = phi(f16x2_ty, 2);
|
||||
has[{m, 0}].second = phi(f16x2_ty, 2);
|
||||
if (!is_a_row && vec_a>4) {
|
||||
has[{m+1, 0}].first = phi(f16x2_ty, 2);
|
||||
has[{m+1, 0}].second = phi(f16x2_ty, 2);
|
||||
}
|
||||
}
|
||||
auto ha = has[{m, K}];
|
||||
auto hb = hbs[{n, K}];
|
||||
// arguments
|
||||
std::vector<size_t> idx = {
|
||||
(m*2 + 0) + (n*4 + 0)*num_m, (m*2 + 0) + (n*4 + 1)*num_m,
|
||||
(m*2 + 1) + (n*4 + 0)*num_m, (m*2 + 1) + (n*4 + 1)*num_m,
|
||||
(m*2 + 0) + (n*4 + 2)*num_m, (m*2 + 0) + (n*4 + 3)*num_m,
|
||||
(m*2 + 1) + (n*4 + 2)*num_m, (m*2 + 1) + (n*4 + 3)*num_m
|
||||
};
|
||||
std::vector<Value*> args = {ha.first, ha.second, hb.first, hb.second};
|
||||
for(unsigned i = 0; i < 8; i++)
|
||||
args.push_back(acc[idx[i]]);
|
||||
// execute mma
|
||||
Value *nc = call(mma, args);
|
||||
// unpack
|
||||
for(unsigned i = 0; i < 8; i++)
|
||||
acc[idx[i]] = extract_val(nc, {i});
|
||||
for (unsigned n = 0; n < num_n/2; n += is_b_row?2:1) {
|
||||
hbs[{n, 0}].first = phi(f16x2_ty, 2);
|
||||
hbs[{n, 0}].second = phi(f16x2_ty, 2);
|
||||
if (is_b_row && vec_b>4) {
|
||||
hbs[{n+1, 0}].first = phi(f16x2_ty, 2);
|
||||
hbs[{n+1, 0}].second = phi(f16x2_ty, 2);
|
||||
}
|
||||
}
|
||||
|
||||
builder_->SetInsertPoint(bbs_[phiA->get_incoming_block(0)]->getTerminator());
|
||||
for (unsigned m = 0; m < num_m/2; m += is_a_row?1:2)
|
||||
load_a(m, 0, 0);
|
||||
for (unsigned n = 0; n < num_n/2; n += is_b_row?2:1)
|
||||
load_b(n, 0, 0);
|
||||
|
||||
// update accumulators
|
||||
builder_->SetInsertPoint(curr_bb);
|
||||
for (unsigned K = 0; K < NK; K += 4) {
|
||||
int NEXTK = (K + 4) % NK;
|
||||
// prefetch A
|
||||
for (unsigned m = 0; m < num_m/2; m+=is_a_row?1:2)
|
||||
load_a(m, NEXTK, 1);
|
||||
// prefetch B
|
||||
for (unsigned n = 0; n < num_n/2; n+=is_b_row?2:1)
|
||||
load_b(n, NEXTK, 1);
|
||||
// tensor core ops
|
||||
for(unsigned m = 0; m < num_m/2; m++)
|
||||
for(unsigned n = 0; n < num_n/2; n++){
|
||||
auto ha = has[{m, K}];
|
||||
auto hb = hbs[{n, K}];
|
||||
// arguments
|
||||
std::vector<size_t> idx = {
|
||||
(m*2 + 0) + (n*4 + 0)*num_m, (m*2 + 0) + (n*4 + 1)*num_m,
|
||||
(m*2 + 1) + (n*4 + 0)*num_m, (m*2 + 1) + (n*4 + 1)*num_m,
|
||||
(m*2 + 0) + (n*4 + 2)*num_m, (m*2 + 0) + (n*4 + 3)*num_m,
|
||||
(m*2 + 1) + (n*4 + 2)*num_m, (m*2 + 1) + (n*4 + 3)*num_m
|
||||
};
|
||||
std::vector<Value*> args = {ha.first, ha.second, hb.first, hb.second};
|
||||
for(unsigned i = 0; i < 8; i++)
|
||||
args.push_back(acc[idx[i]]);
|
||||
// execute mma
|
||||
Value *nc = call(mma, args);
|
||||
// unpack
|
||||
for(unsigned i = 0; i < 8; i++)
|
||||
acc[idx[i]] = extract_val(nc, {i});
|
||||
}
|
||||
}
|
||||
} else { // not prefetched
|
||||
for(unsigned K = 0; K < NK; K += 4)
|
||||
for(unsigned m = 0; m < num_m/2; m++)
|
||||
for(unsigned n = 0; n < num_n/2; n++) {
|
||||
if(has.find({m, K}) == has.end()){
|
||||
Value* ptra = ptr_a[(is_a_row ? K/4 : m) % num_ptr_a];
|
||||
int step_am = is_a_row ? m : m / (num_ptr_a)*(num_ptr_a);
|
||||
int step_ak = is_a_row ? K / (num_ptr_a*vec_a)*(num_ptr_a*vec_a) : K;
|
||||
Value* pa = gep(ptra, i32(step_am*stride_rep_m*stride_am + step_ak*stride_ak));
|
||||
Value* ha = load(bit_cast(pa, ptr_ty(vec_ty(i32_ty, vec_a/2), 3)));
|
||||
Value *ha00 = bit_cast(extract_elt(ha, i32(0)), f16x2_ty);
|
||||
Value *ha01 = bit_cast(extract_elt(ha, i32(1)), f16x2_ty);
|
||||
has[{m, K}] = {ha00, ha01};
|
||||
if(vec_a > 4){
|
||||
Value *ha10 = bit_cast(extract_elt(ha, i32(2)), f16x2_ty);
|
||||
Value *ha11 = bit_cast(extract_elt(ha, i32(3)), f16x2_ty);
|
||||
if(is_a_row)
|
||||
has[{m, K+4}] = {ha10, ha11};
|
||||
else
|
||||
has[{m+1, K}] = {ha10, ha11};
|
||||
}
|
||||
}
|
||||
if(hbs.find({n, K}) == hbs.end()){
|
||||
Value* ptrb = ptr_b[(is_b_row? n : K/4) % num_ptr_b];
|
||||
int stepbn = is_b_row ? n / (num_ptr_b)*(num_ptr_b) : n;
|
||||
int stepbk = is_b_row ? K : K / (num_ptr_b*vec_b)*(num_ptr_b*vec_b);
|
||||
Value* pb = gep(ptrb, i32(stepbn*stride_rep_n*stride_bn + stepbk*stride_bk));
|
||||
Value* hb = load(bit_cast(pb, ptr_ty(vec_ty(i32_ty, vec_b/2), 3)));
|
||||
Value *hb00 = bit_cast(extract_elt(hb, i32(0)), f16x2_ty);
|
||||
Value *hb01 = bit_cast(extract_elt(hb, i32(1)), f16x2_ty);
|
||||
hbs[{n, K}] = {hb00, hb01};
|
||||
if(vec_b > 4){
|
||||
Value *hb10 = bit_cast(extract_elt(hb, i32(2)), f16x2_ty);
|
||||
Value *hb11 = bit_cast(extract_elt(hb, i32(3)), f16x2_ty);
|
||||
if(is_b_row)
|
||||
hbs[{n+1, K}] = {hb10, hb11};
|
||||
else
|
||||
hbs[{n, K+4}] = {hb10, hb11};
|
||||
}
|
||||
}
|
||||
auto ha = has[{m, K}];
|
||||
auto hb = hbs[{n, K}];
|
||||
// arguments
|
||||
std::vector<size_t> idx = {
|
||||
(m*2 + 0) + (n*4 + 0)*num_m, (m*2 + 0) + (n*4 + 1)*num_m,
|
||||
(m*2 + 1) + (n*4 + 0)*num_m, (m*2 + 1) + (n*4 + 1)*num_m,
|
||||
(m*2 + 0) + (n*4 + 2)*num_m, (m*2 + 0) + (n*4 + 3)*num_m,
|
||||
(m*2 + 1) + (n*4 + 2)*num_m, (m*2 + 1) + (n*4 + 3)*num_m
|
||||
};
|
||||
std::vector<Value*> args = {ha.first, ha.second, hb.first, hb.second};
|
||||
for(unsigned i = 0; i < 8; i++)
|
||||
args.push_back(acc[idx[i]]);
|
||||
// execute mma
|
||||
Value *nc = call(mma, args);
|
||||
// unpack
|
||||
for(unsigned i = 0; i < 8; i++)
|
||||
acc[idx[i]] = extract_val(nc, {i});
|
||||
}
|
||||
}
|
||||
|
||||
// write back accumulators
|
||||
@@ -1827,6 +1970,40 @@ void generator::visit_barrier_inst(ir::barrier_inst*) {
|
||||
add_barrier();
|
||||
}
|
||||
|
||||
void generator::visit_prefetch_s_inst(ir::prefetch_s_inst *i) {
|
||||
ir::value *v = i->get_operand(0);
|
||||
int inc = i->get_inc();
|
||||
if (inc == 0) {
|
||||
// If dot has not been visitied, do nothing.
|
||||
} else {
|
||||
// If dot has been visitied, insert prefetched lds
|
||||
assert(inc == 1);
|
||||
assert(prefetch_latch_to_bb_.find(v) != prefetch_latch_to_bb_.end() &&
|
||||
"dot hasn't be visited");
|
||||
// sink lds & extract element
|
||||
// move lds & all uses to current location
|
||||
std::stack<Value*> work_stack;
|
||||
for (Value *value : prefetch_latch_to_bb_[v])
|
||||
work_stack.push(value);
|
||||
std::vector<Instruction*> dead_instrs;
|
||||
while (!work_stack.empty()) {
|
||||
Value *m = work_stack.top();
|
||||
work_stack.pop();
|
||||
|
||||
for (auto u : m->users())
|
||||
work_stack.push(u);
|
||||
|
||||
assert(isa<Instruction>(m));
|
||||
auto m_instr = static_cast<Instruction*>(m);
|
||||
|
||||
m_instr->removeFromParent();
|
||||
m_instr->insertAfter(&*std::prev(builder_->GetInsertBlock()->end()));
|
||||
assert(m_instr->getParent() == &*builder_->GetInsertBlock());
|
||||
builder_->SetInsertPoint(m_instr->getParent());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void generator::visit_async_wait_inst(ir::async_wait_inst* i) {
|
||||
std::string asm_str = "cp.async.wait_group " + std::to_string(i->get_N()) + ";";
|
||||
InlineAsm *iasm = InlineAsm::get(FunctionType::get(void_ty, {}), asm_str, "", true);
|
||||
@@ -2144,6 +2321,7 @@ void generator::visit_basic_block(ir::basic_block * block) {
|
||||
for(ir::instruction *i: block->get_inst_list()){
|
||||
visit_value(i);
|
||||
}
|
||||
// Update ir bb -> llvm bb mapping
|
||||
bbs_[block] = builder_->GetInsertBlock();
|
||||
}
|
||||
|
||||
@@ -2247,6 +2425,8 @@ void generator::finalize_function(ir::function *fn) {
|
||||
for(ir::instruction *inst: block->get_inst_list())
|
||||
if(auto *phi = dynamic_cast<ir::phi_node*>(inst))
|
||||
finalize_phi_node(phi);
|
||||
for(auto& x: lazy_phi_incs_)
|
||||
std::get<0>(x)->addIncoming(std::get<1>(x), bbs_[std::get<2>(x)]);
|
||||
}
|
||||
|
||||
void generator::finalize_phi_node(ir::phi_node *x) {
|
||||
|
Reference in New Issue
Block a user