diff --git a/include/triton/codegen/selection/generator.h b/include/triton/codegen/selection/generator.h index 3d33bf3e2..61e513d47 100644 --- a/include/triton/codegen/selection/generator.h +++ b/include/triton/codegen/selection/generator.h @@ -11,6 +11,7 @@ namespace llvm{ class Type; class Value; + class PHINode; class BasicBlock; class Attribute; class Instruction; @@ -169,6 +170,7 @@ public: void visit_copy_to_shared_inst(ir::copy_to_shared_inst*); void visit_copy_from_shared_inst(ir::copy_from_shared_inst*); void visit_barrier_inst(ir::barrier_inst*); + void visit_prefetch_s_inst(ir::prefetch_s_inst*); void visit_async_wait_inst(ir::async_wait_inst*); // void visit_make_range_dyn(ir::make_range_dyn*); void visit_make_range(ir::make_range*); @@ -209,16 +211,19 @@ private: std::map offset_b_k_; std::map offset_b_n_; + /// layout -> base ptr std::map shared_ptr_; std::map shared_pre_ptr_; std::map shared_next_ptr_; + /// offset for double-buffered layout std::map shared_off_; - + /// Base shmem pointer of ir value std::map shmems_; std::map shoffs_; std::map> idxs_; std::map> vals_; + /// triton bb -> llvm bb std::map bbs_; std::map> ords_; @@ -227,6 +232,11 @@ private: multiplier mul; geper gep; + /// PHI nodes + std::vector> lazy_phi_incs_; + + /// Record prefetch instrs that needs to be moved + std::map> prefetch_latch_to_bb_; }; } diff --git a/include/triton/codegen/transform/membar.h b/include/triton/codegen/transform/membar.h index 2b3054f5b..b5a11a46b 100644 --- a/include/triton/codegen/transform/membar.h +++ b/include/triton/codegen/transform/membar.h @@ -5,6 +5,7 @@ #include #include #include +#include "triton/codegen/target.h" namespace triton { @@ -44,14 +45,16 @@ private: std::set &safe_war, bool &inserted, ir::builder &builder); public: - membar(analysis::liveness *liveness, analysis::layouts *layouts, analysis::allocation *alloc): - liveness_(liveness), layouts_(layouts), alloc_(alloc) {} + membar(analysis::liveness *liveness, analysis::layouts *layouts, analysis::allocation *alloc, target* tgt): + liveness_(liveness), layouts_(layouts), alloc_(alloc), tgt_(tgt) {} void run(ir::module &mod); private: analysis::liveness *liveness_; analysis::layouts *layouts_; analysis::allocation *alloc_; + + target* tgt_; }; diff --git a/include/triton/codegen/transform/prefetch.h b/include/triton/codegen/transform/prefetch.h new file mode 100644 index 000000000..01fad8875 --- /dev/null +++ b/include/triton/codegen/transform/prefetch.h @@ -0,0 +1,22 @@ +#ifndef TRITON_INCLUDE_TRITON_CODEGEN_TRANSFORM_PREFETCH_H +#define TRITON_INCLUDE_TRITON_CODEGEN_TRANSFORM_PREFETCH_H + +// forward dclaration +namespace triton::ir{ +class module; +} + +namespace triton::codegen { +class target; +} + +namespace triton::codegen::transform { +class prefetch { + target* tgt_; +public: + prefetch(target *tgt) : tgt_(tgt) {} + void run(ir::module &module); +}; +} + +#endif \ No newline at end of file diff --git a/include/triton/ir/builder.h b/include/triton/ir/builder.h index 1498a8e57..5029aa673 100644 --- a/include/triton/ir/builder.h +++ b/include/triton/ir/builder.h @@ -151,6 +151,7 @@ public: value *create_copy_from_shared(value *arg); value *create_barrier(const std::string &name = ""); value *create_async_wait(int N); + value *create_prefetch_s(value *arg, int inc); private: context &ctx_; diff --git a/include/triton/ir/enums.h b/include/triton/ir/enums.h index e34a3a1ae..720a2e307 100644 --- a/include/triton/ir/enums.h +++ b/include/triton/ir/enums.h @@ -143,7 +143,8 @@ enum value_id_t: unsigned { INST_ASYNC_WAIT, INST_MAKE_RANGE_DYN, INST_MAKE_RANGE_STA, - INST_MAKE_RANGE + INST_MAKE_RANGE, + INST_PREFETCH_S, }; diff --git a/include/triton/ir/instructions.h b/include/triton/ir/instructions.h index f5dd397e1..ca06080ff 100644 --- a/include/triton/ir/instructions.h +++ b/include/triton/ir/instructions.h @@ -666,6 +666,11 @@ private: dot_inst(value *A, value *B, value *C, TransT AT, TransT BT, const std::string &name, instruction *next); std::string repr_impl() const { return "dot"; } + bool is_prefetched_ = false; +public: + bool is_prefetched() const { return is_prefetched_; } + void set_prefetched(bool is_prefetched) { is_prefetched_ = is_prefetched; } + public: static instruction *create(value *A, value *B, value *C, bool AT, bool BT, const std::string &name = "", instruction *next = nullptr); static instruction* create_nn(value *A, value *B, value *C, const std::string &name = "", instruction *next = nullptr); @@ -821,6 +826,23 @@ private: int N_; }; +class prefetch_s_inst : public instruction { + std::string repr_impl() const { return "prefetch_s"; } + _TRITON_DEFINE_CLONE(prefetch_s_inst) + _TRITON_DEFINE_ACCEPT(prefetch_s_inst) + + /// inc_: 0->first, 1->latch + int inc_ = 0; +public: + prefetch_s_inst(context &ctx, value *arg, int inc, const std::string &name, instruction *next) + : instruction(type::get_void_ty(ctx), INST_PREFETCH_S, 1, name, next), inc_(inc) { + set_operand(0, arg); + } + int get_inc() const { return inc_; } + static prefetch_s_inst *create(context &ctx, value *arg, int inc, const std::string &name = "", + instruction *next=nullptr); +}; + //// On NVIDIA, implementation is such that //// constant_range = nv_dynamic_program_idx + nv_static_program_idx //// so as to enable re-association on nv_static_program_idx which is constant diff --git a/include/triton/ir/visitor.h b/include/triton/ir/visitor.h index 6062f15ba..f244fbb06 100644 --- a/include/triton/ir/visitor.h +++ b/include/triton/ir/visitor.h @@ -70,6 +70,7 @@ class barrier_inst; class async_wait_inst; class make_range_dyn; class make_range; +class prefetch_s_inst; class make_range_sta; class undef_value; @@ -146,6 +147,7 @@ public: virtual void visit_async_wait_inst(async_wait_inst*) = 0; // virtual void visit_make_range_dyn(make_range_dyn*) = 0; virtual void visit_make_range(make_range*) = 0; + virtual void visit_prefetch_s_inst(prefetch_s_inst*) = 0; virtual void visit_function(function*) = 0; diff --git a/lib/codegen/pass.cc b/lib/codegen/pass.cc index 77d97d941..2100a7770 100644 --- a/lib/codegen/pass.cc +++ b/lib/codegen/pass.cc @@ -12,6 +12,7 @@ #include "triton/codegen/transform/membar.h" #include "triton/codegen/transform/peephole.h" #include "triton/codegen/transform/pipeline.h" +#include "triton/codegen/transform/prefetch.h" #include "triton/driver/device.h" #include "triton/driver/kernel.h" #include "triton/driver/module.h" @@ -44,11 +45,12 @@ void add_passes_to_emit_bin(ir::module &ir, driver::device *dev, int num_warps, codegen::analysis::liveness liveness(&layouts); codegen::analysis::swizzle swizzle(&layouts, target.get()); codegen::analysis::allocation allocation(&liveness); - codegen::transform::membar barriers(&liveness, &layouts, &allocation); + codegen::transform::membar barriers(&liveness, &layouts, &allocation, target.get()); codegen::transform::dce dce; codegen::transform::peephole peephole(target.get(), &layouts); // codegen::transform::reassociate reassociate; codegen::transform::coalesce coalesce(&align, &layouts); + codegen::transform::prefetch prefetch_s(target.get()); codegen::generator isel(&axes, &layouts, &align, &allocation, &swizzle, target.get(), num_warps); // run passes dce.run(ir); @@ -90,8 +92,9 @@ void add_passes_to_emit_bin(ir::module &ir, driver::device *dev, int num_warps, swizzle.run(ir); liveness.run(ir); allocation.run(ir); - barriers.run(ir); -// ir::print(ir, std::cout); + prefetch_s.run(ir); + barriers.run(ir); + // ir::print(ir, std::cout); isel.visit(ir, *llvm); mod = driver::module::create(dev, std::move(llvm)); ker = driver::kernel::create(&*mod, name.c_str()); diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index e7c242f23..671ac1071 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -1048,65 +1048,208 @@ void generator::visit_mma884(ir::dot_inst* C, ir::value *A, ir::value *B, ir::va for(indices_t idx: idxs_.at(C)) acc.push_back(vals_[D][idx]); - // update accumulators unsigned num_m = layout_c->rep(0) * shape_c[0] / layout_c->spt(0); unsigned num_n = layout_c->rep(1) * shape_c[1] / layout_c->spt(1); - for(unsigned K = 0; K < NK; K += 4) - for(unsigned m = 0; m < num_m/2; m++) - for(unsigned n = 0; n < num_n/2; n++) { - if(has.find({m, K}) == has.end()){ - Value* ptra = ptr_a[(is_a_row ? K/4 : m) % num_ptr_a]; + + // update accumulators + if (C->is_prefetched()) { + ir::phi_node* phiA = dynamic_cast(A); + ir::phi_node* phiB = dynamic_cast(B); + + /// Cache lds value. If values are prefetched, create phi node + auto register_lds = + [&](decltype(has)& vals, int m, int K, int inc, Value* val0, Value *val1, ir::value *v) { + if (K == 0) { + ir::basic_block* inc_block = phiA->get_incoming_block(inc); + lazy_phi_incs_.push_back(std::make_tuple((PHINode*)vals[{m, K}].first, val0, inc_block)); + lazy_phi_incs_.push_back(std::make_tuple((PHINode*)vals[{m, K}].second, val1, inc_block)); + } else + vals[{m, K}] = {val0, val1}; + }; + + auto load_a = [&](int m, int K, int inc) ->void { + int offidx = (is_a_row ? K/4 : m) % num_ptr_a; + Value* ptra; + if(K==0){ + if(inc == 0) { + ptra = gep(shared_pre_ptr_[layout_a], off_a[offidx]); + } + else { + ptra = gep(shared_next_ptr_[layout_a], off_a[offidx]); + } + } + else + ptra = ptr_a[offidx]; int step_am = is_a_row ? m : m / (num_ptr_a)*(num_ptr_a); int step_ak = is_a_row ? K / (num_ptr_a*vec_a)*(num_ptr_a*vec_a) : K; Value* pa = gep(ptra, i32(step_am*stride_rep_m*stride_am + step_ak*stride_ak)); Value* ha = load(bit_cast(pa, ptr_ty(vec_ty(i32_ty, vec_a/2), 3))); + // record lds that needs to be moved + if (K == 0 && inc == 1) + prefetch_latch_to_bb_[phiA->get_incoming_value(1)].push_back(ha); Value *ha00 = bit_cast(extract_elt(ha, i32(0)), f16x2_ty); Value *ha01 = bit_cast(extract_elt(ha, i32(1)), f16x2_ty); - has[{m, K}] = {ha00, ha01}; + register_lds(has, m, K, inc, ha00, ha01, A); if(vec_a > 4){ Value *ha10 = bit_cast(extract_elt(ha, i32(2)), f16x2_ty); Value *ha11 = bit_cast(extract_elt(ha, i32(3)), f16x2_ty); if(is_a_row) - has[{m, K+4}] = {ha10, ha11}; + register_lds(has, m, K+4, inc, ha10, ha11, A); else - has[{m+1, K}] = {ha10, ha11}; + register_lds(has, m+1, K, inc, ha10, ha11, A); } - } - if(hbs.find({n, K}) == hbs.end()){ - Value* ptrb = ptr_b[(is_b_row? n : K/4) % num_ptr_b]; + }; + + auto load_b = [&](int n, int K, int inc){ + int offidx = (is_b_row? n : K/4) % num_ptr_b; + Value* ptrb; + if(K==0){ + if(inc == 0) { + ptrb = gep(shared_pre_ptr_[layout_b], off_b[offidx]); + } + else { + ptrb = gep(shared_next_ptr_[layout_b], off_b[offidx]); + } + } + else + ptrb = ptr_b[offidx]; + int stepbn = is_b_row ? n / (num_ptr_b)*(num_ptr_b) : n; int stepbk = is_b_row ? K : K / (num_ptr_b*vec_b)*(num_ptr_b*vec_b); Value* pb = gep(ptrb, i32(stepbn*stride_rep_n*stride_bn + stepbk*stride_bk)); Value* hb = load(bit_cast(pb, ptr_ty(vec_ty(i32_ty, vec_b/2), 3))); + // record lds that needs to be moved + if (K == 0 && inc == 1) + prefetch_latch_to_bb_[phiB->get_incoming_value(1)].push_back(hb); Value *hb00 = bit_cast(extract_elt(hb, i32(0)), f16x2_ty); Value *hb01 = bit_cast(extract_elt(hb, i32(1)), f16x2_ty); - hbs[{n, K}] = {hb00, hb01}; + register_lds(hbs, n, K, inc, hb00, hb01, B); if(vec_b > 4){ Value *hb10 = bit_cast(extract_elt(hb, i32(2)), f16x2_ty); Value *hb11 = bit_cast(extract_elt(hb, i32(3)), f16x2_ty); if(is_b_row) - hbs[{n+1, K}] = {hb10, hb11}; + register_lds(hbs, n+1, K, inc, hb10, hb11, B); else - hbs[{n, K+4}] = {hb10, hb11}; + register_lds(hbs, n, K+4, inc, hb10, hb11, B); + } + + }; + + // create phis + builder_->SetInsertPoint(curr_bb->getFirstNonPHI()); + for (unsigned m = 0; m < num_m/2; m += is_a_row?1:2) { + has[{m, 0}].first = phi(f16x2_ty, 2); + has[{m, 0}].second = phi(f16x2_ty, 2); + if (!is_a_row && vec_a>4) { + has[{m+1, 0}].first = phi(f16x2_ty, 2); + has[{m+1, 0}].second = phi(f16x2_ty, 2); } } - auto ha = has[{m, K}]; - auto hb = hbs[{n, K}]; - // arguments - std::vector idx = { - (m*2 + 0) + (n*4 + 0)*num_m, (m*2 + 0) + (n*4 + 1)*num_m, - (m*2 + 1) + (n*4 + 0)*num_m, (m*2 + 1) + (n*4 + 1)*num_m, - (m*2 + 0) + (n*4 + 2)*num_m, (m*2 + 0) + (n*4 + 3)*num_m, - (m*2 + 1) + (n*4 + 2)*num_m, (m*2 + 1) + (n*4 + 3)*num_m - }; - std::vector args = {ha.first, ha.second, hb.first, hb.second}; - for(unsigned i = 0; i < 8; i++) - args.push_back(acc[idx[i]]); - // execute mma - Value *nc = call(mma, args); - // unpack - for(unsigned i = 0; i < 8; i++) - acc[idx[i]] = extract_val(nc, {i}); + for (unsigned n = 0; n < num_n/2; n += is_b_row?2:1) { + hbs[{n, 0}].first = phi(f16x2_ty, 2); + hbs[{n, 0}].second = phi(f16x2_ty, 2); + if (is_b_row && vec_b>4) { + hbs[{n+1, 0}].first = phi(f16x2_ty, 2); + hbs[{n+1, 0}].second = phi(f16x2_ty, 2); + } + } + + builder_->SetInsertPoint(bbs_[phiA->get_incoming_block(0)]->getTerminator()); + for (unsigned m = 0; m < num_m/2; m += is_a_row?1:2) + load_a(m, 0, 0); + for (unsigned n = 0; n < num_n/2; n += is_b_row?2:1) + load_b(n, 0, 0); + + // update accumulators + builder_->SetInsertPoint(curr_bb); + for (unsigned K = 0; K < NK; K += 4) { + int NEXTK = (K + 4) % NK; + // prefetch A + for (unsigned m = 0; m < num_m/2; m+=is_a_row?1:2) + load_a(m, NEXTK, 1); + // prefetch B + for (unsigned n = 0; n < num_n/2; n+=is_b_row?2:1) + load_b(n, NEXTK, 1); + // tensor core ops + for(unsigned m = 0; m < num_m/2; m++) + for(unsigned n = 0; n < num_n/2; n++){ + auto ha = has[{m, K}]; + auto hb = hbs[{n, K}]; + // arguments + std::vector idx = { + (m*2 + 0) + (n*4 + 0)*num_m, (m*2 + 0) + (n*4 + 1)*num_m, + (m*2 + 1) + (n*4 + 0)*num_m, (m*2 + 1) + (n*4 + 1)*num_m, + (m*2 + 0) + (n*4 + 2)*num_m, (m*2 + 0) + (n*4 + 3)*num_m, + (m*2 + 1) + (n*4 + 2)*num_m, (m*2 + 1) + (n*4 + 3)*num_m + }; + std::vector args = {ha.first, ha.second, hb.first, hb.second}; + for(unsigned i = 0; i < 8; i++) + args.push_back(acc[idx[i]]); + // execute mma + Value *nc = call(mma, args); + // unpack + for(unsigned i = 0; i < 8; i++) + acc[idx[i]] = extract_val(nc, {i}); + } + } + } else { // not prefetched + for(unsigned K = 0; K < NK; K += 4) + for(unsigned m = 0; m < num_m/2; m++) + for(unsigned n = 0; n < num_n/2; n++) { + if(has.find({m, K}) == has.end()){ + Value* ptra = ptr_a[(is_a_row ? K/4 : m) % num_ptr_a]; + int step_am = is_a_row ? m : m / (num_ptr_a)*(num_ptr_a); + int step_ak = is_a_row ? K / (num_ptr_a*vec_a)*(num_ptr_a*vec_a) : K; + Value* pa = gep(ptra, i32(step_am*stride_rep_m*stride_am + step_ak*stride_ak)); + Value* ha = load(bit_cast(pa, ptr_ty(vec_ty(i32_ty, vec_a/2), 3))); + Value *ha00 = bit_cast(extract_elt(ha, i32(0)), f16x2_ty); + Value *ha01 = bit_cast(extract_elt(ha, i32(1)), f16x2_ty); + has[{m, K}] = {ha00, ha01}; + if(vec_a > 4){ + Value *ha10 = bit_cast(extract_elt(ha, i32(2)), f16x2_ty); + Value *ha11 = bit_cast(extract_elt(ha, i32(3)), f16x2_ty); + if(is_a_row) + has[{m, K+4}] = {ha10, ha11}; + else + has[{m+1, K}] = {ha10, ha11}; + } + } + if(hbs.find({n, K}) == hbs.end()){ + Value* ptrb = ptr_b[(is_b_row? n : K/4) % num_ptr_b]; + int stepbn = is_b_row ? n / (num_ptr_b)*(num_ptr_b) : n; + int stepbk = is_b_row ? K : K / (num_ptr_b*vec_b)*(num_ptr_b*vec_b); + Value* pb = gep(ptrb, i32(stepbn*stride_rep_n*stride_bn + stepbk*stride_bk)); + Value* hb = load(bit_cast(pb, ptr_ty(vec_ty(i32_ty, vec_b/2), 3))); + Value *hb00 = bit_cast(extract_elt(hb, i32(0)), f16x2_ty); + Value *hb01 = bit_cast(extract_elt(hb, i32(1)), f16x2_ty); + hbs[{n, K}] = {hb00, hb01}; + if(vec_b > 4){ + Value *hb10 = bit_cast(extract_elt(hb, i32(2)), f16x2_ty); + Value *hb11 = bit_cast(extract_elt(hb, i32(3)), f16x2_ty); + if(is_b_row) + hbs[{n+1, K}] = {hb10, hb11}; + else + hbs[{n, K+4}] = {hb10, hb11}; + } + } + auto ha = has[{m, K}]; + auto hb = hbs[{n, K}]; + // arguments + std::vector idx = { + (m*2 + 0) + (n*4 + 0)*num_m, (m*2 + 0) + (n*4 + 1)*num_m, + (m*2 + 1) + (n*4 + 0)*num_m, (m*2 + 1) + (n*4 + 1)*num_m, + (m*2 + 0) + (n*4 + 2)*num_m, (m*2 + 0) + (n*4 + 3)*num_m, + (m*2 + 1) + (n*4 + 2)*num_m, (m*2 + 1) + (n*4 + 3)*num_m + }; + std::vector args = {ha.first, ha.second, hb.first, hb.second}; + for(unsigned i = 0; i < 8; i++) + args.push_back(acc[idx[i]]); + // execute mma + Value *nc = call(mma, args); + // unpack + for(unsigned i = 0; i < 8; i++) + acc[idx[i]] = extract_val(nc, {i}); + } } // write back accumulators @@ -1827,6 +1970,40 @@ void generator::visit_barrier_inst(ir::barrier_inst*) { add_barrier(); } +void generator::visit_prefetch_s_inst(ir::prefetch_s_inst *i) { + ir::value *v = i->get_operand(0); + int inc = i->get_inc(); + if (inc == 0) { + // If dot has not been visitied, do nothing. + } else { + // If dot has been visitied, insert prefetched lds + assert(inc == 1); + assert(prefetch_latch_to_bb_.find(v) != prefetch_latch_to_bb_.end() && + "dot hasn't be visited"); + // sink lds & extract element + // move lds & all uses to current location + std::stack work_stack; + for (Value *value : prefetch_latch_to_bb_[v]) + work_stack.push(value); + std::vector dead_instrs; + while (!work_stack.empty()) { + Value *m = work_stack.top(); + work_stack.pop(); + + for (auto u : m->users()) + work_stack.push(u); + + assert(isa(m)); + auto m_instr = static_cast(m); + + m_instr->removeFromParent(); + m_instr->insertAfter(&*std::prev(builder_->GetInsertBlock()->end())); + assert(m_instr->getParent() == &*builder_->GetInsertBlock()); + builder_->SetInsertPoint(m_instr->getParent()); + } + } +} + void generator::visit_async_wait_inst(ir::async_wait_inst* i) { std::string asm_str = "cp.async.wait_group " + std::to_string(i->get_N()) + ";"; InlineAsm *iasm = InlineAsm::get(FunctionType::get(void_ty, {}), asm_str, "", true); @@ -2144,6 +2321,7 @@ void generator::visit_basic_block(ir::basic_block * block) { for(ir::instruction *i: block->get_inst_list()){ visit_value(i); } + // Update ir bb -> llvm bb mapping bbs_[block] = builder_->GetInsertBlock(); } @@ -2247,6 +2425,8 @@ void generator::finalize_function(ir::function *fn) { for(ir::instruction *inst: block->get_inst_list()) if(auto *phi = dynamic_cast(inst)) finalize_phi_node(phi); + for(auto& x: lazy_phi_incs_) + std::get<0>(x)->addIncoming(std::get<1>(x), bbs_[std::get<2>(x)]); } void generator::finalize_phi_node(ir::phi_node *x) { diff --git a/lib/codegen/transform/membar.cc b/lib/codegen/transform/membar.cc index 517fd96d9..95bb044b8 100644 --- a/lib/codegen/transform/membar.cc +++ b/lib/codegen/transform/membar.cc @@ -96,7 +96,15 @@ void membar::transfer(ir::basic_block *block, } } // RAW, WAR - if(intersect_with(read, sync_write).size() || intersect_with({i}, sync_read).size()){ + bool is_i_double_buffered = i->get_type()->is_block_ty() && + layouts_->get(i)->to_shared() && + layouts_->get(i)->to_shared()->get_double_buffer(); + // WAR barrier is not required when data is double-buffered + // TODO: how about other patterns, like WWAR? + if(!intersect_with(read, sync_write).empty() || + (!intersect_with({i}, sync_read).empty() && !is_i_double_buffered) || + // force WAR barrier on A100 + (!intersect_with({i}, sync_read).empty() && tgt_->as_nvidia()->sm() >= 80)){ builder.set_insert_point(i); barrier = (ir::barrier_inst*)builder.create_barrier(); inserted = true; @@ -133,6 +141,7 @@ void membar::run(ir::module &mod) { } for(ir::function *fn: mod.get_function_list()){ + // TODO: (dyan) we need DominatorTree here. std::vector rpo = ir::cfg::reverse_post_order(fn); std::map async_writes; std::map sync_writes; @@ -166,4 +175,4 @@ void membar::run(ir::module &mod) { } } -} +} \ No newline at end of file diff --git a/lib/codegen/transform/prefetch.cc b/lib/codegen/transform/prefetch.cc new file mode 100644 index 000000000..876ee3e43 --- /dev/null +++ b/lib/codegen/transform/prefetch.cc @@ -0,0 +1,105 @@ +#include "triton/codegen/transform/prefetch.h" +#include "triton/codegen/target.h" +#include "triton/ir/module.h" +#include "triton/ir/function.h" +#include "triton/ir/basic_block.h" +#include "triton/ir/instructions.h" +#include "triton/ir/utils.h" +#include "triton/ir/print.h" +#include +#include +#include + +namespace triton::codegen::transform { + +/// find defs till phis +static void recursive_defs(ir::value *v, ir::basic_block *bb, std::vector &ret) { + ir::instruction *i = dynamic_cast(v); + if (!i || i->get_parent() != bb) + return; + if (i->get_id() == ir::INST_PHI) + return; + ret.push_back(i); + for (ir::value *op : i->ops()) + recursive_defs(op, bb, ret); +} + +void prefetch::run(ir::module &mod) { + // 1. collect dot that can be prefethced + std::vector to_prefetch; + ir::for_each_instruction(mod, [&](ir::instruction *i) { + if (auto *dot = dynamic_cast(i)) { + // Now only do prefetching when dot is fp16 & volta/turing + if (dot->get_operand(0)->get_type()->get_scalar_ty()->get_type_id() != ir::type::HalfTyID || + tgt_->as_nvidia()->sm() >= 80) + return; + auto *a = dynamic_cast(dot->get_operand(0)); + auto *b = dynamic_cast(dot->get_operand(1)); + if (a && a->get_incoming_block(1) == a->get_parent() && + b && b->get_incoming_block(1) == b->get_parent()) + to_prefetch.push_back(dot); + } + }); + + assert(to_prefetch.size() <=1 && "Don't know what to do with multiple dots"); + ir::builder &builder = mod.get_builder(); + // 2. do the prefetching + for (ir::dot_inst* dot : to_prefetch) { + auto *a = dynamic_cast(dot->get_operand(0)); + auto *b = dynamic_cast(dot->get_operand(1)); + assert(a->get_incoming_block(0) == b->get_incoming_block(0)); + ir::basic_block *loop_header = a->get_incoming_block(0); + ir::basic_block *loop_body = a->get_parent(); + + // mark as prefetched + dot->set_prefetched(true); + + // 1. in the loop header (first iteration) + builder.set_insert_point(loop_header->get_inst_list().back()); + builder.create_barrier(); + assert(a && b); + builder.create_prefetch_s(a->get_incoming_value(0), /*inc*/ 0); + builder.create_prefetch_s(b->get_incoming_value(0), /*inc*/ 0); + + // 2. at the end of the loop body (next iteration) + builder.set_insert_point(loop_body->get_inst_list().back()); + builder.create_barrier(); + builder.create_prefetch_s(a->get_incoming_value(1), /*inc*/ 1); + builder.create_prefetch_s(b->get_incoming_value(1), /*inc*/ 1); + } + + // move loads to the beginning of the loop + if (tgt_->as_nvidia()->sm() < 80) { + for (ir::function *fn : mod.get_function_list()) + for (ir::basic_block *bb : fn->blocks()) { + // only apply to loop body + if (bb->get_predecessors().size() != 2 || bb->get_predecessors()[1] != bb) + continue; + // record loads (& dependency) to move + std::vector loads; + // record original inst order + std::map idx_map; + size_t idx = 0; + for (ir::instruction *inst : bb->get_inst_list()) { + if (auto *i = dynamic_cast(inst)) + recursive_defs(i, bb, loads); + idx_map[inst] = idx; + idx++; + } + + // remove duplicates & keep the original input order + std::sort(loads.begin(), loads.end()); + loads.erase(std::unique(loads.begin(), loads.end()), loads.end()); + std::sort(loads.begin(), loads.end(), [&idx_map](ir::instruction *a, ir::instruction *b) { + return idx_map[a] < idx_map[b]; + }); + + builder.set_insert_point(bb->get_first_non_phi()); + for (ir::instruction *i : loads) { + bb->erase(i); + builder.insert(i); + } + } + } +} +} // namespace triton::codegen::transform \ No newline at end of file diff --git a/lib/driver/module.cc b/lib/driver/module.cc index 26fc60692..ed984ad43 100755 --- a/lib/driver/module.cc +++ b/lib/driver/module.cc @@ -209,7 +209,8 @@ static std::map vptx = { {10020, 65}, {11000, 70}, {11010, 71}, - {11020, 72} + {11020, 72}, + {11030, 73}, }; std::string cu_module::compile_llvm_module(llvm::Module* module, driver::device* device) { diff --git a/lib/ir/builder.cc b/lib/ir/builder.cc index 051e26636..b6baa9a2e 100644 --- a/lib/ir/builder.cc +++ b/lib/ir/builder.cc @@ -381,6 +381,9 @@ value *builder::create_async_wait(int N) { return insert(async_wait_inst::create(ctx_, N)); } +value *builder::create_prefetch_s(value *arg, int inc) { + return insert(prefetch_s_inst::create(ctx_, arg, inc)); +} } diff --git a/lib/ir/instructions.cc b/lib/ir/instructions.cc index 9b013f1d5..1162b1d15 100644 --- a/lib/ir/instructions.cc +++ b/lib/ir/instructions.cc @@ -832,6 +832,10 @@ async_wait_inst* async_wait_inst::create(context &ctx, int N, const std::string return new async_wait_inst(ctx, N, name, next); } +// prefetch_s +prefetch_s_inst *prefetch_s_inst::create(context &ctx, value *arg, int inc, const std::string &name, instruction *next) { + return new prefetch_s_inst(ctx, arg, inc, name, next); +} //// nv_dynamic_program_idx //make_range_dyn::make_range_dyn(type *ty, const std::string &name, instruction *next)