[CODEGEN] Fixed atomic_add issue (#112)

* [CODEGEN] Fixed atomic_add issue

* [CODEGEN] Fixed liveness analysis bug for instructions that are not
DCE'd but have no users (e.g., atomic_cas)
This commit is contained in:
Philippe Tillet
2021-05-19 21:40:41 -04:00
committed by Philippe Tillet
parent 325ee38581
commit f81012a8cf
6 changed files with 77 additions and 34 deletions

View File

@@ -452,16 +452,6 @@ public:
_TRITON_DEFINE_ACCEPT(masked_load_async_inst)
};
class atomic_add_inst: public io_inst {
private:
atomic_add_inst(value *ptr, value *val, value *msk, const std::string &name = "", instruction *next = nullptr);
std::string repr_impl() const { return "atomic_add"; }
_TRITON_DEFINE_CLONE(atomic_add_inst)
_TRITON_DEFINE_ACCEPT(atomic_add_inst)
public:
static instruction* create(value *ptr, value *val, value *msk, const std::string &name = "", instruction *next = nullptr);
};
// store
@@ -612,7 +602,25 @@ private:
unsigned axis_;
};
class atomic_cas_inst: public builtin_inst {
class atomic_inst: public io_inst {
public:
using io_inst::io_inst;
};
class atomic_add_inst: public atomic_inst {
private:
atomic_add_inst(value *ptr, value *val, value *msk, const std::string &name = "", instruction *next = nullptr);
std::string repr_impl() const { return "atomic_add"; }
_TRITON_DEFINE_CLONE(atomic_add_inst)
_TRITON_DEFINE_ACCEPT(atomic_add_inst)
public:
static instruction* create(value *ptr, value *val, value *msk, const std::string &name = "", instruction *next = nullptr);
};
class atomic_cas_inst: public atomic_inst {
private:
atomic_cas_inst(value *ptr, value *cmp, value *val, const std::string &name, instruction *next);
std::string repr_impl() const { return "atomic_cas"; }
@@ -623,7 +631,7 @@ public:
static instruction* create(value *ptr, value *cmp, value *val, const std::string &name = "", instruction *next = nullptr);
};
class atomic_exch_inst: public builtin_inst {
class atomic_exch_inst: public atomic_inst {
private:
atomic_exch_inst(value *ptr, value *val, const std::string &name = "", instruction *next = nullptr);
std::string repr_impl() const { return "atomic_exch"; }

View File

@@ -436,7 +436,7 @@ void layouts::run(ir::module &mod) {
layouts_[id] = new shared_layout(out_layout, axes_->get(val), shape, {recoalasce}, val->get_type()->get_scalar_ty(), align_);
tmp_[recoalasce] = id;
}
if(auto *atom = dynamic_cast<ir::atomic_cas_inst*>(i)){
if(auto *atom = dynamic_cast<ir::atomic_inst*>(i)){
id++;
layouts_[id] = new shared_layout(nullptr, {}, {1}, {atom}, atom->get_type()->get_scalar_ty(), align_);
tmp_[atom] = id;

View File

@@ -45,6 +45,8 @@ void liveness::run(ir::module &mod) {
for(ir::user *u: users)
if(indices.find(u) != indices.end())
end = std::max(end, indices.at(u));
if(end == 0)
end = start + 1;
intervals_[layout] = segment{start, end};
}

View File

@@ -902,13 +902,14 @@ void generator::visit_atomic_add_inst(ir::atomic_add_inst* add) {
// asm string
std::string suffix = vec == 2 ? "x2" : "";
std::string mod = nbits == 32 ? "" : ".noftz";
std::string asm_str = "@$0 atom.global.gpu.add" + mod + ".f" + std::to_string(nbits) + suffix + " $1, [$2" + offset + "], $3;";
std::string ty_id = nbits == 32 ? "f" : (vec == 1 ? "h" : "r");
std::string ty_str = ty->isFloatingPointTy() ? "f" : "u";
std::string asm_str = "@$0 atom.global.gpu.add" + mod + "." + ty_str + std::to_string(nbits) + suffix + " $1, [$2" + offset + "], $3;";
std::string ty_id = nbits == 32 ? ty_str : (vec == 1 ? "h" : "r");
std::string constraint = "b,=" + ty_id + ",l," + ty_id;
// create inline asm
InlineAsm *iasm = InlineAsm::get(fn_ty, asm_str, constraint, true);
// call asm
call(iasm, (ArrayRef<Value*>{rmw_msk, rmw_ptr, rmw_val}));
vals_[add][idx] = call(iasm, (ArrayRef<Value*>{rmw_msk, rmw_ptr, rmw_val}));
}
}
else{
@@ -920,8 +921,9 @@ void generator::visit_atomic_add_inst(ir::atomic_add_inst* add) {
std::vector<Type*> arg_ty = {rmw_msk->getType(), rmw_ptr->getType(), rmw_val->getType()};
FunctionType *fn_ty = FunctionType::get(ty, arg_ty, false);
std::string mod = nbits == 32 ? "" : ".noftz";
std::string asm_str = "@$0 atom.global.gpu.add" + mod + ".f" + std::to_string(nbits) + " $1, [$2], $3;";
std::string ty_id = nbits == 32 ? "f" : "h";
std::string ty_str = ty->isFloatingPointTy() ? "f" : "u";
std::string asm_str = "@$0 atom.global.gpu.add" + mod + "." + ty_str + std::to_string(nbits) + " $1, [$2], $3;";
std::string ty_id = nbits == 32 ? "r" : "h";
InlineAsm *iasm = InlineAsm::get(fn_ty, asm_str, "b,="+ty_id+",l,"+ty_id, true);
BasicBlock *current = builder_->GetInsertBlock();
@@ -935,10 +937,16 @@ void generator::visit_atomic_add_inst(ir::atomic_add_inst* add) {
add_barrier();
cond_br(pred, tid_0_bb, tid_0_done_bb);
builder_->SetInsertPoint(tid_0_bb);
call(iasm, (ArrayRef<Value*>{rmw_msk, rmw_ptr, rmw_val}));
Value *old = call(iasm, (ArrayRef<Value*>{rmw_msk, rmw_ptr, rmw_val}));
Value *atom_ptr;
atom_ptr = gep(shmem_, i32(alloc_->offset(layouts_->get(layouts_->tmp(add)))), "");
atom_ptr = bit_cast(atom_ptr, ptr_ty(old->getType(), 3));
store(old, atom_ptr);
br(tid_0_done_bb);
builder_->SetInsertPoint(tid_0_done_bb);
tgt_->add_memfence(module, *builder_);
add_barrier();
vals_[add][{}] = load(atom_ptr);
}
}

View File

@@ -484,19 +484,6 @@ masked_load_async_inst* masked_load_async_inst::create(value *ptr, value *mask,
return new masked_load_async_inst(ptr, mask, false_value, name, next);
}
// atomic add
atomic_add_inst::atomic_add_inst(value *ptr, value *val, value *msk, const std::string &name, instruction *next)
: io_inst(ptr->get_type()->get_pointer_element_ty(), INST_ATOMIC_ADD, 3, name, next) {
set_operand(0, ptr);
set_operand(1, val);
set_operand(2, msk);
}
instruction* atomic_add_inst::create(value *ptr, value *val, value *msk, const std::string &name, instruction *next) {
return new atomic_add_inst(ptr, val, msk, name, next);
}
// store
store_inst::store_inst(value *ptr, value_id_t id, unsigned num_ops, const std::string &name, instruction *next)
@@ -744,10 +731,20 @@ instruction* get_num_programs_inst::create(context &ctx, unsigned axis, const st
}
atomic_add_inst::atomic_add_inst(value *ptr, value *val, value *msk, const std::string &name, instruction *next)
: atomic_inst(ptr->get_type()->get_pointer_element_ty(), INST_ATOMIC_ADD, 3, name, next) {
set_operand(0, ptr);
set_operand(1, val);
set_operand(2, msk);
}
instruction* atomic_add_inst::create(value *ptr, value *val, value *msk, const std::string &name, instruction *next) {
return new atomic_add_inst(ptr, val, msk, name, next);
}
// atomic cas
atomic_cas_inst::atomic_cas_inst(value *ptr, value *cmp, value *val, const std::string &name, instruction *next)
: builtin_inst(ptr->get_type()->get_pointer_element_ty(), INST_ATOMIC_CAS, 3, name, next) {
: atomic_inst(ptr->get_type()->get_pointer_element_ty(), INST_ATOMIC_CAS, 3, name, next) {
set_operand(0, ptr);
set_operand(1, cmp);
set_operand(2, val);
@@ -760,7 +757,7 @@ instruction* atomic_cas_inst::create(value *ptr, value *cmp, value *val, const s
// atomic exch
atomic_exch_inst::atomic_exch_inst(value *ptr, value *val, const std::string &name, instruction *next)
: builtin_inst(ptr->get_type()->get_pointer_element_ty(), INST_ATOMIC_EXCH, 2, name, next) {
: atomic_inst(ptr->get_type()->get_pointer_element_ty(), INST_ATOMIC_EXCH, 2, name, next) {
set_operand(0, ptr);
set_operand(1, val);
}

View File

@@ -189,6 +189,34 @@ def test_index1d(expr, device='cuda'):
triton.testing.assert_allclose(z_ref, z_tri)
# ---------------
# test atomics
# ---------------
@pytest.mark.parametrize("dtype_x", ['int32', 'float16', 'float32'])
def test_atomic_add(dtype_x, device='cuda'):
dtype_x = cvt[dtype_x]
n_programs = 37
# triton kernel
@triton.jit
def kernel(X, Z, **meta):
pid = tl.program_id(0)
old = tl.atomic_add(X, pid)
tl.store(Z + pid, old)
# triton result
x_tri = torch.zeros((1, ), dtype=dtype_x, device=device)
z_tri = torch.empty((n_programs, ), dtype=torch.int32, device=device)
kernel[(n_programs, )](x_tri, z_tri)
last_sum = torch.max(z_tri) + torch.argmax(z_tri)
last_sum = last_sum.to(dtype_x)
# torch result
range = torch.arange(n_programs, dtype=torch.int32, device=device)
x_ref = torch.sum(range).to(dtype_x)
triton.testing.assert_allclose(x_ref, x_tri[0])
triton.testing.assert_allclose(x_ref, last_sum)
# ---------------
# test load
# ---------------