[CODEGEN] Improvements and bugfixes (#463)
This commit is contained in:
@@ -110,6 +110,15 @@ libLLVMBinaryFormat.a
|
|||||||
libLLVMAMDGPUInfo.a
|
libLLVMAMDGPUInfo.a
|
||||||
libLLVMSupport.a
|
libLLVMSupport.a
|
||||||
libLLVMDemangle.a
|
libLLVMDemangle.a
|
||||||
|
libLLVMPasses.a
|
||||||
|
libLLVMAnalysis.a
|
||||||
|
libLLVMTransformUtils.a
|
||||||
|
libLLVMScalarOpts.a
|
||||||
|
libLLVMTransformUtils.a
|
||||||
|
libLLVMipo.a
|
||||||
|
libLLVMObjCARCOpts.a
|
||||||
|
libLLVMCoroutines.a
|
||||||
|
libLLVMAnalysis.a
|
||||||
)
|
)
|
||||||
endif()
|
endif()
|
||||||
include_directories("${LLVM_INCLUDE_DIRS}")
|
include_directories("${LLVM_INCLUDE_DIRS}")
|
||||||
@@ -148,7 +157,7 @@ target_link_options(triton PRIVATE ${LLVM_LDFLAGS})
|
|||||||
if(WIN32)
|
if(WIN32)
|
||||||
target_link_libraries(triton PRIVATE ${LLVM_LIBRARIES} dl) # dl is from dlfcn-win32
|
target_link_libraries(triton PRIVATE ${LLVM_LIBRARIES} dl) # dl is from dlfcn-win32
|
||||||
else()
|
else()
|
||||||
target_link_libraries(triton ${LLVM_LIBRARIES} z ${TERMINFO_LIBRARY})
|
target_link_libraries(triton ${LLVM_LIBRARIES} z)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
|
||||||
|
@@ -9,8 +9,9 @@ namespace triton{
|
|||||||
namespace driver{
|
namespace driver{
|
||||||
|
|
||||||
void init_llvm();
|
void init_llvm();
|
||||||
|
std::string path_to_ptxas(int& version);
|
||||||
std::string llir_to_ptx(llvm::Module* module, int cc, int version);
|
std::string llir_to_ptx(llvm::Module* module, int cc, int version);
|
||||||
std::string ptx_to_cubin(const std::string& ptx, int cc);
|
std::string ptx_to_cubin(const std::string& ptx, const std::string& ptxas_path, int cc);
|
||||||
CUmodule ptx_to_cumodule(const std::string& ptx, int cc);
|
CUmodule ptx_to_cumodule(const std::string& ptx, int cc);
|
||||||
std::string llir_to_amdgpu(llvm::Module* module, const std::string& proc);
|
std::string llir_to_amdgpu(llvm::Module* module, const std::string& proc);
|
||||||
hipModule_t amdgpu_to_hipmodule(const std::string& path);
|
hipModule_t amdgpu_to_hipmodule(const std::string& path);
|
||||||
|
@@ -136,9 +136,9 @@ public:
|
|||||||
value *create_xor(value *lhs, value *rhs);
|
value *create_xor(value *lhs, value *rhs);
|
||||||
value *create_or(value *lhs, value *rhs);
|
value *create_or(value *lhs, value *rhs);
|
||||||
// Input/Output
|
// Input/Output
|
||||||
value *create_load(value *arg, load_inst::CACHE_MODIFIER cache, bool is_volatile);
|
value *create_load(value *arg, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile);
|
||||||
value *create_store(value *ptr, value *val);
|
value *create_store(value *ptr, value *val);
|
||||||
value *create_masked_load(value *arg, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, bool is_volatile);
|
value *create_masked_load(value *arg, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile);
|
||||||
value *create_masked_store(value *ptr, value *val, value *mask);
|
value *create_masked_store(value *ptr, value *val, value *mask);
|
||||||
// Block instruction
|
// Block instruction
|
||||||
value *create_splat(value *arg, const type::block_shapes_t &shapes);
|
value *create_splat(value *arg, const type::block_shapes_t &shapes);
|
||||||
@@ -163,7 +163,7 @@ public:
|
|||||||
// These have no place in the IR, and hopefully they can be removed at some point
|
// These have no place in the IR, and hopefully they can be removed at some point
|
||||||
value *create_umulhi(value* lhs, value* rhs);
|
value *create_umulhi(value* lhs, value* rhs);
|
||||||
value *create_copy_to_shared(value *arg);
|
value *create_copy_to_shared(value *arg);
|
||||||
value *create_masked_load_async(value *arg, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache);
|
value *create_masked_load_async(value *arg, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY);
|
||||||
value *create_copy_from_shared(value *arg);
|
value *create_copy_from_shared(value *arg);
|
||||||
value *create_barrier(const std::string &name = "");
|
value *create_barrier(const std::string &name = "");
|
||||||
value *create_async_wait(int N);
|
value *create_async_wait(int N);
|
||||||
|
@@ -69,7 +69,8 @@ struct dispatch{
|
|||||||
static ir::value *cast(ir::value *input, ir::type *type, ir::builder *builder);
|
static ir::value *cast(ir::value *input, ir::type *type, ir::builder *builder);
|
||||||
|
|
||||||
// memory operators
|
// memory operators
|
||||||
static ir::value *load(ir::value* ptr, ir::value* mask, ir::value* other, const std::string &cache, int is_volatile, ir::builder *builder);
|
static ir::value *load(ir::value* ptr, ir::value* mask, ir::value* other, const std::string &cache,
|
||||||
|
const std::string& eviction_policy, int is_volatile, ir::builder *builder);
|
||||||
static ir::value *store(ir::value* ptr, ir::value *value, ir::value *mask, ir::builder *builder);
|
static ir::value *store(ir::value* ptr, ir::value *value, ir::value *mask, ir::builder *builder);
|
||||||
static ir::value *atomic_cas(ir::value* ptr, ir::value *cmp, ir::value *val, ir::builder *builder);
|
static ir::value *atomic_cas(ir::value* ptr, ir::value *cmp, ir::value *val, ir::builder *builder);
|
||||||
static ir::value *atomic_add(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);
|
static ir::value *atomic_add(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);
|
||||||
|
@@ -408,11 +408,18 @@ public:
|
|||||||
CG,
|
CG,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
enum EVICTION_POLICY : uint32_t {
|
||||||
|
NORMAL=0,
|
||||||
|
EVICT_FIRST,
|
||||||
|
EVICT_LAST,
|
||||||
|
};
|
||||||
|
|
||||||
CACHE_MODIFIER get_cache_modifier() const { return cache_; }
|
CACHE_MODIFIER get_cache_modifier() const { return cache_; }
|
||||||
|
EVICTION_POLICY get_eviction_policy() const { return eviction_; }
|
||||||
bool get_is_volatile() const { return is_volatile_; }
|
bool get_is_volatile() const { return is_volatile_; }
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
load_inst(value *ptr, value_id_t id, unsigned num_ops, CACHE_MODIFIER cache,
|
load_inst(value *ptr, value_id_t id, unsigned num_ops, CACHE_MODIFIER cache, EVICTION_POLICY eviction,
|
||||||
bool is_volatile,
|
bool is_volatile,
|
||||||
const std::string &name = "", instruction *next = nullptr);
|
const std::string &name = "", instruction *next = nullptr);
|
||||||
std::string get_cache_modifier_repr() const {
|
std::string get_cache_modifier_repr() const {
|
||||||
@@ -420,6 +427,11 @@ protected:
|
|||||||
if (cache_ == CG) return ".cg";
|
if (cache_ == CG) return ".cg";
|
||||||
return "";
|
return "";
|
||||||
}
|
}
|
||||||
|
std::string get_eviction_policy_repr() const {
|
||||||
|
if (eviction_ == EVICT_FIRST) return ".L1::evict_first";
|
||||||
|
if (eviction_ == EVICT_LAST) return ".L2::evict_last";
|
||||||
|
}
|
||||||
|
EVICTION_POLICY eviction_;
|
||||||
CACHE_MODIFIER cache_;
|
CACHE_MODIFIER cache_;
|
||||||
|
|
||||||
std::string get_volatile_repr() {
|
std::string get_volatile_repr() {
|
||||||
@@ -435,11 +447,12 @@ private:
|
|||||||
class unmasked_load_inst: public load_inst {
|
class unmasked_load_inst: public load_inst {
|
||||||
private:
|
private:
|
||||||
std::string repr_impl() const { return "unmasked_load" + get_cache_modifier_repr(); }
|
std::string repr_impl() const { return "unmasked_load" + get_cache_modifier_repr(); }
|
||||||
unmasked_load_inst(value *ptr, load_inst::CACHE_MODIFIER cache, bool is_volatile, const std::string &name, instruction *next);
|
unmasked_load_inst(value *ptr, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile, const std::string &name, instruction *next);
|
||||||
|
|
||||||
public:
|
public:
|
||||||
static unmasked_load_inst* create(value *ptr,
|
static unmasked_load_inst* create(value *ptr,
|
||||||
CACHE_MODIFIER cache, bool is_volatile,
|
CACHE_MODIFIER cache, EVICTION_POLICY eviction,
|
||||||
|
bool is_volatile,
|
||||||
const std::string &name = "",
|
const std::string &name = "",
|
||||||
instruction *next = nullptr);
|
instruction *next = nullptr);
|
||||||
_TRITON_DEFINE_CLONE(unmasked_load_inst)
|
_TRITON_DEFINE_CLONE(unmasked_load_inst)
|
||||||
@@ -450,7 +463,7 @@ public:
|
|||||||
class masked_load_inst: public load_inst {
|
class masked_load_inst: public load_inst {
|
||||||
private:
|
private:
|
||||||
std::string repr_impl() const { return "masked_load" + get_cache_modifier_repr(); }
|
std::string repr_impl() const { return "masked_load" + get_cache_modifier_repr(); }
|
||||||
masked_load_inst(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, bool is_volatile,
|
masked_load_inst(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile,
|
||||||
const std::string &name, instruction *next);
|
const std::string &name, instruction *next);
|
||||||
|
|
||||||
public:
|
public:
|
||||||
@@ -459,7 +472,8 @@ public:
|
|||||||
value *get_false_value_operand() { return get_operand(2); }
|
value *get_false_value_operand() { return get_operand(2); }
|
||||||
// factory method
|
// factory method
|
||||||
static masked_load_inst* create(value *ptr, value *mask, value *false_value,
|
static masked_load_inst* create(value *ptr, value *mask, value *false_value,
|
||||||
CACHE_MODIFIER cache, bool is_volatile,
|
CACHE_MODIFIER cache, EVICTION_POLICY eviction,
|
||||||
|
bool is_volatile,
|
||||||
const std::string &name = "",
|
const std::string &name = "",
|
||||||
instruction *next = nullptr);
|
instruction *next = nullptr);
|
||||||
_TRITON_DEFINE_CLONE(masked_load_inst)
|
_TRITON_DEFINE_CLONE(masked_load_inst)
|
||||||
@@ -470,8 +484,9 @@ public:
|
|||||||
class masked_load_async_inst: public load_inst {
|
class masked_load_async_inst: public load_inst {
|
||||||
private:
|
private:
|
||||||
std::string repr_impl() const { return "masked_load_async" + get_cache_modifier_repr(); }
|
std::string repr_impl() const { return "masked_load_async" + get_cache_modifier_repr(); }
|
||||||
masked_load_async_inst(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache,
|
masked_load_async_inst(value *ptr, value *mask, value *false_value,
|
||||||
const std::string &name, instruction *next);
|
CACHE_MODIFIER cache, EVICTION_POLICY eviction,
|
||||||
|
const std::string &name, instruction *next);
|
||||||
|
|
||||||
public:
|
public:
|
||||||
// accessors
|
// accessors
|
||||||
@@ -480,6 +495,7 @@ public:
|
|||||||
// factory method
|
// factory method
|
||||||
static masked_load_async_inst* create(value *ptr, value *mask, value *false_value,
|
static masked_load_async_inst* create(value *ptr, value *mask, value *false_value,
|
||||||
load_inst::CACHE_MODIFIER cache,
|
load_inst::CACHE_MODIFIER cache,
|
||||||
|
EVICTION_POLICY eviction,
|
||||||
const std::string &name = "",
|
const std::string &name = "",
|
||||||
instruction *next = nullptr);
|
instruction *next = nullptr);
|
||||||
_TRITON_DEFINE_CLONE(masked_load_async_inst)
|
_TRITON_DEFINE_CLONE(masked_load_async_inst)
|
||||||
|
@@ -119,7 +119,7 @@ Value* geper::operator()(Value *ptr, Value* off, const std::string& name){
|
|||||||
#define icmp_ult(...) builder_->CreateICmpULT(__VA_ARGS__)
|
#define icmp_ult(...) builder_->CreateICmpULT(__VA_ARGS__)
|
||||||
#define insert_elt(...) builder_->CreateInsertElement(__VA_ARGS__)
|
#define insert_elt(...) builder_->CreateInsertElement(__VA_ARGS__)
|
||||||
#define intrinsic(...) builder_->CreateIntrinsic(__VA_ARGS__)
|
#define intrinsic(...) builder_->CreateIntrinsic(__VA_ARGS__)
|
||||||
#define load(...) builder_->CreateLoad(__VA_ARGS__)
|
#define load(ptr) builder_->CreateLoad(ptr->getType()->getPointerElementType(), ptr)
|
||||||
#define lshr(...) builder_->CreateLShr(__VA_ARGS__)
|
#define lshr(...) builder_->CreateLShr(__VA_ARGS__)
|
||||||
#define max_num(...) builder_->CreateMaxNum(__VA_ARGS__)
|
#define max_num(...) builder_->CreateMaxNum(__VA_ARGS__)
|
||||||
#define min_num(...) builder_->CreateMinNum(__VA_ARGS__)
|
#define min_num(...) builder_->CreateMinNum(__VA_ARGS__)
|
||||||
@@ -576,18 +576,19 @@ void generator::visit_cast_inst(ir::cast_inst* x) {
|
|||||||
// <> BF16
|
// <> BF16
|
||||||
if(ret_sca_ty->is_bf16_ty() || op_sca_ty->is_bf16_ty()){
|
if(ret_sca_ty->is_bf16_ty() || op_sca_ty->is_bf16_ty()){
|
||||||
// FP32 -> BF16
|
// FP32 -> BF16
|
||||||
if(op_sca_ty->is_fp32_ty())
|
if(op_sca_ty->is_fp32_ty()){
|
||||||
// for(size_t i = 0; i < x_idxs.size(); i++)
|
|
||||||
// vals_[x][x_idxs[i + 0]] = fp32_to_bf16(vals_[op][op_idxs[i + 0]]);
|
|
||||||
for (indices_t idx: idxs_.at(x)) {
|
for (indices_t idx: idxs_.at(x)) {
|
||||||
Value *arg = vals_[x->get_operand(0)][idx];
|
Value *arg = vals_[x->get_operand(0)][idx];
|
||||||
vals_[x][idx] = fp32_to_bf16(arg); // cast(cvt(x->get_op()), arg, ty);
|
vals_[x][idx] = fp32_to_bf16(arg); // cast(cvt(x->get_op()), arg, ty);
|
||||||
}
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
// BF16 -> FP32
|
// BF16 -> FP32
|
||||||
if(ret_sca_ty->is_fp32_ty())
|
if(ret_sca_ty->is_fp32_ty()){
|
||||||
for(size_t i = 0; i < x_idxs.size(); i++)
|
for(size_t i = 0; i < x_idxs.size(); i++)
|
||||||
vals_[x][x_idxs[i + 0]] = bf16_to_fp32(vals_[op][op_idxs[i + 0]]);
|
vals_[x][x_idxs[i + 0]] = bf16_to_fp32(vals_[op][op_idxs[i + 0]]);
|
||||||
return;
|
return;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -697,12 +698,13 @@ void generator::visit_load_inst(ir::load_inst* x){
|
|||||||
std::ostringstream asm_oss;
|
std::ostringstream asm_oss;
|
||||||
asm_oss << "@$" << n_words; // predicate
|
asm_oss << "@$" << n_words; // predicate
|
||||||
asm_oss << " ld";
|
asm_oss << " ld";
|
||||||
// std::cout << x->get_is_volatile() << std::endl;
|
|
||||||
if(x->get_is_volatile())
|
if(x->get_is_volatile())
|
||||||
asm_oss << ".volatile";
|
asm_oss << ".volatile";
|
||||||
asm_oss << ".global";
|
asm_oss << ".global";
|
||||||
if (x->get_cache_modifier() == ir::load_inst::CA) asm_oss << ".ca";
|
if (x->get_cache_modifier() == ir::load_inst::CA) asm_oss << ".ca";
|
||||||
if (x->get_cache_modifier() == ir::load_inst::CG) asm_oss << ".cg";
|
if (x->get_cache_modifier() == ir::load_inst::CG) asm_oss << ".cg";
|
||||||
|
if (x->get_eviction_policy() == ir::load_inst::EVICT_LAST) asm_oss << ".L1::evict_last";
|
||||||
|
if (x->get_eviction_policy() == ir::load_inst::EVICT_FIRST) asm_oss << ".L1::evict_first";
|
||||||
if(n_words > 1)
|
if(n_words > 1)
|
||||||
asm_oss << ".v" << n_words; // vector width
|
asm_oss << ".v" << n_words; // vector width
|
||||||
asm_oss << ".b" << width; // word size
|
asm_oss << ".b" << width; // word size
|
||||||
|
@@ -123,7 +123,7 @@ bool peephole::rewrite_load_to_shared(ir::instruction *value, ir::builder& build
|
|||||||
int nts = layout->nts(layout->get_order()[0]);
|
int nts = layout->nts(layout->get_order()[0]);
|
||||||
int dtsize = value->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
|
int dtsize = value->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
|
||||||
if(nts*dtsize >= 4){
|
if(nts*dtsize >= 4){
|
||||||
ir::value* new_load = builder.create_masked_load_async(ptr, msk, val, ld->get_cache_modifier());
|
ir::value* new_load = builder.create_masked_load_async(ptr, msk, val, ld->get_cache_modifier(), ld->get_eviction_policy());
|
||||||
copy_to_shared->replace_all_uses_with(new_load);
|
copy_to_shared->replace_all_uses_with(new_load);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
@@ -215,6 +215,7 @@ bool peephole::rewrite_select_masked_load(ir::instruction *value, ir::builder& b
|
|||||||
if_value->get_mask_operand(),
|
if_value->get_mask_operand(),
|
||||||
select->get_else_value_op(),
|
select->get_else_value_op(),
|
||||||
if_value->get_cache_modifier(),
|
if_value->get_cache_modifier(),
|
||||||
|
if_value->get_eviction_policy(),
|
||||||
if_value->get_is_volatile());
|
if_value->get_is_volatile());
|
||||||
select->replace_all_uses_with(new_load);
|
select->replace_all_uses_with(new_load);
|
||||||
return true;
|
return true;
|
||||||
|
@@ -178,7 +178,7 @@ void pipeline::run(ir::module &mod) {
|
|||||||
false_value = remat_false_value;
|
false_value = remat_false_value;
|
||||||
} else
|
} else
|
||||||
false_value = builder.create_splat(ir::undef_value::get(ty->get_scalar_ty()), ty->get_block_shapes());
|
false_value = builder.create_splat(ir::undef_value::get(ty->get_scalar_ty()), ty->get_block_shapes());
|
||||||
first_loads[0] = builder.create_masked_load(first_ptrs[0], first_masks[0], false_value, load->get_cache_modifier(), load->get_is_volatile());
|
first_loads[0] = builder.create_masked_load(first_ptrs[0], first_masks[0], false_value, load->get_cache_modifier(), load->get_eviction_policy(), load->get_is_volatile());
|
||||||
|
|
||||||
for (int stage = 1; stage < num_stages-1; ++stage) {
|
for (int stage = 1; stage < num_stages-1; ++stage) {
|
||||||
// mask is the loop condition of the previous iteration
|
// mask is the loop condition of the previous iteration
|
||||||
@@ -193,7 +193,7 @@ void pipeline::run(ir::module &mod) {
|
|||||||
first_masks[stage] = builder.create_and(first_masks[stage], remat_mask);
|
first_masks[stage] = builder.create_and(first_masks[stage], remat_mask);
|
||||||
false_value = remat_false_value;
|
false_value = remat_false_value;
|
||||||
}
|
}
|
||||||
first_loads[stage] = builder.create_masked_load(first_ptrs[stage], first_masks[stage], false_value, load->get_cache_modifier(), load->get_is_volatile());
|
first_loads[stage] = builder.create_masked_load(first_ptrs[stage], first_masks[stage], false_value, load->get_cache_modifier(), load->get_eviction_policy(), load->get_is_volatile());
|
||||||
}
|
}
|
||||||
|
|
||||||
// create new phis for induction variables
|
// create new phis for induction variables
|
||||||
@@ -222,7 +222,7 @@ void pipeline::run(ir::module &mod) {
|
|||||||
next_mask = builder.create_and(next_mask, remat_mask);
|
next_mask = builder.create_and(next_mask, remat_mask);
|
||||||
false_value = remat_false_value;
|
false_value = remat_false_value;
|
||||||
}
|
}
|
||||||
ir::value* next_load = builder.create_masked_load(next_ptr, next_mask, false_value, load->get_cache_modifier(), load->get_is_volatile());
|
ir::value* next_load = builder.create_masked_load(next_ptr, next_mask, false_value, load->get_cache_modifier(), load->get_eviction_policy(), load->get_is_volatile());
|
||||||
|
|
||||||
|
|
||||||
// phi node
|
// phi node
|
||||||
@@ -257,7 +257,7 @@ void pipeline::run(ir::module &mod) {
|
|||||||
}
|
}
|
||||||
else
|
else
|
||||||
false_value = builder.create_splat(ir::undef_value::get(ty->get_scalar_ty()), ty->get_block_shapes());
|
false_value = builder.create_splat(ir::undef_value::get(ty->get_scalar_ty()), ty->get_block_shapes());
|
||||||
ir::value* first_load = builder.create_masked_load(first_ptr, first_mask, false_value, load->get_cache_modifier(), load->get_is_volatile());
|
ir::value* first_load = builder.create_masked_load(first_ptr, first_mask, false_value, load->get_cache_modifier(), load->get_eviction_policy(), load->get_is_volatile());
|
||||||
// pre-fetch next iteration
|
// pre-fetch next iteration
|
||||||
builder.set_insert_point(block->get_inst_list().back());
|
builder.set_insert_point(block->get_inst_list().back());
|
||||||
ir::value* next_ptr = ptr->get_value_for_block(block);
|
ir::value* next_ptr = ptr->get_value_for_block(block);
|
||||||
@@ -268,7 +268,7 @@ void pipeline::run(ir::module &mod) {
|
|||||||
next_mask = builder.create_and(next_mask, remat_mask);
|
next_mask = builder.create_and(next_mask, remat_mask);
|
||||||
false_value = remat_false_value;
|
false_value = remat_false_value;
|
||||||
}
|
}
|
||||||
ir::value* next_load = builder.create_masked_load(next_ptr, next_mask, false_value, load->get_cache_modifier(), load->get_is_volatile());
|
ir::value* next_load = builder.create_masked_load(next_ptr, next_mask, false_value, load->get_cache_modifier(), load->get_eviction_policy(), load->get_is_volatile());
|
||||||
// phi node
|
// phi node
|
||||||
builder.set_insert_point(block->get_first_non_phi());
|
builder.set_insert_point(block->get_first_non_phi());
|
||||||
ir::phi_node* new_load = builder.create_phi(ty, 2);
|
ir::phi_node* new_load = builder.create_phi(ty, 2);
|
||||||
|
@@ -59,6 +59,13 @@
|
|||||||
#include "llvm/Analysis/TargetLibraryInfo.h"
|
#include "llvm/Analysis/TargetLibraryInfo.h"
|
||||||
// end AMD stuff
|
// end AMD stuff
|
||||||
|
|
||||||
|
extern "C"{
|
||||||
|
int set_curterm(char* nterm){ return 0; }
|
||||||
|
int del_curterm(char* nterm){ return 0; }
|
||||||
|
int tigetnum(char *capname) { return 0; }
|
||||||
|
int setupterm(char *term, int fildes, int *errret) { return 0; }
|
||||||
|
}
|
||||||
|
|
||||||
namespace triton{
|
namespace triton{
|
||||||
namespace driver{
|
namespace driver{
|
||||||
|
|
||||||
@@ -77,6 +84,7 @@ void init_llvm() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/* ------------------------ */
|
/* ------------------------ */
|
||||||
// CUDA //
|
// CUDA //
|
||||||
/* ------------------------ */
|
/* ------------------------ */
|
||||||
@@ -89,7 +97,42 @@ static bool find_and_replace(std::string& str, const std::string& begin, const s
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::string path_to_ptxas(int& version) {
|
||||||
|
std::string ret;
|
||||||
|
// search pathes for ptxas
|
||||||
|
std::vector<std::string> ptxas_prefixes = {"", "/usr/local/cuda/bin/"};
|
||||||
|
std::string triton_ptxas = tools::getenv("TRITON_PTXAS_PATH");
|
||||||
|
if(!triton_ptxas.empty())
|
||||||
|
ptxas_prefixes.insert(ptxas_prefixes.begin(), triton_ptxas);
|
||||||
|
// see what path for ptxas are valid
|
||||||
|
std::vector<std::string> working_ptxas;
|
||||||
|
for(std::string prefix: ptxas_prefixes){
|
||||||
|
std::string ptxas = prefix + "ptxas";
|
||||||
|
bool works = tools::exec(ptxas + " --version 2>&1", ret) == 0;
|
||||||
|
if(works)
|
||||||
|
working_ptxas.push_back(ptxas);
|
||||||
|
}
|
||||||
|
// error if no working ptxas was found
|
||||||
|
if(working_ptxas.empty())
|
||||||
|
throw std::runtime_error("`ptxas` was searched in TRITON_PTXAS_PATH, /usr/local/cuda/bin/ or PATH"
|
||||||
|
" but a working version could not be found.");
|
||||||
|
std::string ptxas = working_ptxas.front();
|
||||||
|
// parse version
|
||||||
|
std::regex version_regex("release (\\d+)\\.(\\d+)");
|
||||||
|
std::smatch match;
|
||||||
|
if(std::regex_search(ret, match, version_regex)){
|
||||||
|
int major = std::stoi(match[1]);
|
||||||
|
int minor = std::stoi(match[2]);
|
||||||
|
version = major*1000 + minor*10;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
throw std::runtime_error("couldn't parse ptxas version: " + ret);
|
||||||
|
return ptxas;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
int vptx(int version){
|
int vptx(int version){
|
||||||
|
if(version >= 11040) return 74;
|
||||||
if(version >= 11030) return 73;
|
if(version >= 11030) return 73;
|
||||||
if(version >= 11020) return 72;
|
if(version >= 11020) return 72;
|
||||||
if(version >= 11010) return 71;
|
if(version >= 11010) return 71;
|
||||||
@@ -103,7 +146,7 @@ int vptx(int version){
|
|||||||
std::string llir_to_ptx(llvm::Module* module, int cc, int version){
|
std::string llir_to_ptx(llvm::Module* module, int cc, int version){
|
||||||
// LLVM version in use may not officially support target hardware
|
// LLVM version in use may not officially support target hardware
|
||||||
int max_nvvm_cc = 75;
|
int max_nvvm_cc = 75;
|
||||||
int max_nvvm_ptx = 64;
|
int max_nvvm_ptx = 74;
|
||||||
// options
|
// options
|
||||||
auto options = llvm::cl::getRegisteredOptions();
|
auto options = llvm::cl::getRegisteredOptions();
|
||||||
auto* short_ptr = static_cast<llvm::cl::opt<bool>*>(options["nvptx-short-ptr"]);
|
auto* short_ptr = static_cast<llvm::cl::opt<bool>*>(options["nvptx-short-ptr"]);
|
||||||
@@ -120,7 +163,8 @@ std::string llir_to_ptx(llvm::Module* module, int cc, int version){
|
|||||||
std::string triple = "nvptx64-nvidia-cuda";
|
std::string triple = "nvptx64-nvidia-cuda";
|
||||||
std::string proc = "sm_" + std::to_string(std::min(cc, max_nvvm_cc));
|
std::string proc = "sm_" + std::to_string(std::min(cc, max_nvvm_cc));
|
||||||
std::string layout = "";
|
std::string layout = "";
|
||||||
std::string features = "+ptx" + std::to_string(std::min(ptx, max_nvvm_ptx));
|
std::string features = "";
|
||||||
|
// std::string features = "+ptx" + std::to_string(std::min(ptx, max_nvvm_ptx));
|
||||||
init_llvm();
|
init_llvm();
|
||||||
// verify and store llvm
|
// verify and store llvm
|
||||||
llvm::legacy::PassManager pm;
|
llvm::legacy::PassManager pm;
|
||||||
@@ -164,26 +208,7 @@ std::string llir_to_ptx(llvm::Module* module, int cc, int version){
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string ptx_to_cubin(const std::string& ptx, int cc) {
|
std::string ptx_to_cubin(const std::string& ptx, const std::string& ptxas, int cc) {
|
||||||
std::string version;
|
|
||||||
// search pathes for ptxas
|
|
||||||
std::vector<std::string> ptxas_prefixes = {"", "/usr/local/cuda/bin/"};
|
|
||||||
std::string triton_ptxas = tools::getenv("TRITON_PTXAS_PATH");
|
|
||||||
if(!triton_ptxas.empty())
|
|
||||||
ptxas_prefixes.insert(ptxas_prefixes.begin(), triton_ptxas);
|
|
||||||
// see what path for ptxas are valid
|
|
||||||
std::vector<std::string> working_ptxas;
|
|
||||||
for(std::string prefix: ptxas_prefixes){
|
|
||||||
std::string ptxas = prefix + "ptxas";
|
|
||||||
bool works = tools::exec(ptxas + " --version 2>&1", version) == 0;
|
|
||||||
if(works)
|
|
||||||
working_ptxas.push_back(ptxas);
|
|
||||||
}
|
|
||||||
// error if no working ptxas was found
|
|
||||||
if(working_ptxas.empty())
|
|
||||||
throw std::runtime_error("`ptxas` was searched in TRITON_PTXAS_PATH, /usr/local/cuda/bin/ or PATH"
|
|
||||||
" but a working version could not be found.");
|
|
||||||
std::string ptxas = working_ptxas.front();
|
|
||||||
// compile ptx with ptxas
|
// compile ptx with ptxas
|
||||||
char _fsrc[L_tmpnam];
|
char _fsrc[L_tmpnam];
|
||||||
char _flog[L_tmpnam];
|
char _flog[L_tmpnam];
|
||||||
|
@@ -291,16 +291,16 @@ DEFINE_FCMP_INSTR(UNE, cmp_pred_t::FCMP_UNE)
|
|||||||
// load/store instructions
|
// load/store instructions
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
value *builder::create_load(value *ptr, load_inst::CACHE_MODIFIER cache, bool is_volatile){
|
value *builder::create_load(value *ptr, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile){
|
||||||
return insert(unmasked_load_inst::create(ptr, cache, is_volatile));
|
return insert(unmasked_load_inst::create(ptr, cache, eviction, is_volatile));
|
||||||
}
|
}
|
||||||
|
|
||||||
value *builder::create_store(value *ptr, value *val){
|
value *builder::create_store(value *ptr, value *val){
|
||||||
return insert(unmasked_store_inst::create(ptr, val));
|
return insert(unmasked_store_inst::create(ptr, val));
|
||||||
}
|
}
|
||||||
|
|
||||||
value *builder::create_masked_load(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, bool is_volatile){
|
value *builder::create_masked_load(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile){
|
||||||
return insert(masked_load_inst::create(ptr, mask, false_value, cache, is_volatile));
|
return insert(masked_load_inst::create(ptr, mask, false_value, cache, eviction, is_volatile));
|
||||||
}
|
}
|
||||||
|
|
||||||
value *builder::create_masked_store(value *ptr, value *val, value *mask){
|
value *builder::create_masked_store(value *ptr, value *val, value *mask){
|
||||||
@@ -404,8 +404,8 @@ value *builder::create_copy_from_shared(value *arg) {
|
|||||||
return insert(copy_from_shared_inst::create(arg));
|
return insert(copy_from_shared_inst::create(arg));
|
||||||
}
|
}
|
||||||
|
|
||||||
value *builder::create_masked_load_async(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache) {
|
value *builder::create_masked_load_async(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction) {
|
||||||
return insert(masked_load_async_inst::create(ptr, mask, false_value, cache));
|
return insert(masked_load_async_inst::create(ptr, mask, false_value, cache, eviction));
|
||||||
}
|
}
|
||||||
|
|
||||||
value *builder::create_barrier(const std::string &name) {
|
value *builder::create_barrier(const std::string &name) {
|
||||||
|
@@ -506,6 +506,11 @@ ir::value *dispatch::cast(ir::value *input, ir::type *dst_ty, ir::builder *build
|
|||||||
return input;
|
return input;
|
||||||
ir::type *src_sca_ty = src_ty->get_scalar_ty();
|
ir::type *src_sca_ty = src_ty->get_scalar_ty();
|
||||||
ir::type *dst_sca_ty = dst_ty->get_scalar_ty();
|
ir::type *dst_sca_ty = dst_ty->get_scalar_ty();
|
||||||
|
//
|
||||||
|
if((src_sca_ty->is_bf16_ty() && !dst_sca_ty->is_fp32_ty()) ||
|
||||||
|
(dst_sca_ty->is_bf16_ty() && !src_sca_ty->is_fp32_ty())){
|
||||||
|
return cast(cast(input, builder->get_float_ty(), builder), dst_sca_ty, builder);
|
||||||
|
}
|
||||||
// FP Truncation
|
// FP Truncation
|
||||||
bool truncate_fp = src_sca_ty->is_floating_point_ty() &&
|
bool truncate_fp = src_sca_ty->is_floating_point_ty() &&
|
||||||
dst_sca_ty->is_floating_point_ty() &&
|
dst_sca_ty->is_floating_point_ty() &&
|
||||||
@@ -569,18 +574,17 @@ ir::value *dispatch::cast(ir::value *input, ir::type *dst_ty, ir::builder *build
|
|||||||
// Memory Operators
|
// Memory Operators
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
ir::value *dispatch::load(ir::value* ptr, ir::value* mask, ir::value* other, const std::string &cache_modifier, int is_volatile, ir::builder* builder) {
|
ir::value *dispatch::load(ir::value* ptr, ir::value* mask, ir::value* other, const std::string &cache_modifier, const std::string& eviction_policy, int is_volatile, ir::builder* builder) {
|
||||||
if(!ptr->get_type()->get_scalar_ty()->is_pointer_ty())
|
if(!ptr->get_type()->get_scalar_ty()->is_pointer_ty())
|
||||||
throw semantic_error("Pointer argument of load instruction is " + ptr->get_type()->repr());
|
throw semantic_error("Pointer argument of load instruction is " + ptr->get_type()->repr());
|
||||||
if(ptr->get_type()->is_block_ty()){
|
if(ptr->get_type()->is_block_ty()){
|
||||||
if(mask){
|
if(mask)
|
||||||
mask = dispatch::broadcast(mask, ptr->get_type()->get_block_shapes(), builder);
|
mask = dispatch::broadcast(mask, ptr->get_type()->get_block_shapes(), builder);
|
||||||
}
|
if(other)
|
||||||
if(other){
|
|
||||||
other = dispatch::broadcast(other, ptr->get_type()->get_block_shapes(), builder);
|
other = dispatch::broadcast(other, ptr->get_type()->get_block_shapes(), builder);
|
||||||
other = dispatch::cast(other, ptr->get_type()->get_scalar_ty()->get_pointer_element_ty(), builder);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
if(other)
|
||||||
|
other = dispatch::cast(other, ptr->get_type()->get_scalar_ty()->get_pointer_element_ty(), builder);
|
||||||
ir::type *ptr_ty = ptr->get_type()->get_scalar_ty();
|
ir::type *ptr_ty = ptr->get_type()->get_scalar_ty();
|
||||||
ir::type *elt_ty = ptr_ty->get_pointer_element_ty();
|
ir::type *elt_ty = ptr_ty->get_pointer_element_ty();
|
||||||
// treat bool* as int8*
|
// treat bool* as int8*
|
||||||
@@ -599,8 +603,20 @@ ir::value *dispatch::load(ir::value* ptr, ir::value* mask, ir::value* other, con
|
|||||||
else
|
else
|
||||||
throw std::runtime_error(std::string("Cache modifier ") + cache_modifier + " not supported");
|
throw std::runtime_error(std::string("Cache modifier ") + cache_modifier + " not supported");
|
||||||
}
|
}
|
||||||
|
// eviction policy
|
||||||
|
load_inst::EVICTION_POLICY eviction = load_inst::NORMAL; //default
|
||||||
|
if(!eviction_policy.empty()){
|
||||||
|
if (eviction_policy == "evict_last")
|
||||||
|
eviction = load_inst::EVICT_LAST;
|
||||||
|
else if(eviction_policy == "evict_first")
|
||||||
|
eviction = load_inst::EVICT_FIRST;
|
||||||
|
else
|
||||||
|
throw std::runtime_error(std::string("Eviction policy") + eviction_policy + " not supported");
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
if (!mask && !other)
|
if (!mask && !other)
|
||||||
return builder->create_load(ptr, cache, is_volatile);
|
return builder->create_load(ptr, cache, eviction, is_volatile);
|
||||||
if (!mask)
|
if (!mask)
|
||||||
throw std::runtime_error("`other` cannot be provided without `mask`");
|
throw std::runtime_error("`other` cannot be provided without `mask`");
|
||||||
auto shape = ptr->get_type()->get_block_shapes();
|
auto shape = ptr->get_type()->get_block_shapes();
|
||||||
@@ -609,7 +625,7 @@ ir::value *dispatch::load(ir::value* ptr, ir::value* mask, ir::value* other, con
|
|||||||
if(ptr->get_type()->is_block_ty())
|
if(ptr->get_type()->is_block_ty())
|
||||||
other = builder->create_splat(other, ptr->get_type()->get_block_shapes());
|
other = builder->create_splat(other, ptr->get_type()->get_block_shapes());
|
||||||
}
|
}
|
||||||
return builder->create_masked_load(ptr, mask, other, cache, is_volatile);
|
return builder->create_masked_load(ptr, mask, other, cache, eviction, is_volatile);
|
||||||
}
|
}
|
||||||
|
|
||||||
ir::value *dispatch::store(ir::value* ptr, ir::value *val, ir::value* mask, ir::builder *builder) {
|
ir::value *dispatch::store(ir::value* ptr, ir::value *val, ir::value* mask, ir::builder *builder) {
|
||||||
|
@@ -434,8 +434,8 @@ io_inst::io_inst(type *ty, value_id_t id, unsigned num_ops, const std::string &n
|
|||||||
{ }
|
{ }
|
||||||
|
|
||||||
// load_inst
|
// load_inst
|
||||||
load_inst::load_inst(value *ptr, value_id_t id, unsigned num_ops, load_inst::CACHE_MODIFIER cache, bool is_volatile, const std::string &name, instruction *next)
|
load_inst::load_inst(value *ptr, value_id_t id, unsigned num_ops, load_inst::CACHE_MODIFIER cache, EVICTION_POLICY eviction, bool is_volatile, const std::string &name, instruction *next)
|
||||||
: io_inst(get_pointee_type(ptr->get_type()), id, num_ops, name, next), cache_(cache), is_volatile_(is_volatile)
|
: io_inst(get_pointee_type(ptr->get_type()), id, num_ops, name, next), cache_(cache), eviction_(eviction), is_volatile_(is_volatile)
|
||||||
{ }
|
{ }
|
||||||
|
|
||||||
// load
|
// load
|
||||||
@@ -448,44 +448,46 @@ type *load_inst::get_pointee_type(type *ty) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// unmasked_load
|
// unmasked_load
|
||||||
unmasked_load_inst::unmasked_load_inst(value *ptr, load_inst::CACHE_MODIFIER cache, bool is_volatile, const std::string &name, instruction *next)
|
unmasked_load_inst::unmasked_load_inst(value *ptr, load_inst::CACHE_MODIFIER cache,load_inst::EVICTION_POLICY eviction, bool is_volatile, const std::string &name, instruction *next)
|
||||||
: load_inst(ptr, INST_UNMASKED_LOAD, 1, cache, is_volatile, name, next) {
|
: load_inst(ptr, INST_UNMASKED_LOAD, 1, cache, eviction, is_volatile, name, next) {
|
||||||
set_operand(0, ptr);
|
set_operand(0, ptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
unmasked_load_inst* unmasked_load_inst::create(value *ptr, load_inst::CACHE_MODIFIER cache, bool is_volatile, const std::string &name, instruction *next) {
|
unmasked_load_inst* unmasked_load_inst::create(value *ptr, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile, const std::string &name, instruction *next) {
|
||||||
return new unmasked_load_inst(ptr, cache, is_volatile, name, next);
|
return new unmasked_load_inst(ptr, cache, eviction, is_volatile, name, next);
|
||||||
}
|
}
|
||||||
|
|
||||||
// masked load
|
// masked load
|
||||||
masked_load_inst::masked_load_inst(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, bool is_volatile,
|
masked_load_inst::masked_load_inst(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction,
|
||||||
|
bool is_volatile,
|
||||||
const std::string &name, instruction *next)
|
const std::string &name, instruction *next)
|
||||||
: load_inst(ptr, INST_MASKED_LOAD, 3, cache, is_volatile, name, next) {
|
: load_inst(ptr, INST_MASKED_LOAD, 3, cache, eviction, is_volatile, name, next) {
|
||||||
set_operand(0, ptr);
|
set_operand(0, ptr);
|
||||||
set_operand(1, mask);
|
set_operand(1, mask);
|
||||||
set_operand(2, false_value);
|
set_operand(2, false_value);
|
||||||
}
|
}
|
||||||
|
|
||||||
masked_load_inst* masked_load_inst::create(value *ptr, value *mask, value *false_value,
|
masked_load_inst* masked_load_inst::create(value *ptr, value *mask, value *false_value,
|
||||||
load_inst::CACHE_MODIFIER cache, bool is_volatile,
|
load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction,
|
||||||
|
bool is_volatile,
|
||||||
const std::string &name, instruction *next) {
|
const std::string &name, instruction *next) {
|
||||||
return new masked_load_inst(ptr, mask, false_value, cache, is_volatile, name, next);
|
return new masked_load_inst(ptr, mask, false_value, cache, eviction, is_volatile, name, next);
|
||||||
}
|
}
|
||||||
|
|
||||||
// masked load async
|
// masked load async
|
||||||
masked_load_async_inst::masked_load_async_inst(value *ptr, value *mask, value *false_value,
|
masked_load_async_inst::masked_load_async_inst(value *ptr, value *mask, value *false_value,
|
||||||
load_inst::CACHE_MODIFIER cache,
|
load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction,
|
||||||
const std::string &name, instruction *next)
|
const std::string &name, instruction *next)
|
||||||
: load_inst(ptr, INST_MASKED_LOAD_ASYNC, 3, cache, false, name, next) {
|
: load_inst(ptr, INST_MASKED_LOAD_ASYNC, 3, cache, eviction, false, name, next) {
|
||||||
set_operand(0, ptr);
|
set_operand(0, ptr);
|
||||||
set_operand(1, mask);
|
set_operand(1, mask);
|
||||||
set_operand(2, false_value);
|
set_operand(2, false_value);
|
||||||
}
|
}
|
||||||
|
|
||||||
masked_load_async_inst* masked_load_async_inst::create(value *ptr, value *mask, value *false_value,
|
masked_load_async_inst* masked_load_async_inst::create(value *ptr, value *mask, value *false_value,
|
||||||
load_inst::CACHE_MODIFIER cache,
|
load_inst::CACHE_MODIFIER cache, EVICTION_POLICY eviction,
|
||||||
const std::string &name, instruction *next) {
|
const std::string &name, instruction *next) {
|
||||||
return new masked_load_async_inst(ptr, mask, false_value, cache, name, next);
|
return new masked_load_async_inst(ptr, mask, false_value, cache, eviction, name, next);
|
||||||
}
|
}
|
||||||
|
|
||||||
// store
|
// store
|
||||||
|
@@ -472,7 +472,7 @@ std::tuple<std::string, asm_map_t, int> cu_compile_ttir(const std::string& name,
|
|||||||
size_t minor = cuGetInfo<CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR>(dev);
|
size_t minor = cuGetInfo<CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR>(dev);
|
||||||
size_t cc = major*10 + minor;
|
size_t cc = major*10 + minor;
|
||||||
int version;
|
int version;
|
||||||
drv::dispatch::cuDriverGetVersion(&version);
|
std::string ptxas_path = drv::path_to_ptxas(version);
|
||||||
// Triton-IR -> NVPTX LLVM-IR
|
// Triton-IR -> NVPTX LLVM-IR
|
||||||
triton::codegen::nvidia_cu_target target(cc);
|
triton::codegen::nvidia_cu_target target(cc);
|
||||||
auto llvm = triton::codegen::add_passes_to_emit_bin(ir, ctx, &target, cc, num_warps, num_stages, n_shared_bytes);
|
auto llvm = triton::codegen::add_passes_to_emit_bin(ir, ctx, &target, cc, num_warps, num_stages, n_shared_bytes);
|
||||||
@@ -485,7 +485,7 @@ std::tuple<std::string, asm_map_t, int> cu_compile_ttir(const std::string& name,
|
|||||||
std::string ptx = drv::llir_to_ptx(llvm.get(), cc, version);
|
std::string ptx = drv::llir_to_ptx(llvm.get(), cc, version);
|
||||||
asm_map["ptx"] = py::cast(ptx);
|
asm_map["ptx"] = py::cast(ptx);
|
||||||
// PTX -> Binary
|
// PTX -> Binary
|
||||||
std::string cubin = drv::ptx_to_cubin(ptx, cc);
|
std::string cubin = drv::ptx_to_cubin(ptx, ptxas_path, cc);
|
||||||
if(!cubin.empty()){
|
if(!cubin.empty()){
|
||||||
py::bytes bytes(cubin);
|
py::bytes bytes(cubin);
|
||||||
asm_map["cubin"] = bytes;
|
asm_map["cubin"] = bytes;
|
||||||
|
@@ -556,7 +556,7 @@ def dot(input, other, allow_tf32=True, _builder=None):
|
|||||||
|
|
||||||
|
|
||||||
@builtin
|
@builtin
|
||||||
def load(pointer, mask=None, other=None, cache_modifier="", volatile=False, _builder=None):
|
def load(pointer, mask=None, other=None, cache_modifier="", eviction_policy="", volatile=False, _builder=None):
|
||||||
"""
|
"""
|
||||||
Return a block of data whose values are, elementwise, loaded from memory at location defined by :code:`pointer`.
|
Return a block of data whose values are, elementwise, loaded from memory at location defined by :code:`pointer`.
|
||||||
|
|
||||||
@@ -573,7 +573,7 @@ def load(pointer, mask=None, other=None, cache_modifier="", volatile=False, _bui
|
|||||||
:param cache_modifier: changes cache option in nvidia ptx
|
:param cache_modifier: changes cache option in nvidia ptx
|
||||||
'type cache_modifier: str, optional
|
'type cache_modifier: str, optional
|
||||||
"""
|
"""
|
||||||
return frontend.load(pointer, mask, other, cache_modifier, volatile, _builder)
|
return frontend.load(pointer, mask, other, cache_modifier, eviction_policy, volatile, _builder)
|
||||||
|
|
||||||
|
|
||||||
@builtin
|
@builtin
|
||||||
|
Reference in New Issue
Block a user