diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index 63f72f9ff..6507080f9 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -903,9 +903,9 @@ void generator::visit_atomic_add_inst(ir::atomic_add_inst* add) { std::string suffix = vec == 2 ? "x2" : ""; std::string mod = nbits == 32 ? "" : ".noftz"; std::string ty_str = add->get_type()->get_scalar_ty()->is_floating_point_ty() ? "f" : "u"; - std::string asm_str = "@$0 atom.global.gpu.add" + mod + "." + ty_str + std::to_string(nbits) + suffix + " $1, [$2" + offset + "], $3;"; + std::string asm_str = "@$1 atom.global.gpu.add" + mod + "." + ty_str + std::to_string(nbits) + suffix + " $0, [$2" + offset + "], $3;"; std::string ty_id = nbits == 32 ? ty_str : (vec == 1 ? "h" : "r"); - std::string constraint = "b,=" + ty_id + ",l," + ty_id; + std::string constraint = "=" + ty_id + ",b,l," + ty_id; // create inline asm InlineAsm *iasm = InlineAsm::get(fn_ty, asm_str, constraint, true); // call asm @@ -922,9 +922,9 @@ void generator::visit_atomic_add_inst(ir::atomic_add_inst* add) { FunctionType *fn_ty = FunctionType::get(ty, arg_ty, false); std::string mod = nbits == 32 ? "" : ".noftz"; std::string ty_str = add->get_type()->get_scalar_ty()->is_floating_point_ty() ? "f" : "u"; - std::string asm_str = "@$0 atom.global.gpu.add" + mod + "." + ty_str + std::to_string(nbits) + " $1, [$2], $3;"; + std::string asm_str = "@$1 atom.global.gpu.add" + mod + "." + ty_str + std::to_string(nbits) + " $0, [$2], $3;"; std::string ty_id = nbits == 32 ? "r" : "h"; - InlineAsm *iasm = InlineAsm::get(fn_ty, asm_str, "b,="+ty_id+",l,"+ty_id, true); + InlineAsm *iasm = InlineAsm::get(fn_ty, asm_str, "="+ty_id+",b,l,"+ty_id, true); BasicBlock *current = builder_->GetInsertBlock(); Module *module = current->getModule(); diff --git a/lib/codegen/transform/prefetch.cc b/lib/codegen/transform/prefetch.cc index 876ee3e43..96bcf57d7 100644 --- a/lib/codegen/transform/prefetch.cc +++ b/lib/codegen/transform/prefetch.cc @@ -95,11 +95,17 @@ void prefetch::run(ir::module &mod) { }); builder.set_insert_point(bb->get_first_non_phi()); - for (ir::instruction *i : loads) { + auto& inst_list = bb->get_inst_list(); + for (ir::instruction *i : loads){ + auto it = std::find(inst_list.begin(), inst_list.end(), i); + // make sure we don't invalidate insert point + // in case instruction already at the top + if(it == builder.get_insert_point()) + continue; bb->erase(i); builder.insert(i); } } } } -} // namespace triton::codegen::transform \ No newline at end of file +} // namespace triton::codegen::transform