diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index 8bcbefdf8..577b6c9c7 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -767,10 +767,21 @@ void generator::visit_atomic_add_inst(ir::atomic_add_inst* add) { }); } else{ - BasicBlock *current = builder_->GetInsertBlock(); - Module *module = current->getModule(); Value *rmw_ptr = vmap_.at(add->get_operand(0)); Value *rmw_val = vmap_.at(add->get_operand(1)); + Value *rmw_msk = vmap_.at(add->get_operand(2)); + Type* ty = rmw_val->getType(); + size_t nbits = ty->getScalarSizeInBits(); + std::vector arg_ty = {rmw_msk->getType(), rmw_ptr->getType(), rmw_val->getType()}; + FunctionType *fn_ty = FunctionType::get(ty, arg_ty, false); + std::string mod = nbits == 32 ? "" : ".noftz"; + std::string asm_str = "@$0 atom.global.gpu.add" + mod + ".f" + std::to_string(nbits) + " $1, [$2], $3;"; + std::string ty_id = nbits == 32 ? "f" : "h"; + InlineAsm *iasm = InlineAsm::get(fn_ty, asm_str, "b,="+ty_id+",l,"+ty_id, true); + + BasicBlock *current = builder_->GetInsertBlock(); + Module *module = current->getModule(); + Value *tid = tgt_->get_local_id(module, *builder_, 0); Value *pred = builder_->CreateICmpEQ(tid, builder_->getInt32(0)); BasicBlock *tid_0_bb = BasicBlock::Create(*ctx_, "tid_0", current->getParent()); @@ -779,9 +790,7 @@ void generator::visit_atomic_add_inst(ir::atomic_add_inst* add) { tgt_->add_barrier(module, *builder_); builder_->CreateCondBr(pred, tid_0_bb, tid_0_done_bb); builder_->SetInsertPoint(tid_0_bb); - builder_->CreateAtomicRMW(AtomicRMWInst::FAdd, rmw_ptr, rmw_val, - AtomicOrdering::Monotonic, - SyncScope::System); + builder_->CreateCall(iasm, {rmw_msk, rmw_ptr, rmw_val}); builder_->CreateBr(tid_0_done_bb); builder_->SetInsertPoint(tid_0_done_bb); tgt_->add_memfence(module, *builder_); diff --git a/lib/runtime/function.cc b/lib/runtime/function.cc index 71a18027d..6f457d6c9 100644 --- a/lib/runtime/function.cc +++ b/lib/runtime/function.cc @@ -358,14 +358,19 @@ std::string function::preheader() { #define DECLARATION(TYPE, TM, TN) extern void atomic_add(TYPE, TM, TN)(TYPE*[TM, TN], TYPE[TM, TN], bool[TM, TN]) DECLARATION(float, 64, 64); +DECLARATION(float, 64, 128); +DECLARATION(float, 128, 64); +DECLARATION(float, 128, 128); +extern void atomic_add_half_1x1(half*, half, bool); + DECLARATION(half , 64, 64); DECLARATION(half , 64, 128); DECLARATION(half , 128, 64); DECLARATION(half , 128, 128); +extern void atomic_add_float_1x1(float*, float, bool); extern int atomic_cas(int*, int, int); extern int atomic_xchg(int*, int); -extern float f32_atomic_add(float*, float); extern int get_program_id(int); extern int get_num_programs(int); extern int select(bool, int, int);