[CODEGEN] Fixed atomic_add issue (#112)
* [CODEGEN] Fixed atomic_add issue * [CODEGEN] Fixed liveness analysis bug for instructions that are not DCE'd but have no users (e.g., atomic_cas)
This commit is contained in:
committed by
Philippe Tillet
parent
325ee38581
commit
f81012a8cf
@@ -452,16 +452,6 @@ public:
|
||||
_TRITON_DEFINE_ACCEPT(masked_load_async_inst)
|
||||
};
|
||||
|
||||
class atomic_add_inst: public io_inst {
|
||||
private:
|
||||
atomic_add_inst(value *ptr, value *val, value *msk, const std::string &name = "", instruction *next = nullptr);
|
||||
std::string repr_impl() const { return "atomic_add"; }
|
||||
_TRITON_DEFINE_CLONE(atomic_add_inst)
|
||||
_TRITON_DEFINE_ACCEPT(atomic_add_inst)
|
||||
|
||||
public:
|
||||
static instruction* create(value *ptr, value *val, value *msk, const std::string &name = "", instruction *next = nullptr);
|
||||
};
|
||||
|
||||
|
||||
// store
|
||||
@@ -612,7 +602,25 @@ private:
|
||||
unsigned axis_;
|
||||
};
|
||||
|
||||
class atomic_cas_inst: public builtin_inst {
|
||||
|
||||
class atomic_inst: public io_inst {
|
||||
public:
|
||||
using io_inst::io_inst;
|
||||
};
|
||||
|
||||
class atomic_add_inst: public atomic_inst {
|
||||
private:
|
||||
atomic_add_inst(value *ptr, value *val, value *msk, const std::string &name = "", instruction *next = nullptr);
|
||||
std::string repr_impl() const { return "atomic_add"; }
|
||||
_TRITON_DEFINE_CLONE(atomic_add_inst)
|
||||
_TRITON_DEFINE_ACCEPT(atomic_add_inst)
|
||||
|
||||
public:
|
||||
static instruction* create(value *ptr, value *val, value *msk, const std::string &name = "", instruction *next = nullptr);
|
||||
};
|
||||
|
||||
|
||||
class atomic_cas_inst: public atomic_inst {
|
||||
private:
|
||||
atomic_cas_inst(value *ptr, value *cmp, value *val, const std::string &name, instruction *next);
|
||||
std::string repr_impl() const { return "atomic_cas"; }
|
||||
@@ -623,7 +631,7 @@ public:
|
||||
static instruction* create(value *ptr, value *cmp, value *val, const std::string &name = "", instruction *next = nullptr);
|
||||
};
|
||||
|
||||
class atomic_exch_inst: public builtin_inst {
|
||||
class atomic_exch_inst: public atomic_inst {
|
||||
private:
|
||||
atomic_exch_inst(value *ptr, value *val, const std::string &name = "", instruction *next = nullptr);
|
||||
std::string repr_impl() const { return "atomic_exch"; }
|
||||
|
@@ -436,7 +436,7 @@ void layouts::run(ir::module &mod) {
|
||||
layouts_[id] = new shared_layout(out_layout, axes_->get(val), shape, {recoalasce}, val->get_type()->get_scalar_ty(), align_);
|
||||
tmp_[recoalasce] = id;
|
||||
}
|
||||
if(auto *atom = dynamic_cast<ir::atomic_cas_inst*>(i)){
|
||||
if(auto *atom = dynamic_cast<ir::atomic_inst*>(i)){
|
||||
id++;
|
||||
layouts_[id] = new shared_layout(nullptr, {}, {1}, {atom}, atom->get_type()->get_scalar_ty(), align_);
|
||||
tmp_[atom] = id;
|
||||
|
@@ -45,6 +45,8 @@ void liveness::run(ir::module &mod) {
|
||||
for(ir::user *u: users)
|
||||
if(indices.find(u) != indices.end())
|
||||
end = std::max(end, indices.at(u));
|
||||
if(end == 0)
|
||||
end = start + 1;
|
||||
intervals_[layout] = segment{start, end};
|
||||
}
|
||||
|
||||
|
@@ -902,13 +902,14 @@ void generator::visit_atomic_add_inst(ir::atomic_add_inst* add) {
|
||||
// asm string
|
||||
std::string suffix = vec == 2 ? "x2" : "";
|
||||
std::string mod = nbits == 32 ? "" : ".noftz";
|
||||
std::string asm_str = "@$0 atom.global.gpu.add" + mod + ".f" + std::to_string(nbits) + suffix + " $1, [$2" + offset + "], $3;";
|
||||
std::string ty_id = nbits == 32 ? "f" : (vec == 1 ? "h" : "r");
|
||||
std::string ty_str = ty->isFloatingPointTy() ? "f" : "u";
|
||||
std::string asm_str = "@$0 atom.global.gpu.add" + mod + "." + ty_str + std::to_string(nbits) + suffix + " $1, [$2" + offset + "], $3;";
|
||||
std::string ty_id = nbits == 32 ? ty_str : (vec == 1 ? "h" : "r");
|
||||
std::string constraint = "b,=" + ty_id + ",l," + ty_id;
|
||||
// create inline asm
|
||||
InlineAsm *iasm = InlineAsm::get(fn_ty, asm_str, constraint, true);
|
||||
// call asm
|
||||
call(iasm, (ArrayRef<Value*>{rmw_msk, rmw_ptr, rmw_val}));
|
||||
vals_[add][idx] = call(iasm, (ArrayRef<Value*>{rmw_msk, rmw_ptr, rmw_val}));
|
||||
}
|
||||
}
|
||||
else{
|
||||
@@ -920,8 +921,9 @@ void generator::visit_atomic_add_inst(ir::atomic_add_inst* add) {
|
||||
std::vector<Type*> arg_ty = {rmw_msk->getType(), rmw_ptr->getType(), rmw_val->getType()};
|
||||
FunctionType *fn_ty = FunctionType::get(ty, arg_ty, false);
|
||||
std::string mod = nbits == 32 ? "" : ".noftz";
|
||||
std::string asm_str = "@$0 atom.global.gpu.add" + mod + ".f" + std::to_string(nbits) + " $1, [$2], $3;";
|
||||
std::string ty_id = nbits == 32 ? "f" : "h";
|
||||
std::string ty_str = ty->isFloatingPointTy() ? "f" : "u";
|
||||
std::string asm_str = "@$0 atom.global.gpu.add" + mod + "." + ty_str + std::to_string(nbits) + " $1, [$2], $3;";
|
||||
std::string ty_id = nbits == 32 ? "r" : "h";
|
||||
InlineAsm *iasm = InlineAsm::get(fn_ty, asm_str, "b,="+ty_id+",l,"+ty_id, true);
|
||||
|
||||
BasicBlock *current = builder_->GetInsertBlock();
|
||||
@@ -935,10 +937,16 @@ void generator::visit_atomic_add_inst(ir::atomic_add_inst* add) {
|
||||
add_barrier();
|
||||
cond_br(pred, tid_0_bb, tid_0_done_bb);
|
||||
builder_->SetInsertPoint(tid_0_bb);
|
||||
call(iasm, (ArrayRef<Value*>{rmw_msk, rmw_ptr, rmw_val}));
|
||||
Value *old = call(iasm, (ArrayRef<Value*>{rmw_msk, rmw_ptr, rmw_val}));
|
||||
Value *atom_ptr;
|
||||
atom_ptr = gep(shmem_, i32(alloc_->offset(layouts_->get(layouts_->tmp(add)))), "");
|
||||
atom_ptr = bit_cast(atom_ptr, ptr_ty(old->getType(), 3));
|
||||
store(old, atom_ptr);
|
||||
br(tid_0_done_bb);
|
||||
builder_->SetInsertPoint(tid_0_done_bb);
|
||||
tgt_->add_memfence(module, *builder_);
|
||||
add_barrier();
|
||||
vals_[add][{}] = load(atom_ptr);
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -484,19 +484,6 @@ masked_load_async_inst* masked_load_async_inst::create(value *ptr, value *mask,
|
||||
return new masked_load_async_inst(ptr, mask, false_value, name, next);
|
||||
}
|
||||
|
||||
// atomic add
|
||||
|
||||
atomic_add_inst::atomic_add_inst(value *ptr, value *val, value *msk, const std::string &name, instruction *next)
|
||||
: io_inst(ptr->get_type()->get_pointer_element_ty(), INST_ATOMIC_ADD, 3, name, next) {
|
||||
set_operand(0, ptr);
|
||||
set_operand(1, val);
|
||||
set_operand(2, msk);
|
||||
}
|
||||
|
||||
instruction* atomic_add_inst::create(value *ptr, value *val, value *msk, const std::string &name, instruction *next) {
|
||||
return new atomic_add_inst(ptr, val, msk, name, next);
|
||||
}
|
||||
|
||||
// store
|
||||
|
||||
store_inst::store_inst(value *ptr, value_id_t id, unsigned num_ops, const std::string &name, instruction *next)
|
||||
@@ -744,10 +731,20 @@ instruction* get_num_programs_inst::create(context &ctx, unsigned axis, const st
|
||||
}
|
||||
|
||||
|
||||
atomic_add_inst::atomic_add_inst(value *ptr, value *val, value *msk, const std::string &name, instruction *next)
|
||||
: atomic_inst(ptr->get_type()->get_pointer_element_ty(), INST_ATOMIC_ADD, 3, name, next) {
|
||||
set_operand(0, ptr);
|
||||
set_operand(1, val);
|
||||
set_operand(2, msk);
|
||||
}
|
||||
|
||||
instruction* atomic_add_inst::create(value *ptr, value *val, value *msk, const std::string &name, instruction *next) {
|
||||
return new atomic_add_inst(ptr, val, msk, name, next);
|
||||
}
|
||||
// atomic cas
|
||||
|
||||
atomic_cas_inst::atomic_cas_inst(value *ptr, value *cmp, value *val, const std::string &name, instruction *next)
|
||||
: builtin_inst(ptr->get_type()->get_pointer_element_ty(), INST_ATOMIC_CAS, 3, name, next) {
|
||||
: atomic_inst(ptr->get_type()->get_pointer_element_ty(), INST_ATOMIC_CAS, 3, name, next) {
|
||||
set_operand(0, ptr);
|
||||
set_operand(1, cmp);
|
||||
set_operand(2, val);
|
||||
@@ -760,7 +757,7 @@ instruction* atomic_cas_inst::create(value *ptr, value *cmp, value *val, const s
|
||||
// atomic exch
|
||||
|
||||
atomic_exch_inst::atomic_exch_inst(value *ptr, value *val, const std::string &name, instruction *next)
|
||||
: builtin_inst(ptr->get_type()->get_pointer_element_ty(), INST_ATOMIC_EXCH, 2, name, next) {
|
||||
: atomic_inst(ptr->get_type()->get_pointer_element_ty(), INST_ATOMIC_EXCH, 2, name, next) {
|
||||
set_operand(0, ptr);
|
||||
set_operand(1, val);
|
||||
}
|
||||
|
@@ -189,6 +189,34 @@ def test_index1d(expr, device='cuda'):
|
||||
triton.testing.assert_allclose(z_ref, z_tri)
|
||||
|
||||
|
||||
# ---------------
|
||||
# test atomics
|
||||
# ---------------
|
||||
@pytest.mark.parametrize("dtype_x", ['int32', 'float16', 'float32'])
|
||||
def test_atomic_add(dtype_x, device='cuda'):
|
||||
dtype_x = cvt[dtype_x]
|
||||
n_programs = 37
|
||||
|
||||
# triton kernel
|
||||
@triton.jit
|
||||
def kernel(X, Z, **meta):
|
||||
pid = tl.program_id(0)
|
||||
old = tl.atomic_add(X, pid)
|
||||
tl.store(Z + pid, old)
|
||||
|
||||
# triton result
|
||||
x_tri = torch.zeros((1, ), dtype=dtype_x, device=device)
|
||||
z_tri = torch.empty((n_programs, ), dtype=torch.int32, device=device)
|
||||
kernel[(n_programs, )](x_tri, z_tri)
|
||||
last_sum = torch.max(z_tri) + torch.argmax(z_tri)
|
||||
last_sum = last_sum.to(dtype_x)
|
||||
# torch result
|
||||
range = torch.arange(n_programs, dtype=torch.int32, device=device)
|
||||
x_ref = torch.sum(range).to(dtype_x)
|
||||
triton.testing.assert_allclose(x_ref, x_tri[0])
|
||||
triton.testing.assert_allclose(x_ref, last_sum)
|
||||
|
||||
|
||||
# ---------------
|
||||
# test load
|
||||
# ---------------
|
||||
|
Reference in New Issue
Block a user