[CODEGEN] Bugfix in prefetch pass (#118)

This commit is contained in:
Philippe Tillet
2021-05-20 22:34:38 -04:00
committed by Philippe Tillet
parent 3ab121dbdb
commit 38ab4e955a
2 changed files with 12 additions and 6 deletions

View File

@@ -903,9 +903,9 @@ void generator::visit_atomic_add_inst(ir::atomic_add_inst* add) {
std::string suffix = vec == 2 ? "x2" : "";
std::string mod = nbits == 32 ? "" : ".noftz";
std::string ty_str = add->get_type()->get_scalar_ty()->is_floating_point_ty() ? "f" : "u";
std::string asm_str = "@$0 atom.global.gpu.add" + mod + "." + ty_str + std::to_string(nbits) + suffix + " $1, [$2" + offset + "], $3;";
std::string asm_str = "@$1 atom.global.gpu.add" + mod + "." + ty_str + std::to_string(nbits) + suffix + " $0, [$2" + offset + "], $3;";
std::string ty_id = nbits == 32 ? ty_str : (vec == 1 ? "h" : "r");
std::string constraint = "b,=" + ty_id + ",l," + ty_id;
std::string constraint = "=" + ty_id + ",b,l," + ty_id;
// create inline asm
InlineAsm *iasm = InlineAsm::get(fn_ty, asm_str, constraint, true);
// call asm
@@ -922,9 +922,9 @@ void generator::visit_atomic_add_inst(ir::atomic_add_inst* add) {
FunctionType *fn_ty = FunctionType::get(ty, arg_ty, false);
std::string mod = nbits == 32 ? "" : ".noftz";
std::string ty_str = add->get_type()->get_scalar_ty()->is_floating_point_ty() ? "f" : "u";
std::string asm_str = "@$0 atom.global.gpu.add" + mod + "." + ty_str + std::to_string(nbits) + " $1, [$2], $3;";
std::string asm_str = "@$1 atom.global.gpu.add" + mod + "." + ty_str + std::to_string(nbits) + " $0, [$2], $3;";
std::string ty_id = nbits == 32 ? "r" : "h";
InlineAsm *iasm = InlineAsm::get(fn_ty, asm_str, "b,="+ty_id+",l,"+ty_id, true);
InlineAsm *iasm = InlineAsm::get(fn_ty, asm_str, "="+ty_id+",b,l,"+ty_id, true);
BasicBlock *current = builder_->GetInsertBlock();
Module *module = current->getModule();

View File

@@ -95,11 +95,17 @@ void prefetch::run(ir::module &mod) {
});
builder.set_insert_point(bb->get_first_non_phi());
for (ir::instruction *i : loads) {
auto& inst_list = bb->get_inst_list();
for (ir::instruction *i : loads){
auto it = std::find(inst_list.begin(), inst_list.end(), i);
// make sure we don't invalidate insert point
// in case instruction already at the top
if(it == builder.get_insert_point())
continue;
bb->erase(i);
builder.insert(i);
}
}
}
}
} // namespace triton::codegen::transform
} // namespace triton::codegen::transform