[CODEGEN] Improvements and bugfixes (#463)
This commit is contained in:
@@ -291,16 +291,16 @@ DEFINE_FCMP_INSTR(UNE, cmp_pred_t::FCMP_UNE)
|
||||
// load/store instructions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
value *builder::create_load(value *ptr, load_inst::CACHE_MODIFIER cache, bool is_volatile){
|
||||
return insert(unmasked_load_inst::create(ptr, cache, is_volatile));
|
||||
value *builder::create_load(value *ptr, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile){
|
||||
return insert(unmasked_load_inst::create(ptr, cache, eviction, is_volatile));
|
||||
}
|
||||
|
||||
value *builder::create_store(value *ptr, value *val){
|
||||
return insert(unmasked_store_inst::create(ptr, val));
|
||||
}
|
||||
|
||||
value *builder::create_masked_load(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, bool is_volatile){
|
||||
return insert(masked_load_inst::create(ptr, mask, false_value, cache, is_volatile));
|
||||
value *builder::create_masked_load(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile){
|
||||
return insert(masked_load_inst::create(ptr, mask, false_value, cache, eviction, is_volatile));
|
||||
}
|
||||
|
||||
value *builder::create_masked_store(value *ptr, value *val, value *mask){
|
||||
@@ -404,8 +404,8 @@ value *builder::create_copy_from_shared(value *arg) {
|
||||
return insert(copy_from_shared_inst::create(arg));
|
||||
}
|
||||
|
||||
value *builder::create_masked_load_async(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache) {
|
||||
return insert(masked_load_async_inst::create(ptr, mask, false_value, cache));
|
||||
value *builder::create_masked_load_async(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction) {
|
||||
return insert(masked_load_async_inst::create(ptr, mask, false_value, cache, eviction));
|
||||
}
|
||||
|
||||
value *builder::create_barrier(const std::string &name) {
|
||||
|
@@ -506,6 +506,11 @@ ir::value *dispatch::cast(ir::value *input, ir::type *dst_ty, ir::builder *build
|
||||
return input;
|
||||
ir::type *src_sca_ty = src_ty->get_scalar_ty();
|
||||
ir::type *dst_sca_ty = dst_ty->get_scalar_ty();
|
||||
//
|
||||
if((src_sca_ty->is_bf16_ty() && !dst_sca_ty->is_fp32_ty()) ||
|
||||
(dst_sca_ty->is_bf16_ty() && !src_sca_ty->is_fp32_ty())){
|
||||
return cast(cast(input, builder->get_float_ty(), builder), dst_sca_ty, builder);
|
||||
}
|
||||
// FP Truncation
|
||||
bool truncate_fp = src_sca_ty->is_floating_point_ty() &&
|
||||
dst_sca_ty->is_floating_point_ty() &&
|
||||
@@ -569,18 +574,17 @@ ir::value *dispatch::cast(ir::value *input, ir::type *dst_ty, ir::builder *build
|
||||
// Memory Operators
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
ir::value *dispatch::load(ir::value* ptr, ir::value* mask, ir::value* other, const std::string &cache_modifier, int is_volatile, ir::builder* builder) {
|
||||
ir::value *dispatch::load(ir::value* ptr, ir::value* mask, ir::value* other, const std::string &cache_modifier, const std::string& eviction_policy, int is_volatile, ir::builder* builder) {
|
||||
if(!ptr->get_type()->get_scalar_ty()->is_pointer_ty())
|
||||
throw semantic_error("Pointer argument of load instruction is " + ptr->get_type()->repr());
|
||||
if(ptr->get_type()->is_block_ty()){
|
||||
if(mask){
|
||||
if(mask)
|
||||
mask = dispatch::broadcast(mask, ptr->get_type()->get_block_shapes(), builder);
|
||||
}
|
||||
if(other){
|
||||
if(other)
|
||||
other = dispatch::broadcast(other, ptr->get_type()->get_block_shapes(), builder);
|
||||
other = dispatch::cast(other, ptr->get_type()->get_scalar_ty()->get_pointer_element_ty(), builder);
|
||||
}
|
||||
}
|
||||
if(other)
|
||||
other = dispatch::cast(other, ptr->get_type()->get_scalar_ty()->get_pointer_element_ty(), builder);
|
||||
ir::type *ptr_ty = ptr->get_type()->get_scalar_ty();
|
||||
ir::type *elt_ty = ptr_ty->get_pointer_element_ty();
|
||||
// treat bool* as int8*
|
||||
@@ -599,8 +603,20 @@ ir::value *dispatch::load(ir::value* ptr, ir::value* mask, ir::value* other, con
|
||||
else
|
||||
throw std::runtime_error(std::string("Cache modifier ") + cache_modifier + " not supported");
|
||||
}
|
||||
// eviction policy
|
||||
load_inst::EVICTION_POLICY eviction = load_inst::NORMAL; //default
|
||||
if(!eviction_policy.empty()){
|
||||
if (eviction_policy == "evict_last")
|
||||
eviction = load_inst::EVICT_LAST;
|
||||
else if(eviction_policy == "evict_first")
|
||||
eviction = load_inst::EVICT_FIRST;
|
||||
else
|
||||
throw std::runtime_error(std::string("Eviction policy") + eviction_policy + " not supported");
|
||||
}
|
||||
|
||||
|
||||
if (!mask && !other)
|
||||
return builder->create_load(ptr, cache, is_volatile);
|
||||
return builder->create_load(ptr, cache, eviction, is_volatile);
|
||||
if (!mask)
|
||||
throw std::runtime_error("`other` cannot be provided without `mask`");
|
||||
auto shape = ptr->get_type()->get_block_shapes();
|
||||
@@ -609,7 +625,7 @@ ir::value *dispatch::load(ir::value* ptr, ir::value* mask, ir::value* other, con
|
||||
if(ptr->get_type()->is_block_ty())
|
||||
other = builder->create_splat(other, ptr->get_type()->get_block_shapes());
|
||||
}
|
||||
return builder->create_masked_load(ptr, mask, other, cache, is_volatile);
|
||||
return builder->create_masked_load(ptr, mask, other, cache, eviction, is_volatile);
|
||||
}
|
||||
|
||||
ir::value *dispatch::store(ir::value* ptr, ir::value *val, ir::value* mask, ir::builder *builder) {
|
||||
|
@@ -434,8 +434,8 @@ io_inst::io_inst(type *ty, value_id_t id, unsigned num_ops, const std::string &n
|
||||
{ }
|
||||
|
||||
// load_inst
|
||||
load_inst::load_inst(value *ptr, value_id_t id, unsigned num_ops, load_inst::CACHE_MODIFIER cache, bool is_volatile, const std::string &name, instruction *next)
|
||||
: io_inst(get_pointee_type(ptr->get_type()), id, num_ops, name, next), cache_(cache), is_volatile_(is_volatile)
|
||||
load_inst::load_inst(value *ptr, value_id_t id, unsigned num_ops, load_inst::CACHE_MODIFIER cache, EVICTION_POLICY eviction, bool is_volatile, const std::string &name, instruction *next)
|
||||
: io_inst(get_pointee_type(ptr->get_type()), id, num_ops, name, next), cache_(cache), eviction_(eviction), is_volatile_(is_volatile)
|
||||
{ }
|
||||
|
||||
// load
|
||||
@@ -448,44 +448,46 @@ type *load_inst::get_pointee_type(type *ty) {
|
||||
}
|
||||
|
||||
// unmasked_load
|
||||
unmasked_load_inst::unmasked_load_inst(value *ptr, load_inst::CACHE_MODIFIER cache, bool is_volatile, const std::string &name, instruction *next)
|
||||
: load_inst(ptr, INST_UNMASKED_LOAD, 1, cache, is_volatile, name, next) {
|
||||
unmasked_load_inst::unmasked_load_inst(value *ptr, load_inst::CACHE_MODIFIER cache,load_inst::EVICTION_POLICY eviction, bool is_volatile, const std::string &name, instruction *next)
|
||||
: load_inst(ptr, INST_UNMASKED_LOAD, 1, cache, eviction, is_volatile, name, next) {
|
||||
set_operand(0, ptr);
|
||||
}
|
||||
|
||||
unmasked_load_inst* unmasked_load_inst::create(value *ptr, load_inst::CACHE_MODIFIER cache, bool is_volatile, const std::string &name, instruction *next) {
|
||||
return new unmasked_load_inst(ptr, cache, is_volatile, name, next);
|
||||
unmasked_load_inst* unmasked_load_inst::create(value *ptr, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile, const std::string &name, instruction *next) {
|
||||
return new unmasked_load_inst(ptr, cache, eviction, is_volatile, name, next);
|
||||
}
|
||||
|
||||
// masked load
|
||||
masked_load_inst::masked_load_inst(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, bool is_volatile,
|
||||
masked_load_inst::masked_load_inst(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction,
|
||||
bool is_volatile,
|
||||
const std::string &name, instruction *next)
|
||||
: load_inst(ptr, INST_MASKED_LOAD, 3, cache, is_volatile, name, next) {
|
||||
: load_inst(ptr, INST_MASKED_LOAD, 3, cache, eviction, is_volatile, name, next) {
|
||||
set_operand(0, ptr);
|
||||
set_operand(1, mask);
|
||||
set_operand(2, false_value);
|
||||
}
|
||||
|
||||
masked_load_inst* masked_load_inst::create(value *ptr, value *mask, value *false_value,
|
||||
load_inst::CACHE_MODIFIER cache, bool is_volatile,
|
||||
load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction,
|
||||
bool is_volatile,
|
||||
const std::string &name, instruction *next) {
|
||||
return new masked_load_inst(ptr, mask, false_value, cache, is_volatile, name, next);
|
||||
return new masked_load_inst(ptr, mask, false_value, cache, eviction, is_volatile, name, next);
|
||||
}
|
||||
|
||||
// masked load async
|
||||
masked_load_async_inst::masked_load_async_inst(value *ptr, value *mask, value *false_value,
|
||||
load_inst::CACHE_MODIFIER cache,
|
||||
load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction,
|
||||
const std::string &name, instruction *next)
|
||||
: load_inst(ptr, INST_MASKED_LOAD_ASYNC, 3, cache, false, name, next) {
|
||||
: load_inst(ptr, INST_MASKED_LOAD_ASYNC, 3, cache, eviction, false, name, next) {
|
||||
set_operand(0, ptr);
|
||||
set_operand(1, mask);
|
||||
set_operand(2, false_value);
|
||||
}
|
||||
|
||||
masked_load_async_inst* masked_load_async_inst::create(value *ptr, value *mask, value *false_value,
|
||||
load_inst::CACHE_MODIFIER cache,
|
||||
load_inst::CACHE_MODIFIER cache, EVICTION_POLICY eviction,
|
||||
const std::string &name, instruction *next) {
|
||||
return new masked_load_async_inst(ptr, mask, false_value, cache, name, next);
|
||||
return new masked_load_async_inst(ptr, mask, false_value, cache, eviction, name, next);
|
||||
}
|
||||
|
||||
// store
|
||||
|
Reference in New Issue
Block a user