[CODEGEN] Add cache modifier to tl.load (#351)

* Add cache modifier to tl.load
* Add comment to cache_modifier
* Remove force_nc_cache
* Update test
This commit is contained in:
daadaada
2021-10-18 13:14:04 +08:00
committed by GitHub
parent 90ded16c32
commit 858dec8372
16 changed files with 119 additions and 63 deletions

View File

@@ -33,7 +33,7 @@ namespace codegen{
std::unique_ptr<llvm::Module> add_passes_to_emit_bin(ir::module &ir, llvm::LLVMContext& ctx, std::unique_ptr<llvm::Module> add_passes_to_emit_bin(ir::module &ir, llvm::LLVMContext& ctx,
codegen::target* target, codegen::target* target,
int sm, int num_warps, int sm, int num_warps,
int num_stages, bool force_nc_cache, int &shared_static); int num_stages, int &shared_static);
} }

View File

@@ -122,8 +122,7 @@ public:
analysis::allocation *alloc, analysis::allocation *alloc,
analysis::swizzle *swizzle, analysis::swizzle *swizzle,
target *tgt, target *tgt,
unsigned num_warps, unsigned num_warps);
bool force_nc_cache = false);
void visit_value(ir::value* v); void visit_value(ir::value* v);
void visit_phi_node(ir::phi_node*); void visit_phi_node(ir::phi_node*);
@@ -213,7 +212,6 @@ private:
std::set<ir::value*> seen_; std::set<ir::value*> seen_;
unsigned num_warps_; unsigned num_warps_;
bool force_nc_cache_;
std::map<analysis::data_layout*, Value*> offset_a_m_; std::map<analysis::data_layout*, Value*> offset_a_m_;
std::map<analysis::data_layout*, Value*> offset_a_k_; std::map<analysis::data_layout*, Value*> offset_a_k_;

View File

@@ -130,9 +130,9 @@ public:
value *create_xor(value *lhs, value *rhs); value *create_xor(value *lhs, value *rhs);
value *create_or(value *lhs, value *rhs); value *create_or(value *lhs, value *rhs);
// Input/Output // Input/Output
value *create_load(value *arg); value *create_load(value *arg, load_inst::CACHE_MODIFIER cache);
value *create_store(value *ptr, value *val); value *create_store(value *ptr, value *val);
value *create_masked_load(value *arg, value *mask, value *false_value); value *create_masked_load(value *arg, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache);
value *create_masked_store(value *ptr, value *val, value *mask); value *create_masked_store(value *ptr, value *val, value *mask);
// Block instruction // Block instruction
value *create_splat(value *arg, const type::block_shapes_t &shapes); value *create_splat(value *arg, const type::block_shapes_t &shapes);
@@ -154,7 +154,7 @@ public:
value *create_select(value *pred, value *if_value, value *else_value); value *create_select(value *pred, value *if_value, value *else_value);
// Intrinsics // Intrinsics
value *create_copy_to_shared(value *arg); value *create_copy_to_shared(value *arg);
value *create_masked_load_async(value *arg, value *mask, value *false_value); value *create_masked_load_async(value *arg, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache);
value *create_copy_from_shared(value *arg); value *create_copy_from_shared(value *arg);
value *create_barrier(const std::string &name = ""); value *create_barrier(const std::string &name = "");
value *create_async_wait(int N); value *create_async_wait(int N);

View File

@@ -67,7 +67,7 @@ struct dispatch{
static ir::value *cast(ir::value *input, ir::type *type, ir::builder *builder); static ir::value *cast(ir::value *input, ir::type *type, ir::builder *builder);
// memory operators // memory operators
static ir::value *load(ir::value* ptr, ir::value* mask, ir::value* other, ir::builder *builder); static ir::value *load(ir::value* ptr, ir::value* mask, ir::value* other, const std::string &cache, ir::builder *builder);
static ir::value *store(ir::value* ptr, ir::value *value, ir::value *mask, ir::builder *builder); static ir::value *store(ir::value* ptr, ir::value *value, ir::value *mask, ir::builder *builder);
static ir::value *atomic_cas(ir::value* ptr, ir::value *cmp, ir::value *val, ir::builder *builder); static ir::value *atomic_cas(ir::value* ptr, ir::value *cmp, ir::value *val, ir::builder *builder);
static ir::value *atomic_add(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder); static ir::value *atomic_add(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);

View File

@@ -394,22 +394,38 @@ public:
// load // load
class load_inst: public io_inst { class load_inst: public io_inst {
public:
enum CACHE_MODIFIER : uint32_t {
NONE=0,
CA,
CG,
};
CACHE_MODIFIER get_cache_modifier() const { return cache_; }
protected: protected:
load_inst(value *ptr, value_id_t id, unsigned num_ops, load_inst(value *ptr, value_id_t id, unsigned num_ops, CACHE_MODIFIER cache,
const std::string &name = "", instruction *next = nullptr); const std::string &name = "", instruction *next = nullptr);
std::string get_cache_modifier_repr() const {
if (cache_ == CA) return ".ca";
if (cache_ == CG) return ".cg";
return "";
}
CACHE_MODIFIER cache_;
private: private:
static type *get_pointee_type(type *ty); static type *get_pointee_type(type *ty);
}; };
// unmasked load // unmasked load
class unmasked_load_inst: public load_inst { class unmasked_load_inst: public load_inst {
private: private:
std::string repr_impl() const { return "unmasked_load"; } std::string repr_impl() const { return "unmasked_load" + get_cache_modifier_repr(); }
unmasked_load_inst(value *ptr, const std::string &name, instruction *next); unmasked_load_inst(value *ptr, load_inst::CACHE_MODIFIER cache, const std::string &name, instruction *next);
public: public:
static unmasked_load_inst* create(value *ptr, static unmasked_load_inst* create(value *ptr,
CACHE_MODIFIER cache,
const std::string &name = "", const std::string &name = "",
instruction *next = nullptr); instruction *next = nullptr);
_TRITON_DEFINE_CLONE(unmasked_load_inst) _TRITON_DEFINE_CLONE(unmasked_load_inst)
@@ -419,8 +435,8 @@ public:
// masked load // masked load
class masked_load_inst: public load_inst { class masked_load_inst: public load_inst {
private: private:
std::string repr_impl() const { return "masked_load"; } std::string repr_impl() const { return "masked_load" + get_cache_modifier_repr(); }
masked_load_inst(value *ptr, value *mask, value *false_value, masked_load_inst(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache,
const std::string &name, instruction *next); const std::string &name, instruction *next);
public: public:
@@ -429,6 +445,7 @@ public:
value *get_false_value_operand() { return get_operand(2); } value *get_false_value_operand() { return get_operand(2); }
// factory method // factory method
static masked_load_inst* create(value *ptr, value *mask, value *false_value, static masked_load_inst* create(value *ptr, value *mask, value *false_value,
CACHE_MODIFIER cache,
const std::string &name = "", const std::string &name = "",
instruction *next = nullptr); instruction *next = nullptr);
_TRITON_DEFINE_CLONE(masked_load_inst) _TRITON_DEFINE_CLONE(masked_load_inst)
@@ -438,8 +455,8 @@ public:
// masked load async // masked load async
class masked_load_async_inst: public load_inst { class masked_load_async_inst: public load_inst {
private: private:
std::string repr_impl() const { return "masked_load_async_async"; } std::string repr_impl() const { return "masked_load_async_async" + get_cache_modifier_repr(); }
masked_load_async_inst(value *ptr, value *mask, value *false_value, masked_load_async_inst(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache,
const std::string &name, instruction *next); const std::string &name, instruction *next);
public: public:
@@ -448,6 +465,7 @@ public:
value *get_false_value_operand() { return get_operand(2); } value *get_false_value_operand() { return get_operand(2); }
// factory method // factory method
static masked_load_async_inst* create(value *ptr, value *mask, value *false_value, static masked_load_async_inst* create(value *ptr, value *mask, value *false_value,
load_inst::CACHE_MODIFIER cache,
const std::string &name = "", const std::string &name = "",
instruction *next = nullptr); instruction *next = nullptr);
_TRITON_DEFINE_CLONE(masked_load_async_inst) _TRITON_DEFINE_CLONE(masked_load_async_inst)

View File

@@ -25,7 +25,7 @@ namespace codegen {
// TODO: // TODO:
// There should be a proper pass manager there! // There should be a proper pass manager there!
std::unique_ptr<llvm::Module> add_passes_to_emit_bin(ir::module &ir, llvm::LLVMContext& ctx, codegen::target* target, std::unique_ptr<llvm::Module> add_passes_to_emit_bin(ir::module &ir, llvm::LLVMContext& ctx, codegen::target* target,
int cc, int num_warps, int num_stages, bool force_nc_cache, int& shared_static) { int cc, int num_warps, int num_stages, int& shared_static) {
// generate llvm code // generate llvm code
std::string name = ir.get_function_list()[0]->get_name(); std::string name = ir.get_function_list()[0]->get_name();
std::unique_ptr<llvm::Module> llvm(new llvm::Module(name, ctx)); std::unique_ptr<llvm::Module> llvm(new llvm::Module(name, ctx));
@@ -46,7 +46,7 @@ std::unique_ptr<llvm::Module> add_passes_to_emit_bin(ir::module &ir, llvm::LLVMC
codegen::transform::coalesce coalesce(&align, &layouts); codegen::transform::coalesce coalesce(&align, &layouts);
codegen::transform::prefetch prefetch_s(target); codegen::transform::prefetch prefetch_s(target);
codegen::transform::membar barriers(&liveness, &layouts, &allocation, &prefetch_s, target); codegen::transform::membar barriers(&liveness, &layouts, &allocation, &prefetch_s, target);
codegen::generator isel(&axes, &layouts, &align, &allocation, &swizzle, target, num_warps, force_nc_cache); codegen::generator isel(&axes, &layouts, &align, &allocation, &swizzle, target, num_warps);
// run passes // run passes
dce.run(ir); dce.run(ir);
peephole.run(ir); peephole.run(ir);

View File

@@ -197,9 +197,9 @@ generator::generator(analysis::axes *a_axes,
analysis::allocation *alloc, analysis::allocation *alloc,
analysis::swizzle *swizzle, analysis::swizzle *swizzle,
target *tgt, target *tgt,
unsigned num_warps, bool force_nc_cache) unsigned num_warps)
: a_axes_(a_axes), layouts_(layouts), alignment_(alignment), alloc_(alloc), swizzle_(swizzle), : a_axes_(a_axes), layouts_(layouts), alignment_(alignment), alloc_(alloc), swizzle_(swizzle),
tgt_(tgt), num_warps_(num_warps), force_nc_cache_(force_nc_cache), add(&builder_), mul(&builder_), gep(&builder_) { tgt_(tgt), num_warps_(num_warps), add(&builder_), mul(&builder_), gep(&builder_) {
} }
@@ -629,10 +629,9 @@ void generator::visit_load_inst(ir::load_inst* x){
// ----- // -----
std::ostringstream asm_oss; std::ostringstream asm_oss;
asm_oss << "@$" << n_words; // predicate asm_oss << "@$" << n_words; // predicate
// if(force_nc_cache_) asm_oss << " ld.global";
asm_oss << " ld.global"; if (x->get_cache_modifier() == ir::load_inst::CA) asm_oss << ".ca";
// else if (x->get_cache_modifier() == ir::load_inst::CG) asm_oss << ".cg";
// asm_oss << " ld.global.cg";
if(n_words > 1) if(n_words > 1)
asm_oss << ".v" << n_words; // vector width asm_oss << ".v" << n_words; // vector width
asm_oss << ".b" << width; // word size asm_oss << ".b" << width; // word size

View File

@@ -116,7 +116,7 @@ bool peephole::rewrite_load_to_shared(ir::instruction *value, ir::builder& build
int nts = layout->nts(layout->get_order()[0]); int nts = layout->nts(layout->get_order()[0]);
int dtsize = value->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8; int dtsize = value->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
if(nts*dtsize >= 4){ if(nts*dtsize >= 4){
ir::value* new_load = builder.create_masked_load_async(ptr, msk, val); ir::value* new_load = builder.create_masked_load_async(ptr, msk, val, ld->get_cache_modifier());
copy_to_shared->replace_all_uses_with(new_load); copy_to_shared->replace_all_uses_with(new_load);
return true; return true;
} }
@@ -206,7 +206,8 @@ bool peephole::rewrite_select_masked_load(ir::instruction *value, ir::builder& b
builder.set_insert_point(select); builder.set_insert_point(select);
ir::value* new_load = builder.create_masked_load(if_value->get_pointer_operand(), ir::value* new_load = builder.create_masked_load(if_value->get_pointer_operand(),
if_value->get_mask_operand(), if_value->get_mask_operand(),
select->get_else_value_op()); select->get_else_value_op(),
if_value->get_cache_modifier());
select->replace_all_uses_with(new_load); select->replace_all_uses_with(new_load);
return true; return true;
} }

View File

@@ -111,6 +111,8 @@ struct pipeline_info_t {
}; };
void pipeline::run(ir::module &mod) { void pipeline::run(ir::module &mod) {
if (num_stages_ <= 1)
return;
// *Very* conservative heuristics for pre-fetching. // *Very* conservative heuristics for pre-fetching.
// A load instruction can be pipelined if: // A load instruction can be pipelined if:
// - the pointer is a phi node that references a value // - the pointer is a phi node that references a value
@@ -176,7 +178,7 @@ void pipeline::run(ir::module &mod) {
false_value = remat_false_value; false_value = remat_false_value;
} else } else
false_value = builder.create_splat(ir::undef_value::get(ty->get_scalar_ty()), ty->get_block_shapes()); false_value = builder.create_splat(ir::undef_value::get(ty->get_scalar_ty()), ty->get_block_shapes());
first_loads[0] = builder.create_masked_load(first_ptrs[0], first_masks[0], false_value); first_loads[0] = builder.create_masked_load(first_ptrs[0], first_masks[0], false_value, load->get_cache_modifier());
for (int stage = 1; stage < num_stages-1; ++stage) { for (int stage = 1; stage < num_stages-1; ++stage) {
// mask is the loop condition of the previous iteration // mask is the loop condition of the previous iteration
@@ -191,7 +193,7 @@ void pipeline::run(ir::module &mod) {
first_masks[stage] = builder.create_and(first_masks[stage], remat_mask); first_masks[stage] = builder.create_and(first_masks[stage], remat_mask);
false_value = remat_false_value; false_value = remat_false_value;
} }
first_loads[stage] = builder.create_masked_load(first_ptrs[stage], first_masks[stage], false_value); first_loads[stage] = builder.create_masked_load(first_ptrs[stage], first_masks[stage], false_value, load->get_cache_modifier());
} }
// create new phis for induction variables // create new phis for induction variables
@@ -220,7 +222,7 @@ void pipeline::run(ir::module &mod) {
next_mask = builder.create_and(next_mask, remat_mask); next_mask = builder.create_and(next_mask, remat_mask);
false_value = remat_false_value; false_value = remat_false_value;
} }
ir::value* next_load = builder.create_masked_load(next_ptr, next_mask, false_value); ir::value* next_load = builder.create_masked_load(next_ptr, next_mask, false_value, load->get_cache_modifier());
// phi node // phi node
@@ -255,7 +257,7 @@ void pipeline::run(ir::module &mod) {
} }
else else
false_value = builder.create_splat(ir::undef_value::get(ty->get_scalar_ty()), ty->get_block_shapes()); false_value = builder.create_splat(ir::undef_value::get(ty->get_scalar_ty()), ty->get_block_shapes());
ir::value* first_load = builder.create_masked_load(first_ptr, first_mask, false_value); ir::value* first_load = builder.create_masked_load(first_ptr, first_mask, false_value, load->get_cache_modifier());
// pre-fetch next iteration // pre-fetch next iteration
builder.set_insert_point(block->get_inst_list().back()); builder.set_insert_point(block->get_inst_list().back());
ir::value* next_ptr = ptr->get_value_for_block(block); ir::value* next_ptr = ptr->get_value_for_block(block);
@@ -266,7 +268,7 @@ void pipeline::run(ir::module &mod) {
next_mask = builder.create_and(next_mask, remat_mask); next_mask = builder.create_and(next_mask, remat_mask);
false_value = remat_false_value; false_value = remat_false_value;
} }
ir::value* next_load = builder.create_masked_load(next_ptr, next_mask, false_value); ir::value* next_load = builder.create_masked_load(next_ptr, next_mask, false_value, load->get_cache_modifier());
// phi node // phi node
builder.set_insert_point(block->get_first_non_phi()); builder.set_insert_point(block->get_first_non_phi());
ir::phi_node* new_load = builder.create_phi(ty, 2); ir::phi_node* new_load = builder.create_phi(ty, 2);

View File

@@ -273,16 +273,16 @@ DEFINE_FCMP_INSTR(UNE, cmp_pred_t::FCMP_UNE)
// load/store instructions // load/store instructions
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
value *builder::create_load(value *ptr){ value *builder::create_load(value *ptr, load_inst::CACHE_MODIFIER cache){
return insert(unmasked_load_inst::create(ptr)); return insert(unmasked_load_inst::create(ptr, cache));
} }
value *builder::create_store(value *ptr, value *val){ value *builder::create_store(value *ptr, value *val){
return insert(unmasked_store_inst::create(ptr, val)); return insert(unmasked_store_inst::create(ptr, val));
} }
value *builder::create_masked_load(value *ptr, value *mask, value *false_value){ value *builder::create_masked_load(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache){
return insert(masked_load_inst::create(ptr, mask, false_value)); return insert(masked_load_inst::create(ptr, mask, false_value, cache));
} }
value *builder::create_masked_store(value *ptr, value *val, value *mask){ value *builder::create_masked_store(value *ptr, value *val, value *mask){
@@ -379,8 +379,8 @@ value *builder::create_copy_from_shared(value *arg) {
return insert(copy_from_shared_inst::create(arg)); return insert(copy_from_shared_inst::create(arg));
} }
value *builder::create_masked_load_async(value *ptr, value *mask, value *false_value) { value *builder::create_masked_load_async(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache) {
return insert(masked_load_async_inst::create(ptr, mask, false_value)); return insert(masked_load_async_inst::create(ptr, mask, false_value, cache));
} }
value *builder::create_barrier(const std::string &name) { value *builder::create_barrier(const std::string &name) {

View File

@@ -497,7 +497,7 @@ ir::value *dispatch::cast(ir::value *input, ir::type *dst_ty, ir::builder *build
// Memory Operators // Memory Operators
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
ir::value *dispatch::load(ir::value* ptr, ir::value* mask, ir::value* other, ir::builder* builder) { ir::value *dispatch::load(ir::value* ptr, ir::value* mask, ir::value* other, const std::string &cache_modifier, ir::builder* builder) {
if(!ptr->get_type()->get_scalar_ty()->is_pointer_ty()) if(!ptr->get_type()->get_scalar_ty()->is_pointer_ty())
throw semantic_error("Pointer argument of load instruction is " + ptr->get_type()->repr()); throw semantic_error("Pointer argument of load instruction is " + ptr->get_type()->repr());
if(ptr->get_type()->is_block_ty()){ if(ptr->get_type()->is_block_ty()){
@@ -517,8 +517,17 @@ ir::value *dispatch::load(ir::value* ptr, ir::value* mask, ir::value* other, ir:
ptr_ty = pointer_type::get(elt_ty, ptr_ty->get_pointer_address_space()); ptr_ty = pointer_type::get(elt_ty, ptr_ty->get_pointer_address_space());
ptr = dispatch::cast(ptr, ptr_ty, builder); ptr = dispatch::cast(ptr, ptr_ty, builder);
} }
load_inst::CACHE_MODIFIER cache = load_inst::NONE; // default
if (!cache_modifier.empty()) {
if (cache_modifier == ".ca")
cache = load_inst::CA;
else if (cache_modifier == ".cg")
cache = load_inst::CG;
else
throw std::runtime_error(std::string("Cache modifier ") + cache_modifier + " not supported");
}
if (!mask && !other) if (!mask && !other)
return builder->create_load(ptr); return builder->create_load(ptr, cache);
if (!mask) if (!mask)
throw std::runtime_error("`other` cannot be provided without `mask`"); throw std::runtime_error("`other` cannot be provided without `mask`");
auto shape = ptr->get_type()->get_block_shapes(); auto shape = ptr->get_type()->get_block_shapes();
@@ -527,7 +536,7 @@ ir::value *dispatch::load(ir::value* ptr, ir::value* mask, ir::value* other, ir:
if(ptr->get_type()->is_block_ty()) if(ptr->get_type()->is_block_ty())
other = builder->create_splat(other, ptr->get_type()->get_block_shapes()); other = builder->create_splat(other, ptr->get_type()->get_block_shapes());
} }
return builder->create_masked_load(ptr, mask, other); return builder->create_masked_load(ptr, mask, other, cache);
} }
ir::value *dispatch::store(ir::value* ptr, ir::value *val, ir::value* mask, ir::builder *builder) { ir::value *dispatch::store(ir::value* ptr, ir::value *val, ir::value* mask, ir::builder *builder) {

View File

@@ -433,8 +433,8 @@ io_inst::io_inst(type *ty, value_id_t id, unsigned num_ops, const std::string &n
{ } { }
// load_inst // load_inst
load_inst::load_inst(value *ptr, value_id_t id, unsigned num_ops, const std::string &name, instruction *next) load_inst::load_inst(value *ptr, value_id_t id, unsigned num_ops, load_inst::CACHE_MODIFIER cache, const std::string &name, instruction *next)
: io_inst(get_pointee_type(ptr->get_type()), id, num_ops, name, next) : io_inst(get_pointee_type(ptr->get_type()), id, num_ops, name, next), cache_(cache)
{ } { }
// load // load
@@ -447,41 +447,44 @@ type *load_inst::get_pointee_type(type *ty) {
} }
// unmasked_load // unmasked_load
unmasked_load_inst::unmasked_load_inst(value *ptr, const std::string &name, instruction *next) unmasked_load_inst::unmasked_load_inst(value *ptr, load_inst::CACHE_MODIFIER cache, const std::string &name, instruction *next)
: load_inst(ptr, INST_UNMASKED_LOAD, 1, name, next) { : load_inst(ptr, INST_UNMASKED_LOAD, 1, cache, name, next) {
set_operand(0, ptr); set_operand(0, ptr);
} }
unmasked_load_inst* unmasked_load_inst::create(value *ptr, const std::string &name, instruction *next) { unmasked_load_inst* unmasked_load_inst::create(value *ptr, load_inst::CACHE_MODIFIER cache, const std::string &name, instruction *next) {
return new unmasked_load_inst(ptr, name, next); return new unmasked_load_inst(ptr, cache, name, next);
} }
// masked load // masked load
masked_load_inst::masked_load_inst(value *ptr, value *mask, value *false_value, masked_load_inst::masked_load_inst(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache,
const std::string &name, instruction *next) const std::string &name, instruction *next)
: load_inst(ptr, INST_MASKED_LOAD, 3, name, next) { : load_inst(ptr, INST_MASKED_LOAD, 3, cache, name, next) {
set_operand(0, ptr); set_operand(0, ptr);
set_operand(1, mask); set_operand(1, mask);
set_operand(2, false_value); set_operand(2, false_value);
} }
masked_load_inst* masked_load_inst::create(value *ptr, value *mask, value *false_value, masked_load_inst* masked_load_inst::create(value *ptr, value *mask, value *false_value,
load_inst::CACHE_MODIFIER cache,
const std::string &name, instruction *next) { const std::string &name, instruction *next) {
return new masked_load_inst(ptr, mask, false_value, name, next); return new masked_load_inst(ptr, mask, false_value, cache, name, next);
} }
// masked load async // masked load async
masked_load_async_inst::masked_load_async_inst(value *ptr, value *mask, value *false_value, masked_load_async_inst::masked_load_async_inst(value *ptr, value *mask, value *false_value,
load_inst::CACHE_MODIFIER cache,
const std::string &name, instruction *next) const std::string &name, instruction *next)
: load_inst(ptr, INST_MASKED_LOAD_ASYNC, 3, name, next) { : load_inst(ptr, INST_MASKED_LOAD_ASYNC, 3, cache, name, next) {
set_operand(0, ptr); set_operand(0, ptr);
set_operand(1, mask); set_operand(1, mask);
set_operand(2, false_value); set_operand(2, false_value);
} }
masked_load_async_inst* masked_load_async_inst::create(value *ptr, value *mask, value *false_value, masked_load_async_inst* masked_load_async_inst::create(value *ptr, value *mask, value *false_value,
load_inst::CACHE_MODIFIER cache,
const std::string &name, instruction *next) { const std::string &name, instruction *next) {
return new masked_load_async_inst(ptr, mask, false_value, name, next); return new masked_load_async_inst(ptr, mask, false_value, cache, name, next);
} }
// store // store

View File

@@ -203,7 +203,7 @@ std::tuple<uint64_t, uint64_t> hip_load_binary(const std::string& name, asm_map_
// CUDA // CUDA
std::tuple<std::string, asm_map_t, int> cu_compile_ttir(const std::string& name, ir::module &ir, std::tuple<std::string, asm_map_t, int> cu_compile_ttir(const std::string& name, ir::module &ir,
uint64_t device, int num_warps, int num_stages, uint64_t device, int num_warps, int num_stages,
bool force_nc_cache, asm_map_t &asm_map){ asm_map_t &asm_map){
llvm::LLVMContext ctx; llvm::LLVMContext ctx;
// device properties // device properties
CUdevice dev = (CUdevice)device; CUdevice dev = (CUdevice)device;
@@ -215,7 +215,7 @@ std::tuple<std::string, asm_map_t, int> cu_compile_ttir(const std::string& name,
// Triton-IR -> NVPTX LLVM-IR // Triton-IR -> NVPTX LLVM-IR
triton::codegen::nvidia_cu_target target(cc); triton::codegen::nvidia_cu_target target(cc);
int n_shared_bytes; int n_shared_bytes;
auto llvm = triton::codegen::add_passes_to_emit_bin(ir, ctx, &target, cc, num_warps, num_stages, force_nc_cache, n_shared_bytes); auto llvm = triton::codegen::add_passes_to_emit_bin(ir, ctx, &target, cc, num_warps, num_stages, n_shared_bytes);
std::string tmp; std::string tmp;
llvm::raw_string_ostream llir(tmp); llvm::raw_string_ostream llir(tmp);
llir << *llvm; llir << *llvm;
@@ -236,12 +236,12 @@ std::tuple<std::string, asm_map_t, int> cu_compile_ttir(const std::string& name,
// HIP // HIP
std::tuple<std::string, asm_map_t, int> hip_compile_ttir(const std::string& name, ir::module &ir, std::tuple<std::string, asm_map_t, int> hip_compile_ttir(const std::string& name, ir::module &ir,
uint64_t device, int num_warps, int num_stages, uint64_t device, int num_warps, int num_stages,
bool force_nc_cache, asm_map_t &asm_map){ asm_map_t &asm_map){
llvm::LLVMContext ctx; llvm::LLVMContext ctx;
// Triton-IR -> NVPTX LLVM-IR // Triton-IR -> NVPTX LLVM-IR
triton::codegen::amd_cl_target target; triton::codegen::amd_cl_target target;
int n_shared_bytes; int n_shared_bytes;
auto llvm = triton::codegen::add_passes_to_emit_bin(ir, ctx, &target, 70, num_warps, num_stages, force_nc_cache, n_shared_bytes); auto llvm = triton::codegen::add_passes_to_emit_bin(ir, ctx, &target, 70, num_warps, num_stages, n_shared_bytes);
std::string tmp; std::string tmp;
llvm::raw_string_ostream llir(tmp); llvm::raw_string_ostream llir(tmp);
llir << *llvm; llir << *llvm;
@@ -255,7 +255,7 @@ std::tuple<std::string, asm_map_t, int> hip_compile_ttir(const std::string& name
void init_triton_codegen(py::module &&m) { void init_triton_codegen(py::module &&m) {
m.def( m.def(
"compile_ttir", [](backend_t backend, ir::module &ir, uint64_t device, int num_warps, int num_stages, bool force_nc_cache) { "compile_ttir", [](backend_t backend, ir::module &ir, uint64_t device, int num_warps, int num_stages) {
std::string name = ir.get_function_list()[0]->get_name(); std::string name = ir.get_function_list()[0]->get_name();
// record asm as we generate // record asm as we generate
asm_map_t asm_map; asm_map_t asm_map;
@@ -264,9 +264,9 @@ void init_triton_codegen(py::module &&m) {
asm_map["ttir"] = py::cast(ttir.str()); asm_map["ttir"] = py::cast(ttir.str());
llvm::LLVMContext ctx; llvm::LLVMContext ctx;
if(backend == CUDA) if(backend == CUDA)
return cu_compile_ttir(name, ir, device, num_warps, num_stages, force_nc_cache, asm_map); return cu_compile_ttir(name, ir, device, num_warps, num_stages, asm_map);
if(backend == ROCM) if(backend == ROCM)
return hip_compile_ttir(name, ir, device, num_warps, num_stages, force_nc_cache, asm_map); return hip_compile_ttir(name, ir, device, num_warps, num_stages, asm_map);
}, py::return_value_policy::take_ownership); }, py::return_value_policy::take_ownership);
m.def("load_binary", [](backend_t backend, const std::string& name, asm_map_t &asm_map, size_t n_shared_bytes, uint64_t dev){ m.def("load_binary", [](backend_t backend, const std::string& name, asm_map_t &asm_map, size_t n_shared_bytes, uint64_t dev){
if(backend == CUDA) if(backend == CUDA)

View File

@@ -599,6 +599,30 @@ def test_masked_load_shared_memory(dtype, device='cuda'):
reference_out =torch.matmul(in1, in2) reference_out =torch.matmul(in1, in2)
triton.testing.allclose(out, reference_out) triton.testing.allclose(out, reference_out)
@pytest.mark.parametrize("cache", ["", ".ca", ".cg"])
def test_load_cache_modifier(cache):
src = torch.empty(128, device='cuda')
dst = torch.empty(128, device='cuda')
@triton.jit
def _kernel(dst, src, **meta):
offsets = tl.arange(0, 128)
x = tl.load(src+offsets, cache_modifier=meta['CACHE'])
tl.store(dst+offsets, x)
pgm = _kernel[(1,)](dst, src, CACHE=cache)
ptx = pgm.asm['ptx']
if cache == '':
assert 'ld.global.ca' not in ptx
assert 'ld.global.cg' not in ptx
if cache == '.cg':
assert 'ld.global.cg' in ptx
assert 'ld.global.ca' not in ptx
if cache == '.ca':
assert 'ld.global.ca' in ptx
assert 'ld.global.cg' not in ptx
# --------------- # ---------------
# test store # test store
# --------------- # ---------------

View File

@@ -537,7 +537,7 @@ class Kernel:
def __init__(self, fn): def __init__(self, fn):
self.fn = fn self.fn = fn
def _compile(self, *wargs, device, attributes, constants, num_warps, num_stages, force_nc_cache, **meta): def _compile(self, *wargs, device, attributes, constants, num_warps, num_stages, **meta):
# create IR module # create IR module
context = _triton.ir.context() context = _triton.ir.context()
# get just-in-time proto-type of kernel # get just-in-time proto-type of kernel
@@ -560,13 +560,13 @@ class Kernel:
backend = _triton.runtime.backend.CUDA backend = _triton.runtime.backend.CUDA
else: else:
backend = _triton.runtime.backend.ROCM backend = _triton.runtime.backend.ROCM
name, asm, shared_mem = _triton.code_gen.compile_ttir(backend, generator.module, device, num_warps, num_stages, force_nc_cache) name, asm, shared_mem = _triton.code_gen.compile_ttir(backend, generator.module, device, num_warps, num_stages)
max_shared_memory = _triton.runtime.max_shared_memory(backend, device) max_shared_memory = _triton.runtime.max_shared_memory(backend, device)
if shared_mem > max_shared_memory: if shared_mem > max_shared_memory:
raise OutOfResources(shared_mem, max_shared_memory, "shared memory") raise OutOfResources(shared_mem, max_shared_memory, "shared memory")
return Binary(backend, name, asm, shared_mem, num_warps) return Binary(backend, name, asm, shared_mem, num_warps)
def __call__(self, *wargs, grid, num_warps=4, num_stages=2, force_nc_cache=False, **meta): def __call__(self, *wargs, grid, num_warps=4, num_stages=2, **meta):
# device inference # device inference
tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')] tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')]
if len(tensor_idxs) == 0: if len(tensor_idxs) == 0:
@@ -643,7 +643,7 @@ class Kernel:
if binary is None: if binary is None:
binary = self._compile( binary = self._compile(
*wargs, device=device_idx, attributes=attributes, *wargs, device=device_idx, attributes=attributes,
num_warps=num_warps, num_stages=num_stages, force_nc_cache=force_nc_cache, num_warps=num_warps, num_stages=num_stages,
constants=constants, **meta constants=constants, **meta
) )
if bin_cache_path: if bin_cache_path:

View File

@@ -387,7 +387,7 @@ def dot(input, other, _builder=None):
@builtin @builtin
def load(pointer, mask=None, other=None, _builder=None): def load(pointer, mask=None, other=None, cache_modifier="", _builder=None):
""" """
Return a block of data whose values are, elementwise, loaded from memory at location defined by :code:`pointer`. Return a block of data whose values are, elementwise, loaded from memory at location defined by :code:`pointer`.
@@ -401,8 +401,10 @@ def load(pointer, mask=None, other=None, _builder=None):
:type mask: Block of triton.int1, optional :type mask: Block of triton.int1, optional
:param other: if mask[idx] is false, return other[idx] :param other: if mask[idx] is false, return other[idx]
:type other: Block, optional :type other: Block, optional
:param cache_modifier: changes cache option in nvidia ptx
'type cache_modifier: str, optional
""" """
return frontend.load(pointer, mask, other, _builder) return frontend.load(pointer, mask, other, cache_modifier, _builder)
@builtin @builtin