[IR] Added IR and Codegen support for atomic_rmw (#120)
This commit is contained in:
committed by
Philippe Tillet
parent
59b0ac672a
commit
0274429429
@@ -403,6 +403,7 @@ void layouts::run(ir::module &mod) {
|
||||
for(const auto& x: values_)
|
||||
create(x.first, x.second);
|
||||
|
||||
|
||||
// create temporaries
|
||||
size_t id = values_.size();
|
||||
ir::for_each_instruction(mod, [this, &id](ir::instruction* i) {
|
||||
|
@@ -93,7 +93,7 @@ void add_passes_to_emit_bin(ir::module &ir, driver::device *dev, int num_warps,
|
||||
liveness.run(ir);
|
||||
allocation.run(ir);
|
||||
prefetch_s.run(ir);
|
||||
barriers.run(ir);
|
||||
barriers.run(ir);
|
||||
// ir::print(ir, std::cout);
|
||||
isel.visit(ir, *llvm);
|
||||
mod = driver::module::create(dev, std::move(llvm));
|
||||
|
@@ -813,28 +813,35 @@ void generator::visit_atomic_cas_inst(ir::atomic_cas_inst* cas) {
|
||||
Module *module = current->getModule();
|
||||
Value *tid = tgt_->get_local_id(module, *builder_, 0);
|
||||
Value *pred = icmp_eq(tid, i32(0));
|
||||
BasicBlock *tid_0_bb = BasicBlock::Create(*ctx_, "tid_0", current->getParent());
|
||||
BasicBlock *tid_0_done_bb = BasicBlock::Create(*ctx_, "tid_0_done", current->getParent());
|
||||
// BasicBlock *tid_0_bb = BasicBlock::Create(*ctx_, "tid_0", current->getParent());
|
||||
// BasicBlock *tid_0_done_bb = BasicBlock::Create(*ctx_, "tid_0_done", current->getParent());
|
||||
add_barrier();
|
||||
tgt_->add_memfence(module, *builder_);
|
||||
cond_br(pred, tid_0_bb, tid_0_done_bb);
|
||||
builder_->SetInsertPoint(tid_0_bb);
|
||||
Value *atom_ptr;
|
||||
atom_ptr = gep(shmem_, i32(alloc_->offset(layouts_->get(layouts_->tmp(cas)))), "");
|
||||
atom_ptr = bit_cast(atom_ptr, ptr_ty(cvt(cas->get_type()->get_scalar_ty()), 3));
|
||||
// cond_br(pred, tid_0_bb, tid_0_done_bb);
|
||||
// builder_->SetInsertPoint(tid_0_bb);
|
||||
Value *cas_ptr = vals_[cas->get_operand(0)][{}];
|
||||
Value *cas_cmp = vals_[cas->get_operand(1)][{}];
|
||||
Value *cas_val = vals_[cas->get_operand(2)][{}];
|
||||
Value *old = atomic_cmp_xchg(cas_ptr, cas_cmp, cas_val, AtomicOrdering::Monotonic, AtomicOrdering::Monotonic);
|
||||
old = extract_val(old, std::vector<unsigned>{0});
|
||||
Value *atom_ptr;
|
||||
atom_ptr = gep(shmem_, i32(alloc_->offset(layouts_->get(layouts_->tmp(cas)))), "");
|
||||
atom_ptr = bit_cast(atom_ptr, ptr_ty(old->getType(), 3));
|
||||
store(old, atom_ptr);
|
||||
br(tid_0_done_bb);
|
||||
builder_->SetInsertPoint(tid_0_done_bb);
|
||||
std::string asm_str = "@$1 atom.global.cas.b32 $0, [$2], $3, $4;";
|
||||
FunctionType *fn_ty = FunctionType::get(i32_ty, {pred->getType(), cas_ptr->getType(), cas_cmp->getType(), cas_val->getType()}, false);
|
||||
InlineAsm *iasm = InlineAsm::get(fn_ty, asm_str, "=r,b,l,r,r", true);
|
||||
add_barrier();
|
||||
Value *old = call(iasm, {pred, cas_ptr, cas_cmp, cas_val});
|
||||
add_barrier();
|
||||
|
||||
std::string asm2_str = "@$0 st.shared.b32 [$1], $2;";
|
||||
FunctionType *fn2_ty = FunctionType::get(void_ty, {pred->getType(), atom_ptr->getType(), old->getType()}, false);
|
||||
InlineAsm *iasm2 = InlineAsm::get(fn2_ty, asm2_str, "b,r,r", true);
|
||||
add_barrier();
|
||||
call(iasm2, {pred, atom_ptr, old});
|
||||
tgt_->add_memfence(module, *builder_);
|
||||
add_barrier();
|
||||
vals_[cas][{}] = load(atom_ptr);
|
||||
add_barrier();
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Code Generation for `atomic_exch`
|
||||
*/
|
||||
@@ -860,93 +867,86 @@ void generator::visit_atomic_exch_inst(ir::atomic_exch_inst* xchg) {
|
||||
/**
|
||||
* \brief Code Generation for `atomic_add`
|
||||
*/
|
||||
//TODO: clean-up
|
||||
void generator::visit_atomic_add_inst(ir::atomic_add_inst* add) {
|
||||
void generator::visit_atomic_rmw_inst(ir::atomic_rmw_inst *atom) {
|
||||
ir::value* ptr = atom->get_operand(0);
|
||||
ir::value* val = atom->get_operand(1);
|
||||
ir::value* msk = atom->get_operand(2);
|
||||
|
||||
if(add->get_type()->is_block_ty()){
|
||||
ir::value* ptr = add->get_operand(0);
|
||||
ir::value* val = add->get_operand(1);
|
||||
ir::value* msk = add->get_operand(2);
|
||||
|
||||
// vector size
|
||||
int vec = 1;
|
||||
// vector size
|
||||
int vec = 1;
|
||||
if(atom->get_type()->is_block_ty()){
|
||||
int ld = ords_.at(ptr)[0];
|
||||
unsigned alignment = alignment_->get(ptr, ld);
|
||||
vec = std::min<int>(layouts_->get(ptr)->to_scanline()->nts(ld), alignment);
|
||||
vec = std::min(vec, val->get_type()->get_tile_element_ty()->is_half_ty() ? 2 : 1);
|
||||
|
||||
for(int i = 0; i < idxs_.at(val).size(); i += vec){
|
||||
auto idx = idxs_[val][i];
|
||||
Value *rmw_val = UndefValue::get(vec_ty(vals_[val][idx]->getType(), vec));
|
||||
for(int ii = 0; ii < vec; ii++)
|
||||
rmw_val = insert_elt(rmw_val, vals_[val][idxs_[val][i+ii]], ii);
|
||||
Value *rmw_ptr = vals_[ptr][idx];
|
||||
Value *rmw_msk = vals_[msk][idx];
|
||||
if(vec == 1)
|
||||
rmw_val = extract_elt(rmw_val, i32(0));
|
||||
Type* ty = rmw_val->getType();
|
||||
size_t nbits = ty->getScalarSizeInBits();
|
||||
// extract pointer offset
|
||||
std::string offset = "";
|
||||
if(GetElementPtrInst *gep = dyn_cast<GetElementPtrInst>(rmw_ptr))
|
||||
if(gep->getNumIndices() == 1)
|
||||
if(ConstantInt *cst = dyn_cast<ConstantInt>(gep->idx_begin())){
|
||||
offset = " + " + std::to_string(cst->getValue().getSExtValue()*nbits/8);
|
||||
rmw_ptr = gep->getPointerOperand();
|
||||
}
|
||||
rmw_ptr = bit_cast(rmw_ptr, ty->getPointerTo(1));
|
||||
// asm argument type
|
||||
std::vector<Type*> arg_ty = {rmw_msk->getType(), rmw_ptr->getType(), rmw_val->getType()};
|
||||
// asm function type
|
||||
FunctionType *fn_ty = FunctionType::get(ty, arg_ty, false);
|
||||
// asm string
|
||||
std::string suffix = vec == 2 ? "x2" : "";
|
||||
std::string mod = nbits == 32 ? "" : ".noftz";
|
||||
std::string ty_str = add->get_type()->get_scalar_ty()->is_floating_point_ty() ? "f" : "u";
|
||||
std::string asm_str = "@$1 atom.global.gpu.add" + mod + "." + ty_str + std::to_string(nbits) + suffix + " $0, [$2" + offset + "], $3;";
|
||||
std::string ty_id = nbits == 32 ? ty_str : (vec == 1 ? "h" : "r");
|
||||
std::string constraint = "=" + ty_id + ",b,l," + ty_id;
|
||||
// create inline asm
|
||||
InlineAsm *iasm = InlineAsm::get(fn_ty, asm_str, constraint, true);
|
||||
// call asm
|
||||
vals_[add][idx] = call(iasm, (ArrayRef<Value*>{rmw_msk, rmw_ptr, rmw_val}));
|
||||
}
|
||||
}
|
||||
else{
|
||||
Value *rmw_ptr = vals_[add->get_operand(0)][{}];
|
||||
Value *rmw_val = vals_[add->get_operand(1)][{}];
|
||||
Value *rmw_msk = vals_[add->get_operand(2)][{}];
|
||||
|
||||
for(int i = 0; i < idxs_.at(val).size(); i += vec){
|
||||
auto idx = idxs_[val][i];
|
||||
Value *rmw_val = UndefValue::get(vec_ty(vals_[val][idx]->getType(), vec));
|
||||
for(int ii = 0; ii < vec; ii++)
|
||||
rmw_val = insert_elt(rmw_val, vals_[val][idxs_[val][i+ii]], ii);
|
||||
Value *rmw_ptr = vals_[ptr][idx];
|
||||
Value *rmw_msk = vals_[msk][idx];
|
||||
if(vec == 1)
|
||||
rmw_val = extract_elt(rmw_val, i32(0));
|
||||
Type* ty = rmw_val->getType();
|
||||
size_t nbits = ty->getScalarSizeInBits();
|
||||
// extract pointer offset
|
||||
std::string offset = "";
|
||||
if(GetElementPtrInst *gep = dyn_cast<GetElementPtrInst>(rmw_ptr))
|
||||
if(gep->getNumIndices() == 1)
|
||||
if(ConstantInt *cst = dyn_cast<ConstantInt>(gep->idx_begin())){
|
||||
offset = " + " + std::to_string(cst->getValue().getSExtValue()*nbits/8);
|
||||
rmw_ptr = gep->getPointerOperand();
|
||||
}
|
||||
rmw_ptr = bit_cast(rmw_ptr, ty->getPointerTo(1));
|
||||
// asm argument type
|
||||
std::vector<Type*> arg_ty = {rmw_msk->getType(), rmw_ptr->getType(), rmw_val->getType()};
|
||||
// asm function type
|
||||
FunctionType *fn_ty = FunctionType::get(ty, arg_ty, false);
|
||||
// asm string
|
||||
std::string s_nbits = std::to_string(nbits);
|
||||
std::string name;
|
||||
std::string s_ty;
|
||||
using tt = ir::atomic_rmw_op_t;
|
||||
switch(atom->get_op()){
|
||||
case tt::Or: name = "or"; s_ty = "b"; break;
|
||||
case tt::And: name = "and"; s_ty = "b"; break;
|
||||
case tt::Xor: name = "xor", s_ty = "b"; break;
|
||||
case tt::Add: name = "add" , s_ty = "s"; break;
|
||||
case tt::Min: name = "min", s_ty = "s"; break;
|
||||
case tt::Max: name = "max", s_ty = "s"; break;
|
||||
case tt::UMin: name = "min", s_ty = "u"; break;
|
||||
case tt::UMax: name = "max", s_ty = "u"; break;
|
||||
case tt::FAdd: name = "add", s_ty = "f"; break;
|
||||
}
|
||||
std::string s_vec = vec == 2 ? "x2" : "";
|
||||
std::string mod = nbits == 32 ? "" : ".noftz";
|
||||
std::string ty_str = add->get_type()->get_scalar_ty()->is_floating_point_ty() ? "f" : "u";
|
||||
std::string asm_str = "@$1 atom.global.gpu.add" + mod + "." + ty_str + std::to_string(nbits) + " $0, [$2], $3;";
|
||||
std::string ty_id = nbits == 32 ? "r" : "h";
|
||||
InlineAsm *iasm = InlineAsm::get(fn_ty, asm_str, "="+ty_id+",b,l,"+ty_id, true);
|
||||
|
||||
BasicBlock *current = builder_->GetInsertBlock();
|
||||
Module *module = current->getModule();
|
||||
|
||||
Value *tid = tgt_->get_local_id(module, *builder_, 0);
|
||||
Value *pred = icmp_eq(tid, i32(0));
|
||||
BasicBlock *tid_0_bb = BasicBlock::Create(*ctx_, "tid_0", current->getParent());
|
||||
BasicBlock *tid_0_done_bb = BasicBlock::Create(*ctx_, "tid_0_done", current->getParent());
|
||||
tgt_->add_memfence(module, *builder_);
|
||||
add_barrier();
|
||||
cond_br(pred, tid_0_bb, tid_0_done_bb);
|
||||
builder_->SetInsertPoint(tid_0_bb);
|
||||
Value *old = call(iasm, (ArrayRef<Value*>{rmw_msk, rmw_ptr, rmw_val}));
|
||||
Value *atom_ptr;
|
||||
atom_ptr = gep(shmem_, i32(alloc_->offset(layouts_->get(layouts_->tmp(add)))), "");
|
||||
atom_ptr = bit_cast(atom_ptr, ptr_ty(old->getType(), 3));
|
||||
store(old, atom_ptr);
|
||||
br(tid_0_done_bb);
|
||||
builder_->SetInsertPoint(tid_0_done_bb);
|
||||
tgt_->add_memfence(module, *builder_);
|
||||
add_barrier();
|
||||
vals_[add][{}] = load(atom_ptr);
|
||||
std::string asm_str = "@$1 atom.global.gpu." + name + mod + "." + s_ty + s_nbits + s_vec + " $0, [$2" + offset + "], $3;";
|
||||
std::string ty_id = nbits*vec == 32 ? "r" : "h";
|
||||
std::string constraint = "=" + ty_id + ",b,l," + ty_id;
|
||||
// create inline asm
|
||||
InlineAsm *iasm = InlineAsm::get(fn_ty, asm_str, constraint, true);
|
||||
// call asm
|
||||
if(atom->get_type()->is_block_ty())
|
||||
vals_[atom][idx] = call(iasm, (ArrayRef<Value*>{rmw_msk, rmw_ptr, rmw_val}));
|
||||
else{
|
||||
Module *mod = builder_->GetInsertBlock()->getModule();
|
||||
tgt_->add_memfence(mod, *builder_);
|
||||
add_barrier();
|
||||
Value *tid = tgt_->get_local_id(mod, *builder_, 0);
|
||||
rmw_msk = builder_->CreateAnd(rmw_msk, icmp_eq(tid, i32(0)));
|
||||
Value *old = call(iasm, (ArrayRef<Value*>{rmw_msk, rmw_ptr, rmw_val}));
|
||||
Value *atom_ptr;
|
||||
atom_ptr = gep(shmem_, i32(alloc_->offset(layouts_->get(layouts_->tmp(atom)))), "");
|
||||
atom_ptr = bit_cast(atom_ptr, ptr_ty(old->getType(), 3));
|
||||
store(old, atom_ptr);
|
||||
add_barrier();
|
||||
vals_[atom][idx] = load(atom_ptr);
|
||||
add_barrier();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -25,8 +25,8 @@ void dce::run(ir::module &mod) {
|
||||
case ir::INST_COND_BRANCH:
|
||||
case ir::INST_UNMASKED_STORE:
|
||||
case ir::INST_MASKED_STORE:
|
||||
case ir::INST_ATOMIC_ADD:
|
||||
case ir::INST_ATOMIC_CAS:
|
||||
case ir::INST_ATOMIC_RMW:
|
||||
case ir::INST_ATOMIC_EXCH:
|
||||
case ir::INST_BARRIER: {
|
||||
work_list.push_back(i);
|
||||
|
@@ -323,8 +323,8 @@ value *builder::create_atomic_exch(value *ptr, value *val){
|
||||
return insert(atomic_exch_inst::create(ptr, val));
|
||||
}
|
||||
|
||||
value *builder::create_atomic_add(value *ptr, value *val, value *msk){
|
||||
return insert(atomic_add_inst::create(ptr, val, msk));
|
||||
value *builder::create_atomic_rmw(ir::atomic_rmw_op_t op, value *ptr, value *val, value *msk){
|
||||
return insert(atomic_rmw_inst::create(op, ptr, val, msk));
|
||||
}
|
||||
|
||||
value *builder::create_exp(value *arg){
|
||||
|
@@ -538,7 +538,7 @@ ir::value *dispatch::atomic_xchg(ir::value* ptr, ir::value *val, ir::builder *bu
|
||||
return builder->create_atomic_exch(ptr, val);
|
||||
}
|
||||
|
||||
ir::value *dispatch::atomic_add(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){
|
||||
void atom_red_typechecking(ir::value*& ptr, ir::value *&val, ir::value *&mask, ir::builder *builder){
|
||||
if(!ptr->get_type()->get_scalar_ty()->is_pointer_ty())
|
||||
throw semantic_error("Pointer argument of store instruction is " + ptr->get_type()->repr());
|
||||
if(ptr->get_type()->is_block_ty()){
|
||||
@@ -550,14 +550,69 @@ ir::value *dispatch::atomic_add(ir::value* ptr, ir::value *val, ir::value *mask,
|
||||
}
|
||||
}
|
||||
val = dispatch::cast(val, ptr->get_type()->get_scalar_ty()->get_pointer_element_ty(), builder);
|
||||
|
||||
if(!mask){
|
||||
mask = builder->get_int1(true);
|
||||
if(ptr->get_type()->is_block_ty())
|
||||
mask = builder->create_splat(mask, ptr->get_type()->get_block_shapes());
|
||||
}
|
||||
}
|
||||
|
||||
return builder->create_atomic_add(ptr, val, mask);
|
||||
ir::value *dispatch::atomic_max(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){
|
||||
atom_red_typechecking(ptr, val, mask, builder);
|
||||
ir::type* sca_ty = val->get_type()->get_scalar_ty();
|
||||
// direct call to atomic_max for integers
|
||||
if(sca_ty->is_integer_ty())
|
||||
return builder->create_atomic_rmw(ir::atomic_rmw_op_t::Max, ptr, val, mask);
|
||||
// for float
|
||||
// return atomic_smax(i_ptr, i_val) if val >= 0
|
||||
// return atomic_umin(i_ptr, i_val) if val < 0
|
||||
ir::value* i_val = bitcast(val, builder->get_int32_ty(), builder);
|
||||
ir::value* i_ptr = bitcast(ptr, pointer_type::get(builder->get_int32_ty(), 1), builder);
|
||||
ir::value* pos = greater_equal(val, constant_fp::get(sca_ty, 0), builder);
|
||||
ir::value* neg = less_than(val, constant_fp::get(sca_ty, 0), builder);
|
||||
ir::value* pos_ret = builder->create_atomic_rmw(ir::atomic_rmw_op_t::Max, i_ptr, i_val, and_(mask, pos, builder));
|
||||
ir::value* neg_ret = builder->create_atomic_rmw(ir::atomic_rmw_op_t::UMin, i_ptr, i_val, and_(mask, neg, builder));
|
||||
return where(pos, pos_ret, neg_ret, builder);
|
||||
}
|
||||
|
||||
ir::value *dispatch::atomic_min(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){
|
||||
atom_red_typechecking(ptr, val, mask, builder);
|
||||
ir::type* sca_ty = val->get_type()->get_scalar_ty();
|
||||
// direct call to atomic_max for integers
|
||||
if(sca_ty->is_integer_ty())
|
||||
return builder->create_atomic_rmw(ir::atomic_rmw_op_t::Min, ptr, val, mask);
|
||||
// for float
|
||||
// return atomic_smin(i_ptr, i_val) if val >= 0
|
||||
// return atomic_umax(i_ptr, i_val) if val < 0
|
||||
ir::value* i_val = bitcast(val, builder->get_int32_ty(), builder);
|
||||
ir::value* i_ptr = bitcast(ptr, pointer_type::get(builder->get_int32_ty(), 1), builder);
|
||||
ir::value* pos = greater_equal(val, constant_fp::get(sca_ty, 0), builder);
|
||||
ir::value* neg = less_than(val, constant_fp::get(sca_ty, 0), builder);
|
||||
ir::value* pos_ret = builder->create_atomic_rmw(ir::atomic_rmw_op_t::Min, i_ptr, i_val, and_(mask, pos, builder));
|
||||
ir::value* neg_ret = builder->create_atomic_rmw(ir::atomic_rmw_op_t::UMax, i_ptr, i_val, and_(mask, neg, builder));
|
||||
return where(pos, pos_ret, neg_ret, builder);
|
||||
}
|
||||
|
||||
ir::value *dispatch::atomic_add(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){
|
||||
atom_red_typechecking(ptr, val, mask, builder);
|
||||
ir::type* sca_ty = val->get_type()->get_scalar_ty();
|
||||
auto op = sca_ty->is_floating_point_ty() ? ir::atomic_rmw_op_t::FAdd : ir::atomic_rmw_op_t::Add;
|
||||
return builder->create_atomic_rmw(op, ptr, val, mask);
|
||||
}
|
||||
|
||||
ir::value *dispatch::atomic_and(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){
|
||||
atom_red_typechecking(ptr, val, mask, builder);
|
||||
return builder->create_atomic_rmw(ir::atomic_rmw_op_t::And, ptr, val, mask);
|
||||
}
|
||||
|
||||
ir::value *dispatch::atomic_or(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){
|
||||
atom_red_typechecking(ptr, val, mask, builder);
|
||||
return builder->create_atomic_rmw(ir::atomic_rmw_op_t::Or, ptr, val, mask);
|
||||
}
|
||||
|
||||
ir::value *dispatch::atomic_xor(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){
|
||||
atom_red_typechecking(ptr, val, mask, builder);
|
||||
return builder->create_atomic_rmw(ir::atomic_rmw_op_t::Xor, ptr, val, mask);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@@ -730,17 +730,20 @@ instruction* get_num_programs_inst::create(context &ctx, unsigned axis, const st
|
||||
return new get_num_programs_inst(type::get_int32_ty(ctx), axis, name, next);
|
||||
}
|
||||
|
||||
// atomic_rmw
|
||||
|
||||
atomic_add_inst::atomic_add_inst(value *ptr, value *val, value *msk, const std::string &name, instruction *next)
|
||||
: atomic_inst(ptr->get_type()->get_pointer_element_ty(), INST_ATOMIC_ADD, 3, name, next) {
|
||||
atomic_rmw_inst::atomic_rmw_inst(atomic_rmw_op_t op, value *ptr, value *val, value *msk, const std::string &name, instruction *next)
|
||||
: atomic_inst(ptr->get_type()->get_pointer_element_ty(), INST_ATOMIC_RMW, 3, name, next), op_(op) {
|
||||
set_operand(0, ptr);
|
||||
set_operand(1, val);
|
||||
set_operand(2, msk);
|
||||
}
|
||||
|
||||
instruction* atomic_add_inst::create(value *ptr, value *val, value *msk, const std::string &name, instruction *next) {
|
||||
return new atomic_add_inst(ptr, val, msk, name, next);
|
||||
instruction* atomic_rmw_inst::create(atomic_rmw_op_t op, value *ptr, value *val, value *msk, const std::string &name, instruction *next) {
|
||||
return new atomic_rmw_inst(op, ptr, val, msk, name, next);
|
||||
}
|
||||
|
||||
|
||||
// atomic cas
|
||||
|
||||
atomic_cas_inst::atomic_cas_inst(value *ptr, value *cmp, value *val, const std::string &name, instruction *next)
|
||||
|
Reference in New Issue
Block a user