diff --git a/include/triton/codegen/selection/generator.h b/include/triton/codegen/selection/generator.h index 61e513d47..c4a5abed6 100644 --- a/include/triton/codegen/selection/generator.h +++ b/include/triton/codegen/selection/generator.h @@ -154,7 +154,7 @@ public: void visit_get_num_programs_inst(ir::get_num_programs_inst*); void visit_atomic_cas_inst(ir::atomic_cas_inst*); void visit_atomic_exch_inst(ir::atomic_exch_inst*); - void visit_atomic_add_inst(ir::atomic_add_inst*); + void visit_atomic_rmw_inst(ir::atomic_rmw_inst*); void visit_mma884(ir::dot_inst*, ir::value *A, ir::value *B, ir::value *D, unsigned NK); void visit_mma16816(ir::dot_inst*, ir::value *A, ir::value *B, ir::value *D, unsigned NK); void visit_fmadot(ir::dot_inst*, ir::value *A, ir::value *B, ir::value *D, unsigned NK, Type *c_ty, Function *f_mul_add); diff --git a/include/triton/ir/builder.h b/include/triton/ir/builder.h index 5029aa673..bfbbf6c66 100644 --- a/include/triton/ir/builder.h +++ b/include/triton/ir/builder.h @@ -137,7 +137,7 @@ public: value *create_get_num_programs(unsigned axis); value *create_atomic_cas(value *ptr, value *cmp, value *val); value *create_atomic_exch(value *ptr, value *val); - value *create_atomic_add(value *ptr, value *val, value *msk); + value *create_atomic_rmw(ir::atomic_rmw_op_t op, value *ptr, value *val, value *msk); value *create_exp(value* arg); value *create_log(value* arg); value *create_dot(value *A, value *B, value *C); diff --git a/include/triton/ir/dispatch.h b/include/triton/ir/dispatch.h index f9ada21ef..b7fd85207 100644 --- a/include/triton/ir/dispatch.h +++ b/include/triton/ir/dispatch.h @@ -71,6 +71,11 @@ struct dispatch{ static ir::value *atomic_cas(ir::value* ptr, ir::value *cmp, ir::value *val, ir::builder *builder); static ir::value *atomic_xchg(ir::value* ptr, ir::value *val, ir::builder *builder); static ir::value *atomic_add(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder); + static ir::value *atomic_max(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder); + static ir::value *atomic_min(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder); + static ir::value *atomic_and(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder); + static ir::value *atomic_or(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder); + static ir::value *atomic_xor(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder); // linear algebra static ir::value *dot(ir::value *lhs, ir::value *rhs, ir::builder *builder); diff --git a/include/triton/ir/enums.h b/include/triton/ir/enums.h index 720a2e307..499767026 100644 --- a/include/triton/ir/enums.h +++ b/include/triton/ir/enums.h @@ -28,6 +28,18 @@ enum binary_op_t: unsigned int{ Xor }; +enum class atomic_rmw_op_t: unsigned int{ + And, + Or, + Xor, + Add, + Max, + Min, + UMax, + UMin, + FAdd, +}; + enum cast_op_t: unsigned int { Trunc, ZExt, @@ -127,7 +139,7 @@ enum value_id_t: unsigned { // atomics INST_ATOMIC_CAS, INST_ATOMIC_EXCH, - INST_ATOMIC_ADD, + INST_ATOMIC_RMW, // math INST_EXP, INST_LOG, diff --git a/include/triton/ir/instructions.h b/include/triton/ir/instructions.h index ed4ad764e..2e3b1e9ed 100644 --- a/include/triton/ir/instructions.h +++ b/include/triton/ir/instructions.h @@ -608,17 +608,20 @@ public: using io_inst::io_inst; }; -class atomic_add_inst: public atomic_inst { +class atomic_rmw_inst: public atomic_inst { private: - atomic_add_inst(value *ptr, value *val, value *msk, const std::string &name = "", instruction *next = nullptr); - std::string repr_impl() const { return "atomic_add"; } - _TRITON_DEFINE_CLONE(atomic_add_inst) - _TRITON_DEFINE_ACCEPT(atomic_add_inst) + atomic_rmw_inst(atomic_rmw_op_t op, value *ptr, value *val, value *msk, const std::string &name = "", instruction *next = nullptr); + std::string repr_impl() const { return "atomic_rmw"; } + _TRITON_DEFINE_CLONE(atomic_rmw_inst) + _TRITON_DEFINE_ACCEPT(atomic_rmw_inst) public: - static instruction* create(value *ptr, value *val, value *msk, const std::string &name = "", instruction *next = nullptr); -}; + static instruction* create(atomic_rmw_op_t op, value *ptr, value *val, value *msk, const std::string &name = "", instruction *next = nullptr); + atomic_rmw_op_t get_op() { return op_; } +private: + atomic_rmw_op_t op_; +}; class atomic_cas_inst: public atomic_inst { private: diff --git a/include/triton/ir/visitor.h b/include/triton/ir/visitor.h index f244fbb06..a42800714 100644 --- a/include/triton/ir/visitor.h +++ b/include/triton/ir/visitor.h @@ -53,9 +53,10 @@ class log_inst; class get_program_id_inst; class get_num_programs_inst; +class atomic_inst; class atomic_cas_inst; class atomic_exch_inst; -class atomic_add_inst; +class atomic_rmw_inst; class dot_inst; class trans_inst; class sqrt_inst; @@ -132,7 +133,7 @@ public: virtual void visit_get_num_programs_inst(get_num_programs_inst*) = 0; virtual void visit_atomic_cas_inst(atomic_cas_inst*) = 0; virtual void visit_atomic_exch_inst(atomic_exch_inst*) = 0; - virtual void visit_atomic_add_inst(atomic_add_inst*) = 0; + virtual void visit_atomic_rmw_inst(atomic_rmw_inst*) = 0; virtual void visit_dot_inst(dot_inst*) = 0; virtual void visit_trans_inst(trans_inst*) = 0; virtual void visit_sqrt_inst(sqrt_inst*) = 0; diff --git a/lib/codegen/analysis/layout.cc b/lib/codegen/analysis/layout.cc index faab4a47f..dc280aba5 100644 --- a/lib/codegen/analysis/layout.cc +++ b/lib/codegen/analysis/layout.cc @@ -403,6 +403,7 @@ void layouts::run(ir::module &mod) { for(const auto& x: values_) create(x.first, x.second); + // create temporaries size_t id = values_.size(); ir::for_each_instruction(mod, [this, &id](ir::instruction* i) { diff --git a/lib/codegen/pass.cc b/lib/codegen/pass.cc index 2100a7770..66d3c238c 100644 --- a/lib/codegen/pass.cc +++ b/lib/codegen/pass.cc @@ -93,7 +93,7 @@ void add_passes_to_emit_bin(ir::module &ir, driver::device *dev, int num_warps, liveness.run(ir); allocation.run(ir); prefetch_s.run(ir); - barriers.run(ir); + barriers.run(ir); // ir::print(ir, std::cout); isel.visit(ir, *llvm); mod = driver::module::create(dev, std::move(llvm)); diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index 6507080f9..349056015 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -813,28 +813,35 @@ void generator::visit_atomic_cas_inst(ir::atomic_cas_inst* cas) { Module *module = current->getModule(); Value *tid = tgt_->get_local_id(module, *builder_, 0); Value *pred = icmp_eq(tid, i32(0)); - BasicBlock *tid_0_bb = BasicBlock::Create(*ctx_, "tid_0", current->getParent()); - BasicBlock *tid_0_done_bb = BasicBlock::Create(*ctx_, "tid_0_done", current->getParent()); +// BasicBlock *tid_0_bb = BasicBlock::Create(*ctx_, "tid_0", current->getParent()); +// BasicBlock *tid_0_done_bb = BasicBlock::Create(*ctx_, "tid_0_done", current->getParent()); add_barrier(); tgt_->add_memfence(module, *builder_); - cond_br(pred, tid_0_bb, tid_0_done_bb); - builder_->SetInsertPoint(tid_0_bb); + Value *atom_ptr; + atom_ptr = gep(shmem_, i32(alloc_->offset(layouts_->get(layouts_->tmp(cas)))), ""); + atom_ptr = bit_cast(atom_ptr, ptr_ty(cvt(cas->get_type()->get_scalar_ty()), 3)); +// cond_br(pred, tid_0_bb, tid_0_done_bb); +// builder_->SetInsertPoint(tid_0_bb); Value *cas_ptr = vals_[cas->get_operand(0)][{}]; Value *cas_cmp = vals_[cas->get_operand(1)][{}]; Value *cas_val = vals_[cas->get_operand(2)][{}]; - Value *old = atomic_cmp_xchg(cas_ptr, cas_cmp, cas_val, AtomicOrdering::Monotonic, AtomicOrdering::Monotonic); - old = extract_val(old, std::vector{0}); - Value *atom_ptr; - atom_ptr = gep(shmem_, i32(alloc_->offset(layouts_->get(layouts_->tmp(cas)))), ""); - atom_ptr = bit_cast(atom_ptr, ptr_ty(old->getType(), 3)); - store(old, atom_ptr); - br(tid_0_done_bb); - builder_->SetInsertPoint(tid_0_done_bb); + std::string asm_str = "@$1 atom.global.cas.b32 $0, [$2], $3, $4;"; + FunctionType *fn_ty = FunctionType::get(i32_ty, {pred->getType(), cas_ptr->getType(), cas_cmp->getType(), cas_val->getType()}, false); + InlineAsm *iasm = InlineAsm::get(fn_ty, asm_str, "=r,b,l,r,r", true); + add_barrier(); + Value *old = call(iasm, {pred, cas_ptr, cas_cmp, cas_val}); + add_barrier(); + + std::string asm2_str = "@$0 st.shared.b32 [$1], $2;"; + FunctionType *fn2_ty = FunctionType::get(void_ty, {pred->getType(), atom_ptr->getType(), old->getType()}, false); + InlineAsm *iasm2 = InlineAsm::get(fn2_ty, asm2_str, "b,r,r", true); + add_barrier(); + call(iasm2, {pred, atom_ptr, old}); tgt_->add_memfence(module, *builder_); add_barrier(); vals_[cas][{}] = load(atom_ptr); + add_barrier(); } - /** * \brief Code Generation for `atomic_exch` */ @@ -860,93 +867,86 @@ void generator::visit_atomic_exch_inst(ir::atomic_exch_inst* xchg) { /** * \brief Code Generation for `atomic_add` */ -//TODO: clean-up -void generator::visit_atomic_add_inst(ir::atomic_add_inst* add) { +void generator::visit_atomic_rmw_inst(ir::atomic_rmw_inst *atom) { + ir::value* ptr = atom->get_operand(0); + ir::value* val = atom->get_operand(1); + ir::value* msk = atom->get_operand(2); - if(add->get_type()->is_block_ty()){ - ir::value* ptr = add->get_operand(0); - ir::value* val = add->get_operand(1); - ir::value* msk = add->get_operand(2); - - // vector size - int vec = 1; + // vector size + int vec = 1; + if(atom->get_type()->is_block_ty()){ int ld = ords_.at(ptr)[0]; unsigned alignment = alignment_->get(ptr, ld); vec = std::min(layouts_->get(ptr)->to_scanline()->nts(ld), alignment); vec = std::min(vec, val->get_type()->get_tile_element_ty()->is_half_ty() ? 2 : 1); - - for(int i = 0; i < idxs_.at(val).size(); i += vec){ - auto idx = idxs_[val][i]; - Value *rmw_val = UndefValue::get(vec_ty(vals_[val][idx]->getType(), vec)); - for(int ii = 0; ii < vec; ii++) - rmw_val = insert_elt(rmw_val, vals_[val][idxs_[val][i+ii]], ii); - Value *rmw_ptr = vals_[ptr][idx]; - Value *rmw_msk = vals_[msk][idx]; - if(vec == 1) - rmw_val = extract_elt(rmw_val, i32(0)); - Type* ty = rmw_val->getType(); - size_t nbits = ty->getScalarSizeInBits(); - // extract pointer offset - std::string offset = ""; - if(GetElementPtrInst *gep = dyn_cast(rmw_ptr)) - if(gep->getNumIndices() == 1) - if(ConstantInt *cst = dyn_cast(gep->idx_begin())){ - offset = " + " + std::to_string(cst->getValue().getSExtValue()*nbits/8); - rmw_ptr = gep->getPointerOperand(); - } - rmw_ptr = bit_cast(rmw_ptr, ty->getPointerTo(1)); - // asm argument type - std::vector arg_ty = {rmw_msk->getType(), rmw_ptr->getType(), rmw_val->getType()}; - // asm function type - FunctionType *fn_ty = FunctionType::get(ty, arg_ty, false); - // asm string - std::string suffix = vec == 2 ? "x2" : ""; - std::string mod = nbits == 32 ? "" : ".noftz"; - std::string ty_str = add->get_type()->get_scalar_ty()->is_floating_point_ty() ? "f" : "u"; - std::string asm_str = "@$1 atom.global.gpu.add" + mod + "." + ty_str + std::to_string(nbits) + suffix + " $0, [$2" + offset + "], $3;"; - std::string ty_id = nbits == 32 ? ty_str : (vec == 1 ? "h" : "r"); - std::string constraint = "=" + ty_id + ",b,l," + ty_id; - // create inline asm - InlineAsm *iasm = InlineAsm::get(fn_ty, asm_str, constraint, true); - // call asm - vals_[add][idx] = call(iasm, (ArrayRef{rmw_msk, rmw_ptr, rmw_val})); - } } - else{ - Value *rmw_ptr = vals_[add->get_operand(0)][{}]; - Value *rmw_val = vals_[add->get_operand(1)][{}]; - Value *rmw_msk = vals_[add->get_operand(2)][{}]; + + for(int i = 0; i < idxs_.at(val).size(); i += vec){ + auto idx = idxs_[val][i]; + Value *rmw_val = UndefValue::get(vec_ty(vals_[val][idx]->getType(), vec)); + for(int ii = 0; ii < vec; ii++) + rmw_val = insert_elt(rmw_val, vals_[val][idxs_[val][i+ii]], ii); + Value *rmw_ptr = vals_[ptr][idx]; + Value *rmw_msk = vals_[msk][idx]; + if(vec == 1) + rmw_val = extract_elt(rmw_val, i32(0)); Type* ty = rmw_val->getType(); size_t nbits = ty->getScalarSizeInBits(); + // extract pointer offset + std::string offset = ""; + if(GetElementPtrInst *gep = dyn_cast(rmw_ptr)) + if(gep->getNumIndices() == 1) + if(ConstantInt *cst = dyn_cast(gep->idx_begin())){ + offset = " + " + std::to_string(cst->getValue().getSExtValue()*nbits/8); + rmw_ptr = gep->getPointerOperand(); + } + rmw_ptr = bit_cast(rmw_ptr, ty->getPointerTo(1)); + // asm argument type std::vector arg_ty = {rmw_msk->getType(), rmw_ptr->getType(), rmw_val->getType()}; + // asm function type FunctionType *fn_ty = FunctionType::get(ty, arg_ty, false); + // asm string + std::string s_nbits = std::to_string(nbits); + std::string name; + std::string s_ty; + using tt = ir::atomic_rmw_op_t; + switch(atom->get_op()){ + case tt::Or: name = "or"; s_ty = "b"; break; + case tt::And: name = "and"; s_ty = "b"; break; + case tt::Xor: name = "xor", s_ty = "b"; break; + case tt::Add: name = "add" , s_ty = "s"; break; + case tt::Min: name = "min", s_ty = "s"; break; + case tt::Max: name = "max", s_ty = "s"; break; + case tt::UMin: name = "min", s_ty = "u"; break; + case tt::UMax: name = "max", s_ty = "u"; break; + case tt::FAdd: name = "add", s_ty = "f"; break; + } + std::string s_vec = vec == 2 ? "x2" : ""; std::string mod = nbits == 32 ? "" : ".noftz"; - std::string ty_str = add->get_type()->get_scalar_ty()->is_floating_point_ty() ? "f" : "u"; - std::string asm_str = "@$1 atom.global.gpu.add" + mod + "." + ty_str + std::to_string(nbits) + " $0, [$2], $3;"; - std::string ty_id = nbits == 32 ? "r" : "h"; - InlineAsm *iasm = InlineAsm::get(fn_ty, asm_str, "="+ty_id+",b,l,"+ty_id, true); - BasicBlock *current = builder_->GetInsertBlock(); - Module *module = current->getModule(); - - Value *tid = tgt_->get_local_id(module, *builder_, 0); - Value *pred = icmp_eq(tid, i32(0)); - BasicBlock *tid_0_bb = BasicBlock::Create(*ctx_, "tid_0", current->getParent()); - BasicBlock *tid_0_done_bb = BasicBlock::Create(*ctx_, "tid_0_done", current->getParent()); - tgt_->add_memfence(module, *builder_); - add_barrier(); - cond_br(pred, tid_0_bb, tid_0_done_bb); - builder_->SetInsertPoint(tid_0_bb); - Value *old = call(iasm, (ArrayRef{rmw_msk, rmw_ptr, rmw_val})); - Value *atom_ptr; - atom_ptr = gep(shmem_, i32(alloc_->offset(layouts_->get(layouts_->tmp(add)))), ""); - atom_ptr = bit_cast(atom_ptr, ptr_ty(old->getType(), 3)); - store(old, atom_ptr); - br(tid_0_done_bb); - builder_->SetInsertPoint(tid_0_done_bb); - tgt_->add_memfence(module, *builder_); - add_barrier(); - vals_[add][{}] = load(atom_ptr); + std::string asm_str = "@$1 atom.global.gpu." + name + mod + "." + s_ty + s_nbits + s_vec + " $0, [$2" + offset + "], $3;"; + std::string ty_id = nbits*vec == 32 ? "r" : "h"; + std::string constraint = "=" + ty_id + ",b,l," + ty_id; + // create inline asm + InlineAsm *iasm = InlineAsm::get(fn_ty, asm_str, constraint, true); + // call asm + if(atom->get_type()->is_block_ty()) + vals_[atom][idx] = call(iasm, (ArrayRef{rmw_msk, rmw_ptr, rmw_val})); + else{ + Module *mod = builder_->GetInsertBlock()->getModule(); + tgt_->add_memfence(mod, *builder_); + add_barrier(); + Value *tid = tgt_->get_local_id(mod, *builder_, 0); + rmw_msk = builder_->CreateAnd(rmw_msk, icmp_eq(tid, i32(0))); + Value *old = call(iasm, (ArrayRef{rmw_msk, rmw_ptr, rmw_val})); + Value *atom_ptr; + atom_ptr = gep(shmem_, i32(alloc_->offset(layouts_->get(layouts_->tmp(atom)))), ""); + atom_ptr = bit_cast(atom_ptr, ptr_ty(old->getType(), 3)); + store(old, atom_ptr); + add_barrier(); + vals_[atom][idx] = load(atom_ptr); + add_barrier(); + } } } diff --git a/lib/codegen/transform/dce.cc b/lib/codegen/transform/dce.cc index 907b1621f..c555290f8 100644 --- a/lib/codegen/transform/dce.cc +++ b/lib/codegen/transform/dce.cc @@ -25,8 +25,8 @@ void dce::run(ir::module &mod) { case ir::INST_COND_BRANCH: case ir::INST_UNMASKED_STORE: case ir::INST_MASKED_STORE: - case ir::INST_ATOMIC_ADD: case ir::INST_ATOMIC_CAS: + case ir::INST_ATOMIC_RMW: case ir::INST_ATOMIC_EXCH: case ir::INST_BARRIER: { work_list.push_back(i); diff --git a/lib/ir/builder.cc b/lib/ir/builder.cc index b6baa9a2e..ad0faeb8d 100644 --- a/lib/ir/builder.cc +++ b/lib/ir/builder.cc @@ -323,8 +323,8 @@ value *builder::create_atomic_exch(value *ptr, value *val){ return insert(atomic_exch_inst::create(ptr, val)); } -value *builder::create_atomic_add(value *ptr, value *val, value *msk){ - return insert(atomic_add_inst::create(ptr, val, msk)); +value *builder::create_atomic_rmw(ir::atomic_rmw_op_t op, value *ptr, value *val, value *msk){ + return insert(atomic_rmw_inst::create(op, ptr, val, msk)); } value *builder::create_exp(value *arg){ diff --git a/lib/ir/dispatch.cc b/lib/ir/dispatch.cc index c3cd71c5d..ee3be6bcc 100644 --- a/lib/ir/dispatch.cc +++ b/lib/ir/dispatch.cc @@ -538,7 +538,7 @@ ir::value *dispatch::atomic_xchg(ir::value* ptr, ir::value *val, ir::builder *bu return builder->create_atomic_exch(ptr, val); } -ir::value *dispatch::atomic_add(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){ +void atom_red_typechecking(ir::value*& ptr, ir::value *&val, ir::value *&mask, ir::builder *builder){ if(!ptr->get_type()->get_scalar_ty()->is_pointer_ty()) throw semantic_error("Pointer argument of store instruction is " + ptr->get_type()->repr()); if(ptr->get_type()->is_block_ty()){ @@ -550,14 +550,69 @@ ir::value *dispatch::atomic_add(ir::value* ptr, ir::value *val, ir::value *mask, } } val = dispatch::cast(val, ptr->get_type()->get_scalar_ty()->get_pointer_element_ty(), builder); - if(!mask){ mask = builder->get_int1(true); if(ptr->get_type()->is_block_ty()) mask = builder->create_splat(mask, ptr->get_type()->get_block_shapes()); } +} - return builder->create_atomic_add(ptr, val, mask); +ir::value *dispatch::atomic_max(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){ + atom_red_typechecking(ptr, val, mask, builder); + ir::type* sca_ty = val->get_type()->get_scalar_ty(); + // direct call to atomic_max for integers + if(sca_ty->is_integer_ty()) + return builder->create_atomic_rmw(ir::atomic_rmw_op_t::Max, ptr, val, mask); + // for float + // return atomic_smax(i_ptr, i_val) if val >= 0 + // return atomic_umin(i_ptr, i_val) if val < 0 + ir::value* i_val = bitcast(val, builder->get_int32_ty(), builder); + ir::value* i_ptr = bitcast(ptr, pointer_type::get(builder->get_int32_ty(), 1), builder); + ir::value* pos = greater_equal(val, constant_fp::get(sca_ty, 0), builder); + ir::value* neg = less_than(val, constant_fp::get(sca_ty, 0), builder); + ir::value* pos_ret = builder->create_atomic_rmw(ir::atomic_rmw_op_t::Max, i_ptr, i_val, and_(mask, pos, builder)); + ir::value* neg_ret = builder->create_atomic_rmw(ir::atomic_rmw_op_t::UMin, i_ptr, i_val, and_(mask, neg, builder)); + return where(pos, pos_ret, neg_ret, builder); +} + +ir::value *dispatch::atomic_min(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){ + atom_red_typechecking(ptr, val, mask, builder); + ir::type* sca_ty = val->get_type()->get_scalar_ty(); + // direct call to atomic_max for integers + if(sca_ty->is_integer_ty()) + return builder->create_atomic_rmw(ir::atomic_rmw_op_t::Min, ptr, val, mask); + // for float + // return atomic_smin(i_ptr, i_val) if val >= 0 + // return atomic_umax(i_ptr, i_val) if val < 0 + ir::value* i_val = bitcast(val, builder->get_int32_ty(), builder); + ir::value* i_ptr = bitcast(ptr, pointer_type::get(builder->get_int32_ty(), 1), builder); + ir::value* pos = greater_equal(val, constant_fp::get(sca_ty, 0), builder); + ir::value* neg = less_than(val, constant_fp::get(sca_ty, 0), builder); + ir::value* pos_ret = builder->create_atomic_rmw(ir::atomic_rmw_op_t::Min, i_ptr, i_val, and_(mask, pos, builder)); + ir::value* neg_ret = builder->create_atomic_rmw(ir::atomic_rmw_op_t::UMax, i_ptr, i_val, and_(mask, neg, builder)); + return where(pos, pos_ret, neg_ret, builder); +} + +ir::value *dispatch::atomic_add(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){ + atom_red_typechecking(ptr, val, mask, builder); + ir::type* sca_ty = val->get_type()->get_scalar_ty(); + auto op = sca_ty->is_floating_point_ty() ? ir::atomic_rmw_op_t::FAdd : ir::atomic_rmw_op_t::Add; + return builder->create_atomic_rmw(op, ptr, val, mask); +} + +ir::value *dispatch::atomic_and(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){ + atom_red_typechecking(ptr, val, mask, builder); + return builder->create_atomic_rmw(ir::atomic_rmw_op_t::And, ptr, val, mask); +} + +ir::value *dispatch::atomic_or(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){ + atom_red_typechecking(ptr, val, mask, builder); + return builder->create_atomic_rmw(ir::atomic_rmw_op_t::Or, ptr, val, mask); +} + +ir::value *dispatch::atomic_xor(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){ + atom_red_typechecking(ptr, val, mask, builder); + return builder->create_atomic_rmw(ir::atomic_rmw_op_t::Xor, ptr, val, mask); } //===----------------------------------------------------------------------===// diff --git a/lib/ir/instructions.cc b/lib/ir/instructions.cc index a7c113680..ebfb26c12 100644 --- a/lib/ir/instructions.cc +++ b/lib/ir/instructions.cc @@ -730,17 +730,20 @@ instruction* get_num_programs_inst::create(context &ctx, unsigned axis, const st return new get_num_programs_inst(type::get_int32_ty(ctx), axis, name, next); } +// atomic_rmw -atomic_add_inst::atomic_add_inst(value *ptr, value *val, value *msk, const std::string &name, instruction *next) - : atomic_inst(ptr->get_type()->get_pointer_element_ty(), INST_ATOMIC_ADD, 3, name, next) { +atomic_rmw_inst::atomic_rmw_inst(atomic_rmw_op_t op, value *ptr, value *val, value *msk, const std::string &name, instruction *next) + : atomic_inst(ptr->get_type()->get_pointer_element_ty(), INST_ATOMIC_RMW, 3, name, next), op_(op) { set_operand(0, ptr); set_operand(1, val); set_operand(2, msk); } -instruction* atomic_add_inst::create(value *ptr, value *val, value *msk, const std::string &name, instruction *next) { - return new atomic_add_inst(ptr, val, msk, name, next); +instruction* atomic_rmw_inst::create(atomic_rmw_op_t op, value *ptr, value *val, value *msk, const std::string &name, instruction *next) { + return new atomic_rmw_inst(op, ptr, val, msk, name, next); } + + // atomic cas atomic_cas_inst::atomic_cas_inst(value *ptr, value *cmp, value *val, const std::string &name, instruction *next) diff --git a/python/src/triton.cc b/python/src/triton.cc index 43a8e5642..2c4e65f1a 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -137,6 +137,11 @@ void init_triton_frontend(py::module &&m) { m.def("atomic_cas", &ir::dispatch::atomic_cas, ret::reference); m.def("atomic_xchg", &ir::dispatch::atomic_xchg, ret::reference); m.def("atomic_add", &ir::dispatch::atomic_add, ret::reference); + m.def("atomic_max", &ir::dispatch::atomic_max, ret::reference); + m.def("atomic_min", &ir::dispatch::atomic_min, ret::reference); + m.def("atomic_and", &ir::dispatch::atomic_and, ret::reference); + m.def("atomic_or", &ir::dispatch::atomic_or, ret::reference); + m.def("atomic_xor", &ir::dispatch::atomic_xor, ret::reference); // linear algebra m.def("dot", &ir::dispatch::dot, ret::reference); // indexing diff --git a/python/test/test_language.py b/python/test/test_language.py index 35cbe7b0e..dd5215b24 100644 --- a/python/test/test_language.py +++ b/python/test/test_language.py @@ -4,6 +4,7 @@ import triton.language as tl import copy import pytest import ast +import itertools torch.manual_seed(0) @@ -238,8 +239,13 @@ def test_tuples(): # --------------- # test atomics # --------------- -@pytest.mark.parametrize("dtype_x", ['int32', 'float16', 'float32']) -def test_atomic_add(dtype_x, device='cuda'): +@pytest.mark.parametrize("op, dtype_x, mode", itertools.chain.from_iterable([ + [('add', 'int32', mode), ('add', 'float16', mode), ('add', 'float32', mode), \ + ('max', 'int32', mode), ('max', 'float32', mode),\ + ('min', 'int32', mode), ('min', 'float32', mode),\ + ] + for mode in ['all_neg', 'all_pos', 'min_neg', 'max_pos']])) +def test_atomic_rmw(op, dtype_x, mode, device='cuda'): dtype_x = cvt[dtype_x] n_programs = 37 @@ -247,20 +253,39 @@ def test_atomic_add(dtype_x, device='cuda'): @triton.jit def kernel(X, Z, **meta): pid = tl.program_id(0) - old = tl.atomic_add(X, pid) - tl.store(Z + pid, old) + x = tl.load(X + pid) + old = GENERATE_TEST_HERE + + kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.atomic_{op}(Z, x)'}) + torch_op = {'add': torch.sum, 'max': torch.max, 'min': torch.min}[op] + max_neutral = float('-inf') if dtype_x.is_floating_point else torch.iinfo(dtype_x).min + min_neutral = float('inf') if dtype_x.is_floating_point else torch.iinfo(dtype_x).max + neutral = {'add': 0, 'max': max_neutral, 'min': min_neutral}[op] # triton result - x_tri = torch.zeros((1, ), dtype=dtype_x, device=device) - z_tri = torch.empty((n_programs, ), dtype=torch.int32, device=device) + x_tri = triton.testing.random((n_programs, ), dtype=dtype_x, device=device) + if mode == 'all_neg': + x_tri = -torch.abs(x_tri) + if mode == 'all_pos': + x_tri = torch.abs(x_tri) + if mode == 'min_neg': + idx = torch.randint(n_programs, size=(1, )).item() + x_tri[idx] = -torch.max(torch.abs(x_tri)) - 1 + if mode == 'max_pos': + idx = torch.randint(n_programs, size=(1, )).item() + x_tri[idx] = torch.max(torch.abs(x_tri)) + 1 + + z_tri = torch.empty([], dtype=dtype_x, device=device) + z_tri.fill_(neutral) kernel[(n_programs, )](x_tri, z_tri) - last_sum = torch.max(z_tri) + torch.argmax(z_tri) - last_sum = last_sum.to(dtype_x) # torch result - range = torch.arange(n_programs, dtype=torch.int32, device=device) - x_ref = torch.sum(range).to(dtype_x) - triton.testing.assert_allclose(x_ref, x_tri[0]) - triton.testing.assert_allclose(x_ref, last_sum) + z_ref = torch_op(x_tri).to(dtype_x) + # compare + exact = op not in ['add'] + if exact: + assert z_ref.item() == z_tri.item() + else: + triton.testing.assert_allclose(z_ref, z_tri) # --------------- diff --git a/python/triton/language.py b/python/triton/language.py index e5791cebf..6029058bd 100644 --- a/python/triton/language.py +++ b/python/triton/language.py @@ -89,6 +89,8 @@ float16 = dtype(ir.type.get_fp16) float32 = dtype(ir.type.get_fp32) float64 = dtype(ir.type.get_fp64) +pi32_t = pointer_dtype(int32) + class block: @staticmethod @@ -464,6 +466,31 @@ def atomic_add(pointer, val, mask=None, builder=None): return frontend.atomic_add(pointer, val, mask, builder) +@builtin +def atomic_max(pointer, val, mask=None, builder=None): + return frontend.atomic_max(pointer, val, mask, builder) + + +@builtin +def atomic_min(pointer, val, mask=None, builder=None): + return frontend.atomic_min(pointer, val, mask, builder) + + +@builtin +def atomic_and(pointer, val, mask=None, builder=None): + return frontend.atomic_and(pointer, val, mask, builder) + + +@builtin +def atomic_or(pointer, val, mask=None, builder=None): + return frontend.atomic_or(pointer, val, mask, builder) + + +@builtin +def atomic_xor(pointer, val, mask=None, builder=None): + return frontend.atomic_xor(pointer, val, mask, builder) + + # ----------------------- # Conditioning # ----------------------- diff --git a/python/triton/testing.py b/python/triton/testing.py index 2c9aa092f..e67d792c4 100644 --- a/python/triton/testing.py +++ b/python/triton/testing.py @@ -76,7 +76,7 @@ def random(shape, dtype, device): if dtype in [torch.int8, torch.int16, torch.int32, torch.int64]: return torch.randint(1, 32, shape, dtype=dtype, device=device) if dtype in [torch.float16, torch.float32, torch.float64]: - return torch.randn(shape, dtype=dtype, device=device) + return torch.normal(0, 10, shape, dtype=dtype, device=device) raise RuntimeError(f'Unknown dtype {dtype}')