[CODEGEN] Clean up visit_mma884 (#107)
This commit is contained in:
committed by
Philippe Tillet
parent
967e629c0c
commit
840d65d8c6
@@ -1051,90 +1051,107 @@ void generator::visit_mma884(ir::dot_inst* C, ir::value *A, ir::value *B, ir::va
|
||||
unsigned num_m = layout_c->rep(0) * shape_c[0] / layout_c->spt(0);
|
||||
unsigned num_n = layout_c->rep(1) * shape_c[1] / layout_c->spt(1);
|
||||
|
||||
// create mma & unpack result
|
||||
auto call_mma = [&](unsigned m, unsigned n, unsigned K) {
|
||||
auto ha = has[{m, K}];
|
||||
auto hb = hbs[{n, K}];
|
||||
// arguments
|
||||
std::vector<size_t> idx = {
|
||||
(m*2 + 0) + (n*4 + 0)*num_m, (m*2 + 0) + (n*4 + 1)*num_m,
|
||||
(m*2 + 1) + (n*4 + 0)*num_m, (m*2 + 1) + (n*4 + 1)*num_m,
|
||||
(m*2 + 0) + (n*4 + 2)*num_m, (m*2 + 0) + (n*4 + 3)*num_m,
|
||||
(m*2 + 1) + (n*4 + 2)*num_m, (m*2 + 1) + (n*4 + 3)*num_m
|
||||
};
|
||||
std::vector<Value*> args = {ha.first, ha.second, hb.first, hb.second};
|
||||
for(unsigned i = 0; i < 8; i++)
|
||||
args.push_back(acc[idx[i]]);
|
||||
// execute mma
|
||||
Value *nc = call(mma, args);
|
||||
// unpack
|
||||
for(unsigned i = 0; i < 8; i++)
|
||||
acc[idx[i]] = extract_val(nc, {i});
|
||||
};
|
||||
|
||||
ir::phi_node* phiA = dynamic_cast<ir::phi_node*>(A);
|
||||
ir::phi_node* phiB = dynamic_cast<ir::phi_node*>(B);
|
||||
|
||||
// Cache lds value. If values are prefetched, create phi node
|
||||
// @param inc: incoming block (0 = header, 1 = loop)
|
||||
auto register_lds =
|
||||
[&](decltype(has)& vals, int m, int K, int inc, Value* val0, Value *val1, bool is_prefetch) {
|
||||
if (K == 0 && is_prefetch) {
|
||||
ir::basic_block* inc_block = phiA->get_incoming_block(inc);
|
||||
lazy_phi_incs_.push_back(std::make_tuple((PHINode*)vals[{m, K}].first, val0, inc_block));
|
||||
lazy_phi_incs_.push_back(std::make_tuple((PHINode*)vals[{m, K}].second, val1, inc_block));
|
||||
} else
|
||||
vals[{m, K}] = {val0, val1};
|
||||
};
|
||||
|
||||
auto load_a = [&](int m, int K, int inc, bool is_prefetch) {
|
||||
int offidx = (is_a_row ? K/4 : m) % num_ptr_a;
|
||||
Value* ptra;
|
||||
if(K==0 && is_prefetch){
|
||||
if(inc == 0)
|
||||
ptra = gep(shared_pre_ptr_[layout_a], off_a[offidx]);
|
||||
else
|
||||
ptra = gep(shared_next_ptr_[layout_a], off_a[offidx]);
|
||||
}
|
||||
else
|
||||
ptra = ptr_a[offidx];
|
||||
int step_am = is_a_row ? m : m / (num_ptr_a)*(num_ptr_a);
|
||||
int step_ak = is_a_row ? K / (num_ptr_a*vec_a)*(num_ptr_a*vec_a) : K;
|
||||
Value* pa = gep(ptra, i32(step_am*stride_rep_m*stride_am + step_ak*stride_ak));
|
||||
Value* ha = load(bit_cast(pa, ptr_ty(vec_ty(i32_ty, vec_a/2), 3)));
|
||||
// record lds that needs to be moved
|
||||
if (K == 0 && inc == 1 && is_prefetch)
|
||||
prefetch_latch_to_bb_[phiA->get_incoming_value(1)].push_back(ha);
|
||||
Value *ha00 = bit_cast(extract_elt(ha, i32(0)), f16x2_ty);
|
||||
Value *ha01 = bit_cast(extract_elt(ha, i32(1)), f16x2_ty);
|
||||
register_lds(has, m, K, inc, ha00, ha01, is_prefetch);
|
||||
if(vec_a > 4){
|
||||
Value *ha10 = bit_cast(extract_elt(ha, i32(2)), f16x2_ty);
|
||||
Value *ha11 = bit_cast(extract_elt(ha, i32(3)), f16x2_ty);
|
||||
if(is_a_row)
|
||||
register_lds(has, m, K+4, inc, ha10, ha11, is_prefetch);
|
||||
else
|
||||
register_lds(has, m+1, K, inc, ha10, ha11, is_prefetch);
|
||||
}
|
||||
};
|
||||
|
||||
auto load_b = [&](int n, int K, int inc, bool is_prefetch) {
|
||||
int offidx = (is_b_row? n : K/4) % num_ptr_b;
|
||||
Value* ptrb;
|
||||
if(K==0 && is_prefetch){
|
||||
if(inc == 0)
|
||||
ptrb = gep(shared_pre_ptr_[layout_b], off_b[offidx]);
|
||||
else
|
||||
ptrb = gep(shared_next_ptr_[layout_b], off_b[offidx]);
|
||||
} else
|
||||
ptrb = ptr_b[offidx];
|
||||
|
||||
int stepbn = is_b_row ? n / (num_ptr_b)*(num_ptr_b) : n;
|
||||
int stepbk = is_b_row ? K : K / (num_ptr_b*vec_b)*(num_ptr_b*vec_b);
|
||||
Value* pb = gep(ptrb, i32(stepbn*stride_rep_n*stride_bn + stepbk*stride_bk));
|
||||
Value* hb = load(bit_cast(pb, ptr_ty(vec_ty(i32_ty, vec_b/2), 3)));
|
||||
// record lds that needs to be moved
|
||||
if (K == 0 && inc == 1 && is_prefetch)
|
||||
prefetch_latch_to_bb_[phiB->get_incoming_value(1)].push_back(hb);
|
||||
Value *hb00 = bit_cast(extract_elt(hb, i32(0)), f16x2_ty);
|
||||
Value *hb01 = bit_cast(extract_elt(hb, i32(1)), f16x2_ty);
|
||||
register_lds(hbs, n, K, inc, hb00, hb01, is_prefetch);
|
||||
if(vec_b > 4){
|
||||
Value *hb10 = bit_cast(extract_elt(hb, i32(2)), f16x2_ty);
|
||||
Value *hb11 = bit_cast(extract_elt(hb, i32(3)), f16x2_ty);
|
||||
if(is_b_row)
|
||||
register_lds(hbs, n+1, K, inc, hb10, hb11, is_prefetch);
|
||||
else
|
||||
register_lds(hbs, n, K+4, inc, hb10, hb11, is_prefetch);
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
// update accumulators
|
||||
if (C->is_prefetched()) {
|
||||
ir::phi_node* phiA = dynamic_cast<ir::phi_node*>(A);
|
||||
ir::phi_node* phiB = dynamic_cast<ir::phi_node*>(B);
|
||||
|
||||
/// Cache lds value. If values are prefetched, create phi node
|
||||
auto register_lds =
|
||||
[&](decltype(has)& vals, int m, int K, int inc, Value* val0, Value *val1, ir::value *v) {
|
||||
if (K == 0) {
|
||||
ir::basic_block* inc_block = phiA->get_incoming_block(inc);
|
||||
lazy_phi_incs_.push_back(std::make_tuple((PHINode*)vals[{m, K}].first, val0, inc_block));
|
||||
lazy_phi_incs_.push_back(std::make_tuple((PHINode*)vals[{m, K}].second, val1, inc_block));
|
||||
} else
|
||||
vals[{m, K}] = {val0, val1};
|
||||
};
|
||||
|
||||
auto load_a = [&](int m, int K, int inc) ->void {
|
||||
int offidx = (is_a_row ? K/4 : m) % num_ptr_a;
|
||||
Value* ptra;
|
||||
if(K==0){
|
||||
if(inc == 0) {
|
||||
ptra = gep(shared_pre_ptr_[layout_a], off_a[offidx]);
|
||||
}
|
||||
else {
|
||||
ptra = gep(shared_next_ptr_[layout_a], off_a[offidx]);
|
||||
}
|
||||
}
|
||||
else
|
||||
ptra = ptr_a[offidx];
|
||||
int step_am = is_a_row ? m : m / (num_ptr_a)*(num_ptr_a);
|
||||
int step_ak = is_a_row ? K / (num_ptr_a*vec_a)*(num_ptr_a*vec_a) : K;
|
||||
Value* pa = gep(ptra, i32(step_am*stride_rep_m*stride_am + step_ak*stride_ak));
|
||||
Value* ha = load(bit_cast(pa, ptr_ty(vec_ty(i32_ty, vec_a/2), 3)));
|
||||
// record lds that needs to be moved
|
||||
if (K == 0 && inc == 1)
|
||||
prefetch_latch_to_bb_[phiA->get_incoming_value(1)].push_back(ha);
|
||||
Value *ha00 = bit_cast(extract_elt(ha, i32(0)), f16x2_ty);
|
||||
Value *ha01 = bit_cast(extract_elt(ha, i32(1)), f16x2_ty);
|
||||
register_lds(has, m, K, inc, ha00, ha01, A);
|
||||
if(vec_a > 4){
|
||||
Value *ha10 = bit_cast(extract_elt(ha, i32(2)), f16x2_ty);
|
||||
Value *ha11 = bit_cast(extract_elt(ha, i32(3)), f16x2_ty);
|
||||
if(is_a_row)
|
||||
register_lds(has, m, K+4, inc, ha10, ha11, A);
|
||||
else
|
||||
register_lds(has, m+1, K, inc, ha10, ha11, A);
|
||||
}
|
||||
};
|
||||
|
||||
auto load_b = [&](int n, int K, int inc){
|
||||
int offidx = (is_b_row? n : K/4) % num_ptr_b;
|
||||
Value* ptrb;
|
||||
if(K==0){
|
||||
if(inc == 0) {
|
||||
ptrb = gep(shared_pre_ptr_[layout_b], off_b[offidx]);
|
||||
}
|
||||
else {
|
||||
ptrb = gep(shared_next_ptr_[layout_b], off_b[offidx]);
|
||||
}
|
||||
}
|
||||
else
|
||||
ptrb = ptr_b[offidx];
|
||||
|
||||
int stepbn = is_b_row ? n / (num_ptr_b)*(num_ptr_b) : n;
|
||||
int stepbk = is_b_row ? K : K / (num_ptr_b*vec_b)*(num_ptr_b*vec_b);
|
||||
Value* pb = gep(ptrb, i32(stepbn*stride_rep_n*stride_bn + stepbk*stride_bk));
|
||||
Value* hb = load(bit_cast(pb, ptr_ty(vec_ty(i32_ty, vec_b/2), 3)));
|
||||
// record lds that needs to be moved
|
||||
if (K == 0 && inc == 1)
|
||||
prefetch_latch_to_bb_[phiB->get_incoming_value(1)].push_back(hb);
|
||||
Value *hb00 = bit_cast(extract_elt(hb, i32(0)), f16x2_ty);
|
||||
Value *hb01 = bit_cast(extract_elt(hb, i32(1)), f16x2_ty);
|
||||
register_lds(hbs, n, K, inc, hb00, hb01, B);
|
||||
if(vec_b > 4){
|
||||
Value *hb10 = bit_cast(extract_elt(hb, i32(2)), f16x2_ty);
|
||||
Value *hb11 = bit_cast(extract_elt(hb, i32(3)), f16x2_ty);
|
||||
if(is_b_row)
|
||||
register_lds(hbs, n+1, K, inc, hb10, hb11, B);
|
||||
else
|
||||
register_lds(hbs, n, K+4, inc, hb10, hb11, B);
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
// create phis
|
||||
builder_->SetInsertPoint(curr_bb->getFirstNonPHI());
|
||||
for (unsigned m = 0; m < num_m/2; m += is_a_row?1:2) {
|
||||
@@ -1154,11 +1171,12 @@ void generator::visit_mma884(ir::dot_inst* C, ir::value *A, ir::value *B, ir::va
|
||||
}
|
||||
}
|
||||
|
||||
// insert prefetched lds at the end of loop header
|
||||
builder_->SetInsertPoint(bbs_[phiA->get_incoming_block(0)]->getTerminator());
|
||||
for (unsigned m = 0; m < num_m/2; m += is_a_row?1:2)
|
||||
load_a(m, 0, 0);
|
||||
load_a(m, 0, 0, true);
|
||||
for (unsigned n = 0; n < num_n/2; n += is_b_row?2:1)
|
||||
load_b(n, 0, 0);
|
||||
load_b(n, 0, 0, true);
|
||||
|
||||
// update accumulators
|
||||
builder_->SetInsertPoint(curr_bb);
|
||||
@@ -1166,89 +1184,25 @@ void generator::visit_mma884(ir::dot_inst* C, ir::value *A, ir::value *B, ir::va
|
||||
int NEXTK = (K + 4) % NK;
|
||||
// prefetch A
|
||||
for (unsigned m = 0; m < num_m/2; m+=is_a_row?1:2)
|
||||
load_a(m, NEXTK, 1);
|
||||
load_a(m, NEXTK, 1, true);
|
||||
// prefetch B
|
||||
for (unsigned n = 0; n < num_n/2; n+=is_b_row?2:1)
|
||||
load_b(n, NEXTK, 1);
|
||||
load_b(n, NEXTK, 1, true);
|
||||
// tensor core ops
|
||||
for(unsigned m = 0; m < num_m/2; m++)
|
||||
for(unsigned n = 0; n < num_n/2; n++){
|
||||
auto ha = has[{m, K}];
|
||||
auto hb = hbs[{n, K}];
|
||||
// arguments
|
||||
std::vector<size_t> idx = {
|
||||
(m*2 + 0) + (n*4 + 0)*num_m, (m*2 + 0) + (n*4 + 1)*num_m,
|
||||
(m*2 + 1) + (n*4 + 0)*num_m, (m*2 + 1) + (n*4 + 1)*num_m,
|
||||
(m*2 + 0) + (n*4 + 2)*num_m, (m*2 + 0) + (n*4 + 3)*num_m,
|
||||
(m*2 + 1) + (n*4 + 2)*num_m, (m*2 + 1) + (n*4 + 3)*num_m
|
||||
};
|
||||
std::vector<Value*> args = {ha.first, ha.second, hb.first, hb.second};
|
||||
for(unsigned i = 0; i < 8; i++)
|
||||
args.push_back(acc[idx[i]]);
|
||||
// execute mma
|
||||
Value *nc = call(mma, args);
|
||||
// unpack
|
||||
for(unsigned i = 0; i < 8; i++)
|
||||
acc[idx[i]] = extract_val(nc, {i});
|
||||
call_mma(m, n, K);
|
||||
}
|
||||
}
|
||||
} else { // not prefetched
|
||||
for(unsigned K = 0; K < NK; K += 4)
|
||||
for(unsigned m = 0; m < num_m/2; m++)
|
||||
for(unsigned n = 0; n < num_n/2; n++) {
|
||||
if(has.find({m, K}) == has.end()){
|
||||
Value* ptra = ptr_a[(is_a_row ? K/4 : m) % num_ptr_a];
|
||||
int step_am = is_a_row ? m : m / (num_ptr_a)*(num_ptr_a);
|
||||
int step_ak = is_a_row ? K / (num_ptr_a*vec_a)*(num_ptr_a*vec_a) : K;
|
||||
Value* pa = gep(ptra, i32(step_am*stride_rep_m*stride_am + step_ak*stride_ak));
|
||||
Value* ha = load(bit_cast(pa, ptr_ty(vec_ty(i32_ty, vec_a/2), 3)));
|
||||
Value *ha00 = bit_cast(extract_elt(ha, i32(0)), f16x2_ty);
|
||||
Value *ha01 = bit_cast(extract_elt(ha, i32(1)), f16x2_ty);
|
||||
has[{m, K}] = {ha00, ha01};
|
||||
if(vec_a > 4){
|
||||
Value *ha10 = bit_cast(extract_elt(ha, i32(2)), f16x2_ty);
|
||||
Value *ha11 = bit_cast(extract_elt(ha, i32(3)), f16x2_ty);
|
||||
if(is_a_row)
|
||||
has[{m, K+4}] = {ha10, ha11};
|
||||
else
|
||||
has[{m+1, K}] = {ha10, ha11};
|
||||
}
|
||||
}
|
||||
if(hbs.find({n, K}) == hbs.end()){
|
||||
Value* ptrb = ptr_b[(is_b_row? n : K/4) % num_ptr_b];
|
||||
int stepbn = is_b_row ? n / (num_ptr_b)*(num_ptr_b) : n;
|
||||
int stepbk = is_b_row ? K : K / (num_ptr_b*vec_b)*(num_ptr_b*vec_b);
|
||||
Value* pb = gep(ptrb, i32(stepbn*stride_rep_n*stride_bn + stepbk*stride_bk));
|
||||
Value* hb = load(bit_cast(pb, ptr_ty(vec_ty(i32_ty, vec_b/2), 3)));
|
||||
Value *hb00 = bit_cast(extract_elt(hb, i32(0)), f16x2_ty);
|
||||
Value *hb01 = bit_cast(extract_elt(hb, i32(1)), f16x2_ty);
|
||||
hbs[{n, K}] = {hb00, hb01};
|
||||
if(vec_b > 4){
|
||||
Value *hb10 = bit_cast(extract_elt(hb, i32(2)), f16x2_ty);
|
||||
Value *hb11 = bit_cast(extract_elt(hb, i32(3)), f16x2_ty);
|
||||
if(is_b_row)
|
||||
hbs[{n+1, K}] = {hb10, hb11};
|
||||
else
|
||||
hbs[{n, K+4}] = {hb10, hb11};
|
||||
}
|
||||
}
|
||||
auto ha = has[{m, K}];
|
||||
auto hb = hbs[{n, K}];
|
||||
// arguments
|
||||
std::vector<size_t> idx = {
|
||||
(m*2 + 0) + (n*4 + 0)*num_m, (m*2 + 0) + (n*4 + 1)*num_m,
|
||||
(m*2 + 1) + (n*4 + 0)*num_m, (m*2 + 1) + (n*4 + 1)*num_m,
|
||||
(m*2 + 0) + (n*4 + 2)*num_m, (m*2 + 0) + (n*4 + 3)*num_m,
|
||||
(m*2 + 1) + (n*4 + 2)*num_m, (m*2 + 1) + (n*4 + 3)*num_m
|
||||
};
|
||||
std::vector<Value*> args = {ha.first, ha.second, hb.first, hb.second};
|
||||
for(unsigned i = 0; i < 8; i++)
|
||||
args.push_back(acc[idx[i]]);
|
||||
// execute mma
|
||||
Value *nc = call(mma, args);
|
||||
// unpack
|
||||
for(unsigned i = 0; i < 8; i++)
|
||||
acc[idx[i]] = extract_val(nc, {i});
|
||||
if(has.find({m, K}) == has.end())
|
||||
load_a(m, K, /*inc*/0, /*is_prefetch*/false);
|
||||
if(hbs.find({n, K}) == hbs.end())
|
||||
load_b(n, K, /*inc*/0, /*is_prefetch*/false);
|
||||
call_mma(m, n, K);
|
||||
}
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user