From 0cc3b1129beb9f99235a7bc07d27f2565d53ffe7 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Thu, 21 Apr 2022 23:56:01 -0700 Subject: [PATCH] [BACKEND][CODE_GEN] eviction policies now also apply to L2 (#501) --- include/triton/codegen/selection/generator.h | 4 +++ lib/codegen/selection/generator.cc | 26 +++++++++++++++++--- 2 files changed, 27 insertions(+), 3 deletions(-) diff --git a/include/triton/codegen/selection/generator.h b/include/triton/codegen/selection/generator.h index e3191efb1..a4f1d33af 100644 --- a/include/triton/codegen/selection/generator.h +++ b/include/triton/codegen/selection/generator.h @@ -4,6 +4,7 @@ #define _TRITON_SELECTION_GENERATOR_H_ #include "triton/ir/visitor.h" +#include "triton/ir/instructions.h" #include "triton/codegen/analysis/layout.h" #include @@ -261,6 +262,9 @@ private: /// Record prefetch instrs that needs to be moved std::map> prefetch_latch_to_bb_; + + // Eviction policies + std::map policies_; }; } diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index 0e6ae4539..c60350060 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -91,6 +91,7 @@ Value* geper::operator()(Value *ptr, Value* off, const std::string& name){ #define i8_ty builder_->getInt8Ty() #define i16_ty builder_->getInt16Ty() #define i32_ty builder_->getInt32Ty() +#define i64_ty builder_->getInt64Ty() #define vec_ty(type, num_el) VectorType::get(type, num_el, false) #define ptr_ty(...) PointerType::get(__VA_ARGS__) // constants @@ -778,6 +779,7 @@ void generator::visit_load_inst(ir::load_inst* x){ int tot_width = nbits*vec; int width = std::min(tot_width, max_word_width); int n_words = std::max(1, tot_width / width); + bool has_evict_policy = x->get_eviction_policy() != ir::load_inst::NORMAL; // ----- // create inline asm string // ----- @@ -789,8 +791,9 @@ void generator::visit_load_inst(ir::load_inst* x){ asm_oss << ".global"; if (x->get_cache_modifier() == ir::load_inst::CA) asm_oss << ".ca"; if (x->get_cache_modifier() == ir::load_inst::CG) asm_oss << ".cg"; - if (x->get_eviction_policy() == ir::load_inst::EVICT_LAST) asm_oss << ".L1::evict_last"; if (x->get_eviction_policy() == ir::load_inst::EVICT_FIRST) asm_oss << ".L1::evict_first"; + if (x->get_eviction_policy() == ir::load_inst::EVICT_LAST) asm_oss << ".L1::evict_last"; + if (has_evict_policy) asm_oss << ".L2::cache_hint"; if(n_words > 1) asm_oss << ".v" << n_words; // vector width asm_oss << ".b" << width; // word size @@ -801,7 +804,9 @@ void generator::visit_load_inst(ir::load_inst* x){ } asm_oss << "}"; asm_oss << ", [ $" << n_words + 1; // load - asm_oss << " + " << in_off << "];"; // constant offset + asm_oss << " + " << in_off << "]"; // constant offset + if (has_evict_policy) asm_oss << ", $" << n_words + 2; + asm_oss << ";"; bool has_other = other && (other != UndefValue::get(other->getType())); std::vector others; // handle `other` values for indices where the mask @@ -822,7 +827,7 @@ void generator::visit_load_inst(ir::load_inst* x){ if(ConstantInt* cst = dyn_cast(v)) asm_oss << "0x" << std::hex << cst->getSExtValue(); else{ - asm_oss << "$" << n_words + 2 + ii; + asm_oss << "$" << n_words + has_evict_policy + 2 + ii; others.push_back(v); } asm_oss.flags(flags); @@ -837,6 +842,8 @@ void generator::visit_load_inst(ir::load_inst* x){ std::vector arg_tys = {pred->getType(), ptr->getType()}; for(Value *v: others) arg_tys.push_back(v->getType()); + if (has_evict_policy) + arg_tys.push_back(i64_ty); FunctionType *asm_ty = FunctionType::get(ret_ty, arg_tys, false); // --- // create inline ASM constraints @@ -851,6 +858,8 @@ void generator::visit_load_inst(ir::load_inst* x){ asm_cstrt += ","; asm_cstrt += (width == 64) ? "l" : ((width == 32) ? "r" : "c"); } + if (has_evict_policy) + asm_cstrt += ",l"; // --- // finally call inline ASM // --- @@ -858,6 +867,10 @@ void generator::visit_load_inst(ir::load_inst* x){ std::vector args = {pred, ptr}; for(Value *v: others) args.push_back(v); + if (has_evict_policy) + args.push_back(policies_.at(x->get_eviction_policy())); + + Value *_ret = call(inlineAsm, args); // --- // extract and store return values @@ -2880,6 +2893,13 @@ void generator::visit_function(ir::function* fn) { bbs_[block] = dst_block; } builder_->SetInsertPoint(bbs_[fn->blocks()[0]]); + // create policies + for(ir::load_inst::EVICTION_POLICY evict: {ir::load_inst::EVICT_FIRST, ir::load_inst::EVICT_LAST}){ + std::string policy = (evict == ir::load_inst::EVICT_FIRST) ? "evict_first" : "evict_last"; + std::string asm_str = "createpolicy.fractional.L2::" + policy + ".b64 $0;"; + InlineAsm* iasm = InlineAsm::get(FunctionType::get(i64_ty, {}), asm_str, "=l", false); + policies_[evict] = call(iasm); + } // initialize layouts for(auto x: layouts_->get_all()){ visit_layout(x.second);