[BACKEND][CODE_GEN] eviction policies now also apply to L2 (#501)

This commit is contained in:
Philippe Tillet
2022-04-21 23:56:01 -07:00
committed by GitHub
parent 7d6c504e8d
commit 0cc3b1129b
2 changed files with 27 additions and 3 deletions

View File

@@ -4,6 +4,7 @@
#define _TRITON_SELECTION_GENERATOR_H_
#include "triton/ir/visitor.h"
#include "triton/ir/instructions.h"
#include "triton/codegen/analysis/layout.h"
#include <functional>
@@ -261,6 +262,9 @@ private:
/// Record prefetch instrs that needs to be moved
std::map<ir::value*, std::vector<Value*>> prefetch_latch_to_bb_;
// Eviction policies
std::map<ir::load_inst::EVICTION_POLICY, Value*> policies_;
};
}

View File

@@ -91,6 +91,7 @@ Value* geper::operator()(Value *ptr, Value* off, const std::string& name){
#define i8_ty builder_->getInt8Ty()
#define i16_ty builder_->getInt16Ty()
#define i32_ty builder_->getInt32Ty()
#define i64_ty builder_->getInt64Ty()
#define vec_ty(type, num_el) VectorType::get(type, num_el, false)
#define ptr_ty(...) PointerType::get(__VA_ARGS__)
// constants
@@ -778,6 +779,7 @@ void generator::visit_load_inst(ir::load_inst* x){
int tot_width = nbits*vec;
int width = std::min(tot_width, max_word_width);
int n_words = std::max(1, tot_width / width);
bool has_evict_policy = x->get_eviction_policy() != ir::load_inst::NORMAL;
// -----
// create inline asm string
// -----
@@ -789,8 +791,9 @@ void generator::visit_load_inst(ir::load_inst* x){
asm_oss << ".global";
if (x->get_cache_modifier() == ir::load_inst::CA) asm_oss << ".ca";
if (x->get_cache_modifier() == ir::load_inst::CG) asm_oss << ".cg";
if (x->get_eviction_policy() == ir::load_inst::EVICT_LAST) asm_oss << ".L1::evict_last";
if (x->get_eviction_policy() == ir::load_inst::EVICT_FIRST) asm_oss << ".L1::evict_first";
if (x->get_eviction_policy() == ir::load_inst::EVICT_LAST) asm_oss << ".L1::evict_last";
if (has_evict_policy) asm_oss << ".L2::cache_hint";
if(n_words > 1)
asm_oss << ".v" << n_words; // vector width
asm_oss << ".b" << width; // word size
@@ -801,7 +804,9 @@ void generator::visit_load_inst(ir::load_inst* x){
}
asm_oss << "}";
asm_oss << ", [ $" << n_words + 1; // load
asm_oss << " + " << in_off << "];"; // constant offset
asm_oss << " + " << in_off << "]"; // constant offset
if (has_evict_policy) asm_oss << ", $" << n_words + 2;
asm_oss << ";";
bool has_other = other && (other != UndefValue::get(other->getType()));
std::vector<Value *> others;
// handle `other` values for indices where the mask
@@ -822,7 +827,7 @@ void generator::visit_load_inst(ir::load_inst* x){
if(ConstantInt* cst = dyn_cast<ConstantInt>(v))
asm_oss << "0x" << std::hex << cst->getSExtValue();
else{
asm_oss << "$" << n_words + 2 + ii;
asm_oss << "$" << n_words + has_evict_policy + 2 + ii;
others.push_back(v);
}
asm_oss.flags(flags);
@@ -837,6 +842,8 @@ void generator::visit_load_inst(ir::load_inst* x){
std::vector<Type*> arg_tys = {pred->getType(), ptr->getType()};
for(Value *v: others)
arg_tys.push_back(v->getType());
if (has_evict_policy)
arg_tys.push_back(i64_ty);
FunctionType *asm_ty = FunctionType::get(ret_ty, arg_tys, false);
// ---
// create inline ASM constraints
@@ -851,6 +858,8 @@ void generator::visit_load_inst(ir::load_inst* x){
asm_cstrt += ",";
asm_cstrt += (width == 64) ? "l" : ((width == 32) ? "r" : "c");
}
if (has_evict_policy)
asm_cstrt += ",l";
// ---
// finally call inline ASM
// ---
@@ -858,6 +867,10 @@ void generator::visit_load_inst(ir::load_inst* x){
std::vector<Value*> args = {pred, ptr};
for(Value *v: others)
args.push_back(v);
if (has_evict_policy)
args.push_back(policies_.at(x->get_eviction_policy()));
Value *_ret = call(inlineAsm, args);
// ---
// extract and store return values
@@ -2880,6 +2893,13 @@ void generator::visit_function(ir::function* fn) {
bbs_[block] = dst_block;
}
builder_->SetInsertPoint(bbs_[fn->blocks()[0]]);
// create policies
for(ir::load_inst::EVICTION_POLICY evict: {ir::load_inst::EVICT_FIRST, ir::load_inst::EVICT_LAST}){
std::string policy = (evict == ir::load_inst::EVICT_FIRST) ? "evict_first" : "evict_last";
std::string asm_str = "createpolicy.fractional.L2::" + policy + ".b64 $0;";
InlineAsm* iasm = InlineAsm::get(FunctionType::get(i64_ty, {}), asm_str, "=l", false);
policies_[evict] = call(iasm);
}
// initialize layouts
for(auto x: layouts_->get_all()){
visit_layout(x.second);