[CODEGEN] Improvements and bugfixes (#463)
This commit is contained in:
@@ -110,6 +110,15 @@ libLLVMBinaryFormat.a
|
||||
libLLVMAMDGPUInfo.a
|
||||
libLLVMSupport.a
|
||||
libLLVMDemangle.a
|
||||
libLLVMPasses.a
|
||||
libLLVMAnalysis.a
|
||||
libLLVMTransformUtils.a
|
||||
libLLVMScalarOpts.a
|
||||
libLLVMTransformUtils.a
|
||||
libLLVMipo.a
|
||||
libLLVMObjCARCOpts.a
|
||||
libLLVMCoroutines.a
|
||||
libLLVMAnalysis.a
|
||||
)
|
||||
endif()
|
||||
include_directories("${LLVM_INCLUDE_DIRS}")
|
||||
@@ -148,7 +157,7 @@ target_link_options(triton PRIVATE ${LLVM_LDFLAGS})
|
||||
if(WIN32)
|
||||
target_link_libraries(triton PRIVATE ${LLVM_LIBRARIES} dl) # dl is from dlfcn-win32
|
||||
else()
|
||||
target_link_libraries(triton ${LLVM_LIBRARIES} z ${TERMINFO_LIBRARY})
|
||||
target_link_libraries(triton ${LLVM_LIBRARIES} z)
|
||||
endif()
|
||||
|
||||
|
||||
|
@@ -9,8 +9,9 @@ namespace triton{
|
||||
namespace driver{
|
||||
|
||||
void init_llvm();
|
||||
std::string path_to_ptxas(int& version);
|
||||
std::string llir_to_ptx(llvm::Module* module, int cc, int version);
|
||||
std::string ptx_to_cubin(const std::string& ptx, int cc);
|
||||
std::string ptx_to_cubin(const std::string& ptx, const std::string& ptxas_path, int cc);
|
||||
CUmodule ptx_to_cumodule(const std::string& ptx, int cc);
|
||||
std::string llir_to_amdgpu(llvm::Module* module, const std::string& proc);
|
||||
hipModule_t amdgpu_to_hipmodule(const std::string& path);
|
||||
|
@@ -136,9 +136,9 @@ public:
|
||||
value *create_xor(value *lhs, value *rhs);
|
||||
value *create_or(value *lhs, value *rhs);
|
||||
// Input/Output
|
||||
value *create_load(value *arg, load_inst::CACHE_MODIFIER cache, bool is_volatile);
|
||||
value *create_load(value *arg, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile);
|
||||
value *create_store(value *ptr, value *val);
|
||||
value *create_masked_load(value *arg, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, bool is_volatile);
|
||||
value *create_masked_load(value *arg, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile);
|
||||
value *create_masked_store(value *ptr, value *val, value *mask);
|
||||
// Block instruction
|
||||
value *create_splat(value *arg, const type::block_shapes_t &shapes);
|
||||
@@ -163,7 +163,7 @@ public:
|
||||
// These have no place in the IR, and hopefully they can be removed at some point
|
||||
value *create_umulhi(value* lhs, value* rhs);
|
||||
value *create_copy_to_shared(value *arg);
|
||||
value *create_masked_load_async(value *arg, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache);
|
||||
value *create_masked_load_async(value *arg, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY);
|
||||
value *create_copy_from_shared(value *arg);
|
||||
value *create_barrier(const std::string &name = "");
|
||||
value *create_async_wait(int N);
|
||||
|
@@ -69,7 +69,8 @@ struct dispatch{
|
||||
static ir::value *cast(ir::value *input, ir::type *type, ir::builder *builder);
|
||||
|
||||
// memory operators
|
||||
static ir::value *load(ir::value* ptr, ir::value* mask, ir::value* other, const std::string &cache, int is_volatile, ir::builder *builder);
|
||||
static ir::value *load(ir::value* ptr, ir::value* mask, ir::value* other, const std::string &cache,
|
||||
const std::string& eviction_policy, int is_volatile, ir::builder *builder);
|
||||
static ir::value *store(ir::value* ptr, ir::value *value, ir::value *mask, ir::builder *builder);
|
||||
static ir::value *atomic_cas(ir::value* ptr, ir::value *cmp, ir::value *val, ir::builder *builder);
|
||||
static ir::value *atomic_add(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);
|
||||
|
@@ -408,11 +408,18 @@ public:
|
||||
CG,
|
||||
};
|
||||
|
||||
enum EVICTION_POLICY : uint32_t {
|
||||
NORMAL=0,
|
||||
EVICT_FIRST,
|
||||
EVICT_LAST,
|
||||
};
|
||||
|
||||
CACHE_MODIFIER get_cache_modifier() const { return cache_; }
|
||||
EVICTION_POLICY get_eviction_policy() const { return eviction_; }
|
||||
bool get_is_volatile() const { return is_volatile_; }
|
||||
|
||||
protected:
|
||||
load_inst(value *ptr, value_id_t id, unsigned num_ops, CACHE_MODIFIER cache,
|
||||
load_inst(value *ptr, value_id_t id, unsigned num_ops, CACHE_MODIFIER cache, EVICTION_POLICY eviction,
|
||||
bool is_volatile,
|
||||
const std::string &name = "", instruction *next = nullptr);
|
||||
std::string get_cache_modifier_repr() const {
|
||||
@@ -420,6 +427,11 @@ protected:
|
||||
if (cache_ == CG) return ".cg";
|
||||
return "";
|
||||
}
|
||||
std::string get_eviction_policy_repr() const {
|
||||
if (eviction_ == EVICT_FIRST) return ".L1::evict_first";
|
||||
if (eviction_ == EVICT_LAST) return ".L2::evict_last";
|
||||
}
|
||||
EVICTION_POLICY eviction_;
|
||||
CACHE_MODIFIER cache_;
|
||||
|
||||
std::string get_volatile_repr() {
|
||||
@@ -435,11 +447,12 @@ private:
|
||||
class unmasked_load_inst: public load_inst {
|
||||
private:
|
||||
std::string repr_impl() const { return "unmasked_load" + get_cache_modifier_repr(); }
|
||||
unmasked_load_inst(value *ptr, load_inst::CACHE_MODIFIER cache, bool is_volatile, const std::string &name, instruction *next);
|
||||
unmasked_load_inst(value *ptr, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile, const std::string &name, instruction *next);
|
||||
|
||||
public:
|
||||
static unmasked_load_inst* create(value *ptr,
|
||||
CACHE_MODIFIER cache, bool is_volatile,
|
||||
CACHE_MODIFIER cache, EVICTION_POLICY eviction,
|
||||
bool is_volatile,
|
||||
const std::string &name = "",
|
||||
instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(unmasked_load_inst)
|
||||
@@ -450,7 +463,7 @@ public:
|
||||
class masked_load_inst: public load_inst {
|
||||
private:
|
||||
std::string repr_impl() const { return "masked_load" + get_cache_modifier_repr(); }
|
||||
masked_load_inst(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, bool is_volatile,
|
||||
masked_load_inst(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile,
|
||||
const std::string &name, instruction *next);
|
||||
|
||||
public:
|
||||
@@ -459,7 +472,8 @@ public:
|
||||
value *get_false_value_operand() { return get_operand(2); }
|
||||
// factory method
|
||||
static masked_load_inst* create(value *ptr, value *mask, value *false_value,
|
||||
CACHE_MODIFIER cache, bool is_volatile,
|
||||
CACHE_MODIFIER cache, EVICTION_POLICY eviction,
|
||||
bool is_volatile,
|
||||
const std::string &name = "",
|
||||
instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(masked_load_inst)
|
||||
@@ -470,8 +484,9 @@ public:
|
||||
class masked_load_async_inst: public load_inst {
|
||||
private:
|
||||
std::string repr_impl() const { return "masked_load_async" + get_cache_modifier_repr(); }
|
||||
masked_load_async_inst(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache,
|
||||
const std::string &name, instruction *next);
|
||||
masked_load_async_inst(value *ptr, value *mask, value *false_value,
|
||||
CACHE_MODIFIER cache, EVICTION_POLICY eviction,
|
||||
const std::string &name, instruction *next);
|
||||
|
||||
public:
|
||||
// accessors
|
||||
@@ -480,6 +495,7 @@ public:
|
||||
// factory method
|
||||
static masked_load_async_inst* create(value *ptr, value *mask, value *false_value,
|
||||
load_inst::CACHE_MODIFIER cache,
|
||||
EVICTION_POLICY eviction,
|
||||
const std::string &name = "",
|
||||
instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(masked_load_async_inst)
|
||||
|
@@ -119,7 +119,7 @@ Value* geper::operator()(Value *ptr, Value* off, const std::string& name){
|
||||
#define icmp_ult(...) builder_->CreateICmpULT(__VA_ARGS__)
|
||||
#define insert_elt(...) builder_->CreateInsertElement(__VA_ARGS__)
|
||||
#define intrinsic(...) builder_->CreateIntrinsic(__VA_ARGS__)
|
||||
#define load(...) builder_->CreateLoad(__VA_ARGS__)
|
||||
#define load(ptr) builder_->CreateLoad(ptr->getType()->getPointerElementType(), ptr)
|
||||
#define lshr(...) builder_->CreateLShr(__VA_ARGS__)
|
||||
#define max_num(...) builder_->CreateMaxNum(__VA_ARGS__)
|
||||
#define min_num(...) builder_->CreateMinNum(__VA_ARGS__)
|
||||
@@ -576,18 +576,19 @@ void generator::visit_cast_inst(ir::cast_inst* x) {
|
||||
// <> BF16
|
||||
if(ret_sca_ty->is_bf16_ty() || op_sca_ty->is_bf16_ty()){
|
||||
// FP32 -> BF16
|
||||
if(op_sca_ty->is_fp32_ty())
|
||||
// for(size_t i = 0; i < x_idxs.size(); i++)
|
||||
// vals_[x][x_idxs[i + 0]] = fp32_to_bf16(vals_[op][op_idxs[i + 0]]);
|
||||
if(op_sca_ty->is_fp32_ty()){
|
||||
for (indices_t idx: idxs_.at(x)) {
|
||||
Value *arg = vals_[x->get_operand(0)][idx];
|
||||
vals_[x][idx] = fp32_to_bf16(arg); // cast(cvt(x->get_op()), arg, ty);
|
||||
}
|
||||
return;
|
||||
}
|
||||
// BF16 -> FP32
|
||||
if(ret_sca_ty->is_fp32_ty())
|
||||
if(ret_sca_ty->is_fp32_ty()){
|
||||
for(size_t i = 0; i < x_idxs.size(); i++)
|
||||
vals_[x][x_idxs[i + 0]] = bf16_to_fp32(vals_[op][op_idxs[i + 0]]);
|
||||
return;
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -697,12 +698,13 @@ void generator::visit_load_inst(ir::load_inst* x){
|
||||
std::ostringstream asm_oss;
|
||||
asm_oss << "@$" << n_words; // predicate
|
||||
asm_oss << " ld";
|
||||
// std::cout << x->get_is_volatile() << std::endl;
|
||||
if(x->get_is_volatile())
|
||||
asm_oss << ".volatile";
|
||||
asm_oss << ".global";
|
||||
if (x->get_cache_modifier() == ir::load_inst::CA) asm_oss << ".ca";
|
||||
if (x->get_cache_modifier() == ir::load_inst::CG) asm_oss << ".cg";
|
||||
if (x->get_eviction_policy() == ir::load_inst::EVICT_LAST) asm_oss << ".L1::evict_last";
|
||||
if (x->get_eviction_policy() == ir::load_inst::EVICT_FIRST) asm_oss << ".L1::evict_first";
|
||||
if(n_words > 1)
|
||||
asm_oss << ".v" << n_words; // vector width
|
||||
asm_oss << ".b" << width; // word size
|
||||
|
@@ -123,7 +123,7 @@ bool peephole::rewrite_load_to_shared(ir::instruction *value, ir::builder& build
|
||||
int nts = layout->nts(layout->get_order()[0]);
|
||||
int dtsize = value->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
|
||||
if(nts*dtsize >= 4){
|
||||
ir::value* new_load = builder.create_masked_load_async(ptr, msk, val, ld->get_cache_modifier());
|
||||
ir::value* new_load = builder.create_masked_load_async(ptr, msk, val, ld->get_cache_modifier(), ld->get_eviction_policy());
|
||||
copy_to_shared->replace_all_uses_with(new_load);
|
||||
return true;
|
||||
}
|
||||
@@ -215,6 +215,7 @@ bool peephole::rewrite_select_masked_load(ir::instruction *value, ir::builder& b
|
||||
if_value->get_mask_operand(),
|
||||
select->get_else_value_op(),
|
||||
if_value->get_cache_modifier(),
|
||||
if_value->get_eviction_policy(),
|
||||
if_value->get_is_volatile());
|
||||
select->replace_all_uses_with(new_load);
|
||||
return true;
|
||||
|
@@ -178,7 +178,7 @@ void pipeline::run(ir::module &mod) {
|
||||
false_value = remat_false_value;
|
||||
} else
|
||||
false_value = builder.create_splat(ir::undef_value::get(ty->get_scalar_ty()), ty->get_block_shapes());
|
||||
first_loads[0] = builder.create_masked_load(first_ptrs[0], first_masks[0], false_value, load->get_cache_modifier(), load->get_is_volatile());
|
||||
first_loads[0] = builder.create_masked_load(first_ptrs[0], first_masks[0], false_value, load->get_cache_modifier(), load->get_eviction_policy(), load->get_is_volatile());
|
||||
|
||||
for (int stage = 1; stage < num_stages-1; ++stage) {
|
||||
// mask is the loop condition of the previous iteration
|
||||
@@ -193,7 +193,7 @@ void pipeline::run(ir::module &mod) {
|
||||
first_masks[stage] = builder.create_and(first_masks[stage], remat_mask);
|
||||
false_value = remat_false_value;
|
||||
}
|
||||
first_loads[stage] = builder.create_masked_load(first_ptrs[stage], first_masks[stage], false_value, load->get_cache_modifier(), load->get_is_volatile());
|
||||
first_loads[stage] = builder.create_masked_load(first_ptrs[stage], first_masks[stage], false_value, load->get_cache_modifier(), load->get_eviction_policy(), load->get_is_volatile());
|
||||
}
|
||||
|
||||
// create new phis for induction variables
|
||||
@@ -222,7 +222,7 @@ void pipeline::run(ir::module &mod) {
|
||||
next_mask = builder.create_and(next_mask, remat_mask);
|
||||
false_value = remat_false_value;
|
||||
}
|
||||
ir::value* next_load = builder.create_masked_load(next_ptr, next_mask, false_value, load->get_cache_modifier(), load->get_is_volatile());
|
||||
ir::value* next_load = builder.create_masked_load(next_ptr, next_mask, false_value, load->get_cache_modifier(), load->get_eviction_policy(), load->get_is_volatile());
|
||||
|
||||
|
||||
// phi node
|
||||
@@ -257,7 +257,7 @@ void pipeline::run(ir::module &mod) {
|
||||
}
|
||||
else
|
||||
false_value = builder.create_splat(ir::undef_value::get(ty->get_scalar_ty()), ty->get_block_shapes());
|
||||
ir::value* first_load = builder.create_masked_load(first_ptr, first_mask, false_value, load->get_cache_modifier(), load->get_is_volatile());
|
||||
ir::value* first_load = builder.create_masked_load(first_ptr, first_mask, false_value, load->get_cache_modifier(), load->get_eviction_policy(), load->get_is_volatile());
|
||||
// pre-fetch next iteration
|
||||
builder.set_insert_point(block->get_inst_list().back());
|
||||
ir::value* next_ptr = ptr->get_value_for_block(block);
|
||||
@@ -268,7 +268,7 @@ void pipeline::run(ir::module &mod) {
|
||||
next_mask = builder.create_and(next_mask, remat_mask);
|
||||
false_value = remat_false_value;
|
||||
}
|
||||
ir::value* next_load = builder.create_masked_load(next_ptr, next_mask, false_value, load->get_cache_modifier(), load->get_is_volatile());
|
||||
ir::value* next_load = builder.create_masked_load(next_ptr, next_mask, false_value, load->get_cache_modifier(), load->get_eviction_policy(), load->get_is_volatile());
|
||||
// phi node
|
||||
builder.set_insert_point(block->get_first_non_phi());
|
||||
ir::phi_node* new_load = builder.create_phi(ty, 2);
|
||||
|
@@ -59,6 +59,13 @@
|
||||
#include "llvm/Analysis/TargetLibraryInfo.h"
|
||||
// end AMD stuff
|
||||
|
||||
extern "C"{
|
||||
int set_curterm(char* nterm){ return 0; }
|
||||
int del_curterm(char* nterm){ return 0; }
|
||||
int tigetnum(char *capname) { return 0; }
|
||||
int setupterm(char *term, int fildes, int *errret) { return 0; }
|
||||
}
|
||||
|
||||
namespace triton{
|
||||
namespace driver{
|
||||
|
||||
@@ -77,6 +84,7 @@ void init_llvm() {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/* ------------------------ */
|
||||
// CUDA //
|
||||
/* ------------------------ */
|
||||
@@ -89,7 +97,42 @@ static bool find_and_replace(std::string& str, const std::string& begin, const s
|
||||
return true;
|
||||
}
|
||||
|
||||
std::string path_to_ptxas(int& version) {
|
||||
std::string ret;
|
||||
// search pathes for ptxas
|
||||
std::vector<std::string> ptxas_prefixes = {"", "/usr/local/cuda/bin/"};
|
||||
std::string triton_ptxas = tools::getenv("TRITON_PTXAS_PATH");
|
||||
if(!triton_ptxas.empty())
|
||||
ptxas_prefixes.insert(ptxas_prefixes.begin(), triton_ptxas);
|
||||
// see what path for ptxas are valid
|
||||
std::vector<std::string> working_ptxas;
|
||||
for(std::string prefix: ptxas_prefixes){
|
||||
std::string ptxas = prefix + "ptxas";
|
||||
bool works = tools::exec(ptxas + " --version 2>&1", ret) == 0;
|
||||
if(works)
|
||||
working_ptxas.push_back(ptxas);
|
||||
}
|
||||
// error if no working ptxas was found
|
||||
if(working_ptxas.empty())
|
||||
throw std::runtime_error("`ptxas` was searched in TRITON_PTXAS_PATH, /usr/local/cuda/bin/ or PATH"
|
||||
" but a working version could not be found.");
|
||||
std::string ptxas = working_ptxas.front();
|
||||
// parse version
|
||||
std::regex version_regex("release (\\d+)\\.(\\d+)");
|
||||
std::smatch match;
|
||||
if(std::regex_search(ret, match, version_regex)){
|
||||
int major = std::stoi(match[1]);
|
||||
int minor = std::stoi(match[2]);
|
||||
version = major*1000 + minor*10;
|
||||
}
|
||||
else
|
||||
throw std::runtime_error("couldn't parse ptxas version: " + ret);
|
||||
return ptxas;
|
||||
}
|
||||
|
||||
|
||||
int vptx(int version){
|
||||
if(version >= 11040) return 74;
|
||||
if(version >= 11030) return 73;
|
||||
if(version >= 11020) return 72;
|
||||
if(version >= 11010) return 71;
|
||||
@@ -103,7 +146,7 @@ int vptx(int version){
|
||||
std::string llir_to_ptx(llvm::Module* module, int cc, int version){
|
||||
// LLVM version in use may not officially support target hardware
|
||||
int max_nvvm_cc = 75;
|
||||
int max_nvvm_ptx = 64;
|
||||
int max_nvvm_ptx = 74;
|
||||
// options
|
||||
auto options = llvm::cl::getRegisteredOptions();
|
||||
auto* short_ptr = static_cast<llvm::cl::opt<bool>*>(options["nvptx-short-ptr"]);
|
||||
@@ -120,7 +163,8 @@ std::string llir_to_ptx(llvm::Module* module, int cc, int version){
|
||||
std::string triple = "nvptx64-nvidia-cuda";
|
||||
std::string proc = "sm_" + std::to_string(std::min(cc, max_nvvm_cc));
|
||||
std::string layout = "";
|
||||
std::string features = "+ptx" + std::to_string(std::min(ptx, max_nvvm_ptx));
|
||||
std::string features = "";
|
||||
// std::string features = "+ptx" + std::to_string(std::min(ptx, max_nvvm_ptx));
|
||||
init_llvm();
|
||||
// verify and store llvm
|
||||
llvm::legacy::PassManager pm;
|
||||
@@ -164,26 +208,7 @@ std::string llir_to_ptx(llvm::Module* module, int cc, int version){
|
||||
return result;
|
||||
}
|
||||
|
||||
std::string ptx_to_cubin(const std::string& ptx, int cc) {
|
||||
std::string version;
|
||||
// search pathes for ptxas
|
||||
std::vector<std::string> ptxas_prefixes = {"", "/usr/local/cuda/bin/"};
|
||||
std::string triton_ptxas = tools::getenv("TRITON_PTXAS_PATH");
|
||||
if(!triton_ptxas.empty())
|
||||
ptxas_prefixes.insert(ptxas_prefixes.begin(), triton_ptxas);
|
||||
// see what path for ptxas are valid
|
||||
std::vector<std::string> working_ptxas;
|
||||
for(std::string prefix: ptxas_prefixes){
|
||||
std::string ptxas = prefix + "ptxas";
|
||||
bool works = tools::exec(ptxas + " --version 2>&1", version) == 0;
|
||||
if(works)
|
||||
working_ptxas.push_back(ptxas);
|
||||
}
|
||||
// error if no working ptxas was found
|
||||
if(working_ptxas.empty())
|
||||
throw std::runtime_error("`ptxas` was searched in TRITON_PTXAS_PATH, /usr/local/cuda/bin/ or PATH"
|
||||
" but a working version could not be found.");
|
||||
std::string ptxas = working_ptxas.front();
|
||||
std::string ptx_to_cubin(const std::string& ptx, const std::string& ptxas, int cc) {
|
||||
// compile ptx with ptxas
|
||||
char _fsrc[L_tmpnam];
|
||||
char _flog[L_tmpnam];
|
||||
|
@@ -291,16 +291,16 @@ DEFINE_FCMP_INSTR(UNE, cmp_pred_t::FCMP_UNE)
|
||||
// load/store instructions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
value *builder::create_load(value *ptr, load_inst::CACHE_MODIFIER cache, bool is_volatile){
|
||||
return insert(unmasked_load_inst::create(ptr, cache, is_volatile));
|
||||
value *builder::create_load(value *ptr, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile){
|
||||
return insert(unmasked_load_inst::create(ptr, cache, eviction, is_volatile));
|
||||
}
|
||||
|
||||
value *builder::create_store(value *ptr, value *val){
|
||||
return insert(unmasked_store_inst::create(ptr, val));
|
||||
}
|
||||
|
||||
value *builder::create_masked_load(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, bool is_volatile){
|
||||
return insert(masked_load_inst::create(ptr, mask, false_value, cache, is_volatile));
|
||||
value *builder::create_masked_load(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile){
|
||||
return insert(masked_load_inst::create(ptr, mask, false_value, cache, eviction, is_volatile));
|
||||
}
|
||||
|
||||
value *builder::create_masked_store(value *ptr, value *val, value *mask){
|
||||
@@ -404,8 +404,8 @@ value *builder::create_copy_from_shared(value *arg) {
|
||||
return insert(copy_from_shared_inst::create(arg));
|
||||
}
|
||||
|
||||
value *builder::create_masked_load_async(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache) {
|
||||
return insert(masked_load_async_inst::create(ptr, mask, false_value, cache));
|
||||
value *builder::create_masked_load_async(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction) {
|
||||
return insert(masked_load_async_inst::create(ptr, mask, false_value, cache, eviction));
|
||||
}
|
||||
|
||||
value *builder::create_barrier(const std::string &name) {
|
||||
|
@@ -506,6 +506,11 @@ ir::value *dispatch::cast(ir::value *input, ir::type *dst_ty, ir::builder *build
|
||||
return input;
|
||||
ir::type *src_sca_ty = src_ty->get_scalar_ty();
|
||||
ir::type *dst_sca_ty = dst_ty->get_scalar_ty();
|
||||
//
|
||||
if((src_sca_ty->is_bf16_ty() && !dst_sca_ty->is_fp32_ty()) ||
|
||||
(dst_sca_ty->is_bf16_ty() && !src_sca_ty->is_fp32_ty())){
|
||||
return cast(cast(input, builder->get_float_ty(), builder), dst_sca_ty, builder);
|
||||
}
|
||||
// FP Truncation
|
||||
bool truncate_fp = src_sca_ty->is_floating_point_ty() &&
|
||||
dst_sca_ty->is_floating_point_ty() &&
|
||||
@@ -569,18 +574,17 @@ ir::value *dispatch::cast(ir::value *input, ir::type *dst_ty, ir::builder *build
|
||||
// Memory Operators
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
ir::value *dispatch::load(ir::value* ptr, ir::value* mask, ir::value* other, const std::string &cache_modifier, int is_volatile, ir::builder* builder) {
|
||||
ir::value *dispatch::load(ir::value* ptr, ir::value* mask, ir::value* other, const std::string &cache_modifier, const std::string& eviction_policy, int is_volatile, ir::builder* builder) {
|
||||
if(!ptr->get_type()->get_scalar_ty()->is_pointer_ty())
|
||||
throw semantic_error("Pointer argument of load instruction is " + ptr->get_type()->repr());
|
||||
if(ptr->get_type()->is_block_ty()){
|
||||
if(mask){
|
||||
if(mask)
|
||||
mask = dispatch::broadcast(mask, ptr->get_type()->get_block_shapes(), builder);
|
||||
}
|
||||
if(other){
|
||||
if(other)
|
||||
other = dispatch::broadcast(other, ptr->get_type()->get_block_shapes(), builder);
|
||||
other = dispatch::cast(other, ptr->get_type()->get_scalar_ty()->get_pointer_element_ty(), builder);
|
||||
}
|
||||
}
|
||||
if(other)
|
||||
other = dispatch::cast(other, ptr->get_type()->get_scalar_ty()->get_pointer_element_ty(), builder);
|
||||
ir::type *ptr_ty = ptr->get_type()->get_scalar_ty();
|
||||
ir::type *elt_ty = ptr_ty->get_pointer_element_ty();
|
||||
// treat bool* as int8*
|
||||
@@ -599,8 +603,20 @@ ir::value *dispatch::load(ir::value* ptr, ir::value* mask, ir::value* other, con
|
||||
else
|
||||
throw std::runtime_error(std::string("Cache modifier ") + cache_modifier + " not supported");
|
||||
}
|
||||
// eviction policy
|
||||
load_inst::EVICTION_POLICY eviction = load_inst::NORMAL; //default
|
||||
if(!eviction_policy.empty()){
|
||||
if (eviction_policy == "evict_last")
|
||||
eviction = load_inst::EVICT_LAST;
|
||||
else if(eviction_policy == "evict_first")
|
||||
eviction = load_inst::EVICT_FIRST;
|
||||
else
|
||||
throw std::runtime_error(std::string("Eviction policy") + eviction_policy + " not supported");
|
||||
}
|
||||
|
||||
|
||||
if (!mask && !other)
|
||||
return builder->create_load(ptr, cache, is_volatile);
|
||||
return builder->create_load(ptr, cache, eviction, is_volatile);
|
||||
if (!mask)
|
||||
throw std::runtime_error("`other` cannot be provided without `mask`");
|
||||
auto shape = ptr->get_type()->get_block_shapes();
|
||||
@@ -609,7 +625,7 @@ ir::value *dispatch::load(ir::value* ptr, ir::value* mask, ir::value* other, con
|
||||
if(ptr->get_type()->is_block_ty())
|
||||
other = builder->create_splat(other, ptr->get_type()->get_block_shapes());
|
||||
}
|
||||
return builder->create_masked_load(ptr, mask, other, cache, is_volatile);
|
||||
return builder->create_masked_load(ptr, mask, other, cache, eviction, is_volatile);
|
||||
}
|
||||
|
||||
ir::value *dispatch::store(ir::value* ptr, ir::value *val, ir::value* mask, ir::builder *builder) {
|
||||
|
@@ -434,8 +434,8 @@ io_inst::io_inst(type *ty, value_id_t id, unsigned num_ops, const std::string &n
|
||||
{ }
|
||||
|
||||
// load_inst
|
||||
load_inst::load_inst(value *ptr, value_id_t id, unsigned num_ops, load_inst::CACHE_MODIFIER cache, bool is_volatile, const std::string &name, instruction *next)
|
||||
: io_inst(get_pointee_type(ptr->get_type()), id, num_ops, name, next), cache_(cache), is_volatile_(is_volatile)
|
||||
load_inst::load_inst(value *ptr, value_id_t id, unsigned num_ops, load_inst::CACHE_MODIFIER cache, EVICTION_POLICY eviction, bool is_volatile, const std::string &name, instruction *next)
|
||||
: io_inst(get_pointee_type(ptr->get_type()), id, num_ops, name, next), cache_(cache), eviction_(eviction), is_volatile_(is_volatile)
|
||||
{ }
|
||||
|
||||
// load
|
||||
@@ -448,44 +448,46 @@ type *load_inst::get_pointee_type(type *ty) {
|
||||
}
|
||||
|
||||
// unmasked_load
|
||||
unmasked_load_inst::unmasked_load_inst(value *ptr, load_inst::CACHE_MODIFIER cache, bool is_volatile, const std::string &name, instruction *next)
|
||||
: load_inst(ptr, INST_UNMASKED_LOAD, 1, cache, is_volatile, name, next) {
|
||||
unmasked_load_inst::unmasked_load_inst(value *ptr, load_inst::CACHE_MODIFIER cache,load_inst::EVICTION_POLICY eviction, bool is_volatile, const std::string &name, instruction *next)
|
||||
: load_inst(ptr, INST_UNMASKED_LOAD, 1, cache, eviction, is_volatile, name, next) {
|
||||
set_operand(0, ptr);
|
||||
}
|
||||
|
||||
unmasked_load_inst* unmasked_load_inst::create(value *ptr, load_inst::CACHE_MODIFIER cache, bool is_volatile, const std::string &name, instruction *next) {
|
||||
return new unmasked_load_inst(ptr, cache, is_volatile, name, next);
|
||||
unmasked_load_inst* unmasked_load_inst::create(value *ptr, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile, const std::string &name, instruction *next) {
|
||||
return new unmasked_load_inst(ptr, cache, eviction, is_volatile, name, next);
|
||||
}
|
||||
|
||||
// masked load
|
||||
masked_load_inst::masked_load_inst(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, bool is_volatile,
|
||||
masked_load_inst::masked_load_inst(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction,
|
||||
bool is_volatile,
|
||||
const std::string &name, instruction *next)
|
||||
: load_inst(ptr, INST_MASKED_LOAD, 3, cache, is_volatile, name, next) {
|
||||
: load_inst(ptr, INST_MASKED_LOAD, 3, cache, eviction, is_volatile, name, next) {
|
||||
set_operand(0, ptr);
|
||||
set_operand(1, mask);
|
||||
set_operand(2, false_value);
|
||||
}
|
||||
|
||||
masked_load_inst* masked_load_inst::create(value *ptr, value *mask, value *false_value,
|
||||
load_inst::CACHE_MODIFIER cache, bool is_volatile,
|
||||
load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction,
|
||||
bool is_volatile,
|
||||
const std::string &name, instruction *next) {
|
||||
return new masked_load_inst(ptr, mask, false_value, cache, is_volatile, name, next);
|
||||
return new masked_load_inst(ptr, mask, false_value, cache, eviction, is_volatile, name, next);
|
||||
}
|
||||
|
||||
// masked load async
|
||||
masked_load_async_inst::masked_load_async_inst(value *ptr, value *mask, value *false_value,
|
||||
load_inst::CACHE_MODIFIER cache,
|
||||
load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction,
|
||||
const std::string &name, instruction *next)
|
||||
: load_inst(ptr, INST_MASKED_LOAD_ASYNC, 3, cache, false, name, next) {
|
||||
: load_inst(ptr, INST_MASKED_LOAD_ASYNC, 3, cache, eviction, false, name, next) {
|
||||
set_operand(0, ptr);
|
||||
set_operand(1, mask);
|
||||
set_operand(2, false_value);
|
||||
}
|
||||
|
||||
masked_load_async_inst* masked_load_async_inst::create(value *ptr, value *mask, value *false_value,
|
||||
load_inst::CACHE_MODIFIER cache,
|
||||
load_inst::CACHE_MODIFIER cache, EVICTION_POLICY eviction,
|
||||
const std::string &name, instruction *next) {
|
||||
return new masked_load_async_inst(ptr, mask, false_value, cache, name, next);
|
||||
return new masked_load_async_inst(ptr, mask, false_value, cache, eviction, name, next);
|
||||
}
|
||||
|
||||
// store
|
||||
|
@@ -472,7 +472,7 @@ std::tuple<std::string, asm_map_t, int> cu_compile_ttir(const std::string& name,
|
||||
size_t minor = cuGetInfo<CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR>(dev);
|
||||
size_t cc = major*10 + minor;
|
||||
int version;
|
||||
drv::dispatch::cuDriverGetVersion(&version);
|
||||
std::string ptxas_path = drv::path_to_ptxas(version);
|
||||
// Triton-IR -> NVPTX LLVM-IR
|
||||
triton::codegen::nvidia_cu_target target(cc);
|
||||
auto llvm = triton::codegen::add_passes_to_emit_bin(ir, ctx, &target, cc, num_warps, num_stages, n_shared_bytes);
|
||||
@@ -485,7 +485,7 @@ std::tuple<std::string, asm_map_t, int> cu_compile_ttir(const std::string& name,
|
||||
std::string ptx = drv::llir_to_ptx(llvm.get(), cc, version);
|
||||
asm_map["ptx"] = py::cast(ptx);
|
||||
// PTX -> Binary
|
||||
std::string cubin = drv::ptx_to_cubin(ptx, cc);
|
||||
std::string cubin = drv::ptx_to_cubin(ptx, ptxas_path, cc);
|
||||
if(!cubin.empty()){
|
||||
py::bytes bytes(cubin);
|
||||
asm_map["cubin"] = bytes;
|
||||
|
@@ -556,7 +556,7 @@ def dot(input, other, allow_tf32=True, _builder=None):
|
||||
|
||||
|
||||
@builtin
|
||||
def load(pointer, mask=None, other=None, cache_modifier="", volatile=False, _builder=None):
|
||||
def load(pointer, mask=None, other=None, cache_modifier="", eviction_policy="", volatile=False, _builder=None):
|
||||
"""
|
||||
Return a block of data whose values are, elementwise, loaded from memory at location defined by :code:`pointer`.
|
||||
|
||||
@@ -573,7 +573,7 @@ def load(pointer, mask=None, other=None, cache_modifier="", volatile=False, _bui
|
||||
:param cache_modifier: changes cache option in nvidia ptx
|
||||
'type cache_modifier: str, optional
|
||||
"""
|
||||
return frontend.load(pointer, mask, other, cache_modifier, volatile, _builder)
|
||||
return frontend.load(pointer, mask, other, cache_modifier, eviction_policy, volatile, _builder)
|
||||
|
||||
|
||||
@builtin
|
||||
|
Reference in New Issue
Block a user