[IR] Added IR and Codegen support for atomic_rmw (#120)
This commit is contained in:
committed by
Philippe Tillet
parent
59b0ac672a
commit
0274429429
@@ -323,8 +323,8 @@ value *builder::create_atomic_exch(value *ptr, value *val){
|
||||
return insert(atomic_exch_inst::create(ptr, val));
|
||||
}
|
||||
|
||||
value *builder::create_atomic_add(value *ptr, value *val, value *msk){
|
||||
return insert(atomic_add_inst::create(ptr, val, msk));
|
||||
value *builder::create_atomic_rmw(ir::atomic_rmw_op_t op, value *ptr, value *val, value *msk){
|
||||
return insert(atomic_rmw_inst::create(op, ptr, val, msk));
|
||||
}
|
||||
|
||||
value *builder::create_exp(value *arg){
|
||||
|
@@ -538,7 +538,7 @@ ir::value *dispatch::atomic_xchg(ir::value* ptr, ir::value *val, ir::builder *bu
|
||||
return builder->create_atomic_exch(ptr, val);
|
||||
}
|
||||
|
||||
ir::value *dispatch::atomic_add(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){
|
||||
void atom_red_typechecking(ir::value*& ptr, ir::value *&val, ir::value *&mask, ir::builder *builder){
|
||||
if(!ptr->get_type()->get_scalar_ty()->is_pointer_ty())
|
||||
throw semantic_error("Pointer argument of store instruction is " + ptr->get_type()->repr());
|
||||
if(ptr->get_type()->is_block_ty()){
|
||||
@@ -550,14 +550,69 @@ ir::value *dispatch::atomic_add(ir::value* ptr, ir::value *val, ir::value *mask,
|
||||
}
|
||||
}
|
||||
val = dispatch::cast(val, ptr->get_type()->get_scalar_ty()->get_pointer_element_ty(), builder);
|
||||
|
||||
if(!mask){
|
||||
mask = builder->get_int1(true);
|
||||
if(ptr->get_type()->is_block_ty())
|
||||
mask = builder->create_splat(mask, ptr->get_type()->get_block_shapes());
|
||||
}
|
||||
}
|
||||
|
||||
return builder->create_atomic_add(ptr, val, mask);
|
||||
ir::value *dispatch::atomic_max(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){
|
||||
atom_red_typechecking(ptr, val, mask, builder);
|
||||
ir::type* sca_ty = val->get_type()->get_scalar_ty();
|
||||
// direct call to atomic_max for integers
|
||||
if(sca_ty->is_integer_ty())
|
||||
return builder->create_atomic_rmw(ir::atomic_rmw_op_t::Max, ptr, val, mask);
|
||||
// for float
|
||||
// return atomic_smax(i_ptr, i_val) if val >= 0
|
||||
// return atomic_umin(i_ptr, i_val) if val < 0
|
||||
ir::value* i_val = bitcast(val, builder->get_int32_ty(), builder);
|
||||
ir::value* i_ptr = bitcast(ptr, pointer_type::get(builder->get_int32_ty(), 1), builder);
|
||||
ir::value* pos = greater_equal(val, constant_fp::get(sca_ty, 0), builder);
|
||||
ir::value* neg = less_than(val, constant_fp::get(sca_ty, 0), builder);
|
||||
ir::value* pos_ret = builder->create_atomic_rmw(ir::atomic_rmw_op_t::Max, i_ptr, i_val, and_(mask, pos, builder));
|
||||
ir::value* neg_ret = builder->create_atomic_rmw(ir::atomic_rmw_op_t::UMin, i_ptr, i_val, and_(mask, neg, builder));
|
||||
return where(pos, pos_ret, neg_ret, builder);
|
||||
}
|
||||
|
||||
ir::value *dispatch::atomic_min(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){
|
||||
atom_red_typechecking(ptr, val, mask, builder);
|
||||
ir::type* sca_ty = val->get_type()->get_scalar_ty();
|
||||
// direct call to atomic_max for integers
|
||||
if(sca_ty->is_integer_ty())
|
||||
return builder->create_atomic_rmw(ir::atomic_rmw_op_t::Min, ptr, val, mask);
|
||||
// for float
|
||||
// return atomic_smin(i_ptr, i_val) if val >= 0
|
||||
// return atomic_umax(i_ptr, i_val) if val < 0
|
||||
ir::value* i_val = bitcast(val, builder->get_int32_ty(), builder);
|
||||
ir::value* i_ptr = bitcast(ptr, pointer_type::get(builder->get_int32_ty(), 1), builder);
|
||||
ir::value* pos = greater_equal(val, constant_fp::get(sca_ty, 0), builder);
|
||||
ir::value* neg = less_than(val, constant_fp::get(sca_ty, 0), builder);
|
||||
ir::value* pos_ret = builder->create_atomic_rmw(ir::atomic_rmw_op_t::Min, i_ptr, i_val, and_(mask, pos, builder));
|
||||
ir::value* neg_ret = builder->create_atomic_rmw(ir::atomic_rmw_op_t::UMax, i_ptr, i_val, and_(mask, neg, builder));
|
||||
return where(pos, pos_ret, neg_ret, builder);
|
||||
}
|
||||
|
||||
ir::value *dispatch::atomic_add(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){
|
||||
atom_red_typechecking(ptr, val, mask, builder);
|
||||
ir::type* sca_ty = val->get_type()->get_scalar_ty();
|
||||
auto op = sca_ty->is_floating_point_ty() ? ir::atomic_rmw_op_t::FAdd : ir::atomic_rmw_op_t::Add;
|
||||
return builder->create_atomic_rmw(op, ptr, val, mask);
|
||||
}
|
||||
|
||||
ir::value *dispatch::atomic_and(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){
|
||||
atom_red_typechecking(ptr, val, mask, builder);
|
||||
return builder->create_atomic_rmw(ir::atomic_rmw_op_t::And, ptr, val, mask);
|
||||
}
|
||||
|
||||
ir::value *dispatch::atomic_or(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){
|
||||
atom_red_typechecking(ptr, val, mask, builder);
|
||||
return builder->create_atomic_rmw(ir::atomic_rmw_op_t::Or, ptr, val, mask);
|
||||
}
|
||||
|
||||
ir::value *dispatch::atomic_xor(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){
|
||||
atom_red_typechecking(ptr, val, mask, builder);
|
||||
return builder->create_atomic_rmw(ir::atomic_rmw_op_t::Xor, ptr, val, mask);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@@ -730,17 +730,20 @@ instruction* get_num_programs_inst::create(context &ctx, unsigned axis, const st
|
||||
return new get_num_programs_inst(type::get_int32_ty(ctx), axis, name, next);
|
||||
}
|
||||
|
||||
// atomic_rmw
|
||||
|
||||
atomic_add_inst::atomic_add_inst(value *ptr, value *val, value *msk, const std::string &name, instruction *next)
|
||||
: atomic_inst(ptr->get_type()->get_pointer_element_ty(), INST_ATOMIC_ADD, 3, name, next) {
|
||||
atomic_rmw_inst::atomic_rmw_inst(atomic_rmw_op_t op, value *ptr, value *val, value *msk, const std::string &name, instruction *next)
|
||||
: atomic_inst(ptr->get_type()->get_pointer_element_ty(), INST_ATOMIC_RMW, 3, name, next), op_(op) {
|
||||
set_operand(0, ptr);
|
||||
set_operand(1, val);
|
||||
set_operand(2, msk);
|
||||
}
|
||||
|
||||
instruction* atomic_add_inst::create(value *ptr, value *val, value *msk, const std::string &name, instruction *next) {
|
||||
return new atomic_add_inst(ptr, val, msk, name, next);
|
||||
instruction* atomic_rmw_inst::create(atomic_rmw_op_t op, value *ptr, value *val, value *msk, const std::string &name, instruction *next) {
|
||||
return new atomic_rmw_inst(op, ptr, val, msk, name, next);
|
||||
}
|
||||
|
||||
|
||||
// atomic cas
|
||||
|
||||
atomic_cas_inst::atomic_cas_inst(value *ptr, value *cmp, value *val, const std::string &name, instruction *next)
|
||||
|
Reference in New Issue
Block a user