[CODEGEN] Add cache modifier to tl.load (#351)

* Add cache modifier to tl.load
* Add comment to cache_modifier
* Remove force_nc_cache
* Update test
This commit is contained in:
daadaada
2021-10-18 13:14:04 +08:00
committed by GitHub
parent 90ded16c32
commit 858dec8372
16 changed files with 119 additions and 63 deletions

View File

@@ -116,7 +116,7 @@ bool peephole::rewrite_load_to_shared(ir::instruction *value, ir::builder& build
int nts = layout->nts(layout->get_order()[0]);
int dtsize = value->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
if(nts*dtsize >= 4){
ir::value* new_load = builder.create_masked_load_async(ptr, msk, val);
ir::value* new_load = builder.create_masked_load_async(ptr, msk, val, ld->get_cache_modifier());
copy_to_shared->replace_all_uses_with(new_load);
return true;
}
@@ -206,7 +206,8 @@ bool peephole::rewrite_select_masked_load(ir::instruction *value, ir::builder& b
builder.set_insert_point(select);
ir::value* new_load = builder.create_masked_load(if_value->get_pointer_operand(),
if_value->get_mask_operand(),
select->get_else_value_op());
select->get_else_value_op(),
if_value->get_cache_modifier());
select->replace_all_uses_with(new_load);
return true;
}