diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index 92674b7db..0f69f90cd 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -618,8 +618,25 @@ void generator::visit_atomic_exch_inst(ir::atomic_exch_inst* xchg) { tgt_->add_memfence(module, *builder_); } -void generator::visit_atomic_add_inst(ir::atomic_add_inst*) { - throw std::runtime_error("unsupported"); +void generator::visit_atomic_add_inst(ir::atomic_add_inst* add) { + BasicBlock *current = builder_->GetInsertBlock(); + Module *module = current->getModule(); + Value *rmw_ptr = vmap_.at(add->get_operand(0)); + Value *rmw_val = vmap_.at(add->get_operand(1)); + Value *tid = tgt_->get_local_id(module, *builder_, 0); + Value *pred = builder_->CreateICmpEQ(tid, builder_->getInt32(0)); + BasicBlock *tid_0_bb = BasicBlock::Create(*ctx_, "tid_0", current->getParent()); + BasicBlock *tid_0_done_bb = BasicBlock::Create(*ctx_, "tid_0_done", current->getParent()); + tgt_->add_memfence(module, *builder_); + tgt_->add_barrier(module, *builder_); + builder_->CreateCondBr(pred, tid_0_bb, tid_0_done_bb); + builder_->SetInsertPoint(tid_0_bb); + builder_->CreateAtomicRMW(AtomicRMWInst::FAdd, rmw_ptr, rmw_val, + AtomicOrdering::Monotonic, + SyncScope::System); + builder_->CreateBr(tid_0_done_bb); + builder_->SetInsertPoint(tid_0_done_bb); + tgt_->add_memfence(module, *builder_); } void generator::visit_hmma_dot(ir::dot_inst* dot, shared_tile *TA, shared_tile *TB, distributed_tile *TD, unsigned NK) { diff --git a/lib/codegen/transform/disassociate.cc b/lib/codegen/transform/disassociate.cc index 2244ebccd..1384e2102 100644 --- a/lib/codegen/transform/disassociate.cc +++ b/lib/codegen/transform/disassociate.cc @@ -34,10 +34,13 @@ void disassociate::run(ir::module &mod) { std::map>> clone_info; ir::for_each_instruction(mod, [&](ir::instruction *i){ if(dynamic_cast(i)){ + ir::value* op = i->get_operand(0); + if(!dynamic_cast(op)) + return; + if(op->get_type()->get_tile_rank() > i->get_type()->get_tile_rank()) + return; std::map> chains; std::set seen; - if(!dynamic_cast(i->get_operand(0))) - return; extract_retile_chain(i, chains, 0, seen); if(chains.size()) clone_info[i] = chains; diff --git a/lib/lang/code_gen.cc b/lib/lang/code_gen.cc index 517553e97..de11ab646 100644 --- a/lib/lang/code_gen.cc +++ b/lib/lang/code_gen.cc @@ -277,6 +277,13 @@ void Generator::VisitFuncCall(FuncCall* funcCall) { ir::value* val = ret_; return set_ret(bld_->create_atomic_exch(ptr, val)); } + if(name == "f32_atomic_add"){ + VisitExpr(funcCall->Args()->at(0)); + ir::value* ptr = ret_; + VisitExpr(funcCall->Args()->at(1)); + ir::value* val = ret_; + return set_ret(bld_->create_atomic_add(ptr, val)); + } if(name == "sqrtf"){ VisitExpr(funcCall->Args()->at(0)); ir::value* ret = ret_; diff --git a/lib/runtime/function.cc b/lib/runtime/function.cc index 657004a32..406dd89a4 100644 --- a/lib/runtime/function.cc +++ b/lib/runtime/function.cc @@ -399,6 +399,7 @@ std::string function::preheader() { extern int atomic_cas(int*, int, int); extern int atomic_xchg(int*, int); +extern float f32_atomic_add(float*, float); extern int get_program_id(int); extern int get_num_programs(int); extern float sqrtf(float);