[CODEGEN] Bugfix in Disassociate pass; Added fp32 atomic_add support
This commit is contained in:
committed by
Philippe Tillet
parent
bb2d98ce4b
commit
e7461a862b
@@ -618,8 +618,25 @@ void generator::visit_atomic_exch_inst(ir::atomic_exch_inst* xchg) {
|
||||
tgt_->add_memfence(module, *builder_);
|
||||
}
|
||||
|
||||
void generator::visit_atomic_add_inst(ir::atomic_add_inst*) {
|
||||
throw std::runtime_error("unsupported");
|
||||
void generator::visit_atomic_add_inst(ir::atomic_add_inst* add) {
|
||||
BasicBlock *current = builder_->GetInsertBlock();
|
||||
Module *module = current->getModule();
|
||||
Value *rmw_ptr = vmap_.at(add->get_operand(0));
|
||||
Value *rmw_val = vmap_.at(add->get_operand(1));
|
||||
Value *tid = tgt_->get_local_id(module, *builder_, 0);
|
||||
Value *pred = builder_->CreateICmpEQ(tid, builder_->getInt32(0));
|
||||
BasicBlock *tid_0_bb = BasicBlock::Create(*ctx_, "tid_0", current->getParent());
|
||||
BasicBlock *tid_0_done_bb = BasicBlock::Create(*ctx_, "tid_0_done", current->getParent());
|
||||
tgt_->add_memfence(module, *builder_);
|
||||
tgt_->add_barrier(module, *builder_);
|
||||
builder_->CreateCondBr(pred, tid_0_bb, tid_0_done_bb);
|
||||
builder_->SetInsertPoint(tid_0_bb);
|
||||
builder_->CreateAtomicRMW(AtomicRMWInst::FAdd, rmw_ptr, rmw_val,
|
||||
AtomicOrdering::Monotonic,
|
||||
SyncScope::System);
|
||||
builder_->CreateBr(tid_0_done_bb);
|
||||
builder_->SetInsertPoint(tid_0_done_bb);
|
||||
tgt_->add_memfence(module, *builder_);
|
||||
}
|
||||
|
||||
void generator::visit_hmma_dot(ir::dot_inst* dot, shared_tile *TA, shared_tile *TB, distributed_tile *TD, unsigned NK) {
|
||||
|
@@ -34,10 +34,13 @@ void disassociate::run(ir::module &mod) {
|
||||
std::map<ir::user*, std::map<int, std::set<ir::user*>>> clone_info;
|
||||
ir::for_each_instruction(mod, [&](ir::instruction *i){
|
||||
if(dynamic_cast<ir::reshape_inst*>(i)){
|
||||
ir::value* op = i->get_operand(0);
|
||||
if(!dynamic_cast<ir::user*>(op))
|
||||
return;
|
||||
if(op->get_type()->get_tile_rank() > i->get_type()->get_tile_rank())
|
||||
return;
|
||||
std::map<int, std::set<ir::user*>> chains;
|
||||
std::set<ir::value*> seen;
|
||||
if(!dynamic_cast<ir::user*>(i->get_operand(0)))
|
||||
return;
|
||||
extract_retile_chain(i, chains, 0, seen);
|
||||
if(chains.size())
|
||||
clone_info[i] = chains;
|
||||
|
@@ -277,6 +277,13 @@ void Generator::VisitFuncCall(FuncCall* funcCall) {
|
||||
ir::value* val = ret_;
|
||||
return set_ret(bld_->create_atomic_exch(ptr, val));
|
||||
}
|
||||
if(name == "f32_atomic_add"){
|
||||
VisitExpr(funcCall->Args()->at(0));
|
||||
ir::value* ptr = ret_;
|
||||
VisitExpr(funcCall->Args()->at(1));
|
||||
ir::value* val = ret_;
|
||||
return set_ret(bld_->create_atomic_add(ptr, val));
|
||||
}
|
||||
if(name == "sqrtf"){
|
||||
VisitExpr(funcCall->Args()->at(0));
|
||||
ir::value* ret = ret_;
|
||||
|
@@ -399,6 +399,7 @@ std::string function::preheader() {
|
||||
|
||||
extern int atomic_cas(int*, int, int);
|
||||
extern int atomic_xchg(int*, int);
|
||||
extern float f32_atomic_add(float*, float);
|
||||
extern int get_program_id(int);
|
||||
extern int get_num_programs(int);
|
||||
extern float sqrtf(float);
|
||||
|
Reference in New Issue
Block a user