diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index 671ac1071..bb9f518e9 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -1051,90 +1051,107 @@ void generator::visit_mma884(ir::dot_inst* C, ir::value *A, ir::value *B, ir::va unsigned num_m = layout_c->rep(0) * shape_c[0] / layout_c->spt(0); unsigned num_n = layout_c->rep(1) * shape_c[1] / layout_c->spt(1); + // create mma & unpack result + auto call_mma = [&](unsigned m, unsigned n, unsigned K) { + auto ha = has[{m, K}]; + auto hb = hbs[{n, K}]; + // arguments + std::vector idx = { + (m*2 + 0) + (n*4 + 0)*num_m, (m*2 + 0) + (n*4 + 1)*num_m, + (m*2 + 1) + (n*4 + 0)*num_m, (m*2 + 1) + (n*4 + 1)*num_m, + (m*2 + 0) + (n*4 + 2)*num_m, (m*2 + 0) + (n*4 + 3)*num_m, + (m*2 + 1) + (n*4 + 2)*num_m, (m*2 + 1) + (n*4 + 3)*num_m + }; + std::vector args = {ha.first, ha.second, hb.first, hb.second}; + for(unsigned i = 0; i < 8; i++) + args.push_back(acc[idx[i]]); + // execute mma + Value *nc = call(mma, args); + // unpack + for(unsigned i = 0; i < 8; i++) + acc[idx[i]] = extract_val(nc, {i}); + }; + + ir::phi_node* phiA = dynamic_cast(A); + ir::phi_node* phiB = dynamic_cast(B); + + // Cache lds value. If values are prefetched, create phi node + // @param inc: incoming block (0 = header, 1 = loop) + auto register_lds = + [&](decltype(has)& vals, int m, int K, int inc, Value* val0, Value *val1, bool is_prefetch) { + if (K == 0 && is_prefetch) { + ir::basic_block* inc_block = phiA->get_incoming_block(inc); + lazy_phi_incs_.push_back(std::make_tuple((PHINode*)vals[{m, K}].first, val0, inc_block)); + lazy_phi_incs_.push_back(std::make_tuple((PHINode*)vals[{m, K}].second, val1, inc_block)); + } else + vals[{m, K}] = {val0, val1}; + }; + + auto load_a = [&](int m, int K, int inc, bool is_prefetch) { + int offidx = (is_a_row ? K/4 : m) % num_ptr_a; + Value* ptra; + if(K==0 && is_prefetch){ + if(inc == 0) + ptra = gep(shared_pre_ptr_[layout_a], off_a[offidx]); + else + ptra = gep(shared_next_ptr_[layout_a], off_a[offidx]); + } + else + ptra = ptr_a[offidx]; + int step_am = is_a_row ? m : m / (num_ptr_a)*(num_ptr_a); + int step_ak = is_a_row ? K / (num_ptr_a*vec_a)*(num_ptr_a*vec_a) : K; + Value* pa = gep(ptra, i32(step_am*stride_rep_m*stride_am + step_ak*stride_ak)); + Value* ha = load(bit_cast(pa, ptr_ty(vec_ty(i32_ty, vec_a/2), 3))); + // record lds that needs to be moved + if (K == 0 && inc == 1 && is_prefetch) + prefetch_latch_to_bb_[phiA->get_incoming_value(1)].push_back(ha); + Value *ha00 = bit_cast(extract_elt(ha, i32(0)), f16x2_ty); + Value *ha01 = bit_cast(extract_elt(ha, i32(1)), f16x2_ty); + register_lds(has, m, K, inc, ha00, ha01, is_prefetch); + if(vec_a > 4){ + Value *ha10 = bit_cast(extract_elt(ha, i32(2)), f16x2_ty); + Value *ha11 = bit_cast(extract_elt(ha, i32(3)), f16x2_ty); + if(is_a_row) + register_lds(has, m, K+4, inc, ha10, ha11, is_prefetch); + else + register_lds(has, m+1, K, inc, ha10, ha11, is_prefetch); + } + }; + + auto load_b = [&](int n, int K, int inc, bool is_prefetch) { + int offidx = (is_b_row? n : K/4) % num_ptr_b; + Value* ptrb; + if(K==0 && is_prefetch){ + if(inc == 0) + ptrb = gep(shared_pre_ptr_[layout_b], off_b[offidx]); + else + ptrb = gep(shared_next_ptr_[layout_b], off_b[offidx]); + } else + ptrb = ptr_b[offidx]; + + int stepbn = is_b_row ? n / (num_ptr_b)*(num_ptr_b) : n; + int stepbk = is_b_row ? K : K / (num_ptr_b*vec_b)*(num_ptr_b*vec_b); + Value* pb = gep(ptrb, i32(stepbn*stride_rep_n*stride_bn + stepbk*stride_bk)); + Value* hb = load(bit_cast(pb, ptr_ty(vec_ty(i32_ty, vec_b/2), 3))); + // record lds that needs to be moved + if (K == 0 && inc == 1 && is_prefetch) + prefetch_latch_to_bb_[phiB->get_incoming_value(1)].push_back(hb); + Value *hb00 = bit_cast(extract_elt(hb, i32(0)), f16x2_ty); + Value *hb01 = bit_cast(extract_elt(hb, i32(1)), f16x2_ty); + register_lds(hbs, n, K, inc, hb00, hb01, is_prefetch); + if(vec_b > 4){ + Value *hb10 = bit_cast(extract_elt(hb, i32(2)), f16x2_ty); + Value *hb11 = bit_cast(extract_elt(hb, i32(3)), f16x2_ty); + if(is_b_row) + register_lds(hbs, n+1, K, inc, hb10, hb11, is_prefetch); + else + register_lds(hbs, n, K+4, inc, hb10, hb11, is_prefetch); + } + + }; + // update accumulators if (C->is_prefetched()) { - ir::phi_node* phiA = dynamic_cast(A); - ir::phi_node* phiB = dynamic_cast(B); - - /// Cache lds value. If values are prefetched, create phi node - auto register_lds = - [&](decltype(has)& vals, int m, int K, int inc, Value* val0, Value *val1, ir::value *v) { - if (K == 0) { - ir::basic_block* inc_block = phiA->get_incoming_block(inc); - lazy_phi_incs_.push_back(std::make_tuple((PHINode*)vals[{m, K}].first, val0, inc_block)); - lazy_phi_incs_.push_back(std::make_tuple((PHINode*)vals[{m, K}].second, val1, inc_block)); - } else - vals[{m, K}] = {val0, val1}; - }; - - auto load_a = [&](int m, int K, int inc) ->void { - int offidx = (is_a_row ? K/4 : m) % num_ptr_a; - Value* ptra; - if(K==0){ - if(inc == 0) { - ptra = gep(shared_pre_ptr_[layout_a], off_a[offidx]); - } - else { - ptra = gep(shared_next_ptr_[layout_a], off_a[offidx]); - } - } - else - ptra = ptr_a[offidx]; - int step_am = is_a_row ? m : m / (num_ptr_a)*(num_ptr_a); - int step_ak = is_a_row ? K / (num_ptr_a*vec_a)*(num_ptr_a*vec_a) : K; - Value* pa = gep(ptra, i32(step_am*stride_rep_m*stride_am + step_ak*stride_ak)); - Value* ha = load(bit_cast(pa, ptr_ty(vec_ty(i32_ty, vec_a/2), 3))); - // record lds that needs to be moved - if (K == 0 && inc == 1) - prefetch_latch_to_bb_[phiA->get_incoming_value(1)].push_back(ha); - Value *ha00 = bit_cast(extract_elt(ha, i32(0)), f16x2_ty); - Value *ha01 = bit_cast(extract_elt(ha, i32(1)), f16x2_ty); - register_lds(has, m, K, inc, ha00, ha01, A); - if(vec_a > 4){ - Value *ha10 = bit_cast(extract_elt(ha, i32(2)), f16x2_ty); - Value *ha11 = bit_cast(extract_elt(ha, i32(3)), f16x2_ty); - if(is_a_row) - register_lds(has, m, K+4, inc, ha10, ha11, A); - else - register_lds(has, m+1, K, inc, ha10, ha11, A); - } - }; - - auto load_b = [&](int n, int K, int inc){ - int offidx = (is_b_row? n : K/4) % num_ptr_b; - Value* ptrb; - if(K==0){ - if(inc == 0) { - ptrb = gep(shared_pre_ptr_[layout_b], off_b[offidx]); - } - else { - ptrb = gep(shared_next_ptr_[layout_b], off_b[offidx]); - } - } - else - ptrb = ptr_b[offidx]; - - int stepbn = is_b_row ? n / (num_ptr_b)*(num_ptr_b) : n; - int stepbk = is_b_row ? K : K / (num_ptr_b*vec_b)*(num_ptr_b*vec_b); - Value* pb = gep(ptrb, i32(stepbn*stride_rep_n*stride_bn + stepbk*stride_bk)); - Value* hb = load(bit_cast(pb, ptr_ty(vec_ty(i32_ty, vec_b/2), 3))); - // record lds that needs to be moved - if (K == 0 && inc == 1) - prefetch_latch_to_bb_[phiB->get_incoming_value(1)].push_back(hb); - Value *hb00 = bit_cast(extract_elt(hb, i32(0)), f16x2_ty); - Value *hb01 = bit_cast(extract_elt(hb, i32(1)), f16x2_ty); - register_lds(hbs, n, K, inc, hb00, hb01, B); - if(vec_b > 4){ - Value *hb10 = bit_cast(extract_elt(hb, i32(2)), f16x2_ty); - Value *hb11 = bit_cast(extract_elt(hb, i32(3)), f16x2_ty); - if(is_b_row) - register_lds(hbs, n+1, K, inc, hb10, hb11, B); - else - register_lds(hbs, n, K+4, inc, hb10, hb11, B); - } - - }; - // create phis builder_->SetInsertPoint(curr_bb->getFirstNonPHI()); for (unsigned m = 0; m < num_m/2; m += is_a_row?1:2) { @@ -1154,11 +1171,12 @@ void generator::visit_mma884(ir::dot_inst* C, ir::value *A, ir::value *B, ir::va } } + // insert prefetched lds at the end of loop header builder_->SetInsertPoint(bbs_[phiA->get_incoming_block(0)]->getTerminator()); for (unsigned m = 0; m < num_m/2; m += is_a_row?1:2) - load_a(m, 0, 0); + load_a(m, 0, 0, true); for (unsigned n = 0; n < num_n/2; n += is_b_row?2:1) - load_b(n, 0, 0); + load_b(n, 0, 0, true); // update accumulators builder_->SetInsertPoint(curr_bb); @@ -1166,89 +1184,25 @@ void generator::visit_mma884(ir::dot_inst* C, ir::value *A, ir::value *B, ir::va int NEXTK = (K + 4) % NK; // prefetch A for (unsigned m = 0; m < num_m/2; m+=is_a_row?1:2) - load_a(m, NEXTK, 1); + load_a(m, NEXTK, 1, true); // prefetch B for (unsigned n = 0; n < num_n/2; n+=is_b_row?2:1) - load_b(n, NEXTK, 1); + load_b(n, NEXTK, 1, true); // tensor core ops for(unsigned m = 0; m < num_m/2; m++) for(unsigned n = 0; n < num_n/2; n++){ - auto ha = has[{m, K}]; - auto hb = hbs[{n, K}]; - // arguments - std::vector idx = { - (m*2 + 0) + (n*4 + 0)*num_m, (m*2 + 0) + (n*4 + 1)*num_m, - (m*2 + 1) + (n*4 + 0)*num_m, (m*2 + 1) + (n*4 + 1)*num_m, - (m*2 + 0) + (n*4 + 2)*num_m, (m*2 + 0) + (n*4 + 3)*num_m, - (m*2 + 1) + (n*4 + 2)*num_m, (m*2 + 1) + (n*4 + 3)*num_m - }; - std::vector args = {ha.first, ha.second, hb.first, hb.second}; - for(unsigned i = 0; i < 8; i++) - args.push_back(acc[idx[i]]); - // execute mma - Value *nc = call(mma, args); - // unpack - for(unsigned i = 0; i < 8; i++) - acc[idx[i]] = extract_val(nc, {i}); + call_mma(m, n, K); } } } else { // not prefetched for(unsigned K = 0; K < NK; K += 4) for(unsigned m = 0; m < num_m/2; m++) for(unsigned n = 0; n < num_n/2; n++) { - if(has.find({m, K}) == has.end()){ - Value* ptra = ptr_a[(is_a_row ? K/4 : m) % num_ptr_a]; - int step_am = is_a_row ? m : m / (num_ptr_a)*(num_ptr_a); - int step_ak = is_a_row ? K / (num_ptr_a*vec_a)*(num_ptr_a*vec_a) : K; - Value* pa = gep(ptra, i32(step_am*stride_rep_m*stride_am + step_ak*stride_ak)); - Value* ha = load(bit_cast(pa, ptr_ty(vec_ty(i32_ty, vec_a/2), 3))); - Value *ha00 = bit_cast(extract_elt(ha, i32(0)), f16x2_ty); - Value *ha01 = bit_cast(extract_elt(ha, i32(1)), f16x2_ty); - has[{m, K}] = {ha00, ha01}; - if(vec_a > 4){ - Value *ha10 = bit_cast(extract_elt(ha, i32(2)), f16x2_ty); - Value *ha11 = bit_cast(extract_elt(ha, i32(3)), f16x2_ty); - if(is_a_row) - has[{m, K+4}] = {ha10, ha11}; - else - has[{m+1, K}] = {ha10, ha11}; - } - } - if(hbs.find({n, K}) == hbs.end()){ - Value* ptrb = ptr_b[(is_b_row? n : K/4) % num_ptr_b]; - int stepbn = is_b_row ? n / (num_ptr_b)*(num_ptr_b) : n; - int stepbk = is_b_row ? K : K / (num_ptr_b*vec_b)*(num_ptr_b*vec_b); - Value* pb = gep(ptrb, i32(stepbn*stride_rep_n*stride_bn + stepbk*stride_bk)); - Value* hb = load(bit_cast(pb, ptr_ty(vec_ty(i32_ty, vec_b/2), 3))); - Value *hb00 = bit_cast(extract_elt(hb, i32(0)), f16x2_ty); - Value *hb01 = bit_cast(extract_elt(hb, i32(1)), f16x2_ty); - hbs[{n, K}] = {hb00, hb01}; - if(vec_b > 4){ - Value *hb10 = bit_cast(extract_elt(hb, i32(2)), f16x2_ty); - Value *hb11 = bit_cast(extract_elt(hb, i32(3)), f16x2_ty); - if(is_b_row) - hbs[{n+1, K}] = {hb10, hb11}; - else - hbs[{n, K+4}] = {hb10, hb11}; - } - } - auto ha = has[{m, K}]; - auto hb = hbs[{n, K}]; - // arguments - std::vector idx = { - (m*2 + 0) + (n*4 + 0)*num_m, (m*2 + 0) + (n*4 + 1)*num_m, - (m*2 + 1) + (n*4 + 0)*num_m, (m*2 + 1) + (n*4 + 1)*num_m, - (m*2 + 0) + (n*4 + 2)*num_m, (m*2 + 0) + (n*4 + 3)*num_m, - (m*2 + 1) + (n*4 + 2)*num_m, (m*2 + 1) + (n*4 + 3)*num_m - }; - std::vector args = {ha.first, ha.second, hb.first, hb.second}; - for(unsigned i = 0; i < 8; i++) - args.push_back(acc[idx[i]]); - // execute mma - Value *nc = call(mma, args); - // unpack - for(unsigned i = 0; i < 8; i++) - acc[idx[i]] = extract_val(nc, {i}); + if(has.find({m, K}) == has.end()) + load_a(m, K, /*inc*/0, /*is_prefetch*/false); + if(hbs.find({n, K}) == hbs.end()) + load_b(n, K, /*inc*/0, /*is_prefetch*/false); + call_mma(m, n, K); } }