[CODEGEN] Add cache modifier to tl.load (#351)

* Add cache modifier to tl.load
* Add comment to cache_modifier
* Remove force_nc_cache
* Update test
This commit is contained in:
daadaada
2021-10-18 13:14:04 +08:00
committed by GitHub
parent 90ded16c32
commit 858dec8372
16 changed files with 119 additions and 63 deletions

View File

@@ -33,7 +33,7 @@ namespace codegen{
std::unique_ptr<llvm::Module> add_passes_to_emit_bin(ir::module &ir, llvm::LLVMContext& ctx,
codegen::target* target,
int sm, int num_warps,
int num_stages, bool force_nc_cache, int &shared_static);
int num_stages, int &shared_static);
}

View File

@@ -122,8 +122,7 @@ public:
analysis::allocation *alloc,
analysis::swizzle *swizzle,
target *tgt,
unsigned num_warps,
bool force_nc_cache = false);
unsigned num_warps);
void visit_value(ir::value* v);
void visit_phi_node(ir::phi_node*);
@@ -213,7 +212,6 @@ private:
std::set<ir::value*> seen_;
unsigned num_warps_;
bool force_nc_cache_;
std::map<analysis::data_layout*, Value*> offset_a_m_;
std::map<analysis::data_layout*, Value*> offset_a_k_;

View File

@@ -130,9 +130,9 @@ public:
value *create_xor(value *lhs, value *rhs);
value *create_or(value *lhs, value *rhs);
// Input/Output
value *create_load(value *arg);
value *create_load(value *arg, load_inst::CACHE_MODIFIER cache);
value *create_store(value *ptr, value *val);
value *create_masked_load(value *arg, value *mask, value *false_value);
value *create_masked_load(value *arg, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache);
value *create_masked_store(value *ptr, value *val, value *mask);
// Block instruction
value *create_splat(value *arg, const type::block_shapes_t &shapes);
@@ -154,7 +154,7 @@ public:
value *create_select(value *pred, value *if_value, value *else_value);
// Intrinsics
value *create_copy_to_shared(value *arg);
value *create_masked_load_async(value *arg, value *mask, value *false_value);
value *create_masked_load_async(value *arg, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache);
value *create_copy_from_shared(value *arg);
value *create_barrier(const std::string &name = "");
value *create_async_wait(int N);

View File

@@ -67,7 +67,7 @@ struct dispatch{
static ir::value *cast(ir::value *input, ir::type *type, ir::builder *builder);
// memory operators
static ir::value *load(ir::value* ptr, ir::value* mask, ir::value* other, ir::builder *builder);
static ir::value *load(ir::value* ptr, ir::value* mask, ir::value* other, const std::string &cache, ir::builder *builder);
static ir::value *store(ir::value* ptr, ir::value *value, ir::value *mask, ir::builder *builder);
static ir::value *atomic_cas(ir::value* ptr, ir::value *cmp, ir::value *val, ir::builder *builder);
static ir::value *atomic_add(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);

View File

@@ -394,22 +394,38 @@ public:
// load
class load_inst: public io_inst {
public:
enum CACHE_MODIFIER : uint32_t {
NONE=0,
CA,
CG,
};
CACHE_MODIFIER get_cache_modifier() const { return cache_; }
protected:
load_inst(value *ptr, value_id_t id, unsigned num_ops,
load_inst(value *ptr, value_id_t id, unsigned num_ops, CACHE_MODIFIER cache,
const std::string &name = "", instruction *next = nullptr);
std::string get_cache_modifier_repr() const {
if (cache_ == CA) return ".ca";
if (cache_ == CG) return ".cg";
return "";
}
CACHE_MODIFIER cache_;
private:
static type *get_pointee_type(type *ty);
};
// unmasked load
class unmasked_load_inst: public load_inst {
private:
std::string repr_impl() const { return "unmasked_load"; }
unmasked_load_inst(value *ptr, const std::string &name, instruction *next);
std::string repr_impl() const { return "unmasked_load" + get_cache_modifier_repr(); }
unmasked_load_inst(value *ptr, load_inst::CACHE_MODIFIER cache, const std::string &name, instruction *next);
public:
static unmasked_load_inst* create(value *ptr,
CACHE_MODIFIER cache,
const std::string &name = "",
instruction *next = nullptr);
_TRITON_DEFINE_CLONE(unmasked_load_inst)
@@ -419,8 +435,8 @@ public:
// masked load
class masked_load_inst: public load_inst {
private:
std::string repr_impl() const { return "masked_load"; }
masked_load_inst(value *ptr, value *mask, value *false_value,
std::string repr_impl() const { return "masked_load" + get_cache_modifier_repr(); }
masked_load_inst(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache,
const std::string &name, instruction *next);
public:
@@ -429,6 +445,7 @@ public:
value *get_false_value_operand() { return get_operand(2); }
// factory method
static masked_load_inst* create(value *ptr, value *mask, value *false_value,
CACHE_MODIFIER cache,
const std::string &name = "",
instruction *next = nullptr);
_TRITON_DEFINE_CLONE(masked_load_inst)
@@ -438,8 +455,8 @@ public:
// masked load async
class masked_load_async_inst: public load_inst {
private:
std::string repr_impl() const { return "masked_load_async_async"; }
masked_load_async_inst(value *ptr, value *mask, value *false_value,
std::string repr_impl() const { return "masked_load_async_async" + get_cache_modifier_repr(); }
masked_load_async_inst(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache,
const std::string &name, instruction *next);
public:
@@ -448,6 +465,7 @@ public:
value *get_false_value_operand() { return get_operand(2); }
// factory method
static masked_load_async_inst* create(value *ptr, value *mask, value *false_value,
load_inst::CACHE_MODIFIER cache,
const std::string &name = "",
instruction *next = nullptr);
_TRITON_DEFINE_CLONE(masked_load_async_inst)