[CODEGEN] Improvements and bugfixes (#463)

This commit is contained in:
Philippe Tillet
2022-02-24 14:56:24 -08:00
committed by GitHub
parent a9dfdcaaa9
commit 98ed7db8c1
14 changed files with 154 additions and 81 deletions

View File

@@ -110,6 +110,15 @@ libLLVMBinaryFormat.a
libLLVMAMDGPUInfo.a
libLLVMSupport.a
libLLVMDemangle.a
libLLVMPasses.a
libLLVMAnalysis.a
libLLVMTransformUtils.a
libLLVMScalarOpts.a
libLLVMTransformUtils.a
libLLVMipo.a
libLLVMObjCARCOpts.a
libLLVMCoroutines.a
libLLVMAnalysis.a
)
endif()
include_directories("${LLVM_INCLUDE_DIRS}")
@@ -148,7 +157,7 @@ target_link_options(triton PRIVATE ${LLVM_LDFLAGS})
if(WIN32)
target_link_libraries(triton PRIVATE ${LLVM_LIBRARIES} dl) # dl is from dlfcn-win32
else()
target_link_libraries(triton ${LLVM_LIBRARIES} z ${TERMINFO_LIBRARY})
target_link_libraries(triton ${LLVM_LIBRARIES} z)
endif()

View File

@@ -9,8 +9,9 @@ namespace triton{
namespace driver{
void init_llvm();
std::string path_to_ptxas(int& version);
std::string llir_to_ptx(llvm::Module* module, int cc, int version);
std::string ptx_to_cubin(const std::string& ptx, int cc);
std::string ptx_to_cubin(const std::string& ptx, const std::string& ptxas_path, int cc);
CUmodule ptx_to_cumodule(const std::string& ptx, int cc);
std::string llir_to_amdgpu(llvm::Module* module, const std::string& proc);
hipModule_t amdgpu_to_hipmodule(const std::string& path);

View File

@@ -136,9 +136,9 @@ public:
value *create_xor(value *lhs, value *rhs);
value *create_or(value *lhs, value *rhs);
// Input/Output
value *create_load(value *arg, load_inst::CACHE_MODIFIER cache, bool is_volatile);
value *create_load(value *arg, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile);
value *create_store(value *ptr, value *val);
value *create_masked_load(value *arg, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, bool is_volatile);
value *create_masked_load(value *arg, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile);
value *create_masked_store(value *ptr, value *val, value *mask);
// Block instruction
value *create_splat(value *arg, const type::block_shapes_t &shapes);
@@ -163,7 +163,7 @@ public:
// These have no place in the IR, and hopefully they can be removed at some point
value *create_umulhi(value* lhs, value* rhs);
value *create_copy_to_shared(value *arg);
value *create_masked_load_async(value *arg, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache);
value *create_masked_load_async(value *arg, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY);
value *create_copy_from_shared(value *arg);
value *create_barrier(const std::string &name = "");
value *create_async_wait(int N);

View File

@@ -69,7 +69,8 @@ struct dispatch{
static ir::value *cast(ir::value *input, ir::type *type, ir::builder *builder);
// memory operators
static ir::value *load(ir::value* ptr, ir::value* mask, ir::value* other, const std::string &cache, int is_volatile, ir::builder *builder);
static ir::value *load(ir::value* ptr, ir::value* mask, ir::value* other, const std::string &cache,
const std::string& eviction_policy, int is_volatile, ir::builder *builder);
static ir::value *store(ir::value* ptr, ir::value *value, ir::value *mask, ir::builder *builder);
static ir::value *atomic_cas(ir::value* ptr, ir::value *cmp, ir::value *val, ir::builder *builder);
static ir::value *atomic_add(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);

View File

@@ -408,11 +408,18 @@ public:
CG,
};
enum EVICTION_POLICY : uint32_t {
NORMAL=0,
EVICT_FIRST,
EVICT_LAST,
};
CACHE_MODIFIER get_cache_modifier() const { return cache_; }
EVICTION_POLICY get_eviction_policy() const { return eviction_; }
bool get_is_volatile() const { return is_volatile_; }
protected:
load_inst(value *ptr, value_id_t id, unsigned num_ops, CACHE_MODIFIER cache,
load_inst(value *ptr, value_id_t id, unsigned num_ops, CACHE_MODIFIER cache, EVICTION_POLICY eviction,
bool is_volatile,
const std::string &name = "", instruction *next = nullptr);
std::string get_cache_modifier_repr() const {
@@ -420,6 +427,11 @@ protected:
if (cache_ == CG) return ".cg";
return "";
}
std::string get_eviction_policy_repr() const {
if (eviction_ == EVICT_FIRST) return ".L1::evict_first";
if (eviction_ == EVICT_LAST) return ".L2::evict_last";
}
EVICTION_POLICY eviction_;
CACHE_MODIFIER cache_;
std::string get_volatile_repr() {
@@ -435,11 +447,12 @@ private:
class unmasked_load_inst: public load_inst {
private:
std::string repr_impl() const { return "unmasked_load" + get_cache_modifier_repr(); }
unmasked_load_inst(value *ptr, load_inst::CACHE_MODIFIER cache, bool is_volatile, const std::string &name, instruction *next);
unmasked_load_inst(value *ptr, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile, const std::string &name, instruction *next);
public:
static unmasked_load_inst* create(value *ptr,
CACHE_MODIFIER cache, bool is_volatile,
CACHE_MODIFIER cache, EVICTION_POLICY eviction,
bool is_volatile,
const std::string &name = "",
instruction *next = nullptr);
_TRITON_DEFINE_CLONE(unmasked_load_inst)
@@ -450,7 +463,7 @@ public:
class masked_load_inst: public load_inst {
private:
std::string repr_impl() const { return "masked_load" + get_cache_modifier_repr(); }
masked_load_inst(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, bool is_volatile,
masked_load_inst(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile,
const std::string &name, instruction *next);
public:
@@ -459,7 +472,8 @@ public:
value *get_false_value_operand() { return get_operand(2); }
// factory method
static masked_load_inst* create(value *ptr, value *mask, value *false_value,
CACHE_MODIFIER cache, bool is_volatile,
CACHE_MODIFIER cache, EVICTION_POLICY eviction,
bool is_volatile,
const std::string &name = "",
instruction *next = nullptr);
_TRITON_DEFINE_CLONE(masked_load_inst)
@@ -470,8 +484,9 @@ public:
class masked_load_async_inst: public load_inst {
private:
std::string repr_impl() const { return "masked_load_async" + get_cache_modifier_repr(); }
masked_load_async_inst(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache,
const std::string &name, instruction *next);
masked_load_async_inst(value *ptr, value *mask, value *false_value,
CACHE_MODIFIER cache, EVICTION_POLICY eviction,
const std::string &name, instruction *next);
public:
// accessors
@@ -480,6 +495,7 @@ public:
// factory method
static masked_load_async_inst* create(value *ptr, value *mask, value *false_value,
load_inst::CACHE_MODIFIER cache,
EVICTION_POLICY eviction,
const std::string &name = "",
instruction *next = nullptr);
_TRITON_DEFINE_CLONE(masked_load_async_inst)

View File

@@ -119,7 +119,7 @@ Value* geper::operator()(Value *ptr, Value* off, const std::string& name){
#define icmp_ult(...) builder_->CreateICmpULT(__VA_ARGS__)
#define insert_elt(...) builder_->CreateInsertElement(__VA_ARGS__)
#define intrinsic(...) builder_->CreateIntrinsic(__VA_ARGS__)
#define load(...) builder_->CreateLoad(__VA_ARGS__)
#define load(ptr) builder_->CreateLoad(ptr->getType()->getPointerElementType(), ptr)
#define lshr(...) builder_->CreateLShr(__VA_ARGS__)
#define max_num(...) builder_->CreateMaxNum(__VA_ARGS__)
#define min_num(...) builder_->CreateMinNum(__VA_ARGS__)
@@ -576,18 +576,19 @@ void generator::visit_cast_inst(ir::cast_inst* x) {
// <> BF16
if(ret_sca_ty->is_bf16_ty() || op_sca_ty->is_bf16_ty()){
// FP32 -> BF16
if(op_sca_ty->is_fp32_ty())
// for(size_t i = 0; i < x_idxs.size(); i++)
// vals_[x][x_idxs[i + 0]] = fp32_to_bf16(vals_[op][op_idxs[i + 0]]);
if(op_sca_ty->is_fp32_ty()){
for (indices_t idx: idxs_.at(x)) {
Value *arg = vals_[x->get_operand(0)][idx];
vals_[x][idx] = fp32_to_bf16(arg); // cast(cvt(x->get_op()), arg, ty);
}
return;
}
// BF16 -> FP32
if(ret_sca_ty->is_fp32_ty())
if(ret_sca_ty->is_fp32_ty()){
for(size_t i = 0; i < x_idxs.size(); i++)
vals_[x][x_idxs[i + 0]] = bf16_to_fp32(vals_[op][op_idxs[i + 0]]);
return;
return;
}
}
@@ -697,12 +698,13 @@ void generator::visit_load_inst(ir::load_inst* x){
std::ostringstream asm_oss;
asm_oss << "@$" << n_words; // predicate
asm_oss << " ld";
// std::cout << x->get_is_volatile() << std::endl;
if(x->get_is_volatile())
asm_oss << ".volatile";
asm_oss << ".global";
if (x->get_cache_modifier() == ir::load_inst::CA) asm_oss << ".ca";
if (x->get_cache_modifier() == ir::load_inst::CG) asm_oss << ".cg";
if (x->get_eviction_policy() == ir::load_inst::EVICT_LAST) asm_oss << ".L1::evict_last";
if (x->get_eviction_policy() == ir::load_inst::EVICT_FIRST) asm_oss << ".L1::evict_first";
if(n_words > 1)
asm_oss << ".v" << n_words; // vector width
asm_oss << ".b" << width; // word size

View File

@@ -123,7 +123,7 @@ bool peephole::rewrite_load_to_shared(ir::instruction *value, ir::builder& build
int nts = layout->nts(layout->get_order()[0]);
int dtsize = value->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
if(nts*dtsize >= 4){
ir::value* new_load = builder.create_masked_load_async(ptr, msk, val, ld->get_cache_modifier());
ir::value* new_load = builder.create_masked_load_async(ptr, msk, val, ld->get_cache_modifier(), ld->get_eviction_policy());
copy_to_shared->replace_all_uses_with(new_load);
return true;
}
@@ -215,6 +215,7 @@ bool peephole::rewrite_select_masked_load(ir::instruction *value, ir::builder& b
if_value->get_mask_operand(),
select->get_else_value_op(),
if_value->get_cache_modifier(),
if_value->get_eviction_policy(),
if_value->get_is_volatile());
select->replace_all_uses_with(new_load);
return true;

View File

@@ -178,7 +178,7 @@ void pipeline::run(ir::module &mod) {
false_value = remat_false_value;
} else
false_value = builder.create_splat(ir::undef_value::get(ty->get_scalar_ty()), ty->get_block_shapes());
first_loads[0] = builder.create_masked_load(first_ptrs[0], first_masks[0], false_value, load->get_cache_modifier(), load->get_is_volatile());
first_loads[0] = builder.create_masked_load(first_ptrs[0], first_masks[0], false_value, load->get_cache_modifier(), load->get_eviction_policy(), load->get_is_volatile());
for (int stage = 1; stage < num_stages-1; ++stage) {
// mask is the loop condition of the previous iteration
@@ -193,7 +193,7 @@ void pipeline::run(ir::module &mod) {
first_masks[stage] = builder.create_and(first_masks[stage], remat_mask);
false_value = remat_false_value;
}
first_loads[stage] = builder.create_masked_load(first_ptrs[stage], first_masks[stage], false_value, load->get_cache_modifier(), load->get_is_volatile());
first_loads[stage] = builder.create_masked_load(first_ptrs[stage], first_masks[stage], false_value, load->get_cache_modifier(), load->get_eviction_policy(), load->get_is_volatile());
}
// create new phis for induction variables
@@ -222,7 +222,7 @@ void pipeline::run(ir::module &mod) {
next_mask = builder.create_and(next_mask, remat_mask);
false_value = remat_false_value;
}
ir::value* next_load = builder.create_masked_load(next_ptr, next_mask, false_value, load->get_cache_modifier(), load->get_is_volatile());
ir::value* next_load = builder.create_masked_load(next_ptr, next_mask, false_value, load->get_cache_modifier(), load->get_eviction_policy(), load->get_is_volatile());
// phi node
@@ -257,7 +257,7 @@ void pipeline::run(ir::module &mod) {
}
else
false_value = builder.create_splat(ir::undef_value::get(ty->get_scalar_ty()), ty->get_block_shapes());
ir::value* first_load = builder.create_masked_load(first_ptr, first_mask, false_value, load->get_cache_modifier(), load->get_is_volatile());
ir::value* first_load = builder.create_masked_load(first_ptr, first_mask, false_value, load->get_cache_modifier(), load->get_eviction_policy(), load->get_is_volatile());
// pre-fetch next iteration
builder.set_insert_point(block->get_inst_list().back());
ir::value* next_ptr = ptr->get_value_for_block(block);
@@ -268,7 +268,7 @@ void pipeline::run(ir::module &mod) {
next_mask = builder.create_and(next_mask, remat_mask);
false_value = remat_false_value;
}
ir::value* next_load = builder.create_masked_load(next_ptr, next_mask, false_value, load->get_cache_modifier(), load->get_is_volatile());
ir::value* next_load = builder.create_masked_load(next_ptr, next_mask, false_value, load->get_cache_modifier(), load->get_eviction_policy(), load->get_is_volatile());
// phi node
builder.set_insert_point(block->get_first_non_phi());
ir::phi_node* new_load = builder.create_phi(ty, 2);

View File

@@ -59,6 +59,13 @@
#include "llvm/Analysis/TargetLibraryInfo.h"
// end AMD stuff
extern "C"{
int set_curterm(char* nterm){ return 0; }
int del_curterm(char* nterm){ return 0; }
int tigetnum(char *capname) { return 0; }
int setupterm(char *term, int fildes, int *errret) { return 0; }
}
namespace triton{
namespace driver{
@@ -77,6 +84,7 @@ void init_llvm() {
}
}
/* ------------------------ */
// CUDA //
/* ------------------------ */
@@ -89,7 +97,42 @@ static bool find_and_replace(std::string& str, const std::string& begin, const s
return true;
}
std::string path_to_ptxas(int& version) {
std::string ret;
// search pathes for ptxas
std::vector<std::string> ptxas_prefixes = {"", "/usr/local/cuda/bin/"};
std::string triton_ptxas = tools::getenv("TRITON_PTXAS_PATH");
if(!triton_ptxas.empty())
ptxas_prefixes.insert(ptxas_prefixes.begin(), triton_ptxas);
// see what path for ptxas are valid
std::vector<std::string> working_ptxas;
for(std::string prefix: ptxas_prefixes){
std::string ptxas = prefix + "ptxas";
bool works = tools::exec(ptxas + " --version 2>&1", ret) == 0;
if(works)
working_ptxas.push_back(ptxas);
}
// error if no working ptxas was found
if(working_ptxas.empty())
throw std::runtime_error("`ptxas` was searched in TRITON_PTXAS_PATH, /usr/local/cuda/bin/ or PATH"
" but a working version could not be found.");
std::string ptxas = working_ptxas.front();
// parse version
std::regex version_regex("release (\\d+)\\.(\\d+)");
std::smatch match;
if(std::regex_search(ret, match, version_regex)){
int major = std::stoi(match[1]);
int minor = std::stoi(match[2]);
version = major*1000 + minor*10;
}
else
throw std::runtime_error("couldn't parse ptxas version: " + ret);
return ptxas;
}
int vptx(int version){
if(version >= 11040) return 74;
if(version >= 11030) return 73;
if(version >= 11020) return 72;
if(version >= 11010) return 71;
@@ -103,7 +146,7 @@ int vptx(int version){
std::string llir_to_ptx(llvm::Module* module, int cc, int version){
// LLVM version in use may not officially support target hardware
int max_nvvm_cc = 75;
int max_nvvm_ptx = 64;
int max_nvvm_ptx = 74;
// options
auto options = llvm::cl::getRegisteredOptions();
auto* short_ptr = static_cast<llvm::cl::opt<bool>*>(options["nvptx-short-ptr"]);
@@ -120,7 +163,8 @@ std::string llir_to_ptx(llvm::Module* module, int cc, int version){
std::string triple = "nvptx64-nvidia-cuda";
std::string proc = "sm_" + std::to_string(std::min(cc, max_nvvm_cc));
std::string layout = "";
std::string features = "+ptx" + std::to_string(std::min(ptx, max_nvvm_ptx));
std::string features = "";
// std::string features = "+ptx" + std::to_string(std::min(ptx, max_nvvm_ptx));
init_llvm();
// verify and store llvm
llvm::legacy::PassManager pm;
@@ -164,26 +208,7 @@ std::string llir_to_ptx(llvm::Module* module, int cc, int version){
return result;
}
std::string ptx_to_cubin(const std::string& ptx, int cc) {
std::string version;
// search pathes for ptxas
std::vector<std::string> ptxas_prefixes = {"", "/usr/local/cuda/bin/"};
std::string triton_ptxas = tools::getenv("TRITON_PTXAS_PATH");
if(!triton_ptxas.empty())
ptxas_prefixes.insert(ptxas_prefixes.begin(), triton_ptxas);
// see what path for ptxas are valid
std::vector<std::string> working_ptxas;
for(std::string prefix: ptxas_prefixes){
std::string ptxas = prefix + "ptxas";
bool works = tools::exec(ptxas + " --version 2>&1", version) == 0;
if(works)
working_ptxas.push_back(ptxas);
}
// error if no working ptxas was found
if(working_ptxas.empty())
throw std::runtime_error("`ptxas` was searched in TRITON_PTXAS_PATH, /usr/local/cuda/bin/ or PATH"
" but a working version could not be found.");
std::string ptxas = working_ptxas.front();
std::string ptx_to_cubin(const std::string& ptx, const std::string& ptxas, int cc) {
// compile ptx with ptxas
char _fsrc[L_tmpnam];
char _flog[L_tmpnam];

View File

@@ -291,16 +291,16 @@ DEFINE_FCMP_INSTR(UNE, cmp_pred_t::FCMP_UNE)
// load/store instructions
//===----------------------------------------------------------------------===//
value *builder::create_load(value *ptr, load_inst::CACHE_MODIFIER cache, bool is_volatile){
return insert(unmasked_load_inst::create(ptr, cache, is_volatile));
value *builder::create_load(value *ptr, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile){
return insert(unmasked_load_inst::create(ptr, cache, eviction, is_volatile));
}
value *builder::create_store(value *ptr, value *val){
return insert(unmasked_store_inst::create(ptr, val));
}
value *builder::create_masked_load(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, bool is_volatile){
return insert(masked_load_inst::create(ptr, mask, false_value, cache, is_volatile));
value *builder::create_masked_load(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile){
return insert(masked_load_inst::create(ptr, mask, false_value, cache, eviction, is_volatile));
}
value *builder::create_masked_store(value *ptr, value *val, value *mask){
@@ -404,8 +404,8 @@ value *builder::create_copy_from_shared(value *arg) {
return insert(copy_from_shared_inst::create(arg));
}
value *builder::create_masked_load_async(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache) {
return insert(masked_load_async_inst::create(ptr, mask, false_value, cache));
value *builder::create_masked_load_async(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction) {
return insert(masked_load_async_inst::create(ptr, mask, false_value, cache, eviction));
}
value *builder::create_barrier(const std::string &name) {

View File

@@ -506,6 +506,11 @@ ir::value *dispatch::cast(ir::value *input, ir::type *dst_ty, ir::builder *build
return input;
ir::type *src_sca_ty = src_ty->get_scalar_ty();
ir::type *dst_sca_ty = dst_ty->get_scalar_ty();
//
if((src_sca_ty->is_bf16_ty() && !dst_sca_ty->is_fp32_ty()) ||
(dst_sca_ty->is_bf16_ty() && !src_sca_ty->is_fp32_ty())){
return cast(cast(input, builder->get_float_ty(), builder), dst_sca_ty, builder);
}
// FP Truncation
bool truncate_fp = src_sca_ty->is_floating_point_ty() &&
dst_sca_ty->is_floating_point_ty() &&
@@ -569,18 +574,17 @@ ir::value *dispatch::cast(ir::value *input, ir::type *dst_ty, ir::builder *build
// Memory Operators
//===----------------------------------------------------------------------===//
ir::value *dispatch::load(ir::value* ptr, ir::value* mask, ir::value* other, const std::string &cache_modifier, int is_volatile, ir::builder* builder) {
ir::value *dispatch::load(ir::value* ptr, ir::value* mask, ir::value* other, const std::string &cache_modifier, const std::string& eviction_policy, int is_volatile, ir::builder* builder) {
if(!ptr->get_type()->get_scalar_ty()->is_pointer_ty())
throw semantic_error("Pointer argument of load instruction is " + ptr->get_type()->repr());
if(ptr->get_type()->is_block_ty()){
if(mask){
if(mask)
mask = dispatch::broadcast(mask, ptr->get_type()->get_block_shapes(), builder);
}
if(other){
if(other)
other = dispatch::broadcast(other, ptr->get_type()->get_block_shapes(), builder);
other = dispatch::cast(other, ptr->get_type()->get_scalar_ty()->get_pointer_element_ty(), builder);
}
}
if(other)
other = dispatch::cast(other, ptr->get_type()->get_scalar_ty()->get_pointer_element_ty(), builder);
ir::type *ptr_ty = ptr->get_type()->get_scalar_ty();
ir::type *elt_ty = ptr_ty->get_pointer_element_ty();
// treat bool* as int8*
@@ -599,8 +603,20 @@ ir::value *dispatch::load(ir::value* ptr, ir::value* mask, ir::value* other, con
else
throw std::runtime_error(std::string("Cache modifier ") + cache_modifier + " not supported");
}
// eviction policy
load_inst::EVICTION_POLICY eviction = load_inst::NORMAL; //default
if(!eviction_policy.empty()){
if (eviction_policy == "evict_last")
eviction = load_inst::EVICT_LAST;
else if(eviction_policy == "evict_first")
eviction = load_inst::EVICT_FIRST;
else
throw std::runtime_error(std::string("Eviction policy") + eviction_policy + " not supported");
}
if (!mask && !other)
return builder->create_load(ptr, cache, is_volatile);
return builder->create_load(ptr, cache, eviction, is_volatile);
if (!mask)
throw std::runtime_error("`other` cannot be provided without `mask`");
auto shape = ptr->get_type()->get_block_shapes();
@@ -609,7 +625,7 @@ ir::value *dispatch::load(ir::value* ptr, ir::value* mask, ir::value* other, con
if(ptr->get_type()->is_block_ty())
other = builder->create_splat(other, ptr->get_type()->get_block_shapes());
}
return builder->create_masked_load(ptr, mask, other, cache, is_volatile);
return builder->create_masked_load(ptr, mask, other, cache, eviction, is_volatile);
}
ir::value *dispatch::store(ir::value* ptr, ir::value *val, ir::value* mask, ir::builder *builder) {

View File

@@ -434,8 +434,8 @@ io_inst::io_inst(type *ty, value_id_t id, unsigned num_ops, const std::string &n
{ }
// load_inst
load_inst::load_inst(value *ptr, value_id_t id, unsigned num_ops, load_inst::CACHE_MODIFIER cache, bool is_volatile, const std::string &name, instruction *next)
: io_inst(get_pointee_type(ptr->get_type()), id, num_ops, name, next), cache_(cache), is_volatile_(is_volatile)
load_inst::load_inst(value *ptr, value_id_t id, unsigned num_ops, load_inst::CACHE_MODIFIER cache, EVICTION_POLICY eviction, bool is_volatile, const std::string &name, instruction *next)
: io_inst(get_pointee_type(ptr->get_type()), id, num_ops, name, next), cache_(cache), eviction_(eviction), is_volatile_(is_volatile)
{ }
// load
@@ -448,44 +448,46 @@ type *load_inst::get_pointee_type(type *ty) {
}
// unmasked_load
unmasked_load_inst::unmasked_load_inst(value *ptr, load_inst::CACHE_MODIFIER cache, bool is_volatile, const std::string &name, instruction *next)
: load_inst(ptr, INST_UNMASKED_LOAD, 1, cache, is_volatile, name, next) {
unmasked_load_inst::unmasked_load_inst(value *ptr, load_inst::CACHE_MODIFIER cache,load_inst::EVICTION_POLICY eviction, bool is_volatile, const std::string &name, instruction *next)
: load_inst(ptr, INST_UNMASKED_LOAD, 1, cache, eviction, is_volatile, name, next) {
set_operand(0, ptr);
}
unmasked_load_inst* unmasked_load_inst::create(value *ptr, load_inst::CACHE_MODIFIER cache, bool is_volatile, const std::string &name, instruction *next) {
return new unmasked_load_inst(ptr, cache, is_volatile, name, next);
unmasked_load_inst* unmasked_load_inst::create(value *ptr, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile, const std::string &name, instruction *next) {
return new unmasked_load_inst(ptr, cache, eviction, is_volatile, name, next);
}
// masked load
masked_load_inst::masked_load_inst(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, bool is_volatile,
masked_load_inst::masked_load_inst(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction,
bool is_volatile,
const std::string &name, instruction *next)
: load_inst(ptr, INST_MASKED_LOAD, 3, cache, is_volatile, name, next) {
: load_inst(ptr, INST_MASKED_LOAD, 3, cache, eviction, is_volatile, name, next) {
set_operand(0, ptr);
set_operand(1, mask);
set_operand(2, false_value);
}
masked_load_inst* masked_load_inst::create(value *ptr, value *mask, value *false_value,
load_inst::CACHE_MODIFIER cache, bool is_volatile,
load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction,
bool is_volatile,
const std::string &name, instruction *next) {
return new masked_load_inst(ptr, mask, false_value, cache, is_volatile, name, next);
return new masked_load_inst(ptr, mask, false_value, cache, eviction, is_volatile, name, next);
}
// masked load async
masked_load_async_inst::masked_load_async_inst(value *ptr, value *mask, value *false_value,
load_inst::CACHE_MODIFIER cache,
load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction,
const std::string &name, instruction *next)
: load_inst(ptr, INST_MASKED_LOAD_ASYNC, 3, cache, false, name, next) {
: load_inst(ptr, INST_MASKED_LOAD_ASYNC, 3, cache, eviction, false, name, next) {
set_operand(0, ptr);
set_operand(1, mask);
set_operand(2, false_value);
}
masked_load_async_inst* masked_load_async_inst::create(value *ptr, value *mask, value *false_value,
load_inst::CACHE_MODIFIER cache,
load_inst::CACHE_MODIFIER cache, EVICTION_POLICY eviction,
const std::string &name, instruction *next) {
return new masked_load_async_inst(ptr, mask, false_value, cache, name, next);
return new masked_load_async_inst(ptr, mask, false_value, cache, eviction, name, next);
}
// store

View File

@@ -472,7 +472,7 @@ std::tuple<std::string, asm_map_t, int> cu_compile_ttir(const std::string& name,
size_t minor = cuGetInfo<CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR>(dev);
size_t cc = major*10 + minor;
int version;
drv::dispatch::cuDriverGetVersion(&version);
std::string ptxas_path = drv::path_to_ptxas(version);
// Triton-IR -> NVPTX LLVM-IR
triton::codegen::nvidia_cu_target target(cc);
auto llvm = triton::codegen::add_passes_to_emit_bin(ir, ctx, &target, cc, num_warps, num_stages, n_shared_bytes);
@@ -485,7 +485,7 @@ std::tuple<std::string, asm_map_t, int> cu_compile_ttir(const std::string& name,
std::string ptx = drv::llir_to_ptx(llvm.get(), cc, version);
asm_map["ptx"] = py::cast(ptx);
// PTX -> Binary
std::string cubin = drv::ptx_to_cubin(ptx, cc);
std::string cubin = drv::ptx_to_cubin(ptx, ptxas_path, cc);
if(!cubin.empty()){
py::bytes bytes(cubin);
asm_map["cubin"] = bytes;

View File

@@ -556,7 +556,7 @@ def dot(input, other, allow_tf32=True, _builder=None):
@builtin
def load(pointer, mask=None, other=None, cache_modifier="", volatile=False, _builder=None):
def load(pointer, mask=None, other=None, cache_modifier="", eviction_policy="", volatile=False, _builder=None):
"""
Return a block of data whose values are, elementwise, loaded from memory at location defined by :code:`pointer`.
@@ -573,7 +573,7 @@ def load(pointer, mask=None, other=None, cache_modifier="", volatile=False, _bui
:param cache_modifier: changes cache option in nvidia ptx
'type cache_modifier: str, optional
"""
return frontend.load(pointer, mask, other, cache_modifier, volatile, _builder)
return frontend.load(pointer, mask, other, cache_modifier, eviction_policy, volatile, _builder)
@builtin