[CODEGEN] Small bugfix in atomic-add (#114)

This commit is contained in:
Philippe Tillet
2021-05-20 01:12:30 -04:00
committed by Philippe Tillet
parent f81012a8cf
commit b5dcac484d

View File

@@ -902,7 +902,7 @@ void generator::visit_atomic_add_inst(ir::atomic_add_inst* add) {
// asm string // asm string
std::string suffix = vec == 2 ? "x2" : ""; std::string suffix = vec == 2 ? "x2" : "";
std::string mod = nbits == 32 ? "" : ".noftz"; std::string mod = nbits == 32 ? "" : ".noftz";
std::string ty_str = ty->isFloatingPointTy() ? "f" : "u"; std::string ty_str = add->get_type()->get_scalar_ty()->is_floating_point_ty() ? "f" : "u";
std::string asm_str = "@$0 atom.global.gpu.add" + mod + "." + ty_str + std::to_string(nbits) + suffix + " $1, [$2" + offset + "], $3;"; std::string asm_str = "@$0 atom.global.gpu.add" + mod + "." + ty_str + std::to_string(nbits) + suffix + " $1, [$2" + offset + "], $3;";
std::string ty_id = nbits == 32 ? ty_str : (vec == 1 ? "h" : "r"); std::string ty_id = nbits == 32 ? ty_str : (vec == 1 ? "h" : "r");
std::string constraint = "b,=" + ty_id + ",l," + ty_id; std::string constraint = "b,=" + ty_id + ",l," + ty_id;
@@ -921,7 +921,7 @@ void generator::visit_atomic_add_inst(ir::atomic_add_inst* add) {
std::vector<Type*> arg_ty = {rmw_msk->getType(), rmw_ptr->getType(), rmw_val->getType()}; std::vector<Type*> arg_ty = {rmw_msk->getType(), rmw_ptr->getType(), rmw_val->getType()};
FunctionType *fn_ty = FunctionType::get(ty, arg_ty, false); FunctionType *fn_ty = FunctionType::get(ty, arg_ty, false);
std::string mod = nbits == 32 ? "" : ".noftz"; std::string mod = nbits == 32 ? "" : ".noftz";
std::string ty_str = ty->isFloatingPointTy() ? "f" : "u"; std::string ty_str = add->get_type()->get_scalar_ty()->is_floating_point_ty() ? "f" : "u";
std::string asm_str = "@$0 atom.global.gpu.add" + mod + "." + ty_str + std::to_string(nbits) + " $1, [$2], $3;"; std::string asm_str = "@$0 atom.global.gpu.add" + mod + "." + ty_str + std::to_string(nbits) + " $1, [$2], $3;";
std::string ty_id = nbits == 32 ? "r" : "h"; std::string ty_id = nbits == 32 ? "r" : "h";
InlineAsm *iasm = InlineAsm::get(fn_ty, asm_str, "b,="+ty_id+",l,"+ty_id, true); InlineAsm *iasm = InlineAsm::get(fn_ty, asm_str, "b,="+ty_id+",l,"+ty_id, true);