[CODEGEN] Improvements and bugfixes (#463)
This commit is contained in:
		@@ -110,6 +110,15 @@ libLLVMBinaryFormat.a
 | 
				
			|||||||
libLLVMAMDGPUInfo.a
 | 
					libLLVMAMDGPUInfo.a
 | 
				
			||||||
libLLVMSupport.a
 | 
					libLLVMSupport.a
 | 
				
			||||||
libLLVMDemangle.a
 | 
					libLLVMDemangle.a
 | 
				
			||||||
 | 
					libLLVMPasses.a
 | 
				
			||||||
 | 
					libLLVMAnalysis.a
 | 
				
			||||||
 | 
					libLLVMTransformUtils.a
 | 
				
			||||||
 | 
					libLLVMScalarOpts.a
 | 
				
			||||||
 | 
					libLLVMTransformUtils.a
 | 
				
			||||||
 | 
					libLLVMipo.a
 | 
				
			||||||
 | 
					libLLVMObjCARCOpts.a
 | 
				
			||||||
 | 
					libLLVMCoroutines.a
 | 
				
			||||||
 | 
					libLLVMAnalysis.a
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
endif()
 | 
					endif()
 | 
				
			||||||
include_directories("${LLVM_INCLUDE_DIRS}")
 | 
					include_directories("${LLVM_INCLUDE_DIRS}")
 | 
				
			||||||
@@ -148,7 +157,7 @@ target_link_options(triton PRIVATE ${LLVM_LDFLAGS})
 | 
				
			|||||||
if(WIN32)
 | 
					if(WIN32)
 | 
				
			||||||
    target_link_libraries(triton PRIVATE ${LLVM_LIBRARIES} dl) # dl is from dlfcn-win32
 | 
					    target_link_libraries(triton PRIVATE ${LLVM_LIBRARIES} dl) # dl is from dlfcn-win32
 | 
				
			||||||
else()
 | 
					else()
 | 
				
			||||||
    target_link_libraries(triton ${LLVM_LIBRARIES} z ${TERMINFO_LIBRARY})
 | 
					    target_link_libraries(triton ${LLVM_LIBRARIES} z)
 | 
				
			||||||
endif()
 | 
					endif()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -9,8 +9,9 @@ namespace triton{
 | 
				
			|||||||
namespace driver{
 | 
					namespace driver{
 | 
				
			||||||
 | 
					
 | 
				
			||||||
void init_llvm();
 | 
					void init_llvm();
 | 
				
			||||||
 | 
					std::string path_to_ptxas(int& version);
 | 
				
			||||||
std::string llir_to_ptx(llvm::Module* module, int cc, int version);
 | 
					std::string llir_to_ptx(llvm::Module* module, int cc, int version);
 | 
				
			||||||
std::string ptx_to_cubin(const std::string& ptx, int cc);
 | 
					std::string ptx_to_cubin(const std::string& ptx, const std::string& ptxas_path, int cc);
 | 
				
			||||||
CUmodule ptx_to_cumodule(const std::string& ptx, int cc);
 | 
					CUmodule ptx_to_cumodule(const std::string& ptx, int cc);
 | 
				
			||||||
std::string llir_to_amdgpu(llvm::Module* module, const std::string& proc);
 | 
					std::string llir_to_amdgpu(llvm::Module* module, const std::string& proc);
 | 
				
			||||||
hipModule_t amdgpu_to_hipmodule(const std::string& path);
 | 
					hipModule_t amdgpu_to_hipmodule(const std::string& path);
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -136,9 +136,9 @@ public:
 | 
				
			|||||||
  value *create_xor(value *lhs, value *rhs);
 | 
					  value *create_xor(value *lhs, value *rhs);
 | 
				
			||||||
  value *create_or(value *lhs, value *rhs);
 | 
					  value *create_or(value *lhs, value *rhs);
 | 
				
			||||||
  // Input/Output
 | 
					  // Input/Output
 | 
				
			||||||
  value *create_load(value *arg, load_inst::CACHE_MODIFIER cache, bool is_volatile);
 | 
					  value *create_load(value *arg, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile);
 | 
				
			||||||
  value *create_store(value *ptr, value *val);
 | 
					  value *create_store(value *ptr, value *val);
 | 
				
			||||||
  value *create_masked_load(value *arg, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, bool is_volatile);
 | 
					  value *create_masked_load(value *arg, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile);
 | 
				
			||||||
  value *create_masked_store(value *ptr, value *val, value *mask);
 | 
					  value *create_masked_store(value *ptr, value *val, value *mask);
 | 
				
			||||||
  // Block instruction
 | 
					  // Block instruction
 | 
				
			||||||
  value *create_splat(value *arg, const type::block_shapes_t &shapes);
 | 
					  value *create_splat(value *arg, const type::block_shapes_t &shapes);
 | 
				
			||||||
@@ -163,7 +163,7 @@ public:
 | 
				
			|||||||
  // These have no place in the IR, and hopefully they can be removed at some point
 | 
					  // These have no place in the IR, and hopefully they can be removed at some point
 | 
				
			||||||
  value *create_umulhi(value* lhs, value* rhs);
 | 
					  value *create_umulhi(value* lhs, value* rhs);
 | 
				
			||||||
  value *create_copy_to_shared(value *arg);
 | 
					  value *create_copy_to_shared(value *arg);
 | 
				
			||||||
  value *create_masked_load_async(value *arg, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache);
 | 
					  value *create_masked_load_async(value *arg, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY);
 | 
				
			||||||
  value *create_copy_from_shared(value *arg);
 | 
					  value *create_copy_from_shared(value *arg);
 | 
				
			||||||
  value *create_barrier(const std::string &name = "");
 | 
					  value *create_barrier(const std::string &name = "");
 | 
				
			||||||
  value *create_async_wait(int N);
 | 
					  value *create_async_wait(int N);
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -69,7 +69,8 @@ struct dispatch{
 | 
				
			|||||||
  static ir::value *cast(ir::value *input, ir::type *type, ir::builder *builder);
 | 
					  static ir::value *cast(ir::value *input, ir::type *type, ir::builder *builder);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // memory operators
 | 
					  // memory operators
 | 
				
			||||||
  static ir::value *load(ir::value* ptr, ir::value* mask, ir::value* other, const std::string &cache, int is_volatile, ir::builder *builder);
 | 
					  static ir::value *load(ir::value* ptr, ir::value* mask, ir::value* other, const std::string &cache,
 | 
				
			||||||
 | 
					                         const std::string& eviction_policy, int is_volatile, ir::builder *builder);
 | 
				
			||||||
  static ir::value *store(ir::value* ptr, ir::value *value, ir::value *mask, ir::builder *builder);
 | 
					  static ir::value *store(ir::value* ptr, ir::value *value, ir::value *mask, ir::builder *builder);
 | 
				
			||||||
  static ir::value *atomic_cas(ir::value* ptr, ir::value *cmp, ir::value *val, ir::builder *builder);
 | 
					  static ir::value *atomic_cas(ir::value* ptr, ir::value *cmp, ir::value *val, ir::builder *builder);
 | 
				
			||||||
  static ir::value *atomic_add(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);
 | 
					  static ir::value *atomic_add(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -408,11 +408,18 @@ public:
 | 
				
			|||||||
    CG,
 | 
					    CG,
 | 
				
			||||||
  };
 | 
					  };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  enum EVICTION_POLICY : uint32_t {
 | 
				
			||||||
 | 
					    NORMAL=0,
 | 
				
			||||||
 | 
					    EVICT_FIRST,
 | 
				
			||||||
 | 
					    EVICT_LAST,
 | 
				
			||||||
 | 
					  };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  CACHE_MODIFIER get_cache_modifier() const { return cache_; }
 | 
					  CACHE_MODIFIER get_cache_modifier() const { return cache_; }
 | 
				
			||||||
 | 
					  EVICTION_POLICY get_eviction_policy() const { return eviction_; }
 | 
				
			||||||
  bool get_is_volatile() const { return is_volatile_; }
 | 
					  bool get_is_volatile() const { return is_volatile_; }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
protected:
 | 
					protected:
 | 
				
			||||||
  load_inst(value *ptr, value_id_t id, unsigned num_ops, CACHE_MODIFIER cache,
 | 
					  load_inst(value *ptr, value_id_t id, unsigned num_ops, CACHE_MODIFIER cache, EVICTION_POLICY eviction,
 | 
				
			||||||
          bool is_volatile,
 | 
					          bool is_volatile,
 | 
				
			||||||
          const std::string &name = "", instruction *next = nullptr);
 | 
					          const std::string &name = "", instruction *next = nullptr);
 | 
				
			||||||
  std::string get_cache_modifier_repr() const {
 | 
					  std::string get_cache_modifier_repr() const {
 | 
				
			||||||
@@ -420,6 +427,11 @@ protected:
 | 
				
			|||||||
    if (cache_ == CG) return ".cg";
 | 
					    if (cache_ == CG) return ".cg";
 | 
				
			||||||
    return ""; 
 | 
					    return ""; 
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					  std::string get_eviction_policy_repr() const {
 | 
				
			||||||
 | 
					    if (eviction_ == EVICT_FIRST) return ".L1::evict_first";
 | 
				
			||||||
 | 
					    if (eviction_ == EVICT_LAST) return ".L2::evict_last";
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  EVICTION_POLICY eviction_;
 | 
				
			||||||
  CACHE_MODIFIER cache_;
 | 
					  CACHE_MODIFIER cache_;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  std::string get_volatile_repr() {
 | 
					  std::string get_volatile_repr() {
 | 
				
			||||||
@@ -435,11 +447,12 @@ private:
 | 
				
			|||||||
class unmasked_load_inst: public load_inst {
 | 
					class unmasked_load_inst: public load_inst {
 | 
				
			||||||
private:
 | 
					private:
 | 
				
			||||||
  std::string repr_impl() const { return "unmasked_load" + get_cache_modifier_repr(); }
 | 
					  std::string repr_impl() const { return "unmasked_load" + get_cache_modifier_repr(); }
 | 
				
			||||||
  unmasked_load_inst(value *ptr, load_inst::CACHE_MODIFIER cache, bool is_volatile, const std::string &name, instruction *next);
 | 
					  unmasked_load_inst(value *ptr, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile, const std::string &name, instruction *next);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
public:
 | 
					public:
 | 
				
			||||||
  static unmasked_load_inst* create(value *ptr,
 | 
					  static unmasked_load_inst* create(value *ptr,
 | 
				
			||||||
                                    CACHE_MODIFIER cache, bool is_volatile,
 | 
					                                    CACHE_MODIFIER cache, EVICTION_POLICY eviction,
 | 
				
			||||||
 | 
					                                    bool is_volatile,
 | 
				
			||||||
                                    const std::string &name = "",
 | 
					                                    const std::string &name = "",
 | 
				
			||||||
                                    instruction *next = nullptr);
 | 
					                                    instruction *next = nullptr);
 | 
				
			||||||
  _TRITON_DEFINE_CLONE(unmasked_load_inst)
 | 
					  _TRITON_DEFINE_CLONE(unmasked_load_inst)
 | 
				
			||||||
@@ -450,7 +463,7 @@ public:
 | 
				
			|||||||
class masked_load_inst: public load_inst {
 | 
					class masked_load_inst: public load_inst {
 | 
				
			||||||
private:
 | 
					private:
 | 
				
			||||||
  std::string repr_impl() const { return "masked_load" + get_cache_modifier_repr(); }
 | 
					  std::string repr_impl() const { return "masked_load" + get_cache_modifier_repr(); }
 | 
				
			||||||
  masked_load_inst(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, bool is_volatile,
 | 
					  masked_load_inst(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile,
 | 
				
			||||||
                   const std::string &name, instruction *next);
 | 
					                   const std::string &name, instruction *next);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
public:
 | 
					public:
 | 
				
			||||||
@@ -459,7 +472,8 @@ public:
 | 
				
			|||||||
  value *get_false_value_operand() { return get_operand(2); }
 | 
					  value *get_false_value_operand() { return get_operand(2); }
 | 
				
			||||||
  // factory method
 | 
					  // factory method
 | 
				
			||||||
  static masked_load_inst* create(value *ptr, value *mask, value *false_value,
 | 
					  static masked_load_inst* create(value *ptr, value *mask, value *false_value,
 | 
				
			||||||
                                  CACHE_MODIFIER cache, bool is_volatile,
 | 
					                                  CACHE_MODIFIER cache, EVICTION_POLICY eviction,
 | 
				
			||||||
 | 
					                                  bool is_volatile,
 | 
				
			||||||
                                  const std::string &name = "",
 | 
					                                  const std::string &name = "",
 | 
				
			||||||
                                  instruction *next = nullptr);
 | 
					                                  instruction *next = nullptr);
 | 
				
			||||||
  _TRITON_DEFINE_CLONE(masked_load_inst)
 | 
					  _TRITON_DEFINE_CLONE(masked_load_inst)
 | 
				
			||||||
@@ -470,8 +484,9 @@ public:
 | 
				
			|||||||
class masked_load_async_inst: public load_inst {
 | 
					class masked_load_async_inst: public load_inst {
 | 
				
			||||||
private:
 | 
					private:
 | 
				
			||||||
  std::string repr_impl() const { return "masked_load_async" + get_cache_modifier_repr(); }
 | 
					  std::string repr_impl() const { return "masked_load_async" + get_cache_modifier_repr(); }
 | 
				
			||||||
  masked_load_async_inst(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache,
 | 
					  masked_load_async_inst(value *ptr, value *mask, value *false_value,
 | 
				
			||||||
                   const std::string &name, instruction *next);
 | 
					                         CACHE_MODIFIER cache, EVICTION_POLICY eviction,
 | 
				
			||||||
 | 
					                         const std::string &name, instruction *next);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
public:
 | 
					public:
 | 
				
			||||||
  // accessors
 | 
					  // accessors
 | 
				
			||||||
@@ -480,6 +495,7 @@ public:
 | 
				
			|||||||
  // factory method
 | 
					  // factory method
 | 
				
			||||||
  static masked_load_async_inst* create(value *ptr, value *mask, value *false_value,
 | 
					  static masked_load_async_inst* create(value *ptr, value *mask, value *false_value,
 | 
				
			||||||
                                  load_inst::CACHE_MODIFIER cache,
 | 
					                                  load_inst::CACHE_MODIFIER cache,
 | 
				
			||||||
 | 
					                                  EVICTION_POLICY eviction,
 | 
				
			||||||
                                  const std::string &name = "",
 | 
					                                  const std::string &name = "",
 | 
				
			||||||
                                  instruction *next = nullptr);
 | 
					                                  instruction *next = nullptr);
 | 
				
			||||||
  _TRITON_DEFINE_CLONE(masked_load_async_inst)
 | 
					  _TRITON_DEFINE_CLONE(masked_load_async_inst)
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -119,7 +119,7 @@ Value* geper::operator()(Value *ptr, Value* off, const std::string& name){
 | 
				
			|||||||
#define icmp_ult(...)        builder_->CreateICmpULT(__VA_ARGS__)
 | 
					#define icmp_ult(...)        builder_->CreateICmpULT(__VA_ARGS__)
 | 
				
			||||||
#define insert_elt(...)      builder_->CreateInsertElement(__VA_ARGS__)
 | 
					#define insert_elt(...)      builder_->CreateInsertElement(__VA_ARGS__)
 | 
				
			||||||
#define intrinsic(...)       builder_->CreateIntrinsic(__VA_ARGS__)
 | 
					#define intrinsic(...)       builder_->CreateIntrinsic(__VA_ARGS__)
 | 
				
			||||||
#define load(...)            builder_->CreateLoad(__VA_ARGS__)
 | 
					#define load(ptr)            builder_->CreateLoad(ptr->getType()->getPointerElementType(), ptr)
 | 
				
			||||||
#define lshr(...)            builder_->CreateLShr(__VA_ARGS__)
 | 
					#define lshr(...)            builder_->CreateLShr(__VA_ARGS__)
 | 
				
			||||||
#define max_num(...)         builder_->CreateMaxNum(__VA_ARGS__)
 | 
					#define max_num(...)         builder_->CreateMaxNum(__VA_ARGS__)
 | 
				
			||||||
#define min_num(...)         builder_->CreateMinNum(__VA_ARGS__)
 | 
					#define min_num(...)         builder_->CreateMinNum(__VA_ARGS__)
 | 
				
			||||||
@@ -576,18 +576,19 @@ void generator::visit_cast_inst(ir::cast_inst* x) {
 | 
				
			|||||||
  // <> BF16
 | 
					  // <> BF16
 | 
				
			||||||
  if(ret_sca_ty->is_bf16_ty() || op_sca_ty->is_bf16_ty()){
 | 
					  if(ret_sca_ty->is_bf16_ty() || op_sca_ty->is_bf16_ty()){
 | 
				
			||||||
    // FP32 -> BF16
 | 
					    // FP32 -> BF16
 | 
				
			||||||
    if(op_sca_ty->is_fp32_ty())
 | 
					    if(op_sca_ty->is_fp32_ty()){
 | 
				
			||||||
      // for(size_t i = 0; i < x_idxs.size(); i++)
 | 
					 | 
				
			||||||
      //   vals_[x][x_idxs[i + 0]] = fp32_to_bf16(vals_[op][op_idxs[i + 0]]);
 | 
					 | 
				
			||||||
      for (indices_t idx: idxs_.at(x)) {
 | 
					      for (indices_t idx: idxs_.at(x)) {
 | 
				
			||||||
        Value *arg = vals_[x->get_operand(0)][idx];
 | 
					        Value *arg = vals_[x->get_operand(0)][idx];
 | 
				
			||||||
        vals_[x][idx] = fp32_to_bf16(arg); // cast(cvt(x->get_op()), arg, ty);
 | 
					        vals_[x][idx] = fp32_to_bf16(arg); // cast(cvt(x->get_op()), arg, ty);
 | 
				
			||||||
      }
 | 
					      }
 | 
				
			||||||
 | 
					      return;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
    // BF16 -> FP32
 | 
					    // BF16 -> FP32
 | 
				
			||||||
    if(ret_sca_ty->is_fp32_ty())
 | 
					    if(ret_sca_ty->is_fp32_ty()){
 | 
				
			||||||
      for(size_t i = 0; i < x_idxs.size(); i++)
 | 
					      for(size_t i = 0; i < x_idxs.size(); i++)
 | 
				
			||||||
        vals_[x][x_idxs[i + 0]] = bf16_to_fp32(vals_[op][op_idxs[i + 0]]);
 | 
					        vals_[x][x_idxs[i + 0]] = bf16_to_fp32(vals_[op][op_idxs[i + 0]]);
 | 
				
			||||||
    return;
 | 
					      return;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -697,12 +698,13 @@ void generator::visit_load_inst(ir::load_inst* x){
 | 
				
			|||||||
    std::ostringstream asm_oss;
 | 
					    std::ostringstream asm_oss;
 | 
				
			||||||
    asm_oss << "@$" << n_words; // predicate
 | 
					    asm_oss << "@$" << n_words; // predicate
 | 
				
			||||||
    asm_oss << " ld";
 | 
					    asm_oss << " ld";
 | 
				
			||||||
//    std::cout << x->get_is_volatile() << std::endl;
 | 
					 | 
				
			||||||
    if(x->get_is_volatile())
 | 
					    if(x->get_is_volatile())
 | 
				
			||||||
      asm_oss << ".volatile";
 | 
					      asm_oss << ".volatile";
 | 
				
			||||||
    asm_oss << ".global";
 | 
					    asm_oss << ".global";
 | 
				
			||||||
    if (x->get_cache_modifier() == ir::load_inst::CA) asm_oss << ".ca";
 | 
					    if (x->get_cache_modifier() == ir::load_inst::CA) asm_oss << ".ca";
 | 
				
			||||||
    if (x->get_cache_modifier() == ir::load_inst::CG) asm_oss << ".cg";
 | 
					    if (x->get_cache_modifier() == ir::load_inst::CG) asm_oss << ".cg";
 | 
				
			||||||
 | 
					    if (x->get_eviction_policy() == ir::load_inst::EVICT_LAST) asm_oss << ".L1::evict_last";
 | 
				
			||||||
 | 
					    if (x->get_eviction_policy() == ir::load_inst::EVICT_FIRST) asm_oss << ".L1::evict_first";
 | 
				
			||||||
    if(n_words > 1)
 | 
					    if(n_words > 1)
 | 
				
			||||||
      asm_oss << ".v" << n_words; // vector width
 | 
					      asm_oss << ".v" << n_words; // vector width
 | 
				
			||||||
    asm_oss << ".b" << width; // word size
 | 
					    asm_oss << ".b" << width; // word size
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -123,7 +123,7 @@ bool peephole::rewrite_load_to_shared(ir::instruction *value, ir::builder& build
 | 
				
			|||||||
  int nts = layout->nts(layout->get_order()[0]);
 | 
					  int nts = layout->nts(layout->get_order()[0]);
 | 
				
			||||||
  int dtsize = value->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
 | 
					  int dtsize = value->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
 | 
				
			||||||
  if(nts*dtsize >= 4){
 | 
					  if(nts*dtsize >= 4){
 | 
				
			||||||
    ir::value* new_load = builder.create_masked_load_async(ptr, msk, val, ld->get_cache_modifier());
 | 
					    ir::value* new_load = builder.create_masked_load_async(ptr, msk, val, ld->get_cache_modifier(), ld->get_eviction_policy());
 | 
				
			||||||
    copy_to_shared->replace_all_uses_with(new_load);
 | 
					    copy_to_shared->replace_all_uses_with(new_load);
 | 
				
			||||||
    return true;
 | 
					    return true;
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
@@ -215,6 +215,7 @@ bool peephole::rewrite_select_masked_load(ir::instruction *value, ir::builder& b
 | 
				
			|||||||
                                                   if_value->get_mask_operand(),
 | 
					                                                   if_value->get_mask_operand(),
 | 
				
			||||||
                                                   select->get_else_value_op(),
 | 
					                                                   select->get_else_value_op(),
 | 
				
			||||||
                                                   if_value->get_cache_modifier(),
 | 
					                                                   if_value->get_cache_modifier(),
 | 
				
			||||||
 | 
					                                                   if_value->get_eviction_policy(),
 | 
				
			||||||
                                                   if_value->get_is_volatile());
 | 
					                                                   if_value->get_is_volatile());
 | 
				
			||||||
  select->replace_all_uses_with(new_load);
 | 
					  select->replace_all_uses_with(new_load);
 | 
				
			||||||
  return true;
 | 
					  return true;
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -178,7 +178,7 @@ void pipeline::run(ir::module &mod) {
 | 
				
			|||||||
        false_value = remat_false_value;
 | 
					        false_value = remat_false_value;
 | 
				
			||||||
      } else
 | 
					      } else
 | 
				
			||||||
        false_value = builder.create_splat(ir::undef_value::get(ty->get_scalar_ty()), ty->get_block_shapes());
 | 
					        false_value = builder.create_splat(ir::undef_value::get(ty->get_scalar_ty()), ty->get_block_shapes());
 | 
				
			||||||
      first_loads[0] = builder.create_masked_load(first_ptrs[0], first_masks[0], false_value, load->get_cache_modifier(), load->get_is_volatile());
 | 
					      first_loads[0] = builder.create_masked_load(first_ptrs[0], first_masks[0], false_value, load->get_cache_modifier(), load->get_eviction_policy(), load->get_is_volatile());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
      for (int stage = 1; stage < num_stages-1; ++stage) {
 | 
					      for (int stage = 1; stage < num_stages-1; ++stage) {
 | 
				
			||||||
        // mask is the loop condition of the previous iteration
 | 
					        // mask is the loop condition of the previous iteration
 | 
				
			||||||
@@ -193,7 +193,7 @@ void pipeline::run(ir::module &mod) {
 | 
				
			|||||||
          first_masks[stage] = builder.create_and(first_masks[stage], remat_mask);
 | 
					          first_masks[stage] = builder.create_and(first_masks[stage], remat_mask);
 | 
				
			||||||
          false_value = remat_false_value;
 | 
					          false_value = remat_false_value;
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
        first_loads[stage] = builder.create_masked_load(first_ptrs[stage], first_masks[stage], false_value, load->get_cache_modifier(), load->get_is_volatile());
 | 
					        first_loads[stage] = builder.create_masked_load(first_ptrs[stage], first_masks[stage], false_value, load->get_cache_modifier(), load->get_eviction_policy(), load->get_is_volatile());
 | 
				
			||||||
      }
 | 
					      }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
      // create new phis for induction variables
 | 
					      // create new phis for induction variables
 | 
				
			||||||
@@ -222,7 +222,7 @@ void pipeline::run(ir::module &mod) {
 | 
				
			|||||||
        next_mask = builder.create_and(next_mask, remat_mask);
 | 
					        next_mask = builder.create_and(next_mask, remat_mask);
 | 
				
			||||||
        false_value = remat_false_value;
 | 
					        false_value = remat_false_value;
 | 
				
			||||||
      }
 | 
					      }
 | 
				
			||||||
      ir::value* next_load = builder.create_masked_load(next_ptr, next_mask, false_value, load->get_cache_modifier(), load->get_is_volatile());
 | 
					      ir::value* next_load = builder.create_masked_load(next_ptr, next_mask, false_value, load->get_cache_modifier(), load->get_eviction_policy(), load->get_is_volatile());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
      // phi node
 | 
					      // phi node
 | 
				
			||||||
@@ -257,7 +257,7 @@ void pipeline::run(ir::module &mod) {
 | 
				
			|||||||
      }
 | 
					      }
 | 
				
			||||||
      else
 | 
					      else
 | 
				
			||||||
        false_value = builder.create_splat(ir::undef_value::get(ty->get_scalar_ty()), ty->get_block_shapes());
 | 
					        false_value = builder.create_splat(ir::undef_value::get(ty->get_scalar_ty()), ty->get_block_shapes());
 | 
				
			||||||
      ir::value* first_load = builder.create_masked_load(first_ptr, first_mask, false_value, load->get_cache_modifier(), load->get_is_volatile());
 | 
					      ir::value* first_load = builder.create_masked_load(first_ptr, first_mask, false_value, load->get_cache_modifier(), load->get_eviction_policy(), load->get_is_volatile());
 | 
				
			||||||
      // pre-fetch next iteration
 | 
					      // pre-fetch next iteration
 | 
				
			||||||
      builder.set_insert_point(block->get_inst_list().back());
 | 
					      builder.set_insert_point(block->get_inst_list().back());
 | 
				
			||||||
      ir::value* next_ptr = ptr->get_value_for_block(block);
 | 
					      ir::value* next_ptr = ptr->get_value_for_block(block);
 | 
				
			||||||
@@ -268,7 +268,7 @@ void pipeline::run(ir::module &mod) {
 | 
				
			|||||||
        next_mask = builder.create_and(next_mask, remat_mask);
 | 
					        next_mask = builder.create_and(next_mask, remat_mask);
 | 
				
			||||||
        false_value = remat_false_value;
 | 
					        false_value = remat_false_value;
 | 
				
			||||||
      }
 | 
					      }
 | 
				
			||||||
      ir::value* next_load = builder.create_masked_load(next_ptr, next_mask, false_value, load->get_cache_modifier(), load->get_is_volatile());
 | 
					      ir::value* next_load = builder.create_masked_load(next_ptr, next_mask, false_value, load->get_cache_modifier(), load->get_eviction_policy(), load->get_is_volatile());
 | 
				
			||||||
      // phi node
 | 
					      // phi node
 | 
				
			||||||
      builder.set_insert_point(block->get_first_non_phi());
 | 
					      builder.set_insert_point(block->get_first_non_phi());
 | 
				
			||||||
      ir::phi_node* new_load = builder.create_phi(ty, 2);
 | 
					      ir::phi_node* new_load = builder.create_phi(ty, 2);
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -59,6 +59,13 @@
 | 
				
			|||||||
#include "llvm/Analysis/TargetLibraryInfo.h"
 | 
					#include "llvm/Analysis/TargetLibraryInfo.h"
 | 
				
			||||||
// end AMD stuff
 | 
					// end AMD stuff
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					extern "C"{
 | 
				
			||||||
 | 
					  int set_curterm(char* nterm){ return 0; }
 | 
				
			||||||
 | 
					  int del_curterm(char* nterm){ return 0; }
 | 
				
			||||||
 | 
					  int tigetnum(char *capname) { return 0; }
 | 
				
			||||||
 | 
					  int setupterm(char *term, int fildes, int *errret) { return 0; }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
namespace triton{
 | 
					namespace triton{
 | 
				
			||||||
namespace driver{
 | 
					namespace driver{
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -77,6 +84,7 @@ void init_llvm() {
 | 
				
			|||||||
  }
 | 
					  }
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
/* ------------------------ */
 | 
					/* ------------------------ */
 | 
				
			||||||
//         CUDA             //
 | 
					//         CUDA             //
 | 
				
			||||||
/* ------------------------ */
 | 
					/* ------------------------ */
 | 
				
			||||||
@@ -89,7 +97,42 @@ static bool find_and_replace(std::string& str, const std::string& begin, const s
 | 
				
			|||||||
  return true;
 | 
					  return true;
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					std::string path_to_ptxas(int& version) {
 | 
				
			||||||
 | 
					  std::string ret;
 | 
				
			||||||
 | 
					  // search pathes for ptxas
 | 
				
			||||||
 | 
					  std::vector<std::string> ptxas_prefixes = {"", "/usr/local/cuda/bin/"};
 | 
				
			||||||
 | 
					  std::string triton_ptxas = tools::getenv("TRITON_PTXAS_PATH");
 | 
				
			||||||
 | 
					  if(!triton_ptxas.empty())
 | 
				
			||||||
 | 
					    ptxas_prefixes.insert(ptxas_prefixes.begin(), triton_ptxas);
 | 
				
			||||||
 | 
					  // see what path for ptxas are valid
 | 
				
			||||||
 | 
					  std::vector<std::string> working_ptxas;
 | 
				
			||||||
 | 
					  for(std::string prefix: ptxas_prefixes){
 | 
				
			||||||
 | 
					    std::string ptxas = prefix + "ptxas";
 | 
				
			||||||
 | 
					    bool works = tools::exec(ptxas + " --version 2>&1", ret) == 0;
 | 
				
			||||||
 | 
					    if(works)
 | 
				
			||||||
 | 
					      working_ptxas.push_back(ptxas);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  // error if no working ptxas was found
 | 
				
			||||||
 | 
					  if(working_ptxas.empty())
 | 
				
			||||||
 | 
					    throw std::runtime_error("`ptxas` was searched in TRITON_PTXAS_PATH, /usr/local/cuda/bin/ or PATH"
 | 
				
			||||||
 | 
					                             " but a working version could not be found.");
 | 
				
			||||||
 | 
					  std::string ptxas = working_ptxas.front();
 | 
				
			||||||
 | 
					  // parse version
 | 
				
			||||||
 | 
					  std::regex version_regex("release (\\d+)\\.(\\d+)");
 | 
				
			||||||
 | 
					  std::smatch match;
 | 
				
			||||||
 | 
					  if(std::regex_search(ret, match, version_regex)){
 | 
				
			||||||
 | 
					    int major = std::stoi(match[1]);
 | 
				
			||||||
 | 
					    int minor = std::stoi(match[2]);
 | 
				
			||||||
 | 
					    version = major*1000 + minor*10;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  else
 | 
				
			||||||
 | 
					    throw std::runtime_error("couldn't parse ptxas version: " + ret);
 | 
				
			||||||
 | 
					  return ptxas;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
int vptx(int version){
 | 
					int vptx(int version){
 | 
				
			||||||
 | 
					  if(version >= 11040) return 74;
 | 
				
			||||||
  if(version >= 11030) return 73;
 | 
					  if(version >= 11030) return 73;
 | 
				
			||||||
  if(version >= 11020) return 72;
 | 
					  if(version >= 11020) return 72;
 | 
				
			||||||
  if(version >= 11010) return 71;
 | 
					  if(version >= 11010) return 71;
 | 
				
			||||||
@@ -103,7 +146,7 @@ int vptx(int version){
 | 
				
			|||||||
std::string llir_to_ptx(llvm::Module* module, int cc, int version){
 | 
					std::string llir_to_ptx(llvm::Module* module, int cc, int version){
 | 
				
			||||||
  // LLVM version in use may not officially support target hardware
 | 
					  // LLVM version in use may not officially support target hardware
 | 
				
			||||||
  int max_nvvm_cc = 75;
 | 
					  int max_nvvm_cc = 75;
 | 
				
			||||||
  int max_nvvm_ptx = 64;
 | 
					  int max_nvvm_ptx = 74;
 | 
				
			||||||
  // options
 | 
					  // options
 | 
				
			||||||
  auto options = llvm::cl::getRegisteredOptions();
 | 
					  auto options = llvm::cl::getRegisteredOptions();
 | 
				
			||||||
  auto* short_ptr = static_cast<llvm::cl::opt<bool>*>(options["nvptx-short-ptr"]);
 | 
					  auto* short_ptr = static_cast<llvm::cl::opt<bool>*>(options["nvptx-short-ptr"]);
 | 
				
			||||||
@@ -120,7 +163,8 @@ std::string llir_to_ptx(llvm::Module* module, int cc, int version){
 | 
				
			|||||||
  std::string triple = "nvptx64-nvidia-cuda";
 | 
					  std::string triple = "nvptx64-nvidia-cuda";
 | 
				
			||||||
  std::string proc = "sm_" + std::to_string(std::min(cc, max_nvvm_cc));
 | 
					  std::string proc = "sm_" + std::to_string(std::min(cc, max_nvvm_cc));
 | 
				
			||||||
  std::string layout = "";
 | 
					  std::string layout = "";
 | 
				
			||||||
  std::string features = "+ptx" + std::to_string(std::min(ptx, max_nvvm_ptx));
 | 
					  std::string features = "";
 | 
				
			||||||
 | 
					  // std::string features = "+ptx" + std::to_string(std::min(ptx, max_nvvm_ptx));
 | 
				
			||||||
  init_llvm();
 | 
					  init_llvm();
 | 
				
			||||||
  // verify and store llvm
 | 
					  // verify and store llvm
 | 
				
			||||||
  llvm::legacy::PassManager pm;
 | 
					  llvm::legacy::PassManager pm;
 | 
				
			||||||
@@ -164,26 +208,7 @@ std::string llir_to_ptx(llvm::Module* module, int cc, int version){
 | 
				
			|||||||
  return result;
 | 
					  return result;
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
std::string ptx_to_cubin(const std::string& ptx, int cc) {
 | 
					std::string ptx_to_cubin(const std::string& ptx, const std::string& ptxas, int cc) {
 | 
				
			||||||
  std::string version;
 | 
					 | 
				
			||||||
  // search pathes for ptxas
 | 
					 | 
				
			||||||
  std::vector<std::string> ptxas_prefixes = {"", "/usr/local/cuda/bin/"};
 | 
					 | 
				
			||||||
  std::string triton_ptxas = tools::getenv("TRITON_PTXAS_PATH");
 | 
					 | 
				
			||||||
  if(!triton_ptxas.empty())
 | 
					 | 
				
			||||||
    ptxas_prefixes.insert(ptxas_prefixes.begin(), triton_ptxas);
 | 
					 | 
				
			||||||
  // see what path for ptxas are valid
 | 
					 | 
				
			||||||
  std::vector<std::string> working_ptxas;
 | 
					 | 
				
			||||||
  for(std::string prefix: ptxas_prefixes){
 | 
					 | 
				
			||||||
    std::string ptxas = prefix + "ptxas";
 | 
					 | 
				
			||||||
    bool works = tools::exec(ptxas + " --version 2>&1", version) == 0;
 | 
					 | 
				
			||||||
    if(works)
 | 
					 | 
				
			||||||
      working_ptxas.push_back(ptxas);
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
  // error if no working ptxas was found
 | 
					 | 
				
			||||||
  if(working_ptxas.empty())
 | 
					 | 
				
			||||||
    throw std::runtime_error("`ptxas` was searched in TRITON_PTXAS_PATH, /usr/local/cuda/bin/ or PATH"
 | 
					 | 
				
			||||||
                             " but a working version could not be found.");
 | 
					 | 
				
			||||||
  std::string ptxas = working_ptxas.front();
 | 
					 | 
				
			||||||
  // compile ptx with ptxas
 | 
					  // compile ptx with ptxas
 | 
				
			||||||
  char _fsrc[L_tmpnam];
 | 
					  char _fsrc[L_tmpnam];
 | 
				
			||||||
  char _flog[L_tmpnam];
 | 
					  char _flog[L_tmpnam];
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -291,16 +291,16 @@ DEFINE_FCMP_INSTR(UNE, cmp_pred_t::FCMP_UNE)
 | 
				
			|||||||
//                               load/store instructions
 | 
					//                               load/store instructions
 | 
				
			||||||
//===----------------------------------------------------------------------===//
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
 | 
					
 | 
				
			||||||
value *builder::create_load(value *ptr, load_inst::CACHE_MODIFIER cache, bool is_volatile){
 | 
					value *builder::create_load(value *ptr, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile){
 | 
				
			||||||
  return insert(unmasked_load_inst::create(ptr, cache, is_volatile));
 | 
					  return insert(unmasked_load_inst::create(ptr, cache, eviction, is_volatile));
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
value *builder::create_store(value *ptr, value *val){
 | 
					value *builder::create_store(value *ptr, value *val){
 | 
				
			||||||
  return insert(unmasked_store_inst::create(ptr, val));
 | 
					  return insert(unmasked_store_inst::create(ptr, val));
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
value *builder::create_masked_load(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, bool is_volatile){
 | 
					value *builder::create_masked_load(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile){
 | 
				
			||||||
  return insert(masked_load_inst::create(ptr, mask, false_value, cache, is_volatile));
 | 
					  return insert(masked_load_inst::create(ptr, mask, false_value, cache, eviction, is_volatile));
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
value *builder::create_masked_store(value *ptr, value *val, value *mask){
 | 
					value *builder::create_masked_store(value *ptr, value *val, value *mask){
 | 
				
			||||||
@@ -404,8 +404,8 @@ value *builder::create_copy_from_shared(value *arg) {
 | 
				
			|||||||
  return insert(copy_from_shared_inst::create(arg));
 | 
					  return insert(copy_from_shared_inst::create(arg));
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
value *builder::create_masked_load_async(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache) {
 | 
					value *builder::create_masked_load_async(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction) {
 | 
				
			||||||
  return insert(masked_load_async_inst::create(ptr, mask, false_value, cache));
 | 
					  return insert(masked_load_async_inst::create(ptr, mask, false_value, cache, eviction));
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
value *builder::create_barrier(const std::string &name) {
 | 
					value *builder::create_barrier(const std::string &name) {
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -506,6 +506,11 @@ ir::value *dispatch::cast(ir::value *input, ir::type *dst_ty, ir::builder *build
 | 
				
			|||||||
    return input;
 | 
					    return input;
 | 
				
			||||||
  ir::type *src_sca_ty = src_ty->get_scalar_ty();
 | 
					  ir::type *src_sca_ty = src_ty->get_scalar_ty();
 | 
				
			||||||
  ir::type *dst_sca_ty = dst_ty->get_scalar_ty();
 | 
					  ir::type *dst_sca_ty = dst_ty->get_scalar_ty();
 | 
				
			||||||
 | 
					  //
 | 
				
			||||||
 | 
					  if((src_sca_ty->is_bf16_ty() && !dst_sca_ty->is_fp32_ty()) ||
 | 
				
			||||||
 | 
					     (dst_sca_ty->is_bf16_ty() && !src_sca_ty->is_fp32_ty())){
 | 
				
			||||||
 | 
					    return cast(cast(input, builder->get_float_ty(), builder), dst_sca_ty, builder);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
  // FP Truncation
 | 
					  // FP Truncation
 | 
				
			||||||
  bool truncate_fp = src_sca_ty->is_floating_point_ty() &&
 | 
					  bool truncate_fp = src_sca_ty->is_floating_point_ty() &&
 | 
				
			||||||
                     dst_sca_ty->is_floating_point_ty() &&
 | 
					                     dst_sca_ty->is_floating_point_ty() &&
 | 
				
			||||||
@@ -569,18 +574,17 @@ ir::value *dispatch::cast(ir::value *input, ir::type *dst_ty, ir::builder *build
 | 
				
			|||||||
//                               Memory Operators
 | 
					//                               Memory Operators
 | 
				
			||||||
//===----------------------------------------------------------------------===//
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
 | 
					
 | 
				
			||||||
ir::value *dispatch::load(ir::value* ptr, ir::value* mask, ir::value* other, const std::string &cache_modifier, int is_volatile, ir::builder* builder) {
 | 
					ir::value *dispatch::load(ir::value* ptr, ir::value* mask, ir::value* other, const std::string &cache_modifier, const std::string& eviction_policy, int is_volatile, ir::builder* builder) {
 | 
				
			||||||
  if(!ptr->get_type()->get_scalar_ty()->is_pointer_ty())
 | 
					  if(!ptr->get_type()->get_scalar_ty()->is_pointer_ty())
 | 
				
			||||||
    throw semantic_error("Pointer argument of load instruction is " + ptr->get_type()->repr());
 | 
					    throw semantic_error("Pointer argument of load instruction is " + ptr->get_type()->repr());
 | 
				
			||||||
  if(ptr->get_type()->is_block_ty()){
 | 
					  if(ptr->get_type()->is_block_ty()){
 | 
				
			||||||
    if(mask){
 | 
					    if(mask)
 | 
				
			||||||
      mask = dispatch::broadcast(mask, ptr->get_type()->get_block_shapes(), builder);
 | 
					      mask = dispatch::broadcast(mask, ptr->get_type()->get_block_shapes(), builder);
 | 
				
			||||||
    }
 | 
					    if(other)
 | 
				
			||||||
    if(other){
 | 
					 | 
				
			||||||
      other = dispatch::broadcast(other, ptr->get_type()->get_block_shapes(), builder);
 | 
					      other = dispatch::broadcast(other, ptr->get_type()->get_block_shapes(), builder);
 | 
				
			||||||
      other = dispatch::cast(other, ptr->get_type()->get_scalar_ty()->get_pointer_element_ty(), builder);
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					  if(other)
 | 
				
			||||||
 | 
					    other = dispatch::cast(other, ptr->get_type()->get_scalar_ty()->get_pointer_element_ty(), builder);
 | 
				
			||||||
  ir::type *ptr_ty = ptr->get_type()->get_scalar_ty();
 | 
					  ir::type *ptr_ty = ptr->get_type()->get_scalar_ty();
 | 
				
			||||||
  ir::type *elt_ty = ptr_ty->get_pointer_element_ty();
 | 
					  ir::type *elt_ty = ptr_ty->get_pointer_element_ty();
 | 
				
			||||||
  // treat bool* as int8*
 | 
					  // treat bool* as int8*
 | 
				
			||||||
@@ -599,8 +603,20 @@ ir::value *dispatch::load(ir::value* ptr, ir::value* mask, ir::value* other, con
 | 
				
			|||||||
    else
 | 
					    else
 | 
				
			||||||
      throw std::runtime_error(std::string("Cache modifier ") + cache_modifier + " not supported");
 | 
					      throw std::runtime_error(std::string("Cache modifier ") + cache_modifier + " not supported");
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					  // eviction policy
 | 
				
			||||||
 | 
					  load_inst::EVICTION_POLICY eviction = load_inst::NORMAL; //default
 | 
				
			||||||
 | 
					  if(!eviction_policy.empty()){
 | 
				
			||||||
 | 
					    if (eviction_policy == "evict_last")
 | 
				
			||||||
 | 
					        eviction = load_inst::EVICT_LAST;
 | 
				
			||||||
 | 
					    else if(eviction_policy == "evict_first")
 | 
				
			||||||
 | 
					        eviction = load_inst::EVICT_FIRST;
 | 
				
			||||||
 | 
					    else
 | 
				
			||||||
 | 
					        throw std::runtime_error(std::string("Eviction policy") + eviction_policy + " not supported");
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  if (!mask && !other)
 | 
					  if (!mask && !other)
 | 
				
			||||||
    return builder->create_load(ptr, cache, is_volatile);
 | 
					    return builder->create_load(ptr, cache, eviction, is_volatile);
 | 
				
			||||||
  if (!mask)
 | 
					  if (!mask)
 | 
				
			||||||
    throw std::runtime_error("`other` cannot be provided without `mask`");
 | 
					    throw std::runtime_error("`other` cannot be provided without `mask`");
 | 
				
			||||||
  auto shape = ptr->get_type()->get_block_shapes();
 | 
					  auto shape = ptr->get_type()->get_block_shapes();
 | 
				
			||||||
@@ -609,7 +625,7 @@ ir::value *dispatch::load(ir::value* ptr, ir::value* mask, ir::value* other, con
 | 
				
			|||||||
    if(ptr->get_type()->is_block_ty())
 | 
					    if(ptr->get_type()->is_block_ty())
 | 
				
			||||||
      other = builder->create_splat(other, ptr->get_type()->get_block_shapes());
 | 
					      other = builder->create_splat(other, ptr->get_type()->get_block_shapes());
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
  return builder->create_masked_load(ptr, mask, other, cache, is_volatile);
 | 
					  return builder->create_masked_load(ptr, mask, other, cache, eviction, is_volatile);
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
ir::value *dispatch::store(ir::value* ptr, ir::value *val, ir::value* mask, ir::builder *builder) {
 | 
					ir::value *dispatch::store(ir::value* ptr, ir::value *val, ir::value* mask, ir::builder *builder) {
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -434,8 +434,8 @@ io_inst::io_inst(type *ty, value_id_t id, unsigned num_ops, const std::string &n
 | 
				
			|||||||
{ }
 | 
					{ }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// load_inst
 | 
					// load_inst
 | 
				
			||||||
load_inst::load_inst(value *ptr, value_id_t id, unsigned num_ops, load_inst::CACHE_MODIFIER cache, bool is_volatile, const std::string &name, instruction *next)
 | 
					load_inst::load_inst(value *ptr, value_id_t id, unsigned num_ops, load_inst::CACHE_MODIFIER cache, EVICTION_POLICY eviction, bool is_volatile, const std::string &name, instruction *next)
 | 
				
			||||||
  : io_inst(get_pointee_type(ptr->get_type()), id, num_ops, name, next), cache_(cache), is_volatile_(is_volatile)
 | 
					  : io_inst(get_pointee_type(ptr->get_type()), id, num_ops, name, next), cache_(cache), eviction_(eviction), is_volatile_(is_volatile)
 | 
				
			||||||
{ }
 | 
					{ }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// load
 | 
					// load
 | 
				
			||||||
@@ -448,44 +448,46 @@ type *load_inst::get_pointee_type(type *ty) {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// unmasked_load
 | 
					// unmasked_load
 | 
				
			||||||
unmasked_load_inst::unmasked_load_inst(value *ptr, load_inst::CACHE_MODIFIER cache, bool is_volatile, const std::string &name, instruction *next)
 | 
					unmasked_load_inst::unmasked_load_inst(value *ptr, load_inst::CACHE_MODIFIER cache,load_inst::EVICTION_POLICY eviction, bool is_volatile, const std::string &name, instruction *next)
 | 
				
			||||||
  : load_inst(ptr, INST_UNMASKED_LOAD, 1, cache, is_volatile, name, next) {
 | 
					  : load_inst(ptr, INST_UNMASKED_LOAD, 1, cache, eviction, is_volatile, name, next) {
 | 
				
			||||||
  set_operand(0, ptr);
 | 
					  set_operand(0, ptr);
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
unmasked_load_inst* unmasked_load_inst::create(value *ptr, load_inst::CACHE_MODIFIER cache, bool is_volatile, const std::string &name, instruction *next) {
 | 
					unmasked_load_inst* unmasked_load_inst::create(value *ptr, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile, const std::string &name, instruction *next) {
 | 
				
			||||||
  return new unmasked_load_inst(ptr, cache, is_volatile, name, next);
 | 
					  return new unmasked_load_inst(ptr, cache, eviction, is_volatile, name, next);
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// masked load
 | 
					// masked load
 | 
				
			||||||
masked_load_inst::masked_load_inst(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, bool is_volatile,
 | 
					masked_load_inst::masked_load_inst(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction,
 | 
				
			||||||
 | 
					                                   bool is_volatile,
 | 
				
			||||||
                                   const std::string &name, instruction *next)
 | 
					                                   const std::string &name, instruction *next)
 | 
				
			||||||
  : load_inst(ptr, INST_MASKED_LOAD, 3, cache, is_volatile, name, next) {
 | 
					  : load_inst(ptr, INST_MASKED_LOAD, 3, cache, eviction, is_volatile, name, next) {
 | 
				
			||||||
  set_operand(0, ptr);
 | 
					  set_operand(0, ptr);
 | 
				
			||||||
  set_operand(1, mask);
 | 
					  set_operand(1, mask);
 | 
				
			||||||
  set_operand(2, false_value);
 | 
					  set_operand(2, false_value);
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
masked_load_inst* masked_load_inst::create(value *ptr, value *mask, value *false_value,
 | 
					masked_load_inst* masked_load_inst::create(value *ptr, value *mask, value *false_value,
 | 
				
			||||||
                                           load_inst::CACHE_MODIFIER cache, bool is_volatile,
 | 
					                                           load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction,
 | 
				
			||||||
 | 
					                                           bool is_volatile,
 | 
				
			||||||
                                           const std::string &name, instruction *next) {
 | 
					                                           const std::string &name, instruction *next) {
 | 
				
			||||||
  return new masked_load_inst(ptr, mask, false_value, cache, is_volatile, name, next);
 | 
					  return new masked_load_inst(ptr, mask, false_value, cache, eviction, is_volatile, name, next);
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// masked load async
 | 
					// masked load async
 | 
				
			||||||
masked_load_async_inst::masked_load_async_inst(value *ptr, value *mask, value *false_value,
 | 
					masked_load_async_inst::masked_load_async_inst(value *ptr, value *mask, value *false_value,
 | 
				
			||||||
                                   load_inst::CACHE_MODIFIER cache,
 | 
					                                   load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction,
 | 
				
			||||||
                                   const std::string &name, instruction *next)
 | 
					                                   const std::string &name, instruction *next)
 | 
				
			||||||
  : load_inst(ptr, INST_MASKED_LOAD_ASYNC, 3, cache, false, name, next) {
 | 
					  : load_inst(ptr, INST_MASKED_LOAD_ASYNC, 3, cache, eviction, false, name, next) {
 | 
				
			||||||
  set_operand(0, ptr);
 | 
					  set_operand(0, ptr);
 | 
				
			||||||
  set_operand(1, mask);
 | 
					  set_operand(1, mask);
 | 
				
			||||||
  set_operand(2, false_value);
 | 
					  set_operand(2, false_value);
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
masked_load_async_inst* masked_load_async_inst::create(value *ptr, value *mask, value *false_value,
 | 
					masked_load_async_inst* masked_load_async_inst::create(value *ptr, value *mask, value *false_value,
 | 
				
			||||||
                                           load_inst::CACHE_MODIFIER cache,
 | 
					                                           load_inst::CACHE_MODIFIER cache, EVICTION_POLICY eviction,
 | 
				
			||||||
                                           const std::string &name, instruction *next) {
 | 
					                                           const std::string &name, instruction *next) {
 | 
				
			||||||
  return new masked_load_async_inst(ptr, mask, false_value, cache, name, next);
 | 
					  return new masked_load_async_inst(ptr, mask, false_value, cache, eviction, name, next);
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// store
 | 
					// store
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -472,7 +472,7 @@ std::tuple<std::string, asm_map_t, int> cu_compile_ttir(const std::string& name,
 | 
				
			|||||||
  size_t minor = cuGetInfo<CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR>(dev);
 | 
					  size_t minor = cuGetInfo<CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR>(dev);
 | 
				
			||||||
  size_t cc = major*10 + minor;
 | 
					  size_t cc = major*10 + minor;
 | 
				
			||||||
  int version;
 | 
					  int version;
 | 
				
			||||||
  drv::dispatch::cuDriverGetVersion(&version);
 | 
					  std::string ptxas_path = drv::path_to_ptxas(version);
 | 
				
			||||||
  // Triton-IR -> NVPTX LLVM-IR
 | 
					  // Triton-IR -> NVPTX LLVM-IR
 | 
				
			||||||
  triton::codegen::nvidia_cu_target target(cc);
 | 
					  triton::codegen::nvidia_cu_target target(cc);
 | 
				
			||||||
  auto llvm = triton::codegen::add_passes_to_emit_bin(ir, ctx, &target, cc, num_warps, num_stages, n_shared_bytes);
 | 
					  auto llvm = triton::codegen::add_passes_to_emit_bin(ir, ctx, &target, cc, num_warps, num_stages, n_shared_bytes);
 | 
				
			||||||
@@ -485,7 +485,7 @@ std::tuple<std::string, asm_map_t, int> cu_compile_ttir(const std::string& name,
 | 
				
			|||||||
  std::string ptx = drv::llir_to_ptx(llvm.get(), cc, version);
 | 
					  std::string ptx = drv::llir_to_ptx(llvm.get(), cc, version);
 | 
				
			||||||
  asm_map["ptx"] = py::cast(ptx);
 | 
					  asm_map["ptx"] = py::cast(ptx);
 | 
				
			||||||
  // PTX -> Binary
 | 
					  // PTX -> Binary
 | 
				
			||||||
  std::string cubin = drv::ptx_to_cubin(ptx, cc);
 | 
					  std::string cubin = drv::ptx_to_cubin(ptx, ptxas_path, cc);
 | 
				
			||||||
  if(!cubin.empty()){
 | 
					  if(!cubin.empty()){
 | 
				
			||||||
    py::bytes bytes(cubin);
 | 
					    py::bytes bytes(cubin);
 | 
				
			||||||
    asm_map["cubin"] = bytes;
 | 
					    asm_map["cubin"] = bytes;
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -556,7 +556,7 @@ def dot(input, other, allow_tf32=True, _builder=None):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@builtin
 | 
					@builtin
 | 
				
			||||||
def load(pointer, mask=None, other=None, cache_modifier="", volatile=False, _builder=None):
 | 
					def load(pointer, mask=None, other=None, cache_modifier="", eviction_policy="", volatile=False, _builder=None):
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    Return a block of data whose values are, elementwise, loaded from memory at location defined by :code:`pointer`.
 | 
					    Return a block of data whose values are, elementwise, loaded from memory at location defined by :code:`pointer`.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -573,7 +573,7 @@ def load(pointer, mask=None, other=None, cache_modifier="", volatile=False, _bui
 | 
				
			|||||||
    :param cache_modifier: changes cache option in nvidia ptx
 | 
					    :param cache_modifier: changes cache option in nvidia ptx
 | 
				
			||||||
    'type cache_modifier: str, optional
 | 
					    'type cache_modifier: str, optional
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    return frontend.load(pointer, mask, other, cache_modifier, volatile, _builder)
 | 
					    return frontend.load(pointer, mask, other, cache_modifier, eviction_policy, volatile, _builder)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@builtin
 | 
					@builtin
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user