From 858dec8372a2181dbac97689140723adeaa84e0e Mon Sep 17 00:00:00 2001 From: daadaada Date: Mon, 18 Oct 2021 13:14:04 +0800 Subject: [PATCH] [CODEGEN] Add cache modifier to tl.load (#351) * Add cache modifier to tl.load * Add comment to cache_modifier * Remove force_nc_cache * Update test --- include/triton/codegen/pass.h | 2 +- include/triton/codegen/selection/generator.h | 4 +-- include/triton/ir/builder.h | 6 ++-- include/triton/ir/dispatch.h | 2 +- include/triton/ir/instructions.h | 32 +++++++++++++++----- lib/codegen/pass.cc | 4 +-- lib/codegen/selection/generator.cc | 11 +++---- lib/codegen/transform/peephole.cc | 5 +-- lib/codegen/transform/pipeline.cc | 12 +++++--- lib/ir/builder.cc | 12 ++++---- lib/ir/dispatch.cc | 15 +++++++-- lib/ir/instructions.cc | 25 ++++++++------- python/src/triton.cc | 14 ++++----- python/test/unit/language/test_core.py | 24 +++++++++++++++ python/triton/code_gen.py | 8 ++--- python/triton/language/core.py | 6 ++-- 16 files changed, 119 insertions(+), 63 deletions(-) diff --git a/include/triton/codegen/pass.h b/include/triton/codegen/pass.h index 4ded108fe..0c8f11315 100644 --- a/include/triton/codegen/pass.h +++ b/include/triton/codegen/pass.h @@ -33,7 +33,7 @@ namespace codegen{ std::unique_ptr add_passes_to_emit_bin(ir::module &ir, llvm::LLVMContext& ctx, codegen::target* target, int sm, int num_warps, - int num_stages, bool force_nc_cache, int &shared_static); + int num_stages, int &shared_static); } diff --git a/include/triton/codegen/selection/generator.h b/include/triton/codegen/selection/generator.h index f2e7263e3..a02e98d56 100644 --- a/include/triton/codegen/selection/generator.h +++ b/include/triton/codegen/selection/generator.h @@ -122,8 +122,7 @@ public: analysis::allocation *alloc, analysis::swizzle *swizzle, target *tgt, - unsigned num_warps, - bool force_nc_cache = false); + unsigned num_warps); void visit_value(ir::value* v); void visit_phi_node(ir::phi_node*); @@ -213,7 +212,6 @@ private: std::set seen_; unsigned num_warps_; - bool force_nc_cache_; std::map offset_a_m_; std::map offset_a_k_; diff --git a/include/triton/ir/builder.h b/include/triton/ir/builder.h index cdb76cd66..40ced2bd1 100644 --- a/include/triton/ir/builder.h +++ b/include/triton/ir/builder.h @@ -130,9 +130,9 @@ public: value *create_xor(value *lhs, value *rhs); value *create_or(value *lhs, value *rhs); // Input/Output - value *create_load(value *arg); + value *create_load(value *arg, load_inst::CACHE_MODIFIER cache); value *create_store(value *ptr, value *val); - value *create_masked_load(value *arg, value *mask, value *false_value); + value *create_masked_load(value *arg, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache); value *create_masked_store(value *ptr, value *val, value *mask); // Block instruction value *create_splat(value *arg, const type::block_shapes_t &shapes); @@ -154,7 +154,7 @@ public: value *create_select(value *pred, value *if_value, value *else_value); // Intrinsics value *create_copy_to_shared(value *arg); - value *create_masked_load_async(value *arg, value *mask, value *false_value); + value *create_masked_load_async(value *arg, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache); value *create_copy_from_shared(value *arg); value *create_barrier(const std::string &name = ""); value *create_async_wait(int N); diff --git a/include/triton/ir/dispatch.h b/include/triton/ir/dispatch.h index 97b787faa..b53c89d36 100644 --- a/include/triton/ir/dispatch.h +++ b/include/triton/ir/dispatch.h @@ -67,7 +67,7 @@ struct dispatch{ static ir::value *cast(ir::value *input, ir::type *type, ir::builder *builder); // memory operators - static ir::value *load(ir::value* ptr, ir::value* mask, ir::value* other, ir::builder *builder); + static ir::value *load(ir::value* ptr, ir::value* mask, ir::value* other, const std::string &cache, ir::builder *builder); static ir::value *store(ir::value* ptr, ir::value *value, ir::value *mask, ir::builder *builder); static ir::value *atomic_cas(ir::value* ptr, ir::value *cmp, ir::value *val, ir::builder *builder); static ir::value *atomic_add(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder); diff --git a/include/triton/ir/instructions.h b/include/triton/ir/instructions.h index 3a5011276..9b1ffbb79 100644 --- a/include/triton/ir/instructions.h +++ b/include/triton/ir/instructions.h @@ -394,22 +394,38 @@ public: // load class load_inst: public io_inst { +public: + enum CACHE_MODIFIER : uint32_t { + NONE=0, + CA, + CG, + }; + + CACHE_MODIFIER get_cache_modifier() const { return cache_; } protected: - load_inst(value *ptr, value_id_t id, unsigned num_ops, + load_inst(value *ptr, value_id_t id, unsigned num_ops, CACHE_MODIFIER cache, const std::string &name = "", instruction *next = nullptr); + std::string get_cache_modifier_repr() const { + if (cache_ == CA) return ".ca"; + if (cache_ == CG) return ".cg"; + return ""; + } + CACHE_MODIFIER cache_; private: static type *get_pointee_type(type *ty); + }; // unmasked load class unmasked_load_inst: public load_inst { private: - std::string repr_impl() const { return "unmasked_load"; } - unmasked_load_inst(value *ptr, const std::string &name, instruction *next); + std::string repr_impl() const { return "unmasked_load" + get_cache_modifier_repr(); } + unmasked_load_inst(value *ptr, load_inst::CACHE_MODIFIER cache, const std::string &name, instruction *next); public: static unmasked_load_inst* create(value *ptr, + CACHE_MODIFIER cache, const std::string &name = "", instruction *next = nullptr); _TRITON_DEFINE_CLONE(unmasked_load_inst) @@ -419,8 +435,8 @@ public: // masked load class masked_load_inst: public load_inst { private: - std::string repr_impl() const { return "masked_load"; } - masked_load_inst(value *ptr, value *mask, value *false_value, + std::string repr_impl() const { return "masked_load" + get_cache_modifier_repr(); } + masked_load_inst(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, const std::string &name, instruction *next); public: @@ -429,6 +445,7 @@ public: value *get_false_value_operand() { return get_operand(2); } // factory method static masked_load_inst* create(value *ptr, value *mask, value *false_value, + CACHE_MODIFIER cache, const std::string &name = "", instruction *next = nullptr); _TRITON_DEFINE_CLONE(masked_load_inst) @@ -438,8 +455,8 @@ public: // masked load async class masked_load_async_inst: public load_inst { private: - std::string repr_impl() const { return "masked_load_async_async"; } - masked_load_async_inst(value *ptr, value *mask, value *false_value, + std::string repr_impl() const { return "masked_load_async_async" + get_cache_modifier_repr(); } + masked_load_async_inst(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, const std::string &name, instruction *next); public: @@ -448,6 +465,7 @@ public: value *get_false_value_operand() { return get_operand(2); } // factory method static masked_load_async_inst* create(value *ptr, value *mask, value *false_value, + load_inst::CACHE_MODIFIER cache, const std::string &name = "", instruction *next = nullptr); _TRITON_DEFINE_CLONE(masked_load_async_inst) diff --git a/lib/codegen/pass.cc b/lib/codegen/pass.cc index 3af81bfe3..845e2e36d 100644 --- a/lib/codegen/pass.cc +++ b/lib/codegen/pass.cc @@ -25,7 +25,7 @@ namespace codegen { // TODO: // There should be a proper pass manager there! std::unique_ptr add_passes_to_emit_bin(ir::module &ir, llvm::LLVMContext& ctx, codegen::target* target, - int cc, int num_warps, int num_stages, bool force_nc_cache, int& shared_static) { + int cc, int num_warps, int num_stages, int& shared_static) { // generate llvm code std::string name = ir.get_function_list()[0]->get_name(); std::unique_ptr llvm(new llvm::Module(name, ctx)); @@ -46,7 +46,7 @@ std::unique_ptr add_passes_to_emit_bin(ir::module &ir, llvm::LLVMC codegen::transform::coalesce coalesce(&align, &layouts); codegen::transform::prefetch prefetch_s(target); codegen::transform::membar barriers(&liveness, &layouts, &allocation, &prefetch_s, target); - codegen::generator isel(&axes, &layouts, &align, &allocation, &swizzle, target, num_warps, force_nc_cache); + codegen::generator isel(&axes, &layouts, &align, &allocation, &swizzle, target, num_warps); // run passes dce.run(ir); peephole.run(ir); diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index 20692036f..6e31e5c7e 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -197,9 +197,9 @@ generator::generator(analysis::axes *a_axes, analysis::allocation *alloc, analysis::swizzle *swizzle, target *tgt, - unsigned num_warps, bool force_nc_cache) + unsigned num_warps) : a_axes_(a_axes), layouts_(layouts), alignment_(alignment), alloc_(alloc), swizzle_(swizzle), - tgt_(tgt), num_warps_(num_warps), force_nc_cache_(force_nc_cache), add(&builder_), mul(&builder_), gep(&builder_) { + tgt_(tgt), num_warps_(num_warps), add(&builder_), mul(&builder_), gep(&builder_) { } @@ -629,10 +629,9 @@ void generator::visit_load_inst(ir::load_inst* x){ // ----- std::ostringstream asm_oss; asm_oss << "@$" << n_words; // predicate -// if(force_nc_cache_) - asm_oss << " ld.global"; -// else -// asm_oss << " ld.global.cg"; + asm_oss << " ld.global"; + if (x->get_cache_modifier() == ir::load_inst::CA) asm_oss << ".ca"; + if (x->get_cache_modifier() == ir::load_inst::CG) asm_oss << ".cg"; if(n_words > 1) asm_oss << ".v" << n_words; // vector width asm_oss << ".b" << width; // word size diff --git a/lib/codegen/transform/peephole.cc b/lib/codegen/transform/peephole.cc index 1c8920230..ac0220ebc 100644 --- a/lib/codegen/transform/peephole.cc +++ b/lib/codegen/transform/peephole.cc @@ -116,7 +116,7 @@ bool peephole::rewrite_load_to_shared(ir::instruction *value, ir::builder& build int nts = layout->nts(layout->get_order()[0]); int dtsize = value->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8; if(nts*dtsize >= 4){ - ir::value* new_load = builder.create_masked_load_async(ptr, msk, val); + ir::value* new_load = builder.create_masked_load_async(ptr, msk, val, ld->get_cache_modifier()); copy_to_shared->replace_all_uses_with(new_load); return true; } @@ -206,7 +206,8 @@ bool peephole::rewrite_select_masked_load(ir::instruction *value, ir::builder& b builder.set_insert_point(select); ir::value* new_load = builder.create_masked_load(if_value->get_pointer_operand(), if_value->get_mask_operand(), - select->get_else_value_op()); + select->get_else_value_op(), + if_value->get_cache_modifier()); select->replace_all_uses_with(new_load); return true; } diff --git a/lib/codegen/transform/pipeline.cc b/lib/codegen/transform/pipeline.cc index c09cafe83..cc7835bbc 100644 --- a/lib/codegen/transform/pipeline.cc +++ b/lib/codegen/transform/pipeline.cc @@ -111,6 +111,8 @@ struct pipeline_info_t { }; void pipeline::run(ir::module &mod) { + if (num_stages_ <= 1) + return; // *Very* conservative heuristics for pre-fetching. // A load instruction can be pipelined if: // - the pointer is a phi node that references a value @@ -176,7 +178,7 @@ void pipeline::run(ir::module &mod) { false_value = remat_false_value; } else false_value = builder.create_splat(ir::undef_value::get(ty->get_scalar_ty()), ty->get_block_shapes()); - first_loads[0] = builder.create_masked_load(first_ptrs[0], first_masks[0], false_value); + first_loads[0] = builder.create_masked_load(first_ptrs[0], first_masks[0], false_value, load->get_cache_modifier()); for (int stage = 1; stage < num_stages-1; ++stage) { // mask is the loop condition of the previous iteration @@ -191,7 +193,7 @@ void pipeline::run(ir::module &mod) { first_masks[stage] = builder.create_and(first_masks[stage], remat_mask); false_value = remat_false_value; } - first_loads[stage] = builder.create_masked_load(first_ptrs[stage], first_masks[stage], false_value); + first_loads[stage] = builder.create_masked_load(first_ptrs[stage], first_masks[stage], false_value, load->get_cache_modifier()); } // create new phis for induction variables @@ -220,7 +222,7 @@ void pipeline::run(ir::module &mod) { next_mask = builder.create_and(next_mask, remat_mask); false_value = remat_false_value; } - ir::value* next_load = builder.create_masked_load(next_ptr, next_mask, false_value); + ir::value* next_load = builder.create_masked_load(next_ptr, next_mask, false_value, load->get_cache_modifier()); // phi node @@ -255,7 +257,7 @@ void pipeline::run(ir::module &mod) { } else false_value = builder.create_splat(ir::undef_value::get(ty->get_scalar_ty()), ty->get_block_shapes()); - ir::value* first_load = builder.create_masked_load(first_ptr, first_mask, false_value); + ir::value* first_load = builder.create_masked_load(first_ptr, first_mask, false_value, load->get_cache_modifier()); // pre-fetch next iteration builder.set_insert_point(block->get_inst_list().back()); ir::value* next_ptr = ptr->get_value_for_block(block); @@ -266,7 +268,7 @@ void pipeline::run(ir::module &mod) { next_mask = builder.create_and(next_mask, remat_mask); false_value = remat_false_value; } - ir::value* next_load = builder.create_masked_load(next_ptr, next_mask, false_value); + ir::value* next_load = builder.create_masked_load(next_ptr, next_mask, false_value, load->get_cache_modifier()); // phi node builder.set_insert_point(block->get_first_non_phi()); ir::phi_node* new_load = builder.create_phi(ty, 2); diff --git a/lib/ir/builder.cc b/lib/ir/builder.cc index 7f06ea370..f4eadcb84 100644 --- a/lib/ir/builder.cc +++ b/lib/ir/builder.cc @@ -273,16 +273,16 @@ DEFINE_FCMP_INSTR(UNE, cmp_pred_t::FCMP_UNE) // load/store instructions //===----------------------------------------------------------------------===// -value *builder::create_load(value *ptr){ - return insert(unmasked_load_inst::create(ptr)); +value *builder::create_load(value *ptr, load_inst::CACHE_MODIFIER cache){ + return insert(unmasked_load_inst::create(ptr, cache)); } value *builder::create_store(value *ptr, value *val){ return insert(unmasked_store_inst::create(ptr, val)); } -value *builder::create_masked_load(value *ptr, value *mask, value *false_value){ - return insert(masked_load_inst::create(ptr, mask, false_value)); +value *builder::create_masked_load(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache){ + return insert(masked_load_inst::create(ptr, mask, false_value, cache)); } value *builder::create_masked_store(value *ptr, value *val, value *mask){ @@ -379,8 +379,8 @@ value *builder::create_copy_from_shared(value *arg) { return insert(copy_from_shared_inst::create(arg)); } -value *builder::create_masked_load_async(value *ptr, value *mask, value *false_value) { - return insert(masked_load_async_inst::create(ptr, mask, false_value)); +value *builder::create_masked_load_async(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache) { + return insert(masked_load_async_inst::create(ptr, mask, false_value, cache)); } value *builder::create_barrier(const std::string &name) { diff --git a/lib/ir/dispatch.cc b/lib/ir/dispatch.cc index 61ce7317e..cc0626fbd 100644 --- a/lib/ir/dispatch.cc +++ b/lib/ir/dispatch.cc @@ -497,7 +497,7 @@ ir::value *dispatch::cast(ir::value *input, ir::type *dst_ty, ir::builder *build // Memory Operators //===----------------------------------------------------------------------===// -ir::value *dispatch::load(ir::value* ptr, ir::value* mask, ir::value* other, ir::builder* builder) { +ir::value *dispatch::load(ir::value* ptr, ir::value* mask, ir::value* other, const std::string &cache_modifier, ir::builder* builder) { if(!ptr->get_type()->get_scalar_ty()->is_pointer_ty()) throw semantic_error("Pointer argument of load instruction is " + ptr->get_type()->repr()); if(ptr->get_type()->is_block_ty()){ @@ -517,8 +517,17 @@ ir::value *dispatch::load(ir::value* ptr, ir::value* mask, ir::value* other, ir: ptr_ty = pointer_type::get(elt_ty, ptr_ty->get_pointer_address_space()); ptr = dispatch::cast(ptr, ptr_ty, builder); } + load_inst::CACHE_MODIFIER cache = load_inst::NONE; // default + if (!cache_modifier.empty()) { + if (cache_modifier == ".ca") + cache = load_inst::CA; + else if (cache_modifier == ".cg") + cache = load_inst::CG; + else + throw std::runtime_error(std::string("Cache modifier ") + cache_modifier + " not supported"); + } if (!mask && !other) - return builder->create_load(ptr); + return builder->create_load(ptr, cache); if (!mask) throw std::runtime_error("`other` cannot be provided without `mask`"); auto shape = ptr->get_type()->get_block_shapes(); @@ -527,7 +536,7 @@ ir::value *dispatch::load(ir::value* ptr, ir::value* mask, ir::value* other, ir: if(ptr->get_type()->is_block_ty()) other = builder->create_splat(other, ptr->get_type()->get_block_shapes()); } - return builder->create_masked_load(ptr, mask, other); + return builder->create_masked_load(ptr, mask, other, cache); } ir::value *dispatch::store(ir::value* ptr, ir::value *val, ir::value* mask, ir::builder *builder) { diff --git a/lib/ir/instructions.cc b/lib/ir/instructions.cc index 298492a30..b3e52e94d 100644 --- a/lib/ir/instructions.cc +++ b/lib/ir/instructions.cc @@ -433,8 +433,8 @@ io_inst::io_inst(type *ty, value_id_t id, unsigned num_ops, const std::string &n { } // load_inst -load_inst::load_inst(value *ptr, value_id_t id, unsigned num_ops, const std::string &name, instruction *next) - : io_inst(get_pointee_type(ptr->get_type()), id, num_ops, name, next) +load_inst::load_inst(value *ptr, value_id_t id, unsigned num_ops, load_inst::CACHE_MODIFIER cache, const std::string &name, instruction *next) + : io_inst(get_pointee_type(ptr->get_type()), id, num_ops, name, next), cache_(cache) { } // load @@ -447,41 +447,44 @@ type *load_inst::get_pointee_type(type *ty) { } // unmasked_load -unmasked_load_inst::unmasked_load_inst(value *ptr, const std::string &name, instruction *next) - : load_inst(ptr, INST_UNMASKED_LOAD, 1, name, next) { +unmasked_load_inst::unmasked_load_inst(value *ptr, load_inst::CACHE_MODIFIER cache, const std::string &name, instruction *next) + : load_inst(ptr, INST_UNMASKED_LOAD, 1, cache, name, next) { set_operand(0, ptr); } -unmasked_load_inst* unmasked_load_inst::create(value *ptr, const std::string &name, instruction *next) { - return new unmasked_load_inst(ptr, name, next); +unmasked_load_inst* unmasked_load_inst::create(value *ptr, load_inst::CACHE_MODIFIER cache, const std::string &name, instruction *next) { + return new unmasked_load_inst(ptr, cache, name, next); } // masked load -masked_load_inst::masked_load_inst(value *ptr, value *mask, value *false_value, +masked_load_inst::masked_load_inst(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, const std::string &name, instruction *next) - : load_inst(ptr, INST_MASKED_LOAD, 3, name, next) { + : load_inst(ptr, INST_MASKED_LOAD, 3, cache, name, next) { set_operand(0, ptr); set_operand(1, mask); set_operand(2, false_value); } masked_load_inst* masked_load_inst::create(value *ptr, value *mask, value *false_value, + load_inst::CACHE_MODIFIER cache, const std::string &name, instruction *next) { - return new masked_load_inst(ptr, mask, false_value, name, next); + return new masked_load_inst(ptr, mask, false_value, cache, name, next); } // masked load async masked_load_async_inst::masked_load_async_inst(value *ptr, value *mask, value *false_value, + load_inst::CACHE_MODIFIER cache, const std::string &name, instruction *next) - : load_inst(ptr, INST_MASKED_LOAD_ASYNC, 3, name, next) { + : load_inst(ptr, INST_MASKED_LOAD_ASYNC, 3, cache, name, next) { set_operand(0, ptr); set_operand(1, mask); set_operand(2, false_value); } masked_load_async_inst* masked_load_async_inst::create(value *ptr, value *mask, value *false_value, + load_inst::CACHE_MODIFIER cache, const std::string &name, instruction *next) { - return new masked_load_async_inst(ptr, mask, false_value, name, next); + return new masked_load_async_inst(ptr, mask, false_value, cache, name, next); } // store diff --git a/python/src/triton.cc b/python/src/triton.cc index 1378710e1..5f40f48f4 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -203,7 +203,7 @@ std::tuple hip_load_binary(const std::string& name, asm_map_ // CUDA std::tuple cu_compile_ttir(const std::string& name, ir::module &ir, uint64_t device, int num_warps, int num_stages, - bool force_nc_cache, asm_map_t &asm_map){ + asm_map_t &asm_map){ llvm::LLVMContext ctx; // device properties CUdevice dev = (CUdevice)device; @@ -215,7 +215,7 @@ std::tuple cu_compile_ttir(const std::string& name, // Triton-IR -> NVPTX LLVM-IR triton::codegen::nvidia_cu_target target(cc); int n_shared_bytes; - auto llvm = triton::codegen::add_passes_to_emit_bin(ir, ctx, &target, cc, num_warps, num_stages, force_nc_cache, n_shared_bytes); + auto llvm = triton::codegen::add_passes_to_emit_bin(ir, ctx, &target, cc, num_warps, num_stages, n_shared_bytes); std::string tmp; llvm::raw_string_ostream llir(tmp); llir << *llvm; @@ -236,12 +236,12 @@ std::tuple cu_compile_ttir(const std::string& name, // HIP std::tuple hip_compile_ttir(const std::string& name, ir::module &ir, uint64_t device, int num_warps, int num_stages, - bool force_nc_cache, asm_map_t &asm_map){ + asm_map_t &asm_map){ llvm::LLVMContext ctx; // Triton-IR -> NVPTX LLVM-IR triton::codegen::amd_cl_target target; int n_shared_bytes; - auto llvm = triton::codegen::add_passes_to_emit_bin(ir, ctx, &target, 70, num_warps, num_stages, force_nc_cache, n_shared_bytes); + auto llvm = triton::codegen::add_passes_to_emit_bin(ir, ctx, &target, 70, num_warps, num_stages, n_shared_bytes); std::string tmp; llvm::raw_string_ostream llir(tmp); llir << *llvm; @@ -255,7 +255,7 @@ std::tuple hip_compile_ttir(const std::string& name void init_triton_codegen(py::module &&m) { m.def( - "compile_ttir", [](backend_t backend, ir::module &ir, uint64_t device, int num_warps, int num_stages, bool force_nc_cache) { + "compile_ttir", [](backend_t backend, ir::module &ir, uint64_t device, int num_warps, int num_stages) { std::string name = ir.get_function_list()[0]->get_name(); // record asm as we generate asm_map_t asm_map; @@ -264,9 +264,9 @@ void init_triton_codegen(py::module &&m) { asm_map["ttir"] = py::cast(ttir.str()); llvm::LLVMContext ctx; if(backend == CUDA) - return cu_compile_ttir(name, ir, device, num_warps, num_stages, force_nc_cache, asm_map); + return cu_compile_ttir(name, ir, device, num_warps, num_stages, asm_map); if(backend == ROCM) - return hip_compile_ttir(name, ir, device, num_warps, num_stages, force_nc_cache, asm_map); + return hip_compile_ttir(name, ir, device, num_warps, num_stages, asm_map); }, py::return_value_policy::take_ownership); m.def("load_binary", [](backend_t backend, const std::string& name, asm_map_t &asm_map, size_t n_shared_bytes, uint64_t dev){ if(backend == CUDA) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 024807392..25054f0dc 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -599,6 +599,30 @@ def test_masked_load_shared_memory(dtype, device='cuda'): reference_out =torch.matmul(in1, in2) triton.testing.allclose(out, reference_out) +@pytest.mark.parametrize("cache", ["", ".ca", ".cg"]) +def test_load_cache_modifier(cache): + src = torch.empty(128, device='cuda') + dst = torch.empty(128, device='cuda') + + @triton.jit + def _kernel(dst, src, **meta): + offsets = tl.arange(0, 128) + x = tl.load(src+offsets, cache_modifier=meta['CACHE']) + tl.store(dst+offsets, x) + + pgm = _kernel[(1,)](dst, src, CACHE=cache) + ptx = pgm.asm['ptx'] + + if cache == '': + assert 'ld.global.ca' not in ptx + assert 'ld.global.cg' not in ptx + if cache == '.cg': + assert 'ld.global.cg' in ptx + assert 'ld.global.ca' not in ptx + if cache == '.ca': + assert 'ld.global.ca' in ptx + assert 'ld.global.cg' not in ptx + # --------------- # test store # --------------- diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 40c3ddee4..8622333bf 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -537,7 +537,7 @@ class Kernel: def __init__(self, fn): self.fn = fn - def _compile(self, *wargs, device, attributes, constants, num_warps, num_stages, force_nc_cache, **meta): + def _compile(self, *wargs, device, attributes, constants, num_warps, num_stages, **meta): # create IR module context = _triton.ir.context() # get just-in-time proto-type of kernel @@ -560,13 +560,13 @@ class Kernel: backend = _triton.runtime.backend.CUDA else: backend = _triton.runtime.backend.ROCM - name, asm, shared_mem = _triton.code_gen.compile_ttir(backend, generator.module, device, num_warps, num_stages, force_nc_cache) + name, asm, shared_mem = _triton.code_gen.compile_ttir(backend, generator.module, device, num_warps, num_stages) max_shared_memory = _triton.runtime.max_shared_memory(backend, device) if shared_mem > max_shared_memory: raise OutOfResources(shared_mem, max_shared_memory, "shared memory") return Binary(backend, name, asm, shared_mem, num_warps) - def __call__(self, *wargs, grid, num_warps=4, num_stages=2, force_nc_cache=False, **meta): + def __call__(self, *wargs, grid, num_warps=4, num_stages=2, **meta): # device inference tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')] if len(tensor_idxs) == 0: @@ -643,7 +643,7 @@ class Kernel: if binary is None: binary = self._compile( *wargs, device=device_idx, attributes=attributes, - num_warps=num_warps, num_stages=num_stages, force_nc_cache=force_nc_cache, + num_warps=num_warps, num_stages=num_stages, constants=constants, **meta ) if bin_cache_path: diff --git a/python/triton/language/core.py b/python/triton/language/core.py index f492b33d9..f201a9591 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -387,7 +387,7 @@ def dot(input, other, _builder=None): @builtin -def load(pointer, mask=None, other=None, _builder=None): +def load(pointer, mask=None, other=None, cache_modifier="", _builder=None): """ Return a block of data whose values are, elementwise, loaded from memory at location defined by :code:`pointer`. @@ -401,8 +401,10 @@ def load(pointer, mask=None, other=None, _builder=None): :type mask: Block of triton.int1, optional :param other: if mask[idx] is false, return other[idx] :type other: Block, optional + :param cache_modifier: changes cache option in nvidia ptx + 'type cache_modifier: str, optional """ - return frontend.load(pointer, mask, other, _builder) + return frontend.load(pointer, mask, other, cache_modifier, _builder) @builtin