[BACKEND][CODE_GEN] eviction policies now also apply to L2 (#501)
This commit is contained in:
@@ -4,6 +4,7 @@
|
||||
#define _TRITON_SELECTION_GENERATOR_H_
|
||||
|
||||
#include "triton/ir/visitor.h"
|
||||
#include "triton/ir/instructions.h"
|
||||
#include "triton/codegen/analysis/layout.h"
|
||||
#include <functional>
|
||||
|
||||
@@ -261,6 +262,9 @@ private:
|
||||
|
||||
/// Record prefetch instrs that needs to be moved
|
||||
std::map<ir::value*, std::vector<Value*>> prefetch_latch_to_bb_;
|
||||
|
||||
// Eviction policies
|
||||
std::map<ir::load_inst::EVICTION_POLICY, Value*> policies_;
|
||||
};
|
||||
|
||||
}
|
||||
|
@@ -91,6 +91,7 @@ Value* geper::operator()(Value *ptr, Value* off, const std::string& name){
|
||||
#define i8_ty builder_->getInt8Ty()
|
||||
#define i16_ty builder_->getInt16Ty()
|
||||
#define i32_ty builder_->getInt32Ty()
|
||||
#define i64_ty builder_->getInt64Ty()
|
||||
#define vec_ty(type, num_el) VectorType::get(type, num_el, false)
|
||||
#define ptr_ty(...) PointerType::get(__VA_ARGS__)
|
||||
// constants
|
||||
@@ -778,6 +779,7 @@ void generator::visit_load_inst(ir::load_inst* x){
|
||||
int tot_width = nbits*vec;
|
||||
int width = std::min(tot_width, max_word_width);
|
||||
int n_words = std::max(1, tot_width / width);
|
||||
bool has_evict_policy = x->get_eviction_policy() != ir::load_inst::NORMAL;
|
||||
// -----
|
||||
// create inline asm string
|
||||
// -----
|
||||
@@ -789,8 +791,9 @@ void generator::visit_load_inst(ir::load_inst* x){
|
||||
asm_oss << ".global";
|
||||
if (x->get_cache_modifier() == ir::load_inst::CA) asm_oss << ".ca";
|
||||
if (x->get_cache_modifier() == ir::load_inst::CG) asm_oss << ".cg";
|
||||
if (x->get_eviction_policy() == ir::load_inst::EVICT_LAST) asm_oss << ".L1::evict_last";
|
||||
if (x->get_eviction_policy() == ir::load_inst::EVICT_FIRST) asm_oss << ".L1::evict_first";
|
||||
if (x->get_eviction_policy() == ir::load_inst::EVICT_LAST) asm_oss << ".L1::evict_last";
|
||||
if (has_evict_policy) asm_oss << ".L2::cache_hint";
|
||||
if(n_words > 1)
|
||||
asm_oss << ".v" << n_words; // vector width
|
||||
asm_oss << ".b" << width; // word size
|
||||
@@ -801,7 +804,9 @@ void generator::visit_load_inst(ir::load_inst* x){
|
||||
}
|
||||
asm_oss << "}";
|
||||
asm_oss << ", [ $" << n_words + 1; // load
|
||||
asm_oss << " + " << in_off << "];"; // constant offset
|
||||
asm_oss << " + " << in_off << "]"; // constant offset
|
||||
if (has_evict_policy) asm_oss << ", $" << n_words + 2;
|
||||
asm_oss << ";";
|
||||
bool has_other = other && (other != UndefValue::get(other->getType()));
|
||||
std::vector<Value *> others;
|
||||
// handle `other` values for indices where the mask
|
||||
@@ -822,7 +827,7 @@ void generator::visit_load_inst(ir::load_inst* x){
|
||||
if(ConstantInt* cst = dyn_cast<ConstantInt>(v))
|
||||
asm_oss << "0x" << std::hex << cst->getSExtValue();
|
||||
else{
|
||||
asm_oss << "$" << n_words + 2 + ii;
|
||||
asm_oss << "$" << n_words + has_evict_policy + 2 + ii;
|
||||
others.push_back(v);
|
||||
}
|
||||
asm_oss.flags(flags);
|
||||
@@ -837,6 +842,8 @@ void generator::visit_load_inst(ir::load_inst* x){
|
||||
std::vector<Type*> arg_tys = {pred->getType(), ptr->getType()};
|
||||
for(Value *v: others)
|
||||
arg_tys.push_back(v->getType());
|
||||
if (has_evict_policy)
|
||||
arg_tys.push_back(i64_ty);
|
||||
FunctionType *asm_ty = FunctionType::get(ret_ty, arg_tys, false);
|
||||
// ---
|
||||
// create inline ASM constraints
|
||||
@@ -851,6 +858,8 @@ void generator::visit_load_inst(ir::load_inst* x){
|
||||
asm_cstrt += ",";
|
||||
asm_cstrt += (width == 64) ? "l" : ((width == 32) ? "r" : "c");
|
||||
}
|
||||
if (has_evict_policy)
|
||||
asm_cstrt += ",l";
|
||||
// ---
|
||||
// finally call inline ASM
|
||||
// ---
|
||||
@@ -858,6 +867,10 @@ void generator::visit_load_inst(ir::load_inst* x){
|
||||
std::vector<Value*> args = {pred, ptr};
|
||||
for(Value *v: others)
|
||||
args.push_back(v);
|
||||
if (has_evict_policy)
|
||||
args.push_back(policies_.at(x->get_eviction_policy()));
|
||||
|
||||
|
||||
Value *_ret = call(inlineAsm, args);
|
||||
// ---
|
||||
// extract and store return values
|
||||
@@ -2880,6 +2893,13 @@ void generator::visit_function(ir::function* fn) {
|
||||
bbs_[block] = dst_block;
|
||||
}
|
||||
builder_->SetInsertPoint(bbs_[fn->blocks()[0]]);
|
||||
// create policies
|
||||
for(ir::load_inst::EVICTION_POLICY evict: {ir::load_inst::EVICT_FIRST, ir::load_inst::EVICT_LAST}){
|
||||
std::string policy = (evict == ir::load_inst::EVICT_FIRST) ? "evict_first" : "evict_last";
|
||||
std::string asm_str = "createpolicy.fractional.L2::" + policy + ".b64 $0;";
|
||||
InlineAsm* iasm = InlineAsm::get(FunctionType::get(i64_ty, {}), asm_str, "=l", false);
|
||||
policies_[evict] = call(iasm);
|
||||
}
|
||||
// initialize layouts
|
||||
for(auto x: layouts_->get_all()){
|
||||
visit_layout(x.second);
|
||||
|
Reference in New Issue
Block a user