[CODEGEN] Progress on atom.add.f16x2
This commit is contained in:
committed by
Philippe Tillet
parent
a77c925dfd
commit
da287bb710
@@ -711,11 +711,35 @@ void generator::visit_atomic_add_inst(ir::atomic_add_inst* add) {
|
|||||||
distributed_tile* vals = (distributed_tile*)tmap_.at(val);
|
distributed_tile* vals = (distributed_tile*)tmap_.at(val);
|
||||||
distributed_tile* msks = (distributed_tile*)tmap_.at(msk);
|
distributed_tile* msks = (distributed_tile*)tmap_.at(msk);
|
||||||
|
|
||||||
for_each(ptr, [&](indices_t idx){
|
// vector size
|
||||||
|
int vector_size = 1;
|
||||||
|
int ld = ptrs->get_order()[0];
|
||||||
|
unsigned alignment = alignment_->get(ptr, ld);
|
||||||
|
vector_size = gcd(ptrs->axis(ld).contiguous, alignment);
|
||||||
|
vector_size = std::min(vector_size, val->get_type()->get_tile_element_ty()->is_half_ty() ? 2 : 1);
|
||||||
|
vector_size = 1;
|
||||||
|
|
||||||
|
std::map<unsigned, Value*> packets;
|
||||||
|
for_each(val, [&](indices_t idx){
|
||||||
|
unsigned linear = vals->get_linear_index(idx);
|
||||||
|
unsigned id = linear / vector_size;
|
||||||
|
Value *in_value = vals->get_value(idx);
|
||||||
|
if(linear % vector_size == 0)
|
||||||
|
packets[id] = UndefValue::get(VectorType::get(in_value->getType(), vector_size));
|
||||||
|
packets[id] = builder_->CreateInsertElement(packets.at(id), in_value, linear % vector_size);
|
||||||
|
});
|
||||||
|
|
||||||
|
for_each(val, [&](indices_t idx){
|
||||||
Value *rmw_ptr = ptrs->get_value(idx);
|
Value *rmw_ptr = ptrs->get_value(idx);
|
||||||
Value *rmw_val = vals->get_value(idx);
|
|
||||||
Value *rmw_msk = msks->get_value(idx);
|
Value *rmw_msk = msks->get_value(idx);
|
||||||
|
unsigned linear = vals->get_linear_index(idx);
|
||||||
|
unsigned id = linear / vector_size;
|
||||||
|
if(linear % vector_size != 0)
|
||||||
|
return;
|
||||||
// num bytes
|
// num bytes
|
||||||
|
Value *rmw_val = packets[id];
|
||||||
|
if(vector_size == 1)
|
||||||
|
rmw_val = builder_->CreateExtractElement(rmw_val, builder_->getInt32(0));
|
||||||
Type* ty = rmw_val->getType();
|
Type* ty = rmw_val->getType();
|
||||||
size_t nbits = ty->getScalarSizeInBits();
|
size_t nbits = ty->getScalarSizeInBits();
|
||||||
// extract pointer offset
|
// extract pointer offset
|
||||||
@@ -732,9 +756,10 @@ void generator::visit_atomic_add_inst(ir::atomic_add_inst* add) {
|
|||||||
// asm function type
|
// asm function type
|
||||||
FunctionType *fn_ty = FunctionType::get(ty, arg_ty, false);
|
FunctionType *fn_ty = FunctionType::get(ty, arg_ty, false);
|
||||||
// asm string
|
// asm string
|
||||||
|
std::string suffix = vector_size == 2 ? "x2" : "";
|
||||||
std::string mod = nbits == 32 ? "" : ".noftz";
|
std::string mod = nbits == 32 ? "" : ".noftz";
|
||||||
std::string asm_str = "@$0 atom.global.sys.add" + mod + ".f" + std::to_string(nbits) + " $1, [$2" + offset + "], $3;";
|
std::string asm_str = "@$0 atom.global.gpu.add" + mod + ".f" + std::to_string(nbits) + suffix + " $1, [$2" + offset + "], $3;";
|
||||||
std::string ty_id = nbits == 32 ? "f" : "h";
|
std::string ty_id = nbits == 32 ? "f" : (vector_size == 1 ? "h" : "r");
|
||||||
std::string constraint = "b,=" + ty_id + ",l," + ty_id;
|
std::string constraint = "b,=" + ty_id + ",l," + ty_id;
|
||||||
// create inline asm
|
// create inline asm
|
||||||
InlineAsm *iasm = InlineAsm::get(fn_ty, asm_str, constraint, true);
|
InlineAsm *iasm = InlineAsm::get(fn_ty, asm_str, constraint, true);
|
||||||
|
Reference in New Issue
Block a user