[CODEGEN] Clean up visit_mma884 (#107)

This commit is contained in:
daadaada
2021-05-14 02:13:26 +08:00
committed by Philippe Tillet
parent 967e629c0c
commit 840d65d8c6

View File

@@ -1051,90 +1051,107 @@ void generator::visit_mma884(ir::dot_inst* C, ir::value *A, ir::value *B, ir::va
unsigned num_m = layout_c->rep(0) * shape_c[0] / layout_c->spt(0);
unsigned num_n = layout_c->rep(1) * shape_c[1] / layout_c->spt(1);
// create mma & unpack result
auto call_mma = [&](unsigned m, unsigned n, unsigned K) {
auto ha = has[{m, K}];
auto hb = hbs[{n, K}];
// arguments
std::vector<size_t> idx = {
(m*2 + 0) + (n*4 + 0)*num_m, (m*2 + 0) + (n*4 + 1)*num_m,
(m*2 + 1) + (n*4 + 0)*num_m, (m*2 + 1) + (n*4 + 1)*num_m,
(m*2 + 0) + (n*4 + 2)*num_m, (m*2 + 0) + (n*4 + 3)*num_m,
(m*2 + 1) + (n*4 + 2)*num_m, (m*2 + 1) + (n*4 + 3)*num_m
};
std::vector<Value*> args = {ha.first, ha.second, hb.first, hb.second};
for(unsigned i = 0; i < 8; i++)
args.push_back(acc[idx[i]]);
// execute mma
Value *nc = call(mma, args);
// unpack
for(unsigned i = 0; i < 8; i++)
acc[idx[i]] = extract_val(nc, {i});
};
ir::phi_node* phiA = dynamic_cast<ir::phi_node*>(A);
ir::phi_node* phiB = dynamic_cast<ir::phi_node*>(B);
// Cache lds value. If values are prefetched, create phi node
// @param inc: incoming block (0 = header, 1 = loop)
auto register_lds =
[&](decltype(has)& vals, int m, int K, int inc, Value* val0, Value *val1, bool is_prefetch) {
if (K == 0 && is_prefetch) {
ir::basic_block* inc_block = phiA->get_incoming_block(inc);
lazy_phi_incs_.push_back(std::make_tuple((PHINode*)vals[{m, K}].first, val0, inc_block));
lazy_phi_incs_.push_back(std::make_tuple((PHINode*)vals[{m, K}].second, val1, inc_block));
} else
vals[{m, K}] = {val0, val1};
};
auto load_a = [&](int m, int K, int inc, bool is_prefetch) {
int offidx = (is_a_row ? K/4 : m) % num_ptr_a;
Value* ptra;
if(K==0 && is_prefetch){
if(inc == 0)
ptra = gep(shared_pre_ptr_[layout_a], off_a[offidx]);
else
ptra = gep(shared_next_ptr_[layout_a], off_a[offidx]);
}
else
ptra = ptr_a[offidx];
int step_am = is_a_row ? m : m / (num_ptr_a)*(num_ptr_a);
int step_ak = is_a_row ? K / (num_ptr_a*vec_a)*(num_ptr_a*vec_a) : K;
Value* pa = gep(ptra, i32(step_am*stride_rep_m*stride_am + step_ak*stride_ak));
Value* ha = load(bit_cast(pa, ptr_ty(vec_ty(i32_ty, vec_a/2), 3)));
// record lds that needs to be moved
if (K == 0 && inc == 1 && is_prefetch)
prefetch_latch_to_bb_[phiA->get_incoming_value(1)].push_back(ha);
Value *ha00 = bit_cast(extract_elt(ha, i32(0)), f16x2_ty);
Value *ha01 = bit_cast(extract_elt(ha, i32(1)), f16x2_ty);
register_lds(has, m, K, inc, ha00, ha01, is_prefetch);
if(vec_a > 4){
Value *ha10 = bit_cast(extract_elt(ha, i32(2)), f16x2_ty);
Value *ha11 = bit_cast(extract_elt(ha, i32(3)), f16x2_ty);
if(is_a_row)
register_lds(has, m, K+4, inc, ha10, ha11, is_prefetch);
else
register_lds(has, m+1, K, inc, ha10, ha11, is_prefetch);
}
};
auto load_b = [&](int n, int K, int inc, bool is_prefetch) {
int offidx = (is_b_row? n : K/4) % num_ptr_b;
Value* ptrb;
if(K==0 && is_prefetch){
if(inc == 0)
ptrb = gep(shared_pre_ptr_[layout_b], off_b[offidx]);
else
ptrb = gep(shared_next_ptr_[layout_b], off_b[offidx]);
} else
ptrb = ptr_b[offidx];
int stepbn = is_b_row ? n / (num_ptr_b)*(num_ptr_b) : n;
int stepbk = is_b_row ? K : K / (num_ptr_b*vec_b)*(num_ptr_b*vec_b);
Value* pb = gep(ptrb, i32(stepbn*stride_rep_n*stride_bn + stepbk*stride_bk));
Value* hb = load(bit_cast(pb, ptr_ty(vec_ty(i32_ty, vec_b/2), 3)));
// record lds that needs to be moved
if (K == 0 && inc == 1 && is_prefetch)
prefetch_latch_to_bb_[phiB->get_incoming_value(1)].push_back(hb);
Value *hb00 = bit_cast(extract_elt(hb, i32(0)), f16x2_ty);
Value *hb01 = bit_cast(extract_elt(hb, i32(1)), f16x2_ty);
register_lds(hbs, n, K, inc, hb00, hb01, is_prefetch);
if(vec_b > 4){
Value *hb10 = bit_cast(extract_elt(hb, i32(2)), f16x2_ty);
Value *hb11 = bit_cast(extract_elt(hb, i32(3)), f16x2_ty);
if(is_b_row)
register_lds(hbs, n+1, K, inc, hb10, hb11, is_prefetch);
else
register_lds(hbs, n, K+4, inc, hb10, hb11, is_prefetch);
}
};
// update accumulators
if (C->is_prefetched()) {
ir::phi_node* phiA = dynamic_cast<ir::phi_node*>(A);
ir::phi_node* phiB = dynamic_cast<ir::phi_node*>(B);
/// Cache lds value. If values are prefetched, create phi node
auto register_lds =
[&](decltype(has)& vals, int m, int K, int inc, Value* val0, Value *val1, ir::value *v) {
if (K == 0) {
ir::basic_block* inc_block = phiA->get_incoming_block(inc);
lazy_phi_incs_.push_back(std::make_tuple((PHINode*)vals[{m, K}].first, val0, inc_block));
lazy_phi_incs_.push_back(std::make_tuple((PHINode*)vals[{m, K}].second, val1, inc_block));
} else
vals[{m, K}] = {val0, val1};
};
auto load_a = [&](int m, int K, int inc) ->void {
int offidx = (is_a_row ? K/4 : m) % num_ptr_a;
Value* ptra;
if(K==0){
if(inc == 0) {
ptra = gep(shared_pre_ptr_[layout_a], off_a[offidx]);
}
else {
ptra = gep(shared_next_ptr_[layout_a], off_a[offidx]);
}
}
else
ptra = ptr_a[offidx];
int step_am = is_a_row ? m : m / (num_ptr_a)*(num_ptr_a);
int step_ak = is_a_row ? K / (num_ptr_a*vec_a)*(num_ptr_a*vec_a) : K;
Value* pa = gep(ptra, i32(step_am*stride_rep_m*stride_am + step_ak*stride_ak));
Value* ha = load(bit_cast(pa, ptr_ty(vec_ty(i32_ty, vec_a/2), 3)));
// record lds that needs to be moved
if (K == 0 && inc == 1)
prefetch_latch_to_bb_[phiA->get_incoming_value(1)].push_back(ha);
Value *ha00 = bit_cast(extract_elt(ha, i32(0)), f16x2_ty);
Value *ha01 = bit_cast(extract_elt(ha, i32(1)), f16x2_ty);
register_lds(has, m, K, inc, ha00, ha01, A);
if(vec_a > 4){
Value *ha10 = bit_cast(extract_elt(ha, i32(2)), f16x2_ty);
Value *ha11 = bit_cast(extract_elt(ha, i32(3)), f16x2_ty);
if(is_a_row)
register_lds(has, m, K+4, inc, ha10, ha11, A);
else
register_lds(has, m+1, K, inc, ha10, ha11, A);
}
};
auto load_b = [&](int n, int K, int inc){
int offidx = (is_b_row? n : K/4) % num_ptr_b;
Value* ptrb;
if(K==0){
if(inc == 0) {
ptrb = gep(shared_pre_ptr_[layout_b], off_b[offidx]);
}
else {
ptrb = gep(shared_next_ptr_[layout_b], off_b[offidx]);
}
}
else
ptrb = ptr_b[offidx];
int stepbn = is_b_row ? n / (num_ptr_b)*(num_ptr_b) : n;
int stepbk = is_b_row ? K : K / (num_ptr_b*vec_b)*(num_ptr_b*vec_b);
Value* pb = gep(ptrb, i32(stepbn*stride_rep_n*stride_bn + stepbk*stride_bk));
Value* hb = load(bit_cast(pb, ptr_ty(vec_ty(i32_ty, vec_b/2), 3)));
// record lds that needs to be moved
if (K == 0 && inc == 1)
prefetch_latch_to_bb_[phiB->get_incoming_value(1)].push_back(hb);
Value *hb00 = bit_cast(extract_elt(hb, i32(0)), f16x2_ty);
Value *hb01 = bit_cast(extract_elt(hb, i32(1)), f16x2_ty);
register_lds(hbs, n, K, inc, hb00, hb01, B);
if(vec_b > 4){
Value *hb10 = bit_cast(extract_elt(hb, i32(2)), f16x2_ty);
Value *hb11 = bit_cast(extract_elt(hb, i32(3)), f16x2_ty);
if(is_b_row)
register_lds(hbs, n+1, K, inc, hb10, hb11, B);
else
register_lds(hbs, n, K+4, inc, hb10, hb11, B);
}
};
// create phis
builder_->SetInsertPoint(curr_bb->getFirstNonPHI());
for (unsigned m = 0; m < num_m/2; m += is_a_row?1:2) {
@@ -1154,11 +1171,12 @@ void generator::visit_mma884(ir::dot_inst* C, ir::value *A, ir::value *B, ir::va
}
}
// insert prefetched lds at the end of loop header
builder_->SetInsertPoint(bbs_[phiA->get_incoming_block(0)]->getTerminator());
for (unsigned m = 0; m < num_m/2; m += is_a_row?1:2)
load_a(m, 0, 0);
load_a(m, 0, 0, true);
for (unsigned n = 0; n < num_n/2; n += is_b_row?2:1)
load_b(n, 0, 0);
load_b(n, 0, 0, true);
// update accumulators
builder_->SetInsertPoint(curr_bb);
@@ -1166,89 +1184,25 @@ void generator::visit_mma884(ir::dot_inst* C, ir::value *A, ir::value *B, ir::va
int NEXTK = (K + 4) % NK;
// prefetch A
for (unsigned m = 0; m < num_m/2; m+=is_a_row?1:2)
load_a(m, NEXTK, 1);
load_a(m, NEXTK, 1, true);
// prefetch B
for (unsigned n = 0; n < num_n/2; n+=is_b_row?2:1)
load_b(n, NEXTK, 1);
load_b(n, NEXTK, 1, true);
// tensor core ops
for(unsigned m = 0; m < num_m/2; m++)
for(unsigned n = 0; n < num_n/2; n++){
auto ha = has[{m, K}];
auto hb = hbs[{n, K}];
// arguments
std::vector<size_t> idx = {
(m*2 + 0) + (n*4 + 0)*num_m, (m*2 + 0) + (n*4 + 1)*num_m,
(m*2 + 1) + (n*4 + 0)*num_m, (m*2 + 1) + (n*4 + 1)*num_m,
(m*2 + 0) + (n*4 + 2)*num_m, (m*2 + 0) + (n*4 + 3)*num_m,
(m*2 + 1) + (n*4 + 2)*num_m, (m*2 + 1) + (n*4 + 3)*num_m
};
std::vector<Value*> args = {ha.first, ha.second, hb.first, hb.second};
for(unsigned i = 0; i < 8; i++)
args.push_back(acc[idx[i]]);
// execute mma
Value *nc = call(mma, args);
// unpack
for(unsigned i = 0; i < 8; i++)
acc[idx[i]] = extract_val(nc, {i});
call_mma(m, n, K);
}
}
} else { // not prefetched
for(unsigned K = 0; K < NK; K += 4)
for(unsigned m = 0; m < num_m/2; m++)
for(unsigned n = 0; n < num_n/2; n++) {
if(has.find({m, K}) == has.end()){
Value* ptra = ptr_a[(is_a_row ? K/4 : m) % num_ptr_a];
int step_am = is_a_row ? m : m / (num_ptr_a)*(num_ptr_a);
int step_ak = is_a_row ? K / (num_ptr_a*vec_a)*(num_ptr_a*vec_a) : K;
Value* pa = gep(ptra, i32(step_am*stride_rep_m*stride_am + step_ak*stride_ak));
Value* ha = load(bit_cast(pa, ptr_ty(vec_ty(i32_ty, vec_a/2), 3)));
Value *ha00 = bit_cast(extract_elt(ha, i32(0)), f16x2_ty);
Value *ha01 = bit_cast(extract_elt(ha, i32(1)), f16x2_ty);
has[{m, K}] = {ha00, ha01};
if(vec_a > 4){
Value *ha10 = bit_cast(extract_elt(ha, i32(2)), f16x2_ty);
Value *ha11 = bit_cast(extract_elt(ha, i32(3)), f16x2_ty);
if(is_a_row)
has[{m, K+4}] = {ha10, ha11};
else
has[{m+1, K}] = {ha10, ha11};
}
}
if(hbs.find({n, K}) == hbs.end()){
Value* ptrb = ptr_b[(is_b_row? n : K/4) % num_ptr_b];
int stepbn = is_b_row ? n / (num_ptr_b)*(num_ptr_b) : n;
int stepbk = is_b_row ? K : K / (num_ptr_b*vec_b)*(num_ptr_b*vec_b);
Value* pb = gep(ptrb, i32(stepbn*stride_rep_n*stride_bn + stepbk*stride_bk));
Value* hb = load(bit_cast(pb, ptr_ty(vec_ty(i32_ty, vec_b/2), 3)));
Value *hb00 = bit_cast(extract_elt(hb, i32(0)), f16x2_ty);
Value *hb01 = bit_cast(extract_elt(hb, i32(1)), f16x2_ty);
hbs[{n, K}] = {hb00, hb01};
if(vec_b > 4){
Value *hb10 = bit_cast(extract_elt(hb, i32(2)), f16x2_ty);
Value *hb11 = bit_cast(extract_elt(hb, i32(3)), f16x2_ty);
if(is_b_row)
hbs[{n+1, K}] = {hb10, hb11};
else
hbs[{n, K+4}] = {hb10, hb11};
}
}
auto ha = has[{m, K}];
auto hb = hbs[{n, K}];
// arguments
std::vector<size_t> idx = {
(m*2 + 0) + (n*4 + 0)*num_m, (m*2 + 0) + (n*4 + 1)*num_m,
(m*2 + 1) + (n*4 + 0)*num_m, (m*2 + 1) + (n*4 + 1)*num_m,
(m*2 + 0) + (n*4 + 2)*num_m, (m*2 + 0) + (n*4 + 3)*num_m,
(m*2 + 1) + (n*4 + 2)*num_m, (m*2 + 1) + (n*4 + 3)*num_m
};
std::vector<Value*> args = {ha.first, ha.second, hb.first, hb.second};
for(unsigned i = 0; i < 8; i++)
args.push_back(acc[idx[i]]);
// execute mma
Value *nc = call(mma, args);
// unpack
for(unsigned i = 0; i < 8; i++)
acc[idx[i]] = extract_val(nc, {i});
if(has.find({m, K}) == has.end())
load_a(m, K, /*inc*/0, /*is_prefetch*/false);
if(hbs.find({n, K}) == hbs.end())
load_b(n, K, /*inc*/0, /*is_prefetch*/false);
call_mma(m, n, K);
}
}