diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index 7e580251c..6bd62ba81 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -711,11 +711,35 @@ void generator::visit_atomic_add_inst(ir::atomic_add_inst* add) { distributed_tile* vals = (distributed_tile*)tmap_.at(val); distributed_tile* msks = (distributed_tile*)tmap_.at(msk); - for_each(ptr, [&](indices_t idx){ + // vector size + int vector_size = 1; + int ld = ptrs->get_order()[0]; + unsigned alignment = alignment_->get(ptr, ld); + vector_size = gcd(ptrs->axis(ld).contiguous, alignment); + vector_size = std::min(vector_size, val->get_type()->get_tile_element_ty()->is_half_ty() ? 2 : 1); + vector_size = 1; + + std::map packets; + for_each(val, [&](indices_t idx){ + unsigned linear = vals->get_linear_index(idx); + unsigned id = linear / vector_size; + Value *in_value = vals->get_value(idx); + if(linear % vector_size == 0) + packets[id] = UndefValue::get(VectorType::get(in_value->getType(), vector_size)); + packets[id] = builder_->CreateInsertElement(packets.at(id), in_value, linear % vector_size); + }); + + for_each(val, [&](indices_t idx){ Value *rmw_ptr = ptrs->get_value(idx); - Value *rmw_val = vals->get_value(idx); Value *rmw_msk = msks->get_value(idx); + unsigned linear = vals->get_linear_index(idx); + unsigned id = linear / vector_size; + if(linear % vector_size != 0) + return; // num bytes + Value *rmw_val = packets[id]; + if(vector_size == 1) + rmw_val = builder_->CreateExtractElement(rmw_val, builder_->getInt32(0)); Type* ty = rmw_val->getType(); size_t nbits = ty->getScalarSizeInBits(); // extract pointer offset @@ -732,9 +756,10 @@ void generator::visit_atomic_add_inst(ir::atomic_add_inst* add) { // asm function type FunctionType *fn_ty = FunctionType::get(ty, arg_ty, false); // asm string + std::string suffix = vector_size == 2 ? "x2" : ""; std::string mod = nbits == 32 ? "" : ".noftz"; - std::string asm_str = "@$0 atom.global.sys.add" + mod + ".f" + std::to_string(nbits) + " $1, [$2" + offset + "], $3;"; - std::string ty_id = nbits == 32 ? "f" : "h"; + std::string asm_str = "@$0 atom.global.gpu.add" + mod + ".f" + std::to_string(nbits) + suffix + " $1, [$2" + offset + "], $3;"; + std::string ty_id = nbits == 32 ? "f" : (vector_size == 1 ? "h" : "r"); std::string constraint = "b,=" + ty_id + ",l," + ty_id; // create inline asm InlineAsm *iasm = InlineAsm::get(fn_ty, asm_str, constraint, true);