[CODEGEN] Now using atomic_rmw code path for atomic_xchg (#222)
This commit is contained in:
@@ -159,7 +159,6 @@ public:
|
||||
void visit_get_program_id_inst(ir::get_program_id_inst*);
|
||||
void visit_get_num_programs_inst(ir::get_num_programs_inst*);
|
||||
void visit_atomic_cas_inst(ir::atomic_cas_inst*);
|
||||
void visit_atomic_exch_inst(ir::atomic_exch_inst*);
|
||||
void visit_atomic_rmw_inst(ir::atomic_rmw_inst*);
|
||||
void visit_mma884(ir::dot_inst*, ir::value *A, ir::value *B, ir::value *D, unsigned NK);
|
||||
void visit_mma16816(ir::dot_inst*, ir::value *A, ir::value *B, ir::value *D, unsigned NK);
|
||||
|
@@ -142,7 +142,6 @@ public:
|
||||
value *create_get_program_id(unsigned axis);
|
||||
value *create_get_num_programs(unsigned axis);
|
||||
value *create_atomic_cas(value *ptr, value *cmp, value *val);
|
||||
value *create_atomic_exch(value *ptr, value *val);
|
||||
value *create_atomic_rmw(ir::atomic_rmw_op_t op, value *ptr, value *val, value *msk);
|
||||
value *create_exp(value* arg);
|
||||
value *create_cos(value* arg);
|
||||
|
@@ -70,13 +70,13 @@ struct dispatch{
|
||||
static ir::value *load(ir::value* ptr, ir::value* mask, ir::value* other, ir::builder *builder);
|
||||
static ir::value *store(ir::value* ptr, ir::value *value, ir::value *mask, ir::builder *builder);
|
||||
static ir::value *atomic_cas(ir::value* ptr, ir::value *cmp, ir::value *val, ir::builder *builder);
|
||||
static ir::value *atomic_xchg(ir::value* ptr, ir::value *val, ir::builder *builder);
|
||||
static ir::value *atomic_add(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);
|
||||
static ir::value *atomic_max(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);
|
||||
static ir::value *atomic_min(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);
|
||||
static ir::value *atomic_and(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);
|
||||
static ir::value *atomic_or(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);
|
||||
static ir::value *atomic_xor(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);
|
||||
static ir::value *atomic_xchg(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);
|
||||
|
||||
// linear algebra
|
||||
static ir::value *dot(ir::value *lhs, ir::value *rhs, ir::builder *builder);
|
||||
|
@@ -38,6 +38,7 @@ enum class atomic_rmw_op_t: unsigned int{
|
||||
UMax,
|
||||
UMin,
|
||||
FAdd,
|
||||
Xchg,
|
||||
};
|
||||
|
||||
enum cast_op_t: unsigned int {
|
||||
|
@@ -634,18 +634,6 @@ public:
|
||||
static instruction* create(value *ptr, value *cmp, value *val, const std::string &name = "", instruction *next = nullptr);
|
||||
};
|
||||
|
||||
class atomic_exch_inst: public atomic_inst {
|
||||
private:
|
||||
atomic_exch_inst(value *ptr, value *val, const std::string &name = "", instruction *next = nullptr);
|
||||
std::string repr_impl() const { return "atomic_exch"; }
|
||||
_TRITON_DEFINE_CLONE(atomic_exch_inst)
|
||||
_TRITON_DEFINE_ACCEPT(atomic_exch_inst)
|
||||
|
||||
public:
|
||||
static instruction* create(value *ptr, value *val, const std::string &name = "", instruction *next = nullptr);
|
||||
};
|
||||
|
||||
|
||||
class exp_inst: public builtin_inst {
|
||||
private:
|
||||
exp_inst(value *val, const std::string &name = "", instruction *next = nullptr);
|
||||
|
@@ -57,7 +57,6 @@ class get_program_id_inst;
|
||||
class get_num_programs_inst;
|
||||
class atomic_inst;
|
||||
class atomic_cas_inst;
|
||||
class atomic_exch_inst;
|
||||
class atomic_rmw_inst;
|
||||
class dot_inst;
|
||||
class trans_inst;
|
||||
@@ -136,7 +135,6 @@ public:
|
||||
virtual void visit_get_program_id_inst(get_program_id_inst*) = 0;
|
||||
virtual void visit_get_num_programs_inst(get_num_programs_inst*) = 0;
|
||||
virtual void visit_atomic_cas_inst(atomic_cas_inst*) = 0;
|
||||
virtual void visit_atomic_exch_inst(atomic_exch_inst*) = 0;
|
||||
virtual void visit_atomic_rmw_inst(atomic_rmw_inst*) = 0;
|
||||
virtual void visit_dot_inst(dot_inst*) = 0;
|
||||
virtual void visit_trans_inst(trans_inst*) = 0;
|
||||
|
@@ -922,30 +922,9 @@ void generator::visit_atomic_cas_inst(ir::atomic_cas_inst* cas) {
|
||||
vals_[cas][{}] = load(atom_ptr);
|
||||
add_barrier();
|
||||
}
|
||||
/**
|
||||
* \brief Code Generation for `atomic_exch`
|
||||
*/
|
||||
void generator::visit_atomic_exch_inst(ir::atomic_exch_inst* xchg) {
|
||||
BasicBlock *current = builder_->GetInsertBlock();
|
||||
Module *module = current->getModule();
|
||||
Value *rmw_ptr = vals_[xchg->get_operand(0)][{}];
|
||||
Value *rmw_val = vals_[xchg->get_operand(1)][{}];
|
||||
Value *tid = tgt_->get_local_id(module, *builder_, 0);
|
||||
Value *pred = icmp_eq(tid, i32(0));
|
||||
BasicBlock *tid_0_bb = BasicBlock::Create(*ctx_, "tid_0", current->getParent());
|
||||
BasicBlock *tid_0_done_bb = BasicBlock::Create(*ctx_, "tid_0_done", current->getParent());
|
||||
tgt_->add_memfence(module, *builder_);
|
||||
add_barrier();
|
||||
cond_br(pred, tid_0_bb, tid_0_done_bb);
|
||||
builder_->SetInsertPoint(tid_0_bb);
|
||||
atomic_rmw(AtomicRMWInst::Xchg, rmw_ptr, rmw_val, AtomicOrdering::Monotonic, SyncScope::System);
|
||||
br(tid_0_done_bb);
|
||||
builder_->SetInsertPoint(tid_0_done_bb);
|
||||
tgt_->add_memfence(module, *builder_);
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Code Generation for `atomic_add`
|
||||
* \brief Code Generation for `atomic_rmw`
|
||||
*/
|
||||
void generator::visit_atomic_rmw_inst(ir::atomic_rmw_inst *atom) {
|
||||
ir::value* ptr = atom->get_operand(0);
|
||||
@@ -1000,6 +979,7 @@ void generator::visit_atomic_rmw_inst(ir::atomic_rmw_inst *atom) {
|
||||
case tt::UMin: name = "min", s_ty = "u"; break;
|
||||
case tt::UMax: name = "max", s_ty = "u"; break;
|
||||
case tt::FAdd: name = "add", s_ty = "f"; break;
|
||||
case tt::Xchg: name = "exch", s_ty = "b"; break;
|
||||
}
|
||||
std::string s_vec = vec == 2 ? "x2" : "";
|
||||
std::string mod = nbits == 32 ? "" : ".noftz";
|
||||
|
@@ -325,10 +325,6 @@ value *builder::create_atomic_cas(value *ptr, value *cmp, value *val){
|
||||
return insert(atomic_cas_inst::create(ptr, cmp, val));
|
||||
}
|
||||
|
||||
value *builder::create_atomic_exch(value *ptr, value *val){
|
||||
return insert(atomic_exch_inst::create(ptr, val));
|
||||
}
|
||||
|
||||
value *builder::create_atomic_rmw(ir::atomic_rmw_op_t op, value *ptr, value *val, value *msk){
|
||||
return insert(atomic_rmw_inst::create(op, ptr, val, msk));
|
||||
}
|
||||
|
@@ -534,10 +534,6 @@ ir::value *dispatch::atomic_cas(ir::value* ptr, ir::value *cmp, ir::value *val,
|
||||
return builder->create_atomic_cas(ptr, cmp, val);
|
||||
}
|
||||
|
||||
ir::value *dispatch::atomic_xchg(ir::value* ptr, ir::value *val, ir::builder *builder){
|
||||
return builder->create_atomic_exch(ptr, val);
|
||||
}
|
||||
|
||||
void atom_red_typechecking(ir::value*& ptr, ir::value *&val, ir::value *&mask, ir::builder *builder){
|
||||
if(!ptr->get_type()->get_scalar_ty()->is_pointer_ty())
|
||||
throw semantic_error("Pointer argument of store instruction is " + ptr->get_type()->repr());
|
||||
@@ -615,6 +611,12 @@ ir::value *dispatch::atomic_xor(ir::value* ptr, ir::value *val, ir::value *mask,
|
||||
return builder->create_atomic_rmw(ir::atomic_rmw_op_t::Xor, ptr, val, mask);
|
||||
}
|
||||
|
||||
ir::value *dispatch::atomic_xchg(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){
|
||||
atom_red_typechecking(ptr, val, mask, builder);
|
||||
ir::type* sca_ty = val->get_type()->get_scalar_ty();
|
||||
return builder->create_atomic_rmw(ir::atomic_rmw_op_t::Xchg, ptr, val, mask);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Linear Algebra
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@@ -757,18 +757,6 @@ instruction* atomic_cas_inst::create(value *ptr, value *cmp, value *val, const s
|
||||
return new atomic_cas_inst(ptr, cmp, val, name, next);
|
||||
}
|
||||
|
||||
// atomic exch
|
||||
|
||||
atomic_exch_inst::atomic_exch_inst(value *ptr, value *val, const std::string &name, instruction *next)
|
||||
: atomic_inst(ptr->get_type()->get_pointer_element_ty(), INST_ATOMIC_EXCH, 2, name, next) {
|
||||
set_operand(0, ptr);
|
||||
set_operand(1, val);
|
||||
}
|
||||
|
||||
instruction* atomic_exch_inst::create(value *ptr, value *val, const std::string &name, instruction *next) {
|
||||
return new atomic_exch_inst(ptr, val, name, next);
|
||||
}
|
||||
|
||||
|
||||
// exp
|
||||
|
||||
|
@@ -442,7 +442,7 @@ def atomic_cas(pointer, cmp, val, builder=None):
|
||||
|
||||
|
||||
@builtin
|
||||
def atomic_xchg(pointer, val, builder=None):
|
||||
def atomic_xchg(pointer, val, mask=None, builder=None):
|
||||
"""
|
||||
Swaps the *old* values stored at location :code:`pointer` with the new values given by :code:`val`. Returns the old values.
|
||||
|
||||
@@ -450,8 +450,10 @@ def atomic_xchg(pointer, val, builder=None):
|
||||
:type pointer: Block of dtype=triton.PointerDType
|
||||
:param val: The new values to store
|
||||
:type val: Block of dtype=`pointer.dtype.element_ty`
|
||||
:param mask: If mask[idx] is false, :code:`pointer[idx]` is unaffected.
|
||||
:type mask: Block of triton.int1, optional
|
||||
"""
|
||||
return frontend.atomic_xchg(pointer, val, builder)
|
||||
return frontend.atomic_xchg(pointer, val, mask, builder)
|
||||
|
||||
|
||||
@builtin
|
||||
|
Reference in New Issue
Block a user