[CODEGEN] Bugfix in Disassociate pass; Added fp32 atomic_add support

This commit is contained in:
Philippe Tillet
2020-05-13 23:21:21 -04:00
committed by Philippe Tillet
parent bb2d98ce4b
commit e7461a862b
4 changed files with 32 additions and 4 deletions

View File

@@ -618,8 +618,25 @@ void generator::visit_atomic_exch_inst(ir::atomic_exch_inst* xchg) {
tgt_->add_memfence(module, *builder_);
}
void generator::visit_atomic_add_inst(ir::atomic_add_inst*) {
throw std::runtime_error("unsupported");
void generator::visit_atomic_add_inst(ir::atomic_add_inst* add) {
BasicBlock *current = builder_->GetInsertBlock();
Module *module = current->getModule();
Value *rmw_ptr = vmap_.at(add->get_operand(0));
Value *rmw_val = vmap_.at(add->get_operand(1));
Value *tid = tgt_->get_local_id(module, *builder_, 0);
Value *pred = builder_->CreateICmpEQ(tid, builder_->getInt32(0));
BasicBlock *tid_0_bb = BasicBlock::Create(*ctx_, "tid_0", current->getParent());
BasicBlock *tid_0_done_bb = BasicBlock::Create(*ctx_, "tid_0_done", current->getParent());
tgt_->add_memfence(module, *builder_);
tgt_->add_barrier(module, *builder_);
builder_->CreateCondBr(pred, tid_0_bb, tid_0_done_bb);
builder_->SetInsertPoint(tid_0_bb);
builder_->CreateAtomicRMW(AtomicRMWInst::FAdd, rmw_ptr, rmw_val,
AtomicOrdering::Monotonic,
SyncScope::System);
builder_->CreateBr(tid_0_done_bb);
builder_->SetInsertPoint(tid_0_done_bb);
tgt_->add_memfence(module, *builder_);
}
void generator::visit_hmma_dot(ir::dot_inst* dot, shared_tile *TA, shared_tile *TB, distributed_tile *TD, unsigned NK) {

View File

@@ -34,10 +34,13 @@ void disassociate::run(ir::module &mod) {
std::map<ir::user*, std::map<int, std::set<ir::user*>>> clone_info;
ir::for_each_instruction(mod, [&](ir::instruction *i){
if(dynamic_cast<ir::reshape_inst*>(i)){
ir::value* op = i->get_operand(0);
if(!dynamic_cast<ir::user*>(op))
return;
if(op->get_type()->get_tile_rank() > i->get_type()->get_tile_rank())
return;
std::map<int, std::set<ir::user*>> chains;
std::set<ir::value*> seen;
if(!dynamic_cast<ir::user*>(i->get_operand(0)))
return;
extract_retile_chain(i, chains, 0, seen);
if(chains.size())
clone_info[i] = chains;

View File

@@ -277,6 +277,13 @@ void Generator::VisitFuncCall(FuncCall* funcCall) {
ir::value* val = ret_;
return set_ret(bld_->create_atomic_exch(ptr, val));
}
if(name == "f32_atomic_add"){
VisitExpr(funcCall->Args()->at(0));
ir::value* ptr = ret_;
VisitExpr(funcCall->Args()->at(1));
ir::value* val = ret_;
return set_ret(bld_->create_atomic_add(ptr, val));
}
if(name == "sqrtf"){
VisitExpr(funcCall->Args()->at(0));
ir::value* ret = ret_;

View File

@@ -399,6 +399,7 @@ std::string function::preheader() {
extern int atomic_cas(int*, int, int);
extern int atomic_xchg(int*, int);
extern float f32_atomic_add(float*, float);
extern int get_program_id(int);
extern int get_num_programs(int);
extern float sqrtf(float);