[CODEGEN] Progress on atom.add.f16x2

This commit is contained in:
Philippe Tillet
2020-11-12 16:48:04 -05:00
committed by Philippe Tillet
parent a77c925dfd
commit da287bb710

View File

@@ -711,11 +711,35 @@ void generator::visit_atomic_add_inst(ir::atomic_add_inst* add) {
distributed_tile* vals = (distributed_tile*)tmap_.at(val); distributed_tile* vals = (distributed_tile*)tmap_.at(val);
distributed_tile* msks = (distributed_tile*)tmap_.at(msk); distributed_tile* msks = (distributed_tile*)tmap_.at(msk);
for_each(ptr, [&](indices_t idx){ // vector size
int vector_size = 1;
int ld = ptrs->get_order()[0];
unsigned alignment = alignment_->get(ptr, ld);
vector_size = gcd(ptrs->axis(ld).contiguous, alignment);
vector_size = std::min(vector_size, val->get_type()->get_tile_element_ty()->is_half_ty() ? 2 : 1);
vector_size = 1;
std::map<unsigned, Value*> packets;
for_each(val, [&](indices_t idx){
unsigned linear = vals->get_linear_index(idx);
unsigned id = linear / vector_size;
Value *in_value = vals->get_value(idx);
if(linear % vector_size == 0)
packets[id] = UndefValue::get(VectorType::get(in_value->getType(), vector_size));
packets[id] = builder_->CreateInsertElement(packets.at(id), in_value, linear % vector_size);
});
for_each(val, [&](indices_t idx){
Value *rmw_ptr = ptrs->get_value(idx); Value *rmw_ptr = ptrs->get_value(idx);
Value *rmw_val = vals->get_value(idx);
Value *rmw_msk = msks->get_value(idx); Value *rmw_msk = msks->get_value(idx);
unsigned linear = vals->get_linear_index(idx);
unsigned id = linear / vector_size;
if(linear % vector_size != 0)
return;
// num bytes // num bytes
Value *rmw_val = packets[id];
if(vector_size == 1)
rmw_val = builder_->CreateExtractElement(rmw_val, builder_->getInt32(0));
Type* ty = rmw_val->getType(); Type* ty = rmw_val->getType();
size_t nbits = ty->getScalarSizeInBits(); size_t nbits = ty->getScalarSizeInBits();
// extract pointer offset // extract pointer offset
@@ -732,9 +756,10 @@ void generator::visit_atomic_add_inst(ir::atomic_add_inst* add) {
// asm function type // asm function type
FunctionType *fn_ty = FunctionType::get(ty, arg_ty, false); FunctionType *fn_ty = FunctionType::get(ty, arg_ty, false);
// asm string // asm string
std::string suffix = vector_size == 2 ? "x2" : "";
std::string mod = nbits == 32 ? "" : ".noftz"; std::string mod = nbits == 32 ? "" : ".noftz";
std::string asm_str = "@$0 atom.global.sys.add" + mod + ".f" + std::to_string(nbits) + " $1, [$2" + offset + "], $3;"; std::string asm_str = "@$0 atom.global.gpu.add" + mod + ".f" + std::to_string(nbits) + suffix + " $1, [$2" + offset + "], $3;";
std::string ty_id = nbits == 32 ? "f" : "h"; std::string ty_id = nbits == 32 ? "f" : (vector_size == 1 ? "h" : "r");
std::string constraint = "b,=" + ty_id + ",l," + ty_id; std::string constraint = "b,=" + ty_id + ",l," + ty_id;
// create inline asm // create inline asm
InlineAsm *iasm = InlineAsm::get(fn_ty, asm_str, constraint, true); InlineAsm *iasm = InlineAsm::get(fn_ty, asm_str, constraint, true);