From 226fde6ea19b823a568a4bd037901141b0c59ccf Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Tue, 17 Aug 2021 16:33:23 -0700 Subject: [PATCH] [CODEGEN] Now using atomic_rmw code path for atomic_xchg (#222) --- include/triton/codegen/selection/generator.h | 1 - include/triton/ir/builder.h | 1 - include/triton/ir/dispatch.h | 2 +- include/triton/ir/enums.h | 1 + include/triton/ir/instructions.h | 12 ---------- include/triton/ir/visitor.h | 2 -- lib/codegen/selection/generator.cc | 24 ++------------------ lib/ir/builder.cc | 4 ---- lib/ir/dispatch.cc | 10 ++++---- lib/ir/instructions.cc | 12 ---------- python/triton/language.py | 6 +++-- 11 files changed, 14 insertions(+), 61 deletions(-) diff --git a/include/triton/codegen/selection/generator.h b/include/triton/codegen/selection/generator.h index 2153e3e1a..ddde9198a 100644 --- a/include/triton/codegen/selection/generator.h +++ b/include/triton/codegen/selection/generator.h @@ -159,7 +159,6 @@ public: void visit_get_program_id_inst(ir::get_program_id_inst*); void visit_get_num_programs_inst(ir::get_num_programs_inst*); void visit_atomic_cas_inst(ir::atomic_cas_inst*); - void visit_atomic_exch_inst(ir::atomic_exch_inst*); void visit_atomic_rmw_inst(ir::atomic_rmw_inst*); void visit_mma884(ir::dot_inst*, ir::value *A, ir::value *B, ir::value *D, unsigned NK); void visit_mma16816(ir::dot_inst*, ir::value *A, ir::value *B, ir::value *D, unsigned NK); diff --git a/include/triton/ir/builder.h b/include/triton/ir/builder.h index a5aaf0ec9..cdb76cd66 100644 --- a/include/triton/ir/builder.h +++ b/include/triton/ir/builder.h @@ -142,7 +142,6 @@ public: value *create_get_program_id(unsigned axis); value *create_get_num_programs(unsigned axis); value *create_atomic_cas(value *ptr, value *cmp, value *val); - value *create_atomic_exch(value *ptr, value *val); value *create_atomic_rmw(ir::atomic_rmw_op_t op, value *ptr, value *val, value *msk); value *create_exp(value* arg); value *create_cos(value* arg); diff --git a/include/triton/ir/dispatch.h b/include/triton/ir/dispatch.h index c034dc191..97b787faa 100644 --- a/include/triton/ir/dispatch.h +++ b/include/triton/ir/dispatch.h @@ -70,13 +70,13 @@ struct dispatch{ static ir::value *load(ir::value* ptr, ir::value* mask, ir::value* other, ir::builder *builder); static ir::value *store(ir::value* ptr, ir::value *value, ir::value *mask, ir::builder *builder); static ir::value *atomic_cas(ir::value* ptr, ir::value *cmp, ir::value *val, ir::builder *builder); - static ir::value *atomic_xchg(ir::value* ptr, ir::value *val, ir::builder *builder); static ir::value *atomic_add(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder); static ir::value *atomic_max(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder); static ir::value *atomic_min(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder); static ir::value *atomic_and(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder); static ir::value *atomic_or(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder); static ir::value *atomic_xor(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder); + static ir::value *atomic_xchg(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder); // linear algebra static ir::value *dot(ir::value *lhs, ir::value *rhs, ir::builder *builder); diff --git a/include/triton/ir/enums.h b/include/triton/ir/enums.h index dc677cc51..48bdb7c66 100644 --- a/include/triton/ir/enums.h +++ b/include/triton/ir/enums.h @@ -38,6 +38,7 @@ enum class atomic_rmw_op_t: unsigned int{ UMax, UMin, FAdd, + Xchg, }; enum cast_op_t: unsigned int { diff --git a/include/triton/ir/instructions.h b/include/triton/ir/instructions.h index 35759a5b0..e3a389018 100644 --- a/include/triton/ir/instructions.h +++ b/include/triton/ir/instructions.h @@ -634,18 +634,6 @@ public: static instruction* create(value *ptr, value *cmp, value *val, const std::string &name = "", instruction *next = nullptr); }; -class atomic_exch_inst: public atomic_inst { -private: - atomic_exch_inst(value *ptr, value *val, const std::string &name = "", instruction *next = nullptr); - std::string repr_impl() const { return "atomic_exch"; } - _TRITON_DEFINE_CLONE(atomic_exch_inst) - _TRITON_DEFINE_ACCEPT(atomic_exch_inst) - -public: - static instruction* create(value *ptr, value *val, const std::string &name = "", instruction *next = nullptr); -}; - - class exp_inst: public builtin_inst { private: exp_inst(value *val, const std::string &name = "", instruction *next = nullptr); diff --git a/include/triton/ir/visitor.h b/include/triton/ir/visitor.h index 9600cccb1..a96211227 100644 --- a/include/triton/ir/visitor.h +++ b/include/triton/ir/visitor.h @@ -57,7 +57,6 @@ class get_program_id_inst; class get_num_programs_inst; class atomic_inst; class atomic_cas_inst; -class atomic_exch_inst; class atomic_rmw_inst; class dot_inst; class trans_inst; @@ -136,7 +135,6 @@ public: virtual void visit_get_program_id_inst(get_program_id_inst*) = 0; virtual void visit_get_num_programs_inst(get_num_programs_inst*) = 0; virtual void visit_atomic_cas_inst(atomic_cas_inst*) = 0; - virtual void visit_atomic_exch_inst(atomic_exch_inst*) = 0; virtual void visit_atomic_rmw_inst(atomic_rmw_inst*) = 0; virtual void visit_dot_inst(dot_inst*) = 0; virtual void visit_trans_inst(trans_inst*) = 0; diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index d00984167..8098462bb 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -922,30 +922,9 @@ void generator::visit_atomic_cas_inst(ir::atomic_cas_inst* cas) { vals_[cas][{}] = load(atom_ptr); add_barrier(); } -/** - * \brief Code Generation for `atomic_exch` - */ -void generator::visit_atomic_exch_inst(ir::atomic_exch_inst* xchg) { - BasicBlock *current = builder_->GetInsertBlock(); - Module *module = current->getModule(); - Value *rmw_ptr = vals_[xchg->get_operand(0)][{}]; - Value *rmw_val = vals_[xchg->get_operand(1)][{}]; - Value *tid = tgt_->get_local_id(module, *builder_, 0); - Value *pred = icmp_eq(tid, i32(0)); - BasicBlock *tid_0_bb = BasicBlock::Create(*ctx_, "tid_0", current->getParent()); - BasicBlock *tid_0_done_bb = BasicBlock::Create(*ctx_, "tid_0_done", current->getParent()); - tgt_->add_memfence(module, *builder_); - add_barrier(); - cond_br(pred, tid_0_bb, tid_0_done_bb); - builder_->SetInsertPoint(tid_0_bb); - atomic_rmw(AtomicRMWInst::Xchg, rmw_ptr, rmw_val, AtomicOrdering::Monotonic, SyncScope::System); - br(tid_0_done_bb); - builder_->SetInsertPoint(tid_0_done_bb); - tgt_->add_memfence(module, *builder_); -} /** - * \brief Code Generation for `atomic_add` + * \brief Code Generation for `atomic_rmw` */ void generator::visit_atomic_rmw_inst(ir::atomic_rmw_inst *atom) { ir::value* ptr = atom->get_operand(0); @@ -1000,6 +979,7 @@ void generator::visit_atomic_rmw_inst(ir::atomic_rmw_inst *atom) { case tt::UMin: name = "min", s_ty = "u"; break; case tt::UMax: name = "max", s_ty = "u"; break; case tt::FAdd: name = "add", s_ty = "f"; break; + case tt::Xchg: name = "exch", s_ty = "b"; break; } std::string s_vec = vec == 2 ? "x2" : ""; std::string mod = nbits == 32 ? "" : ".noftz"; diff --git a/lib/ir/builder.cc b/lib/ir/builder.cc index 6173a95cc..7f06ea370 100644 --- a/lib/ir/builder.cc +++ b/lib/ir/builder.cc @@ -325,10 +325,6 @@ value *builder::create_atomic_cas(value *ptr, value *cmp, value *val){ return insert(atomic_cas_inst::create(ptr, cmp, val)); } -value *builder::create_atomic_exch(value *ptr, value *val){ - return insert(atomic_exch_inst::create(ptr, val)); -} - value *builder::create_atomic_rmw(ir::atomic_rmw_op_t op, value *ptr, value *val, value *msk){ return insert(atomic_rmw_inst::create(op, ptr, val, msk)); } diff --git a/lib/ir/dispatch.cc b/lib/ir/dispatch.cc index a55e318b5..21169c031 100644 --- a/lib/ir/dispatch.cc +++ b/lib/ir/dispatch.cc @@ -534,10 +534,6 @@ ir::value *dispatch::atomic_cas(ir::value* ptr, ir::value *cmp, ir::value *val, return builder->create_atomic_cas(ptr, cmp, val); } -ir::value *dispatch::atomic_xchg(ir::value* ptr, ir::value *val, ir::builder *builder){ - return builder->create_atomic_exch(ptr, val); -} - void atom_red_typechecking(ir::value*& ptr, ir::value *&val, ir::value *&mask, ir::builder *builder){ if(!ptr->get_type()->get_scalar_ty()->is_pointer_ty()) throw semantic_error("Pointer argument of store instruction is " + ptr->get_type()->repr()); @@ -615,6 +611,12 @@ ir::value *dispatch::atomic_xor(ir::value* ptr, ir::value *val, ir::value *mask, return builder->create_atomic_rmw(ir::atomic_rmw_op_t::Xor, ptr, val, mask); } +ir::value *dispatch::atomic_xchg(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){ + atom_red_typechecking(ptr, val, mask, builder); + ir::type* sca_ty = val->get_type()->get_scalar_ty(); + return builder->create_atomic_rmw(ir::atomic_rmw_op_t::Xchg, ptr, val, mask); +} + //===----------------------------------------------------------------------===// // Linear Algebra //===----------------------------------------------------------------------===// diff --git a/lib/ir/instructions.cc b/lib/ir/instructions.cc index 1159b22ee..18da3d3ae 100644 --- a/lib/ir/instructions.cc +++ b/lib/ir/instructions.cc @@ -757,18 +757,6 @@ instruction* atomic_cas_inst::create(value *ptr, value *cmp, value *val, const s return new atomic_cas_inst(ptr, cmp, val, name, next); } -// atomic exch - -atomic_exch_inst::atomic_exch_inst(value *ptr, value *val, const std::string &name, instruction *next) - : atomic_inst(ptr->get_type()->get_pointer_element_ty(), INST_ATOMIC_EXCH, 2, name, next) { - set_operand(0, ptr); - set_operand(1, val); -} - -instruction* atomic_exch_inst::create(value *ptr, value *val, const std::string &name, instruction *next) { - return new atomic_exch_inst(ptr, val, name, next); -} - // exp diff --git a/python/triton/language.py b/python/triton/language.py index 6a1a44d60..7255bf542 100644 --- a/python/triton/language.py +++ b/python/triton/language.py @@ -442,7 +442,7 @@ def atomic_cas(pointer, cmp, val, builder=None): @builtin -def atomic_xchg(pointer, val, builder=None): +def atomic_xchg(pointer, val, mask=None, builder=None): """ Swaps the *old* values stored at location :code:`pointer` with the new values given by :code:`val`. Returns the old values. @@ -450,8 +450,10 @@ def atomic_xchg(pointer, val, builder=None): :type pointer: Block of dtype=triton.PointerDType :param val: The new values to store :type val: Block of dtype=`pointer.dtype.element_ty` + :param mask: If mask[idx] is false, :code:`pointer[idx]` is unaffected. + :type mask: Block of triton.int1, optional """ - return frontend.atomic_xchg(pointer, val, builder) + return frontend.atomic_xchg(pointer, val, mask, builder) @builtin