[CODEGEN] Add cache modifier to tl.load (#351)

* Add cache modifier to tl.load
* Add comment to cache_modifier
* Remove force_nc_cache
* Update test
This commit is contained in:
daadaada
2021-10-18 13:14:04 +08:00
committed by GitHub
parent 90ded16c32
commit 858dec8372
16 changed files with 119 additions and 63 deletions

View File

@@ -25,7 +25,7 @@ namespace codegen {
// TODO:
// There should be a proper pass manager there!
std::unique_ptr<llvm::Module> add_passes_to_emit_bin(ir::module &ir, llvm::LLVMContext& ctx, codegen::target* target,
int cc, int num_warps, int num_stages, bool force_nc_cache, int& shared_static) {
int cc, int num_warps, int num_stages, int& shared_static) {
// generate llvm code
std::string name = ir.get_function_list()[0]->get_name();
std::unique_ptr<llvm::Module> llvm(new llvm::Module(name, ctx));
@@ -46,7 +46,7 @@ std::unique_ptr<llvm::Module> add_passes_to_emit_bin(ir::module &ir, llvm::LLVMC
codegen::transform::coalesce coalesce(&align, &layouts);
codegen::transform::prefetch prefetch_s(target);
codegen::transform::membar barriers(&liveness, &layouts, &allocation, &prefetch_s, target);
codegen::generator isel(&axes, &layouts, &align, &allocation, &swizzle, target, num_warps, force_nc_cache);
codegen::generator isel(&axes, &layouts, &align, &allocation, &swizzle, target, num_warps);
// run passes
dce.run(ir);
peephole.run(ir);

View File

@@ -197,9 +197,9 @@ generator::generator(analysis::axes *a_axes,
analysis::allocation *alloc,
analysis::swizzle *swizzle,
target *tgt,
unsigned num_warps, bool force_nc_cache)
unsigned num_warps)
: a_axes_(a_axes), layouts_(layouts), alignment_(alignment), alloc_(alloc), swizzle_(swizzle),
tgt_(tgt), num_warps_(num_warps), force_nc_cache_(force_nc_cache), add(&builder_), mul(&builder_), gep(&builder_) {
tgt_(tgt), num_warps_(num_warps), add(&builder_), mul(&builder_), gep(&builder_) {
}
@@ -629,10 +629,9 @@ void generator::visit_load_inst(ir::load_inst* x){
// -----
std::ostringstream asm_oss;
asm_oss << "@$" << n_words; // predicate
// if(force_nc_cache_)
asm_oss << " ld.global";
// else
// asm_oss << " ld.global.cg";
asm_oss << " ld.global";
if (x->get_cache_modifier() == ir::load_inst::CA) asm_oss << ".ca";
if (x->get_cache_modifier() == ir::load_inst::CG) asm_oss << ".cg";
if(n_words > 1)
asm_oss << ".v" << n_words; // vector width
asm_oss << ".b" << width; // word size

View File

@@ -116,7 +116,7 @@ bool peephole::rewrite_load_to_shared(ir::instruction *value, ir::builder& build
int nts = layout->nts(layout->get_order()[0]);
int dtsize = value->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
if(nts*dtsize >= 4){
ir::value* new_load = builder.create_masked_load_async(ptr, msk, val);
ir::value* new_load = builder.create_masked_load_async(ptr, msk, val, ld->get_cache_modifier());
copy_to_shared->replace_all_uses_with(new_load);
return true;
}
@@ -206,7 +206,8 @@ bool peephole::rewrite_select_masked_load(ir::instruction *value, ir::builder& b
builder.set_insert_point(select);
ir::value* new_load = builder.create_masked_load(if_value->get_pointer_operand(),
if_value->get_mask_operand(),
select->get_else_value_op());
select->get_else_value_op(),
if_value->get_cache_modifier());
select->replace_all_uses_with(new_load);
return true;
}

View File

@@ -111,6 +111,8 @@ struct pipeline_info_t {
};
void pipeline::run(ir::module &mod) {
if (num_stages_ <= 1)
return;
// *Very* conservative heuristics for pre-fetching.
// A load instruction can be pipelined if:
// - the pointer is a phi node that references a value
@@ -176,7 +178,7 @@ void pipeline::run(ir::module &mod) {
false_value = remat_false_value;
} else
false_value = builder.create_splat(ir::undef_value::get(ty->get_scalar_ty()), ty->get_block_shapes());
first_loads[0] = builder.create_masked_load(first_ptrs[0], first_masks[0], false_value);
first_loads[0] = builder.create_masked_load(first_ptrs[0], first_masks[0], false_value, load->get_cache_modifier());
for (int stage = 1; stage < num_stages-1; ++stage) {
// mask is the loop condition of the previous iteration
@@ -191,7 +193,7 @@ void pipeline::run(ir::module &mod) {
first_masks[stage] = builder.create_and(first_masks[stage], remat_mask);
false_value = remat_false_value;
}
first_loads[stage] = builder.create_masked_load(first_ptrs[stage], first_masks[stage], false_value);
first_loads[stage] = builder.create_masked_load(first_ptrs[stage], first_masks[stage], false_value, load->get_cache_modifier());
}
// create new phis for induction variables
@@ -220,7 +222,7 @@ void pipeline::run(ir::module &mod) {
next_mask = builder.create_and(next_mask, remat_mask);
false_value = remat_false_value;
}
ir::value* next_load = builder.create_masked_load(next_ptr, next_mask, false_value);
ir::value* next_load = builder.create_masked_load(next_ptr, next_mask, false_value, load->get_cache_modifier());
// phi node
@@ -255,7 +257,7 @@ void pipeline::run(ir::module &mod) {
}
else
false_value = builder.create_splat(ir::undef_value::get(ty->get_scalar_ty()), ty->get_block_shapes());
ir::value* first_load = builder.create_masked_load(first_ptr, first_mask, false_value);
ir::value* first_load = builder.create_masked_load(first_ptr, first_mask, false_value, load->get_cache_modifier());
// pre-fetch next iteration
builder.set_insert_point(block->get_inst_list().back());
ir::value* next_ptr = ptr->get_value_for_block(block);
@@ -266,7 +268,7 @@ void pipeline::run(ir::module &mod) {
next_mask = builder.create_and(next_mask, remat_mask);
false_value = remat_false_value;
}
ir::value* next_load = builder.create_masked_load(next_ptr, next_mask, false_value);
ir::value* next_load = builder.create_masked_load(next_ptr, next_mask, false_value, load->get_cache_modifier());
// phi node
builder.set_insert_point(block->get_first_non_phi());
ir::phi_node* new_load = builder.create_phi(ty, 2);

View File

@@ -273,16 +273,16 @@ DEFINE_FCMP_INSTR(UNE, cmp_pred_t::FCMP_UNE)
// load/store instructions
//===----------------------------------------------------------------------===//
value *builder::create_load(value *ptr){
return insert(unmasked_load_inst::create(ptr));
value *builder::create_load(value *ptr, load_inst::CACHE_MODIFIER cache){
return insert(unmasked_load_inst::create(ptr, cache));
}
value *builder::create_store(value *ptr, value *val){
return insert(unmasked_store_inst::create(ptr, val));
}
value *builder::create_masked_load(value *ptr, value *mask, value *false_value){
return insert(masked_load_inst::create(ptr, mask, false_value));
value *builder::create_masked_load(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache){
return insert(masked_load_inst::create(ptr, mask, false_value, cache));
}
value *builder::create_masked_store(value *ptr, value *val, value *mask){
@@ -379,8 +379,8 @@ value *builder::create_copy_from_shared(value *arg) {
return insert(copy_from_shared_inst::create(arg));
}
value *builder::create_masked_load_async(value *ptr, value *mask, value *false_value) {
return insert(masked_load_async_inst::create(ptr, mask, false_value));
value *builder::create_masked_load_async(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache) {
return insert(masked_load_async_inst::create(ptr, mask, false_value, cache));
}
value *builder::create_barrier(const std::string &name) {

View File

@@ -497,7 +497,7 @@ ir::value *dispatch::cast(ir::value *input, ir::type *dst_ty, ir::builder *build
// Memory Operators
//===----------------------------------------------------------------------===//
ir::value *dispatch::load(ir::value* ptr, ir::value* mask, ir::value* other, ir::builder* builder) {
ir::value *dispatch::load(ir::value* ptr, ir::value* mask, ir::value* other, const std::string &cache_modifier, ir::builder* builder) {
if(!ptr->get_type()->get_scalar_ty()->is_pointer_ty())
throw semantic_error("Pointer argument of load instruction is " + ptr->get_type()->repr());
if(ptr->get_type()->is_block_ty()){
@@ -517,8 +517,17 @@ ir::value *dispatch::load(ir::value* ptr, ir::value* mask, ir::value* other, ir:
ptr_ty = pointer_type::get(elt_ty, ptr_ty->get_pointer_address_space());
ptr = dispatch::cast(ptr, ptr_ty, builder);
}
load_inst::CACHE_MODIFIER cache = load_inst::NONE; // default
if (!cache_modifier.empty()) {
if (cache_modifier == ".ca")
cache = load_inst::CA;
else if (cache_modifier == ".cg")
cache = load_inst::CG;
else
throw std::runtime_error(std::string("Cache modifier ") + cache_modifier + " not supported");
}
if (!mask && !other)
return builder->create_load(ptr);
return builder->create_load(ptr, cache);
if (!mask)
throw std::runtime_error("`other` cannot be provided without `mask`");
auto shape = ptr->get_type()->get_block_shapes();
@@ -527,7 +536,7 @@ ir::value *dispatch::load(ir::value* ptr, ir::value* mask, ir::value* other, ir:
if(ptr->get_type()->is_block_ty())
other = builder->create_splat(other, ptr->get_type()->get_block_shapes());
}
return builder->create_masked_load(ptr, mask, other);
return builder->create_masked_load(ptr, mask, other, cache);
}
ir::value *dispatch::store(ir::value* ptr, ir::value *val, ir::value* mask, ir::builder *builder) {

View File

@@ -433,8 +433,8 @@ io_inst::io_inst(type *ty, value_id_t id, unsigned num_ops, const std::string &n
{ }
// load_inst
load_inst::load_inst(value *ptr, value_id_t id, unsigned num_ops, const std::string &name, instruction *next)
: io_inst(get_pointee_type(ptr->get_type()), id, num_ops, name, next)
load_inst::load_inst(value *ptr, value_id_t id, unsigned num_ops, load_inst::CACHE_MODIFIER cache, const std::string &name, instruction *next)
: io_inst(get_pointee_type(ptr->get_type()), id, num_ops, name, next), cache_(cache)
{ }
// load
@@ -447,41 +447,44 @@ type *load_inst::get_pointee_type(type *ty) {
}
// unmasked_load
unmasked_load_inst::unmasked_load_inst(value *ptr, const std::string &name, instruction *next)
: load_inst(ptr, INST_UNMASKED_LOAD, 1, name, next) {
unmasked_load_inst::unmasked_load_inst(value *ptr, load_inst::CACHE_MODIFIER cache, const std::string &name, instruction *next)
: load_inst(ptr, INST_UNMASKED_LOAD, 1, cache, name, next) {
set_operand(0, ptr);
}
unmasked_load_inst* unmasked_load_inst::create(value *ptr, const std::string &name, instruction *next) {
return new unmasked_load_inst(ptr, name, next);
unmasked_load_inst* unmasked_load_inst::create(value *ptr, load_inst::CACHE_MODIFIER cache, const std::string &name, instruction *next) {
return new unmasked_load_inst(ptr, cache, name, next);
}
// masked load
masked_load_inst::masked_load_inst(value *ptr, value *mask, value *false_value,
masked_load_inst::masked_load_inst(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache,
const std::string &name, instruction *next)
: load_inst(ptr, INST_MASKED_LOAD, 3, name, next) {
: load_inst(ptr, INST_MASKED_LOAD, 3, cache, name, next) {
set_operand(0, ptr);
set_operand(1, mask);
set_operand(2, false_value);
}
masked_load_inst* masked_load_inst::create(value *ptr, value *mask, value *false_value,
load_inst::CACHE_MODIFIER cache,
const std::string &name, instruction *next) {
return new masked_load_inst(ptr, mask, false_value, name, next);
return new masked_load_inst(ptr, mask, false_value, cache, name, next);
}
// masked load async
masked_load_async_inst::masked_load_async_inst(value *ptr, value *mask, value *false_value,
load_inst::CACHE_MODIFIER cache,
const std::string &name, instruction *next)
: load_inst(ptr, INST_MASKED_LOAD_ASYNC, 3, name, next) {
: load_inst(ptr, INST_MASKED_LOAD_ASYNC, 3, cache, name, next) {
set_operand(0, ptr);
set_operand(1, mask);
set_operand(2, false_value);
}
masked_load_async_inst* masked_load_async_inst::create(value *ptr, value *mask, value *false_value,
load_inst::CACHE_MODIFIER cache,
const std::string &name, instruction *next) {
return new masked_load_async_inst(ptr, mask, false_value, name, next);
return new masked_load_async_inst(ptr, mask, false_value, cache, name, next);
}
// store