From 98ed7db8c1f3e7de6e67cbb3838f0692fb541be4 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Thu, 24 Feb 2022 14:56:24 -0800 Subject: [PATCH] [CODEGEN] Improvements and bugfixes (#463) --- CMakeLists.txt | 11 ++++- include/triton/driver/llvm.h | 3 +- include/triton/ir/builder.h | 6 +-- include/triton/ir/dispatch.h | 3 +- include/triton/ir/instructions.h | 32 ++++++++++---- lib/codegen/selection/generator.cc | 16 ++++--- lib/codegen/transform/peephole.cc | 3 +- lib/codegen/transform/pipeline.cc | 10 ++--- lib/driver/llvm.cc | 69 ++++++++++++++++++++---------- lib/ir/builder.cc | 12 +++--- lib/ir/dispatch.cc | 32 ++++++++++---- lib/ir/instructions.cc | 30 +++++++------ python/src/triton.cc | 4 +- python/triton/language/core.py | 4 +- 14 files changed, 154 insertions(+), 81 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index f44c35aa7..c3aadf9c7 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -110,6 +110,15 @@ libLLVMBinaryFormat.a libLLVMAMDGPUInfo.a libLLVMSupport.a libLLVMDemangle.a +libLLVMPasses.a +libLLVMAnalysis.a +libLLVMTransformUtils.a +libLLVMScalarOpts.a +libLLVMTransformUtils.a +libLLVMipo.a +libLLVMObjCARCOpts.a +libLLVMCoroutines.a +libLLVMAnalysis.a ) endif() include_directories("${LLVM_INCLUDE_DIRS}") @@ -148,7 +157,7 @@ target_link_options(triton PRIVATE ${LLVM_LDFLAGS}) if(WIN32) target_link_libraries(triton PRIVATE ${LLVM_LIBRARIES} dl) # dl is from dlfcn-win32 else() - target_link_libraries(triton ${LLVM_LIBRARIES} z ${TERMINFO_LIBRARY}) + target_link_libraries(triton ${LLVM_LIBRARIES} z) endif() diff --git a/include/triton/driver/llvm.h b/include/triton/driver/llvm.h index 89dc98169..c0c1c0f37 100644 --- a/include/triton/driver/llvm.h +++ b/include/triton/driver/llvm.h @@ -9,8 +9,9 @@ namespace triton{ namespace driver{ void init_llvm(); +std::string path_to_ptxas(int& version); std::string llir_to_ptx(llvm::Module* module, int cc, int version); -std::string ptx_to_cubin(const std::string& ptx, int cc); +std::string ptx_to_cubin(const std::string& ptx, const std::string& ptxas_path, int cc); CUmodule ptx_to_cumodule(const std::string& ptx, int cc); std::string llir_to_amdgpu(llvm::Module* module, const std::string& proc); hipModule_t amdgpu_to_hipmodule(const std::string& path); diff --git a/include/triton/ir/builder.h b/include/triton/ir/builder.h index 67ab47c90..2b6bc6ab3 100644 --- a/include/triton/ir/builder.h +++ b/include/triton/ir/builder.h @@ -136,9 +136,9 @@ public: value *create_xor(value *lhs, value *rhs); value *create_or(value *lhs, value *rhs); // Input/Output - value *create_load(value *arg, load_inst::CACHE_MODIFIER cache, bool is_volatile); + value *create_load(value *arg, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile); value *create_store(value *ptr, value *val); - value *create_masked_load(value *arg, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, bool is_volatile); + value *create_masked_load(value *arg, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile); value *create_masked_store(value *ptr, value *val, value *mask); // Block instruction value *create_splat(value *arg, const type::block_shapes_t &shapes); @@ -163,7 +163,7 @@ public: // These have no place in the IR, and hopefully they can be removed at some point value *create_umulhi(value* lhs, value* rhs); value *create_copy_to_shared(value *arg); - value *create_masked_load_async(value *arg, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache); + value *create_masked_load_async(value *arg, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY); value *create_copy_from_shared(value *arg); value *create_barrier(const std::string &name = ""); value *create_async_wait(int N); diff --git a/include/triton/ir/dispatch.h b/include/triton/ir/dispatch.h index 7dc8ab0ed..ef14043dd 100644 --- a/include/triton/ir/dispatch.h +++ b/include/triton/ir/dispatch.h @@ -69,7 +69,8 @@ struct dispatch{ static ir::value *cast(ir::value *input, ir::type *type, ir::builder *builder); // memory operators - static ir::value *load(ir::value* ptr, ir::value* mask, ir::value* other, const std::string &cache, int is_volatile, ir::builder *builder); + static ir::value *load(ir::value* ptr, ir::value* mask, ir::value* other, const std::string &cache, + const std::string& eviction_policy, int is_volatile, ir::builder *builder); static ir::value *store(ir::value* ptr, ir::value *value, ir::value *mask, ir::builder *builder); static ir::value *atomic_cas(ir::value* ptr, ir::value *cmp, ir::value *val, ir::builder *builder); static ir::value *atomic_add(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder); diff --git a/include/triton/ir/instructions.h b/include/triton/ir/instructions.h index 5af077e8f..0fb85db02 100644 --- a/include/triton/ir/instructions.h +++ b/include/triton/ir/instructions.h @@ -406,13 +406,20 @@ public: NONE=0, CA, CG, - }; + }; + + enum EVICTION_POLICY : uint32_t { + NORMAL=0, + EVICT_FIRST, + EVICT_LAST, + }; CACHE_MODIFIER get_cache_modifier() const { return cache_; } + EVICTION_POLICY get_eviction_policy() const { return eviction_; } bool get_is_volatile() const { return is_volatile_; } protected: - load_inst(value *ptr, value_id_t id, unsigned num_ops, CACHE_MODIFIER cache, + load_inst(value *ptr, value_id_t id, unsigned num_ops, CACHE_MODIFIER cache, EVICTION_POLICY eviction, bool is_volatile, const std::string &name = "", instruction *next = nullptr); std::string get_cache_modifier_repr() const { @@ -420,6 +427,11 @@ protected: if (cache_ == CG) return ".cg"; return ""; } + std::string get_eviction_policy_repr() const { + if (eviction_ == EVICT_FIRST) return ".L1::evict_first"; + if (eviction_ == EVICT_LAST) return ".L2::evict_last"; + } + EVICTION_POLICY eviction_; CACHE_MODIFIER cache_; std::string get_volatile_repr() { @@ -435,11 +447,12 @@ private: class unmasked_load_inst: public load_inst { private: std::string repr_impl() const { return "unmasked_load" + get_cache_modifier_repr(); } - unmasked_load_inst(value *ptr, load_inst::CACHE_MODIFIER cache, bool is_volatile, const std::string &name, instruction *next); + unmasked_load_inst(value *ptr, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile, const std::string &name, instruction *next); public: static unmasked_load_inst* create(value *ptr, - CACHE_MODIFIER cache, bool is_volatile, + CACHE_MODIFIER cache, EVICTION_POLICY eviction, + bool is_volatile, const std::string &name = "", instruction *next = nullptr); _TRITON_DEFINE_CLONE(unmasked_load_inst) @@ -450,7 +463,7 @@ public: class masked_load_inst: public load_inst { private: std::string repr_impl() const { return "masked_load" + get_cache_modifier_repr(); } - masked_load_inst(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, bool is_volatile, + masked_load_inst(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile, const std::string &name, instruction *next); public: @@ -459,7 +472,8 @@ public: value *get_false_value_operand() { return get_operand(2); } // factory method static masked_load_inst* create(value *ptr, value *mask, value *false_value, - CACHE_MODIFIER cache, bool is_volatile, + CACHE_MODIFIER cache, EVICTION_POLICY eviction, + bool is_volatile, const std::string &name = "", instruction *next = nullptr); _TRITON_DEFINE_CLONE(masked_load_inst) @@ -470,8 +484,9 @@ public: class masked_load_async_inst: public load_inst { private: std::string repr_impl() const { return "masked_load_async" + get_cache_modifier_repr(); } - masked_load_async_inst(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, - const std::string &name, instruction *next); + masked_load_async_inst(value *ptr, value *mask, value *false_value, + CACHE_MODIFIER cache, EVICTION_POLICY eviction, + const std::string &name, instruction *next); public: // accessors @@ -480,6 +495,7 @@ public: // factory method static masked_load_async_inst* create(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, + EVICTION_POLICY eviction, const std::string &name = "", instruction *next = nullptr); _TRITON_DEFINE_CLONE(masked_load_async_inst) diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index d2ebce1c6..b4f1dd41e 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -119,7 +119,7 @@ Value* geper::operator()(Value *ptr, Value* off, const std::string& name){ #define icmp_ult(...) builder_->CreateICmpULT(__VA_ARGS__) #define insert_elt(...) builder_->CreateInsertElement(__VA_ARGS__) #define intrinsic(...) builder_->CreateIntrinsic(__VA_ARGS__) -#define load(...) builder_->CreateLoad(__VA_ARGS__) +#define load(ptr) builder_->CreateLoad(ptr->getType()->getPointerElementType(), ptr) #define lshr(...) builder_->CreateLShr(__VA_ARGS__) #define max_num(...) builder_->CreateMaxNum(__VA_ARGS__) #define min_num(...) builder_->CreateMinNum(__VA_ARGS__) @@ -576,18 +576,19 @@ void generator::visit_cast_inst(ir::cast_inst* x) { // <> BF16 if(ret_sca_ty->is_bf16_ty() || op_sca_ty->is_bf16_ty()){ // FP32 -> BF16 - if(op_sca_ty->is_fp32_ty()) - // for(size_t i = 0; i < x_idxs.size(); i++) - // vals_[x][x_idxs[i + 0]] = fp32_to_bf16(vals_[op][op_idxs[i + 0]]); + if(op_sca_ty->is_fp32_ty()){ for (indices_t idx: idxs_.at(x)) { Value *arg = vals_[x->get_operand(0)][idx]; vals_[x][idx] = fp32_to_bf16(arg); // cast(cvt(x->get_op()), arg, ty); } + return; + } // BF16 -> FP32 - if(ret_sca_ty->is_fp32_ty()) + if(ret_sca_ty->is_fp32_ty()){ for(size_t i = 0; i < x_idxs.size(); i++) vals_[x][x_idxs[i + 0]] = bf16_to_fp32(vals_[op][op_idxs[i + 0]]); - return; + return; + } } @@ -697,12 +698,13 @@ void generator::visit_load_inst(ir::load_inst* x){ std::ostringstream asm_oss; asm_oss << "@$" << n_words; // predicate asm_oss << " ld"; -// std::cout << x->get_is_volatile() << std::endl; if(x->get_is_volatile()) asm_oss << ".volatile"; asm_oss << ".global"; if (x->get_cache_modifier() == ir::load_inst::CA) asm_oss << ".ca"; if (x->get_cache_modifier() == ir::load_inst::CG) asm_oss << ".cg"; + if (x->get_eviction_policy() == ir::load_inst::EVICT_LAST) asm_oss << ".L1::evict_last"; + if (x->get_eviction_policy() == ir::load_inst::EVICT_FIRST) asm_oss << ".L1::evict_first"; if(n_words > 1) asm_oss << ".v" << n_words; // vector width asm_oss << ".b" << width; // word size diff --git a/lib/codegen/transform/peephole.cc b/lib/codegen/transform/peephole.cc index e30ab9b35..0961efc9c 100644 --- a/lib/codegen/transform/peephole.cc +++ b/lib/codegen/transform/peephole.cc @@ -123,7 +123,7 @@ bool peephole::rewrite_load_to_shared(ir::instruction *value, ir::builder& build int nts = layout->nts(layout->get_order()[0]); int dtsize = value->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8; if(nts*dtsize >= 4){ - ir::value* new_load = builder.create_masked_load_async(ptr, msk, val, ld->get_cache_modifier()); + ir::value* new_load = builder.create_masked_load_async(ptr, msk, val, ld->get_cache_modifier(), ld->get_eviction_policy()); copy_to_shared->replace_all_uses_with(new_load); return true; } @@ -215,6 +215,7 @@ bool peephole::rewrite_select_masked_load(ir::instruction *value, ir::builder& b if_value->get_mask_operand(), select->get_else_value_op(), if_value->get_cache_modifier(), + if_value->get_eviction_policy(), if_value->get_is_volatile()); select->replace_all_uses_with(new_load); return true; diff --git a/lib/codegen/transform/pipeline.cc b/lib/codegen/transform/pipeline.cc index eb3fe6164..c85ba43a1 100644 --- a/lib/codegen/transform/pipeline.cc +++ b/lib/codegen/transform/pipeline.cc @@ -178,7 +178,7 @@ void pipeline::run(ir::module &mod) { false_value = remat_false_value; } else false_value = builder.create_splat(ir::undef_value::get(ty->get_scalar_ty()), ty->get_block_shapes()); - first_loads[0] = builder.create_masked_load(first_ptrs[0], first_masks[0], false_value, load->get_cache_modifier(), load->get_is_volatile()); + first_loads[0] = builder.create_masked_load(first_ptrs[0], first_masks[0], false_value, load->get_cache_modifier(), load->get_eviction_policy(), load->get_is_volatile()); for (int stage = 1; stage < num_stages-1; ++stage) { // mask is the loop condition of the previous iteration @@ -193,7 +193,7 @@ void pipeline::run(ir::module &mod) { first_masks[stage] = builder.create_and(first_masks[stage], remat_mask); false_value = remat_false_value; } - first_loads[stage] = builder.create_masked_load(first_ptrs[stage], first_masks[stage], false_value, load->get_cache_modifier(), load->get_is_volatile()); + first_loads[stage] = builder.create_masked_load(first_ptrs[stage], first_masks[stage], false_value, load->get_cache_modifier(), load->get_eviction_policy(), load->get_is_volatile()); } // create new phis for induction variables @@ -222,7 +222,7 @@ void pipeline::run(ir::module &mod) { next_mask = builder.create_and(next_mask, remat_mask); false_value = remat_false_value; } - ir::value* next_load = builder.create_masked_load(next_ptr, next_mask, false_value, load->get_cache_modifier(), load->get_is_volatile()); + ir::value* next_load = builder.create_masked_load(next_ptr, next_mask, false_value, load->get_cache_modifier(), load->get_eviction_policy(), load->get_is_volatile()); // phi node @@ -257,7 +257,7 @@ void pipeline::run(ir::module &mod) { } else false_value = builder.create_splat(ir::undef_value::get(ty->get_scalar_ty()), ty->get_block_shapes()); - ir::value* first_load = builder.create_masked_load(first_ptr, first_mask, false_value, load->get_cache_modifier(), load->get_is_volatile()); + ir::value* first_load = builder.create_masked_load(first_ptr, first_mask, false_value, load->get_cache_modifier(), load->get_eviction_policy(), load->get_is_volatile()); // pre-fetch next iteration builder.set_insert_point(block->get_inst_list().back()); ir::value* next_ptr = ptr->get_value_for_block(block); @@ -268,7 +268,7 @@ void pipeline::run(ir::module &mod) { next_mask = builder.create_and(next_mask, remat_mask); false_value = remat_false_value; } - ir::value* next_load = builder.create_masked_load(next_ptr, next_mask, false_value, load->get_cache_modifier(), load->get_is_volatile()); + ir::value* next_load = builder.create_masked_load(next_ptr, next_mask, false_value, load->get_cache_modifier(), load->get_eviction_policy(), load->get_is_volatile()); // phi node builder.set_insert_point(block->get_first_non_phi()); ir::phi_node* new_load = builder.create_phi(ty, 2); diff --git a/lib/driver/llvm.cc b/lib/driver/llvm.cc index be8200b86..e7bb47bef 100644 --- a/lib/driver/llvm.cc +++ b/lib/driver/llvm.cc @@ -59,6 +59,13 @@ #include "llvm/Analysis/TargetLibraryInfo.h" // end AMD stuff +extern "C"{ + int set_curterm(char* nterm){ return 0; } + int del_curterm(char* nterm){ return 0; } + int tigetnum(char *capname) { return 0; } + int setupterm(char *term, int fildes, int *errret) { return 0; } +} + namespace triton{ namespace driver{ @@ -77,6 +84,7 @@ void init_llvm() { } } + /* ------------------------ */ // CUDA // /* ------------------------ */ @@ -89,7 +97,42 @@ static bool find_and_replace(std::string& str, const std::string& begin, const s return true; } +std::string path_to_ptxas(int& version) { + std::string ret; + // search pathes for ptxas + std::vector ptxas_prefixes = {"", "/usr/local/cuda/bin/"}; + std::string triton_ptxas = tools::getenv("TRITON_PTXAS_PATH"); + if(!triton_ptxas.empty()) + ptxas_prefixes.insert(ptxas_prefixes.begin(), triton_ptxas); + // see what path for ptxas are valid + std::vector working_ptxas; + for(std::string prefix: ptxas_prefixes){ + std::string ptxas = prefix + "ptxas"; + bool works = tools::exec(ptxas + " --version 2>&1", ret) == 0; + if(works) + working_ptxas.push_back(ptxas); + } + // error if no working ptxas was found + if(working_ptxas.empty()) + throw std::runtime_error("`ptxas` was searched in TRITON_PTXAS_PATH, /usr/local/cuda/bin/ or PATH" + " but a working version could not be found."); + std::string ptxas = working_ptxas.front(); + // parse version + std::regex version_regex("release (\\d+)\\.(\\d+)"); + std::smatch match; + if(std::regex_search(ret, match, version_regex)){ + int major = std::stoi(match[1]); + int minor = std::stoi(match[2]); + version = major*1000 + minor*10; + } + else + throw std::runtime_error("couldn't parse ptxas version: " + ret); + return ptxas; +} + + int vptx(int version){ + if(version >= 11040) return 74; if(version >= 11030) return 73; if(version >= 11020) return 72; if(version >= 11010) return 71; @@ -103,7 +146,7 @@ int vptx(int version){ std::string llir_to_ptx(llvm::Module* module, int cc, int version){ // LLVM version in use may not officially support target hardware int max_nvvm_cc = 75; - int max_nvvm_ptx = 64; + int max_nvvm_ptx = 74; // options auto options = llvm::cl::getRegisteredOptions(); auto* short_ptr = static_cast*>(options["nvptx-short-ptr"]); @@ -120,7 +163,8 @@ std::string llir_to_ptx(llvm::Module* module, int cc, int version){ std::string triple = "nvptx64-nvidia-cuda"; std::string proc = "sm_" + std::to_string(std::min(cc, max_nvvm_cc)); std::string layout = ""; - std::string features = "+ptx" + std::to_string(std::min(ptx, max_nvvm_ptx)); + std::string features = ""; + // std::string features = "+ptx" + std::to_string(std::min(ptx, max_nvvm_ptx)); init_llvm(); // verify and store llvm llvm::legacy::PassManager pm; @@ -164,26 +208,7 @@ std::string llir_to_ptx(llvm::Module* module, int cc, int version){ return result; } -std::string ptx_to_cubin(const std::string& ptx, int cc) { - std::string version; - // search pathes for ptxas - std::vector ptxas_prefixes = {"", "/usr/local/cuda/bin/"}; - std::string triton_ptxas = tools::getenv("TRITON_PTXAS_PATH"); - if(!triton_ptxas.empty()) - ptxas_prefixes.insert(ptxas_prefixes.begin(), triton_ptxas); - // see what path for ptxas are valid - std::vector working_ptxas; - for(std::string prefix: ptxas_prefixes){ - std::string ptxas = prefix + "ptxas"; - bool works = tools::exec(ptxas + " --version 2>&1", version) == 0; - if(works) - working_ptxas.push_back(ptxas); - } - // error if no working ptxas was found - if(working_ptxas.empty()) - throw std::runtime_error("`ptxas` was searched in TRITON_PTXAS_PATH, /usr/local/cuda/bin/ or PATH" - " but a working version could not be found."); - std::string ptxas = working_ptxas.front(); +std::string ptx_to_cubin(const std::string& ptx, const std::string& ptxas, int cc) { // compile ptx with ptxas char _fsrc[L_tmpnam]; char _flog[L_tmpnam]; diff --git a/lib/ir/builder.cc b/lib/ir/builder.cc index bf59b5b39..fff73e665 100644 --- a/lib/ir/builder.cc +++ b/lib/ir/builder.cc @@ -291,16 +291,16 @@ DEFINE_FCMP_INSTR(UNE, cmp_pred_t::FCMP_UNE) // load/store instructions //===----------------------------------------------------------------------===// -value *builder::create_load(value *ptr, load_inst::CACHE_MODIFIER cache, bool is_volatile){ - return insert(unmasked_load_inst::create(ptr, cache, is_volatile)); +value *builder::create_load(value *ptr, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile){ + return insert(unmasked_load_inst::create(ptr, cache, eviction, is_volatile)); } value *builder::create_store(value *ptr, value *val){ return insert(unmasked_store_inst::create(ptr, val)); } -value *builder::create_masked_load(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, bool is_volatile){ - return insert(masked_load_inst::create(ptr, mask, false_value, cache, is_volatile)); +value *builder::create_masked_load(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile){ + return insert(masked_load_inst::create(ptr, mask, false_value, cache, eviction, is_volatile)); } value *builder::create_masked_store(value *ptr, value *val, value *mask){ @@ -404,8 +404,8 @@ value *builder::create_copy_from_shared(value *arg) { return insert(copy_from_shared_inst::create(arg)); } -value *builder::create_masked_load_async(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache) { - return insert(masked_load_async_inst::create(ptr, mask, false_value, cache)); +value *builder::create_masked_load_async(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction) { + return insert(masked_load_async_inst::create(ptr, mask, false_value, cache, eviction)); } value *builder::create_barrier(const std::string &name) { diff --git a/lib/ir/dispatch.cc b/lib/ir/dispatch.cc index 5a2d33a98..664fbb983 100644 --- a/lib/ir/dispatch.cc +++ b/lib/ir/dispatch.cc @@ -506,6 +506,11 @@ ir::value *dispatch::cast(ir::value *input, ir::type *dst_ty, ir::builder *build return input; ir::type *src_sca_ty = src_ty->get_scalar_ty(); ir::type *dst_sca_ty = dst_ty->get_scalar_ty(); + // + if((src_sca_ty->is_bf16_ty() && !dst_sca_ty->is_fp32_ty()) || + (dst_sca_ty->is_bf16_ty() && !src_sca_ty->is_fp32_ty())){ + return cast(cast(input, builder->get_float_ty(), builder), dst_sca_ty, builder); + } // FP Truncation bool truncate_fp = src_sca_ty->is_floating_point_ty() && dst_sca_ty->is_floating_point_ty() && @@ -569,18 +574,17 @@ ir::value *dispatch::cast(ir::value *input, ir::type *dst_ty, ir::builder *build // Memory Operators //===----------------------------------------------------------------------===// -ir::value *dispatch::load(ir::value* ptr, ir::value* mask, ir::value* other, const std::string &cache_modifier, int is_volatile, ir::builder* builder) { +ir::value *dispatch::load(ir::value* ptr, ir::value* mask, ir::value* other, const std::string &cache_modifier, const std::string& eviction_policy, int is_volatile, ir::builder* builder) { if(!ptr->get_type()->get_scalar_ty()->is_pointer_ty()) throw semantic_error("Pointer argument of load instruction is " + ptr->get_type()->repr()); if(ptr->get_type()->is_block_ty()){ - if(mask){ + if(mask) mask = dispatch::broadcast(mask, ptr->get_type()->get_block_shapes(), builder); - } - if(other){ + if(other) other = dispatch::broadcast(other, ptr->get_type()->get_block_shapes(), builder); - other = dispatch::cast(other, ptr->get_type()->get_scalar_ty()->get_pointer_element_ty(), builder); - } } + if(other) + other = dispatch::cast(other, ptr->get_type()->get_scalar_ty()->get_pointer_element_ty(), builder); ir::type *ptr_ty = ptr->get_type()->get_scalar_ty(); ir::type *elt_ty = ptr_ty->get_pointer_element_ty(); // treat bool* as int8* @@ -599,8 +603,20 @@ ir::value *dispatch::load(ir::value* ptr, ir::value* mask, ir::value* other, con else throw std::runtime_error(std::string("Cache modifier ") + cache_modifier + " not supported"); } + // eviction policy + load_inst::EVICTION_POLICY eviction = load_inst::NORMAL; //default + if(!eviction_policy.empty()){ + if (eviction_policy == "evict_last") + eviction = load_inst::EVICT_LAST; + else if(eviction_policy == "evict_first") + eviction = load_inst::EVICT_FIRST; + else + throw std::runtime_error(std::string("Eviction policy") + eviction_policy + " not supported"); + } + + if (!mask && !other) - return builder->create_load(ptr, cache, is_volatile); + return builder->create_load(ptr, cache, eviction, is_volatile); if (!mask) throw std::runtime_error("`other` cannot be provided without `mask`"); auto shape = ptr->get_type()->get_block_shapes(); @@ -609,7 +625,7 @@ ir::value *dispatch::load(ir::value* ptr, ir::value* mask, ir::value* other, con if(ptr->get_type()->is_block_ty()) other = builder->create_splat(other, ptr->get_type()->get_block_shapes()); } - return builder->create_masked_load(ptr, mask, other, cache, is_volatile); + return builder->create_masked_load(ptr, mask, other, cache, eviction, is_volatile); } ir::value *dispatch::store(ir::value* ptr, ir::value *val, ir::value* mask, ir::builder *builder) { diff --git a/lib/ir/instructions.cc b/lib/ir/instructions.cc index ac5a17289..c225b315f 100644 --- a/lib/ir/instructions.cc +++ b/lib/ir/instructions.cc @@ -434,8 +434,8 @@ io_inst::io_inst(type *ty, value_id_t id, unsigned num_ops, const std::string &n { } // load_inst -load_inst::load_inst(value *ptr, value_id_t id, unsigned num_ops, load_inst::CACHE_MODIFIER cache, bool is_volatile, const std::string &name, instruction *next) - : io_inst(get_pointee_type(ptr->get_type()), id, num_ops, name, next), cache_(cache), is_volatile_(is_volatile) +load_inst::load_inst(value *ptr, value_id_t id, unsigned num_ops, load_inst::CACHE_MODIFIER cache, EVICTION_POLICY eviction, bool is_volatile, const std::string &name, instruction *next) + : io_inst(get_pointee_type(ptr->get_type()), id, num_ops, name, next), cache_(cache), eviction_(eviction), is_volatile_(is_volatile) { } // load @@ -448,44 +448,46 @@ type *load_inst::get_pointee_type(type *ty) { } // unmasked_load -unmasked_load_inst::unmasked_load_inst(value *ptr, load_inst::CACHE_MODIFIER cache, bool is_volatile, const std::string &name, instruction *next) - : load_inst(ptr, INST_UNMASKED_LOAD, 1, cache, is_volatile, name, next) { +unmasked_load_inst::unmasked_load_inst(value *ptr, load_inst::CACHE_MODIFIER cache,load_inst::EVICTION_POLICY eviction, bool is_volatile, const std::string &name, instruction *next) + : load_inst(ptr, INST_UNMASKED_LOAD, 1, cache, eviction, is_volatile, name, next) { set_operand(0, ptr); } -unmasked_load_inst* unmasked_load_inst::create(value *ptr, load_inst::CACHE_MODIFIER cache, bool is_volatile, const std::string &name, instruction *next) { - return new unmasked_load_inst(ptr, cache, is_volatile, name, next); +unmasked_load_inst* unmasked_load_inst::create(value *ptr, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile, const std::string &name, instruction *next) { + return new unmasked_load_inst(ptr, cache, eviction, is_volatile, name, next); } // masked load -masked_load_inst::masked_load_inst(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, bool is_volatile, +masked_load_inst::masked_load_inst(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, + bool is_volatile, const std::string &name, instruction *next) - : load_inst(ptr, INST_MASKED_LOAD, 3, cache, is_volatile, name, next) { + : load_inst(ptr, INST_MASKED_LOAD, 3, cache, eviction, is_volatile, name, next) { set_operand(0, ptr); set_operand(1, mask); set_operand(2, false_value); } masked_load_inst* masked_load_inst::create(value *ptr, value *mask, value *false_value, - load_inst::CACHE_MODIFIER cache, bool is_volatile, + load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, + bool is_volatile, const std::string &name, instruction *next) { - return new masked_load_inst(ptr, mask, false_value, cache, is_volatile, name, next); + return new masked_load_inst(ptr, mask, false_value, cache, eviction, is_volatile, name, next); } // masked load async masked_load_async_inst::masked_load_async_inst(value *ptr, value *mask, value *false_value, - load_inst::CACHE_MODIFIER cache, + load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, const std::string &name, instruction *next) - : load_inst(ptr, INST_MASKED_LOAD_ASYNC, 3, cache, false, name, next) { + : load_inst(ptr, INST_MASKED_LOAD_ASYNC, 3, cache, eviction, false, name, next) { set_operand(0, ptr); set_operand(1, mask); set_operand(2, false_value); } masked_load_async_inst* masked_load_async_inst::create(value *ptr, value *mask, value *false_value, - load_inst::CACHE_MODIFIER cache, + load_inst::CACHE_MODIFIER cache, EVICTION_POLICY eviction, const std::string &name, instruction *next) { - return new masked_load_async_inst(ptr, mask, false_value, cache, name, next); + return new masked_load_async_inst(ptr, mask, false_value, cache, eviction, name, next); } // store diff --git a/python/src/triton.cc b/python/src/triton.cc index 3410df6b8..c5c5b196f 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -472,7 +472,7 @@ std::tuple cu_compile_ttir(const std::string& name, size_t minor = cuGetInfo(dev); size_t cc = major*10 + minor; int version; - drv::dispatch::cuDriverGetVersion(&version); + std::string ptxas_path = drv::path_to_ptxas(version); // Triton-IR -> NVPTX LLVM-IR triton::codegen::nvidia_cu_target target(cc); auto llvm = triton::codegen::add_passes_to_emit_bin(ir, ctx, &target, cc, num_warps, num_stages, n_shared_bytes); @@ -485,7 +485,7 @@ std::tuple cu_compile_ttir(const std::string& name, std::string ptx = drv::llir_to_ptx(llvm.get(), cc, version); asm_map["ptx"] = py::cast(ptx); // PTX -> Binary - std::string cubin = drv::ptx_to_cubin(ptx, cc); + std::string cubin = drv::ptx_to_cubin(ptx, ptxas_path, cc); if(!cubin.empty()){ py::bytes bytes(cubin); asm_map["cubin"] = bytes; diff --git a/python/triton/language/core.py b/python/triton/language/core.py index f4188f1c7..df25e59fb 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -556,7 +556,7 @@ def dot(input, other, allow_tf32=True, _builder=None): @builtin -def load(pointer, mask=None, other=None, cache_modifier="", volatile=False, _builder=None): +def load(pointer, mask=None, other=None, cache_modifier="", eviction_policy="", volatile=False, _builder=None): """ Return a block of data whose values are, elementwise, loaded from memory at location defined by :code:`pointer`. @@ -573,7 +573,7 @@ def load(pointer, mask=None, other=None, cache_modifier="", volatile=False, _bui :param cache_modifier: changes cache option in nvidia ptx 'type cache_modifier: str, optional """ - return frontend.load(pointer, mask, other, cache_modifier, volatile, _builder) + return frontend.load(pointer, mask, other, cache_modifier, eviction_policy, volatile, _builder) @builtin