From f81012a8cf4953b4df638e1724a5eb5b571963c4 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Wed, 19 May 2021 21:40:41 -0400 Subject: [PATCH] [CODEGEN] Fixed atomic_add issue (#112) * [CODEGEN] Fixed atomic_add issue * [CODEGEN] Fixed liveness analysis bug for instructions that are not DCE'd but have no users (e.g., atomic_cas) --- include/triton/ir/instructions.h | 32 +++++++++++++++++++----------- lib/codegen/analysis/layout.cc | 2 +- lib/codegen/analysis/liveness.cc | 2 ++ lib/codegen/selection/generator.cc | 20 +++++++++++++------ lib/ir/instructions.cc | 27 +++++++++++-------------- python/test/test_language.py | 28 ++++++++++++++++++++++++++ 6 files changed, 77 insertions(+), 34 deletions(-) diff --git a/include/triton/ir/instructions.h b/include/triton/ir/instructions.h index ca06080ff..ed4ad764e 100644 --- a/include/triton/ir/instructions.h +++ b/include/triton/ir/instructions.h @@ -452,16 +452,6 @@ public: _TRITON_DEFINE_ACCEPT(masked_load_async_inst) }; -class atomic_add_inst: public io_inst { -private: - atomic_add_inst(value *ptr, value *val, value *msk, const std::string &name = "", instruction *next = nullptr); - std::string repr_impl() const { return "atomic_add"; } - _TRITON_DEFINE_CLONE(atomic_add_inst) - _TRITON_DEFINE_ACCEPT(atomic_add_inst) - -public: - static instruction* create(value *ptr, value *val, value *msk, const std::string &name = "", instruction *next = nullptr); -}; // store @@ -612,7 +602,25 @@ private: unsigned axis_; }; -class atomic_cas_inst: public builtin_inst { + +class atomic_inst: public io_inst { +public: + using io_inst::io_inst; +}; + +class atomic_add_inst: public atomic_inst { +private: + atomic_add_inst(value *ptr, value *val, value *msk, const std::string &name = "", instruction *next = nullptr); + std::string repr_impl() const { return "atomic_add"; } + _TRITON_DEFINE_CLONE(atomic_add_inst) + _TRITON_DEFINE_ACCEPT(atomic_add_inst) + +public: + static instruction* create(value *ptr, value *val, value *msk, const std::string &name = "", instruction *next = nullptr); +}; + + +class atomic_cas_inst: public atomic_inst { private: atomic_cas_inst(value *ptr, value *cmp, value *val, const std::string &name, instruction *next); std::string repr_impl() const { return "atomic_cas"; } @@ -623,7 +631,7 @@ public: static instruction* create(value *ptr, value *cmp, value *val, const std::string &name = "", instruction *next = nullptr); }; -class atomic_exch_inst: public builtin_inst { +class atomic_exch_inst: public atomic_inst { private: atomic_exch_inst(value *ptr, value *val, const std::string &name = "", instruction *next = nullptr); std::string repr_impl() const { return "atomic_exch"; } diff --git a/lib/codegen/analysis/layout.cc b/lib/codegen/analysis/layout.cc index 0d53f4c73..faab4a47f 100644 --- a/lib/codegen/analysis/layout.cc +++ b/lib/codegen/analysis/layout.cc @@ -436,7 +436,7 @@ void layouts::run(ir::module &mod) { layouts_[id] = new shared_layout(out_layout, axes_->get(val), shape, {recoalasce}, val->get_type()->get_scalar_ty(), align_); tmp_[recoalasce] = id; } - if(auto *atom = dynamic_cast(i)){ + if(auto *atom = dynamic_cast(i)){ id++; layouts_[id] = new shared_layout(nullptr, {}, {1}, {atom}, atom->get_type()->get_scalar_ty(), align_); tmp_[atom] = id; diff --git a/lib/codegen/analysis/liveness.cc b/lib/codegen/analysis/liveness.cc index 224a93fc9..7beae21a1 100644 --- a/lib/codegen/analysis/liveness.cc +++ b/lib/codegen/analysis/liveness.cc @@ -45,6 +45,8 @@ void liveness::run(ir::module &mod) { for(ir::user *u: users) if(indices.find(u) != indices.end()) end = std::max(end, indices.at(u)); + if(end == 0) + end = start + 1; intervals_[layout] = segment{start, end}; } diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index bb9f518e9..630ae855f 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -902,13 +902,14 @@ void generator::visit_atomic_add_inst(ir::atomic_add_inst* add) { // asm string std::string suffix = vec == 2 ? "x2" : ""; std::string mod = nbits == 32 ? "" : ".noftz"; - std::string asm_str = "@$0 atom.global.gpu.add" + mod + ".f" + std::to_string(nbits) + suffix + " $1, [$2" + offset + "], $3;"; - std::string ty_id = nbits == 32 ? "f" : (vec == 1 ? "h" : "r"); + std::string ty_str = ty->isFloatingPointTy() ? "f" : "u"; + std::string asm_str = "@$0 atom.global.gpu.add" + mod + "." + ty_str + std::to_string(nbits) + suffix + " $1, [$2" + offset + "], $3;"; + std::string ty_id = nbits == 32 ? ty_str : (vec == 1 ? "h" : "r"); std::string constraint = "b,=" + ty_id + ",l," + ty_id; // create inline asm InlineAsm *iasm = InlineAsm::get(fn_ty, asm_str, constraint, true); // call asm - call(iasm, (ArrayRef{rmw_msk, rmw_ptr, rmw_val})); + vals_[add][idx] = call(iasm, (ArrayRef{rmw_msk, rmw_ptr, rmw_val})); } } else{ @@ -920,8 +921,9 @@ void generator::visit_atomic_add_inst(ir::atomic_add_inst* add) { std::vector arg_ty = {rmw_msk->getType(), rmw_ptr->getType(), rmw_val->getType()}; FunctionType *fn_ty = FunctionType::get(ty, arg_ty, false); std::string mod = nbits == 32 ? "" : ".noftz"; - std::string asm_str = "@$0 atom.global.gpu.add" + mod + ".f" + std::to_string(nbits) + " $1, [$2], $3;"; - std::string ty_id = nbits == 32 ? "f" : "h"; + std::string ty_str = ty->isFloatingPointTy() ? "f" : "u"; + std::string asm_str = "@$0 atom.global.gpu.add" + mod + "." + ty_str + std::to_string(nbits) + " $1, [$2], $3;"; + std::string ty_id = nbits == 32 ? "r" : "h"; InlineAsm *iasm = InlineAsm::get(fn_ty, asm_str, "b,="+ty_id+",l,"+ty_id, true); BasicBlock *current = builder_->GetInsertBlock(); @@ -935,10 +937,16 @@ void generator::visit_atomic_add_inst(ir::atomic_add_inst* add) { add_barrier(); cond_br(pred, tid_0_bb, tid_0_done_bb); builder_->SetInsertPoint(tid_0_bb); - call(iasm, (ArrayRef{rmw_msk, rmw_ptr, rmw_val})); + Value *old = call(iasm, (ArrayRef{rmw_msk, rmw_ptr, rmw_val})); + Value *atom_ptr; + atom_ptr = gep(shmem_, i32(alloc_->offset(layouts_->get(layouts_->tmp(add)))), ""); + atom_ptr = bit_cast(atom_ptr, ptr_ty(old->getType(), 3)); + store(old, atom_ptr); br(tid_0_done_bb); builder_->SetInsertPoint(tid_0_done_bb); tgt_->add_memfence(module, *builder_); + add_barrier(); + vals_[add][{}] = load(atom_ptr); } } diff --git a/lib/ir/instructions.cc b/lib/ir/instructions.cc index 1162b1d15..a7c113680 100644 --- a/lib/ir/instructions.cc +++ b/lib/ir/instructions.cc @@ -484,19 +484,6 @@ masked_load_async_inst* masked_load_async_inst::create(value *ptr, value *mask, return new masked_load_async_inst(ptr, mask, false_value, name, next); } -// atomic add - -atomic_add_inst::atomic_add_inst(value *ptr, value *val, value *msk, const std::string &name, instruction *next) - : io_inst(ptr->get_type()->get_pointer_element_ty(), INST_ATOMIC_ADD, 3, name, next) { - set_operand(0, ptr); - set_operand(1, val); - set_operand(2, msk); -} - -instruction* atomic_add_inst::create(value *ptr, value *val, value *msk, const std::string &name, instruction *next) { - return new atomic_add_inst(ptr, val, msk, name, next); -} - // store store_inst::store_inst(value *ptr, value_id_t id, unsigned num_ops, const std::string &name, instruction *next) @@ -744,10 +731,20 @@ instruction* get_num_programs_inst::create(context &ctx, unsigned axis, const st } +atomic_add_inst::atomic_add_inst(value *ptr, value *val, value *msk, const std::string &name, instruction *next) + : atomic_inst(ptr->get_type()->get_pointer_element_ty(), INST_ATOMIC_ADD, 3, name, next) { + set_operand(0, ptr); + set_operand(1, val); + set_operand(2, msk); +} + +instruction* atomic_add_inst::create(value *ptr, value *val, value *msk, const std::string &name, instruction *next) { + return new atomic_add_inst(ptr, val, msk, name, next); +} // atomic cas atomic_cas_inst::atomic_cas_inst(value *ptr, value *cmp, value *val, const std::string &name, instruction *next) - : builtin_inst(ptr->get_type()->get_pointer_element_ty(), INST_ATOMIC_CAS, 3, name, next) { + : atomic_inst(ptr->get_type()->get_pointer_element_ty(), INST_ATOMIC_CAS, 3, name, next) { set_operand(0, ptr); set_operand(1, cmp); set_operand(2, val); @@ -760,7 +757,7 @@ instruction* atomic_cas_inst::create(value *ptr, value *cmp, value *val, const s // atomic exch atomic_exch_inst::atomic_exch_inst(value *ptr, value *val, const std::string &name, instruction *next) - : builtin_inst(ptr->get_type()->get_pointer_element_ty(), INST_ATOMIC_EXCH, 2, name, next) { + : atomic_inst(ptr->get_type()->get_pointer_element_ty(), INST_ATOMIC_EXCH, 2, name, next) { set_operand(0, ptr); set_operand(1, val); } diff --git a/python/test/test_language.py b/python/test/test_language.py index fbf5af3b1..4b447af78 100644 --- a/python/test/test_language.py +++ b/python/test/test_language.py @@ -189,6 +189,34 @@ def test_index1d(expr, device='cuda'): triton.testing.assert_allclose(z_ref, z_tri) +# --------------- +# test atomics +# --------------- +@pytest.mark.parametrize("dtype_x", ['int32', 'float16', 'float32']) +def test_atomic_add(dtype_x, device='cuda'): + dtype_x = cvt[dtype_x] + n_programs = 37 + + # triton kernel + @triton.jit + def kernel(X, Z, **meta): + pid = tl.program_id(0) + old = tl.atomic_add(X, pid) + tl.store(Z + pid, old) + + # triton result + x_tri = torch.zeros((1, ), dtype=dtype_x, device=device) + z_tri = torch.empty((n_programs, ), dtype=torch.int32, device=device) + kernel[(n_programs, )](x_tri, z_tri) + last_sum = torch.max(z_tri) + torch.argmax(z_tri) + last_sum = last_sum.to(dtype_x) + # torch result + range = torch.arange(n_programs, dtype=torch.int32, device=device) + x_ref = torch.sum(range).to(dtype_x) + triton.testing.assert_allclose(x_ref, x_tri[0]) + triton.testing.assert_allclose(x_ref, last_sum) + + # --------------- # test load # ---------------