[BACKEND][CODE_GEN] eviction policies now also apply to L2 (#501)
This commit is contained in:
@@ -4,6 +4,7 @@
|
|||||||
#define _TRITON_SELECTION_GENERATOR_H_
|
#define _TRITON_SELECTION_GENERATOR_H_
|
||||||
|
|
||||||
#include "triton/ir/visitor.h"
|
#include "triton/ir/visitor.h"
|
||||||
|
#include "triton/ir/instructions.h"
|
||||||
#include "triton/codegen/analysis/layout.h"
|
#include "triton/codegen/analysis/layout.h"
|
||||||
#include <functional>
|
#include <functional>
|
||||||
|
|
||||||
@@ -261,6 +262,9 @@ private:
|
|||||||
|
|
||||||
/// Record prefetch instrs that needs to be moved
|
/// Record prefetch instrs that needs to be moved
|
||||||
std::map<ir::value*, std::vector<Value*>> prefetch_latch_to_bb_;
|
std::map<ir::value*, std::vector<Value*>> prefetch_latch_to_bb_;
|
||||||
|
|
||||||
|
// Eviction policies
|
||||||
|
std::map<ir::load_inst::EVICTION_POLICY, Value*> policies_;
|
||||||
};
|
};
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@@ -91,6 +91,7 @@ Value* geper::operator()(Value *ptr, Value* off, const std::string& name){
|
|||||||
#define i8_ty builder_->getInt8Ty()
|
#define i8_ty builder_->getInt8Ty()
|
||||||
#define i16_ty builder_->getInt16Ty()
|
#define i16_ty builder_->getInt16Ty()
|
||||||
#define i32_ty builder_->getInt32Ty()
|
#define i32_ty builder_->getInt32Ty()
|
||||||
|
#define i64_ty builder_->getInt64Ty()
|
||||||
#define vec_ty(type, num_el) VectorType::get(type, num_el, false)
|
#define vec_ty(type, num_el) VectorType::get(type, num_el, false)
|
||||||
#define ptr_ty(...) PointerType::get(__VA_ARGS__)
|
#define ptr_ty(...) PointerType::get(__VA_ARGS__)
|
||||||
// constants
|
// constants
|
||||||
@@ -778,6 +779,7 @@ void generator::visit_load_inst(ir::load_inst* x){
|
|||||||
int tot_width = nbits*vec;
|
int tot_width = nbits*vec;
|
||||||
int width = std::min(tot_width, max_word_width);
|
int width = std::min(tot_width, max_word_width);
|
||||||
int n_words = std::max(1, tot_width / width);
|
int n_words = std::max(1, tot_width / width);
|
||||||
|
bool has_evict_policy = x->get_eviction_policy() != ir::load_inst::NORMAL;
|
||||||
// -----
|
// -----
|
||||||
// create inline asm string
|
// create inline asm string
|
||||||
// -----
|
// -----
|
||||||
@@ -789,8 +791,9 @@ void generator::visit_load_inst(ir::load_inst* x){
|
|||||||
asm_oss << ".global";
|
asm_oss << ".global";
|
||||||
if (x->get_cache_modifier() == ir::load_inst::CA) asm_oss << ".ca";
|
if (x->get_cache_modifier() == ir::load_inst::CA) asm_oss << ".ca";
|
||||||
if (x->get_cache_modifier() == ir::load_inst::CG) asm_oss << ".cg";
|
if (x->get_cache_modifier() == ir::load_inst::CG) asm_oss << ".cg";
|
||||||
if (x->get_eviction_policy() == ir::load_inst::EVICT_LAST) asm_oss << ".L1::evict_last";
|
|
||||||
if (x->get_eviction_policy() == ir::load_inst::EVICT_FIRST) asm_oss << ".L1::evict_first";
|
if (x->get_eviction_policy() == ir::load_inst::EVICT_FIRST) asm_oss << ".L1::evict_first";
|
||||||
|
if (x->get_eviction_policy() == ir::load_inst::EVICT_LAST) asm_oss << ".L1::evict_last";
|
||||||
|
if (has_evict_policy) asm_oss << ".L2::cache_hint";
|
||||||
if(n_words > 1)
|
if(n_words > 1)
|
||||||
asm_oss << ".v" << n_words; // vector width
|
asm_oss << ".v" << n_words; // vector width
|
||||||
asm_oss << ".b" << width; // word size
|
asm_oss << ".b" << width; // word size
|
||||||
@@ -801,7 +804,9 @@ void generator::visit_load_inst(ir::load_inst* x){
|
|||||||
}
|
}
|
||||||
asm_oss << "}";
|
asm_oss << "}";
|
||||||
asm_oss << ", [ $" << n_words + 1; // load
|
asm_oss << ", [ $" << n_words + 1; // load
|
||||||
asm_oss << " + " << in_off << "];"; // constant offset
|
asm_oss << " + " << in_off << "]"; // constant offset
|
||||||
|
if (has_evict_policy) asm_oss << ", $" << n_words + 2;
|
||||||
|
asm_oss << ";";
|
||||||
bool has_other = other && (other != UndefValue::get(other->getType()));
|
bool has_other = other && (other != UndefValue::get(other->getType()));
|
||||||
std::vector<Value *> others;
|
std::vector<Value *> others;
|
||||||
// handle `other` values for indices where the mask
|
// handle `other` values for indices where the mask
|
||||||
@@ -822,7 +827,7 @@ void generator::visit_load_inst(ir::load_inst* x){
|
|||||||
if(ConstantInt* cst = dyn_cast<ConstantInt>(v))
|
if(ConstantInt* cst = dyn_cast<ConstantInt>(v))
|
||||||
asm_oss << "0x" << std::hex << cst->getSExtValue();
|
asm_oss << "0x" << std::hex << cst->getSExtValue();
|
||||||
else{
|
else{
|
||||||
asm_oss << "$" << n_words + 2 + ii;
|
asm_oss << "$" << n_words + has_evict_policy + 2 + ii;
|
||||||
others.push_back(v);
|
others.push_back(v);
|
||||||
}
|
}
|
||||||
asm_oss.flags(flags);
|
asm_oss.flags(flags);
|
||||||
@@ -837,6 +842,8 @@ void generator::visit_load_inst(ir::load_inst* x){
|
|||||||
std::vector<Type*> arg_tys = {pred->getType(), ptr->getType()};
|
std::vector<Type*> arg_tys = {pred->getType(), ptr->getType()};
|
||||||
for(Value *v: others)
|
for(Value *v: others)
|
||||||
arg_tys.push_back(v->getType());
|
arg_tys.push_back(v->getType());
|
||||||
|
if (has_evict_policy)
|
||||||
|
arg_tys.push_back(i64_ty);
|
||||||
FunctionType *asm_ty = FunctionType::get(ret_ty, arg_tys, false);
|
FunctionType *asm_ty = FunctionType::get(ret_ty, arg_tys, false);
|
||||||
// ---
|
// ---
|
||||||
// create inline ASM constraints
|
// create inline ASM constraints
|
||||||
@@ -851,6 +858,8 @@ void generator::visit_load_inst(ir::load_inst* x){
|
|||||||
asm_cstrt += ",";
|
asm_cstrt += ",";
|
||||||
asm_cstrt += (width == 64) ? "l" : ((width == 32) ? "r" : "c");
|
asm_cstrt += (width == 64) ? "l" : ((width == 32) ? "r" : "c");
|
||||||
}
|
}
|
||||||
|
if (has_evict_policy)
|
||||||
|
asm_cstrt += ",l";
|
||||||
// ---
|
// ---
|
||||||
// finally call inline ASM
|
// finally call inline ASM
|
||||||
// ---
|
// ---
|
||||||
@@ -858,6 +867,10 @@ void generator::visit_load_inst(ir::load_inst* x){
|
|||||||
std::vector<Value*> args = {pred, ptr};
|
std::vector<Value*> args = {pred, ptr};
|
||||||
for(Value *v: others)
|
for(Value *v: others)
|
||||||
args.push_back(v);
|
args.push_back(v);
|
||||||
|
if (has_evict_policy)
|
||||||
|
args.push_back(policies_.at(x->get_eviction_policy()));
|
||||||
|
|
||||||
|
|
||||||
Value *_ret = call(inlineAsm, args);
|
Value *_ret = call(inlineAsm, args);
|
||||||
// ---
|
// ---
|
||||||
// extract and store return values
|
// extract and store return values
|
||||||
@@ -2880,6 +2893,13 @@ void generator::visit_function(ir::function* fn) {
|
|||||||
bbs_[block] = dst_block;
|
bbs_[block] = dst_block;
|
||||||
}
|
}
|
||||||
builder_->SetInsertPoint(bbs_[fn->blocks()[0]]);
|
builder_->SetInsertPoint(bbs_[fn->blocks()[0]]);
|
||||||
|
// create policies
|
||||||
|
for(ir::load_inst::EVICTION_POLICY evict: {ir::load_inst::EVICT_FIRST, ir::load_inst::EVICT_LAST}){
|
||||||
|
std::string policy = (evict == ir::load_inst::EVICT_FIRST) ? "evict_first" : "evict_last";
|
||||||
|
std::string asm_str = "createpolicy.fractional.L2::" + policy + ".b64 $0;";
|
||||||
|
InlineAsm* iasm = InlineAsm::get(FunctionType::get(i64_ty, {}), asm_str, "=l", false);
|
||||||
|
policies_[evict] = call(iasm);
|
||||||
|
}
|
||||||
// initialize layouts
|
// initialize layouts
|
||||||
for(auto x: layouts_->get_all()){
|
for(auto x: layouts_->get_all()){
|
||||||
visit_layout(x.second);
|
visit_layout(x.second);
|
||||||
|
Reference in New Issue
Block a user