[IR] Added IR and Codegen support for atomic_rmw (#120)
This commit is contained in:
committed by
Philippe Tillet
parent
59b0ac672a
commit
0274429429
@@ -154,7 +154,7 @@ public:
|
||||
void visit_get_num_programs_inst(ir::get_num_programs_inst*);
|
||||
void visit_atomic_cas_inst(ir::atomic_cas_inst*);
|
||||
void visit_atomic_exch_inst(ir::atomic_exch_inst*);
|
||||
void visit_atomic_add_inst(ir::atomic_add_inst*);
|
||||
void visit_atomic_rmw_inst(ir::atomic_rmw_inst*);
|
||||
void visit_mma884(ir::dot_inst*, ir::value *A, ir::value *B, ir::value *D, unsigned NK);
|
||||
void visit_mma16816(ir::dot_inst*, ir::value *A, ir::value *B, ir::value *D, unsigned NK);
|
||||
void visit_fmadot(ir::dot_inst*, ir::value *A, ir::value *B, ir::value *D, unsigned NK, Type *c_ty, Function *f_mul_add);
|
||||
|
@@ -137,7 +137,7 @@ public:
|
||||
value *create_get_num_programs(unsigned axis);
|
||||
value *create_atomic_cas(value *ptr, value *cmp, value *val);
|
||||
value *create_atomic_exch(value *ptr, value *val);
|
||||
value *create_atomic_add(value *ptr, value *val, value *msk);
|
||||
value *create_atomic_rmw(ir::atomic_rmw_op_t op, value *ptr, value *val, value *msk);
|
||||
value *create_exp(value* arg);
|
||||
value *create_log(value* arg);
|
||||
value *create_dot(value *A, value *B, value *C);
|
||||
|
@@ -71,6 +71,11 @@ struct dispatch{
|
||||
static ir::value *atomic_cas(ir::value* ptr, ir::value *cmp, ir::value *val, ir::builder *builder);
|
||||
static ir::value *atomic_xchg(ir::value* ptr, ir::value *val, ir::builder *builder);
|
||||
static ir::value *atomic_add(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);
|
||||
static ir::value *atomic_max(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);
|
||||
static ir::value *atomic_min(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);
|
||||
static ir::value *atomic_and(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);
|
||||
static ir::value *atomic_or(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);
|
||||
static ir::value *atomic_xor(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);
|
||||
|
||||
// linear algebra
|
||||
static ir::value *dot(ir::value *lhs, ir::value *rhs, ir::builder *builder);
|
||||
|
@@ -28,6 +28,18 @@ enum binary_op_t: unsigned int{
|
||||
Xor
|
||||
};
|
||||
|
||||
enum class atomic_rmw_op_t: unsigned int{
|
||||
And,
|
||||
Or,
|
||||
Xor,
|
||||
Add,
|
||||
Max,
|
||||
Min,
|
||||
UMax,
|
||||
UMin,
|
||||
FAdd,
|
||||
};
|
||||
|
||||
enum cast_op_t: unsigned int {
|
||||
Trunc,
|
||||
ZExt,
|
||||
@@ -127,7 +139,7 @@ enum value_id_t: unsigned {
|
||||
// atomics
|
||||
INST_ATOMIC_CAS,
|
||||
INST_ATOMIC_EXCH,
|
||||
INST_ATOMIC_ADD,
|
||||
INST_ATOMIC_RMW,
|
||||
// math
|
||||
INST_EXP,
|
||||
INST_LOG,
|
||||
|
@@ -608,17 +608,20 @@ public:
|
||||
using io_inst::io_inst;
|
||||
};
|
||||
|
||||
class atomic_add_inst: public atomic_inst {
|
||||
class atomic_rmw_inst: public atomic_inst {
|
||||
private:
|
||||
atomic_add_inst(value *ptr, value *val, value *msk, const std::string &name = "", instruction *next = nullptr);
|
||||
std::string repr_impl() const { return "atomic_add"; }
|
||||
_TRITON_DEFINE_CLONE(atomic_add_inst)
|
||||
_TRITON_DEFINE_ACCEPT(atomic_add_inst)
|
||||
atomic_rmw_inst(atomic_rmw_op_t op, value *ptr, value *val, value *msk, const std::string &name = "", instruction *next = nullptr);
|
||||
std::string repr_impl() const { return "atomic_rmw"; }
|
||||
_TRITON_DEFINE_CLONE(atomic_rmw_inst)
|
||||
_TRITON_DEFINE_ACCEPT(atomic_rmw_inst)
|
||||
|
||||
public:
|
||||
static instruction* create(value *ptr, value *val, value *msk, const std::string &name = "", instruction *next = nullptr);
|
||||
};
|
||||
static instruction* create(atomic_rmw_op_t op, value *ptr, value *val, value *msk, const std::string &name = "", instruction *next = nullptr);
|
||||
atomic_rmw_op_t get_op() { return op_; }
|
||||
|
||||
private:
|
||||
atomic_rmw_op_t op_;
|
||||
};
|
||||
|
||||
class atomic_cas_inst: public atomic_inst {
|
||||
private:
|
||||
|
@@ -53,9 +53,10 @@ class log_inst;
|
||||
|
||||
class get_program_id_inst;
|
||||
class get_num_programs_inst;
|
||||
class atomic_inst;
|
||||
class atomic_cas_inst;
|
||||
class atomic_exch_inst;
|
||||
class atomic_add_inst;
|
||||
class atomic_rmw_inst;
|
||||
class dot_inst;
|
||||
class trans_inst;
|
||||
class sqrt_inst;
|
||||
@@ -132,7 +133,7 @@ public:
|
||||
virtual void visit_get_num_programs_inst(get_num_programs_inst*) = 0;
|
||||
virtual void visit_atomic_cas_inst(atomic_cas_inst*) = 0;
|
||||
virtual void visit_atomic_exch_inst(atomic_exch_inst*) = 0;
|
||||
virtual void visit_atomic_add_inst(atomic_add_inst*) = 0;
|
||||
virtual void visit_atomic_rmw_inst(atomic_rmw_inst*) = 0;
|
||||
virtual void visit_dot_inst(dot_inst*) = 0;
|
||||
virtual void visit_trans_inst(trans_inst*) = 0;
|
||||
virtual void visit_sqrt_inst(sqrt_inst*) = 0;
|
||||
|
@@ -403,6 +403,7 @@ void layouts::run(ir::module &mod) {
|
||||
for(const auto& x: values_)
|
||||
create(x.first, x.second);
|
||||
|
||||
|
||||
// create temporaries
|
||||
size_t id = values_.size();
|
||||
ir::for_each_instruction(mod, [this, &id](ir::instruction* i) {
|
||||
|
@@ -813,28 +813,35 @@ void generator::visit_atomic_cas_inst(ir::atomic_cas_inst* cas) {
|
||||
Module *module = current->getModule();
|
||||
Value *tid = tgt_->get_local_id(module, *builder_, 0);
|
||||
Value *pred = icmp_eq(tid, i32(0));
|
||||
BasicBlock *tid_0_bb = BasicBlock::Create(*ctx_, "tid_0", current->getParent());
|
||||
BasicBlock *tid_0_done_bb = BasicBlock::Create(*ctx_, "tid_0_done", current->getParent());
|
||||
// BasicBlock *tid_0_bb = BasicBlock::Create(*ctx_, "tid_0", current->getParent());
|
||||
// BasicBlock *tid_0_done_bb = BasicBlock::Create(*ctx_, "tid_0_done", current->getParent());
|
||||
add_barrier();
|
||||
tgt_->add_memfence(module, *builder_);
|
||||
cond_br(pred, tid_0_bb, tid_0_done_bb);
|
||||
builder_->SetInsertPoint(tid_0_bb);
|
||||
Value *atom_ptr;
|
||||
atom_ptr = gep(shmem_, i32(alloc_->offset(layouts_->get(layouts_->tmp(cas)))), "");
|
||||
atom_ptr = bit_cast(atom_ptr, ptr_ty(cvt(cas->get_type()->get_scalar_ty()), 3));
|
||||
// cond_br(pred, tid_0_bb, tid_0_done_bb);
|
||||
// builder_->SetInsertPoint(tid_0_bb);
|
||||
Value *cas_ptr = vals_[cas->get_operand(0)][{}];
|
||||
Value *cas_cmp = vals_[cas->get_operand(1)][{}];
|
||||
Value *cas_val = vals_[cas->get_operand(2)][{}];
|
||||
Value *old = atomic_cmp_xchg(cas_ptr, cas_cmp, cas_val, AtomicOrdering::Monotonic, AtomicOrdering::Monotonic);
|
||||
old = extract_val(old, std::vector<unsigned>{0});
|
||||
Value *atom_ptr;
|
||||
atom_ptr = gep(shmem_, i32(alloc_->offset(layouts_->get(layouts_->tmp(cas)))), "");
|
||||
atom_ptr = bit_cast(atom_ptr, ptr_ty(old->getType(), 3));
|
||||
store(old, atom_ptr);
|
||||
br(tid_0_done_bb);
|
||||
builder_->SetInsertPoint(tid_0_done_bb);
|
||||
std::string asm_str = "@$1 atom.global.cas.b32 $0, [$2], $3, $4;";
|
||||
FunctionType *fn_ty = FunctionType::get(i32_ty, {pred->getType(), cas_ptr->getType(), cas_cmp->getType(), cas_val->getType()}, false);
|
||||
InlineAsm *iasm = InlineAsm::get(fn_ty, asm_str, "=r,b,l,r,r", true);
|
||||
add_barrier();
|
||||
Value *old = call(iasm, {pred, cas_ptr, cas_cmp, cas_val});
|
||||
add_barrier();
|
||||
|
||||
std::string asm2_str = "@$0 st.shared.b32 [$1], $2;";
|
||||
FunctionType *fn2_ty = FunctionType::get(void_ty, {pred->getType(), atom_ptr->getType(), old->getType()}, false);
|
||||
InlineAsm *iasm2 = InlineAsm::get(fn2_ty, asm2_str, "b,r,r", true);
|
||||
add_barrier();
|
||||
call(iasm2, {pred, atom_ptr, old});
|
||||
tgt_->add_memfence(module, *builder_);
|
||||
add_barrier();
|
||||
vals_[cas][{}] = load(atom_ptr);
|
||||
add_barrier();
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Code Generation for `atomic_exch`
|
||||
*/
|
||||
@@ -860,93 +867,86 @@ void generator::visit_atomic_exch_inst(ir::atomic_exch_inst* xchg) {
|
||||
/**
|
||||
* \brief Code Generation for `atomic_add`
|
||||
*/
|
||||
//TODO: clean-up
|
||||
void generator::visit_atomic_add_inst(ir::atomic_add_inst* add) {
|
||||
void generator::visit_atomic_rmw_inst(ir::atomic_rmw_inst *atom) {
|
||||
ir::value* ptr = atom->get_operand(0);
|
||||
ir::value* val = atom->get_operand(1);
|
||||
ir::value* msk = atom->get_operand(2);
|
||||
|
||||
if(add->get_type()->is_block_ty()){
|
||||
ir::value* ptr = add->get_operand(0);
|
||||
ir::value* val = add->get_operand(1);
|
||||
ir::value* msk = add->get_operand(2);
|
||||
|
||||
// vector size
|
||||
int vec = 1;
|
||||
// vector size
|
||||
int vec = 1;
|
||||
if(atom->get_type()->is_block_ty()){
|
||||
int ld = ords_.at(ptr)[0];
|
||||
unsigned alignment = alignment_->get(ptr, ld);
|
||||
vec = std::min<int>(layouts_->get(ptr)->to_scanline()->nts(ld), alignment);
|
||||
vec = std::min(vec, val->get_type()->get_tile_element_ty()->is_half_ty() ? 2 : 1);
|
||||
|
||||
for(int i = 0; i < idxs_.at(val).size(); i += vec){
|
||||
auto idx = idxs_[val][i];
|
||||
Value *rmw_val = UndefValue::get(vec_ty(vals_[val][idx]->getType(), vec));
|
||||
for(int ii = 0; ii < vec; ii++)
|
||||
rmw_val = insert_elt(rmw_val, vals_[val][idxs_[val][i+ii]], ii);
|
||||
Value *rmw_ptr = vals_[ptr][idx];
|
||||
Value *rmw_msk = vals_[msk][idx];
|
||||
if(vec == 1)
|
||||
rmw_val = extract_elt(rmw_val, i32(0));
|
||||
Type* ty = rmw_val->getType();
|
||||
size_t nbits = ty->getScalarSizeInBits();
|
||||
// extract pointer offset
|
||||
std::string offset = "";
|
||||
if(GetElementPtrInst *gep = dyn_cast<GetElementPtrInst>(rmw_ptr))
|
||||
if(gep->getNumIndices() == 1)
|
||||
if(ConstantInt *cst = dyn_cast<ConstantInt>(gep->idx_begin())){
|
||||
offset = " + " + std::to_string(cst->getValue().getSExtValue()*nbits/8);
|
||||
rmw_ptr = gep->getPointerOperand();
|
||||
}
|
||||
rmw_ptr = bit_cast(rmw_ptr, ty->getPointerTo(1));
|
||||
// asm argument type
|
||||
std::vector<Type*> arg_ty = {rmw_msk->getType(), rmw_ptr->getType(), rmw_val->getType()};
|
||||
// asm function type
|
||||
FunctionType *fn_ty = FunctionType::get(ty, arg_ty, false);
|
||||
// asm string
|
||||
std::string suffix = vec == 2 ? "x2" : "";
|
||||
std::string mod = nbits == 32 ? "" : ".noftz";
|
||||
std::string ty_str = add->get_type()->get_scalar_ty()->is_floating_point_ty() ? "f" : "u";
|
||||
std::string asm_str = "@$1 atom.global.gpu.add" + mod + "." + ty_str + std::to_string(nbits) + suffix + " $0, [$2" + offset + "], $3;";
|
||||
std::string ty_id = nbits == 32 ? ty_str : (vec == 1 ? "h" : "r");
|
||||
std::string constraint = "=" + ty_id + ",b,l," + ty_id;
|
||||
// create inline asm
|
||||
InlineAsm *iasm = InlineAsm::get(fn_ty, asm_str, constraint, true);
|
||||
// call asm
|
||||
vals_[add][idx] = call(iasm, (ArrayRef<Value*>{rmw_msk, rmw_ptr, rmw_val}));
|
||||
}
|
||||
}
|
||||
else{
|
||||
Value *rmw_ptr = vals_[add->get_operand(0)][{}];
|
||||
Value *rmw_val = vals_[add->get_operand(1)][{}];
|
||||
Value *rmw_msk = vals_[add->get_operand(2)][{}];
|
||||
|
||||
for(int i = 0; i < idxs_.at(val).size(); i += vec){
|
||||
auto idx = idxs_[val][i];
|
||||
Value *rmw_val = UndefValue::get(vec_ty(vals_[val][idx]->getType(), vec));
|
||||
for(int ii = 0; ii < vec; ii++)
|
||||
rmw_val = insert_elt(rmw_val, vals_[val][idxs_[val][i+ii]], ii);
|
||||
Value *rmw_ptr = vals_[ptr][idx];
|
||||
Value *rmw_msk = vals_[msk][idx];
|
||||
if(vec == 1)
|
||||
rmw_val = extract_elt(rmw_val, i32(0));
|
||||
Type* ty = rmw_val->getType();
|
||||
size_t nbits = ty->getScalarSizeInBits();
|
||||
// extract pointer offset
|
||||
std::string offset = "";
|
||||
if(GetElementPtrInst *gep = dyn_cast<GetElementPtrInst>(rmw_ptr))
|
||||
if(gep->getNumIndices() == 1)
|
||||
if(ConstantInt *cst = dyn_cast<ConstantInt>(gep->idx_begin())){
|
||||
offset = " + " + std::to_string(cst->getValue().getSExtValue()*nbits/8);
|
||||
rmw_ptr = gep->getPointerOperand();
|
||||
}
|
||||
rmw_ptr = bit_cast(rmw_ptr, ty->getPointerTo(1));
|
||||
// asm argument type
|
||||
std::vector<Type*> arg_ty = {rmw_msk->getType(), rmw_ptr->getType(), rmw_val->getType()};
|
||||
// asm function type
|
||||
FunctionType *fn_ty = FunctionType::get(ty, arg_ty, false);
|
||||
// asm string
|
||||
std::string s_nbits = std::to_string(nbits);
|
||||
std::string name;
|
||||
std::string s_ty;
|
||||
using tt = ir::atomic_rmw_op_t;
|
||||
switch(atom->get_op()){
|
||||
case tt::Or: name = "or"; s_ty = "b"; break;
|
||||
case tt::And: name = "and"; s_ty = "b"; break;
|
||||
case tt::Xor: name = "xor", s_ty = "b"; break;
|
||||
case tt::Add: name = "add" , s_ty = "s"; break;
|
||||
case tt::Min: name = "min", s_ty = "s"; break;
|
||||
case tt::Max: name = "max", s_ty = "s"; break;
|
||||
case tt::UMin: name = "min", s_ty = "u"; break;
|
||||
case tt::UMax: name = "max", s_ty = "u"; break;
|
||||
case tt::FAdd: name = "add", s_ty = "f"; break;
|
||||
}
|
||||
std::string s_vec = vec == 2 ? "x2" : "";
|
||||
std::string mod = nbits == 32 ? "" : ".noftz";
|
||||
std::string ty_str = add->get_type()->get_scalar_ty()->is_floating_point_ty() ? "f" : "u";
|
||||
std::string asm_str = "@$1 atom.global.gpu.add" + mod + "." + ty_str + std::to_string(nbits) + " $0, [$2], $3;";
|
||||
std::string ty_id = nbits == 32 ? "r" : "h";
|
||||
InlineAsm *iasm = InlineAsm::get(fn_ty, asm_str, "="+ty_id+",b,l,"+ty_id, true);
|
||||
|
||||
BasicBlock *current = builder_->GetInsertBlock();
|
||||
Module *module = current->getModule();
|
||||
|
||||
Value *tid = tgt_->get_local_id(module, *builder_, 0);
|
||||
Value *pred = icmp_eq(tid, i32(0));
|
||||
BasicBlock *tid_0_bb = BasicBlock::Create(*ctx_, "tid_0", current->getParent());
|
||||
BasicBlock *tid_0_done_bb = BasicBlock::Create(*ctx_, "tid_0_done", current->getParent());
|
||||
tgt_->add_memfence(module, *builder_);
|
||||
add_barrier();
|
||||
cond_br(pred, tid_0_bb, tid_0_done_bb);
|
||||
builder_->SetInsertPoint(tid_0_bb);
|
||||
Value *old = call(iasm, (ArrayRef<Value*>{rmw_msk, rmw_ptr, rmw_val}));
|
||||
Value *atom_ptr;
|
||||
atom_ptr = gep(shmem_, i32(alloc_->offset(layouts_->get(layouts_->tmp(add)))), "");
|
||||
atom_ptr = bit_cast(atom_ptr, ptr_ty(old->getType(), 3));
|
||||
store(old, atom_ptr);
|
||||
br(tid_0_done_bb);
|
||||
builder_->SetInsertPoint(tid_0_done_bb);
|
||||
tgt_->add_memfence(module, *builder_);
|
||||
add_barrier();
|
||||
vals_[add][{}] = load(atom_ptr);
|
||||
std::string asm_str = "@$1 atom.global.gpu." + name + mod + "." + s_ty + s_nbits + s_vec + " $0, [$2" + offset + "], $3;";
|
||||
std::string ty_id = nbits*vec == 32 ? "r" : "h";
|
||||
std::string constraint = "=" + ty_id + ",b,l," + ty_id;
|
||||
// create inline asm
|
||||
InlineAsm *iasm = InlineAsm::get(fn_ty, asm_str, constraint, true);
|
||||
// call asm
|
||||
if(atom->get_type()->is_block_ty())
|
||||
vals_[atom][idx] = call(iasm, (ArrayRef<Value*>{rmw_msk, rmw_ptr, rmw_val}));
|
||||
else{
|
||||
Module *mod = builder_->GetInsertBlock()->getModule();
|
||||
tgt_->add_memfence(mod, *builder_);
|
||||
add_barrier();
|
||||
Value *tid = tgt_->get_local_id(mod, *builder_, 0);
|
||||
rmw_msk = builder_->CreateAnd(rmw_msk, icmp_eq(tid, i32(0)));
|
||||
Value *old = call(iasm, (ArrayRef<Value*>{rmw_msk, rmw_ptr, rmw_val}));
|
||||
Value *atom_ptr;
|
||||
atom_ptr = gep(shmem_, i32(alloc_->offset(layouts_->get(layouts_->tmp(atom)))), "");
|
||||
atom_ptr = bit_cast(atom_ptr, ptr_ty(old->getType(), 3));
|
||||
store(old, atom_ptr);
|
||||
add_barrier();
|
||||
vals_[atom][idx] = load(atom_ptr);
|
||||
add_barrier();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -25,8 +25,8 @@ void dce::run(ir::module &mod) {
|
||||
case ir::INST_COND_BRANCH:
|
||||
case ir::INST_UNMASKED_STORE:
|
||||
case ir::INST_MASKED_STORE:
|
||||
case ir::INST_ATOMIC_ADD:
|
||||
case ir::INST_ATOMIC_CAS:
|
||||
case ir::INST_ATOMIC_RMW:
|
||||
case ir::INST_ATOMIC_EXCH:
|
||||
case ir::INST_BARRIER: {
|
||||
work_list.push_back(i);
|
||||
|
@@ -323,8 +323,8 @@ value *builder::create_atomic_exch(value *ptr, value *val){
|
||||
return insert(atomic_exch_inst::create(ptr, val));
|
||||
}
|
||||
|
||||
value *builder::create_atomic_add(value *ptr, value *val, value *msk){
|
||||
return insert(atomic_add_inst::create(ptr, val, msk));
|
||||
value *builder::create_atomic_rmw(ir::atomic_rmw_op_t op, value *ptr, value *val, value *msk){
|
||||
return insert(atomic_rmw_inst::create(op, ptr, val, msk));
|
||||
}
|
||||
|
||||
value *builder::create_exp(value *arg){
|
||||
|
@@ -538,7 +538,7 @@ ir::value *dispatch::atomic_xchg(ir::value* ptr, ir::value *val, ir::builder *bu
|
||||
return builder->create_atomic_exch(ptr, val);
|
||||
}
|
||||
|
||||
ir::value *dispatch::atomic_add(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){
|
||||
void atom_red_typechecking(ir::value*& ptr, ir::value *&val, ir::value *&mask, ir::builder *builder){
|
||||
if(!ptr->get_type()->get_scalar_ty()->is_pointer_ty())
|
||||
throw semantic_error("Pointer argument of store instruction is " + ptr->get_type()->repr());
|
||||
if(ptr->get_type()->is_block_ty()){
|
||||
@@ -550,14 +550,69 @@ ir::value *dispatch::atomic_add(ir::value* ptr, ir::value *val, ir::value *mask,
|
||||
}
|
||||
}
|
||||
val = dispatch::cast(val, ptr->get_type()->get_scalar_ty()->get_pointer_element_ty(), builder);
|
||||
|
||||
if(!mask){
|
||||
mask = builder->get_int1(true);
|
||||
if(ptr->get_type()->is_block_ty())
|
||||
mask = builder->create_splat(mask, ptr->get_type()->get_block_shapes());
|
||||
}
|
||||
}
|
||||
|
||||
return builder->create_atomic_add(ptr, val, mask);
|
||||
ir::value *dispatch::atomic_max(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){
|
||||
atom_red_typechecking(ptr, val, mask, builder);
|
||||
ir::type* sca_ty = val->get_type()->get_scalar_ty();
|
||||
// direct call to atomic_max for integers
|
||||
if(sca_ty->is_integer_ty())
|
||||
return builder->create_atomic_rmw(ir::atomic_rmw_op_t::Max, ptr, val, mask);
|
||||
// for float
|
||||
// return atomic_smax(i_ptr, i_val) if val >= 0
|
||||
// return atomic_umin(i_ptr, i_val) if val < 0
|
||||
ir::value* i_val = bitcast(val, builder->get_int32_ty(), builder);
|
||||
ir::value* i_ptr = bitcast(ptr, pointer_type::get(builder->get_int32_ty(), 1), builder);
|
||||
ir::value* pos = greater_equal(val, constant_fp::get(sca_ty, 0), builder);
|
||||
ir::value* neg = less_than(val, constant_fp::get(sca_ty, 0), builder);
|
||||
ir::value* pos_ret = builder->create_atomic_rmw(ir::atomic_rmw_op_t::Max, i_ptr, i_val, and_(mask, pos, builder));
|
||||
ir::value* neg_ret = builder->create_atomic_rmw(ir::atomic_rmw_op_t::UMin, i_ptr, i_val, and_(mask, neg, builder));
|
||||
return where(pos, pos_ret, neg_ret, builder);
|
||||
}
|
||||
|
||||
ir::value *dispatch::atomic_min(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){
|
||||
atom_red_typechecking(ptr, val, mask, builder);
|
||||
ir::type* sca_ty = val->get_type()->get_scalar_ty();
|
||||
// direct call to atomic_max for integers
|
||||
if(sca_ty->is_integer_ty())
|
||||
return builder->create_atomic_rmw(ir::atomic_rmw_op_t::Min, ptr, val, mask);
|
||||
// for float
|
||||
// return atomic_smin(i_ptr, i_val) if val >= 0
|
||||
// return atomic_umax(i_ptr, i_val) if val < 0
|
||||
ir::value* i_val = bitcast(val, builder->get_int32_ty(), builder);
|
||||
ir::value* i_ptr = bitcast(ptr, pointer_type::get(builder->get_int32_ty(), 1), builder);
|
||||
ir::value* pos = greater_equal(val, constant_fp::get(sca_ty, 0), builder);
|
||||
ir::value* neg = less_than(val, constant_fp::get(sca_ty, 0), builder);
|
||||
ir::value* pos_ret = builder->create_atomic_rmw(ir::atomic_rmw_op_t::Min, i_ptr, i_val, and_(mask, pos, builder));
|
||||
ir::value* neg_ret = builder->create_atomic_rmw(ir::atomic_rmw_op_t::UMax, i_ptr, i_val, and_(mask, neg, builder));
|
||||
return where(pos, pos_ret, neg_ret, builder);
|
||||
}
|
||||
|
||||
ir::value *dispatch::atomic_add(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){
|
||||
atom_red_typechecking(ptr, val, mask, builder);
|
||||
ir::type* sca_ty = val->get_type()->get_scalar_ty();
|
||||
auto op = sca_ty->is_floating_point_ty() ? ir::atomic_rmw_op_t::FAdd : ir::atomic_rmw_op_t::Add;
|
||||
return builder->create_atomic_rmw(op, ptr, val, mask);
|
||||
}
|
||||
|
||||
ir::value *dispatch::atomic_and(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){
|
||||
atom_red_typechecking(ptr, val, mask, builder);
|
||||
return builder->create_atomic_rmw(ir::atomic_rmw_op_t::And, ptr, val, mask);
|
||||
}
|
||||
|
||||
ir::value *dispatch::atomic_or(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){
|
||||
atom_red_typechecking(ptr, val, mask, builder);
|
||||
return builder->create_atomic_rmw(ir::atomic_rmw_op_t::Or, ptr, val, mask);
|
||||
}
|
||||
|
||||
ir::value *dispatch::atomic_xor(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){
|
||||
atom_red_typechecking(ptr, val, mask, builder);
|
||||
return builder->create_atomic_rmw(ir::atomic_rmw_op_t::Xor, ptr, val, mask);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@@ -730,17 +730,20 @@ instruction* get_num_programs_inst::create(context &ctx, unsigned axis, const st
|
||||
return new get_num_programs_inst(type::get_int32_ty(ctx), axis, name, next);
|
||||
}
|
||||
|
||||
// atomic_rmw
|
||||
|
||||
atomic_add_inst::atomic_add_inst(value *ptr, value *val, value *msk, const std::string &name, instruction *next)
|
||||
: atomic_inst(ptr->get_type()->get_pointer_element_ty(), INST_ATOMIC_ADD, 3, name, next) {
|
||||
atomic_rmw_inst::atomic_rmw_inst(atomic_rmw_op_t op, value *ptr, value *val, value *msk, const std::string &name, instruction *next)
|
||||
: atomic_inst(ptr->get_type()->get_pointer_element_ty(), INST_ATOMIC_RMW, 3, name, next), op_(op) {
|
||||
set_operand(0, ptr);
|
||||
set_operand(1, val);
|
||||
set_operand(2, msk);
|
||||
}
|
||||
|
||||
instruction* atomic_add_inst::create(value *ptr, value *val, value *msk, const std::string &name, instruction *next) {
|
||||
return new atomic_add_inst(ptr, val, msk, name, next);
|
||||
instruction* atomic_rmw_inst::create(atomic_rmw_op_t op, value *ptr, value *val, value *msk, const std::string &name, instruction *next) {
|
||||
return new atomic_rmw_inst(op, ptr, val, msk, name, next);
|
||||
}
|
||||
|
||||
|
||||
// atomic cas
|
||||
|
||||
atomic_cas_inst::atomic_cas_inst(value *ptr, value *cmp, value *val, const std::string &name, instruction *next)
|
||||
|
@@ -137,6 +137,11 @@ void init_triton_frontend(py::module &&m) {
|
||||
m.def("atomic_cas", &ir::dispatch::atomic_cas, ret::reference);
|
||||
m.def("atomic_xchg", &ir::dispatch::atomic_xchg, ret::reference);
|
||||
m.def("atomic_add", &ir::dispatch::atomic_add, ret::reference);
|
||||
m.def("atomic_max", &ir::dispatch::atomic_max, ret::reference);
|
||||
m.def("atomic_min", &ir::dispatch::atomic_min, ret::reference);
|
||||
m.def("atomic_and", &ir::dispatch::atomic_and, ret::reference);
|
||||
m.def("atomic_or", &ir::dispatch::atomic_or, ret::reference);
|
||||
m.def("atomic_xor", &ir::dispatch::atomic_xor, ret::reference);
|
||||
// linear algebra
|
||||
m.def("dot", &ir::dispatch::dot, ret::reference);
|
||||
// indexing
|
||||
|
@@ -4,6 +4,7 @@ import triton.language as tl
|
||||
import copy
|
||||
import pytest
|
||||
import ast
|
||||
import itertools
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
@@ -238,8 +239,13 @@ def test_tuples():
|
||||
# ---------------
|
||||
# test atomics
|
||||
# ---------------
|
||||
@pytest.mark.parametrize("dtype_x", ['int32', 'float16', 'float32'])
|
||||
def test_atomic_add(dtype_x, device='cuda'):
|
||||
@pytest.mark.parametrize("op, dtype_x, mode", itertools.chain.from_iterable([
|
||||
[('add', 'int32', mode), ('add', 'float16', mode), ('add', 'float32', mode), \
|
||||
('max', 'int32', mode), ('max', 'float32', mode),\
|
||||
('min', 'int32', mode), ('min', 'float32', mode),\
|
||||
]
|
||||
for mode in ['all_neg', 'all_pos', 'min_neg', 'max_pos']]))
|
||||
def test_atomic_rmw(op, dtype_x, mode, device='cuda'):
|
||||
dtype_x = cvt[dtype_x]
|
||||
n_programs = 37
|
||||
|
||||
@@ -247,20 +253,39 @@ def test_atomic_add(dtype_x, device='cuda'):
|
||||
@triton.jit
|
||||
def kernel(X, Z, **meta):
|
||||
pid = tl.program_id(0)
|
||||
old = tl.atomic_add(X, pid)
|
||||
tl.store(Z + pid, old)
|
||||
x = tl.load(X + pid)
|
||||
old = GENERATE_TEST_HERE
|
||||
|
||||
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.atomic_{op}(Z, x)'})
|
||||
torch_op = {'add': torch.sum, 'max': torch.max, 'min': torch.min}[op]
|
||||
max_neutral = float('-inf') if dtype_x.is_floating_point else torch.iinfo(dtype_x).min
|
||||
min_neutral = float('inf') if dtype_x.is_floating_point else torch.iinfo(dtype_x).max
|
||||
neutral = {'add': 0, 'max': max_neutral, 'min': min_neutral}[op]
|
||||
|
||||
# triton result
|
||||
x_tri = torch.zeros((1, ), dtype=dtype_x, device=device)
|
||||
z_tri = torch.empty((n_programs, ), dtype=torch.int32, device=device)
|
||||
x_tri = triton.testing.random((n_programs, ), dtype=dtype_x, device=device)
|
||||
if mode == 'all_neg':
|
||||
x_tri = -torch.abs(x_tri)
|
||||
if mode == 'all_pos':
|
||||
x_tri = torch.abs(x_tri)
|
||||
if mode == 'min_neg':
|
||||
idx = torch.randint(n_programs, size=(1, )).item()
|
||||
x_tri[idx] = -torch.max(torch.abs(x_tri)) - 1
|
||||
if mode == 'max_pos':
|
||||
idx = torch.randint(n_programs, size=(1, )).item()
|
||||
x_tri[idx] = torch.max(torch.abs(x_tri)) + 1
|
||||
|
||||
z_tri = torch.empty([], dtype=dtype_x, device=device)
|
||||
z_tri.fill_(neutral)
|
||||
kernel[(n_programs, )](x_tri, z_tri)
|
||||
last_sum = torch.max(z_tri) + torch.argmax(z_tri)
|
||||
last_sum = last_sum.to(dtype_x)
|
||||
# torch result
|
||||
range = torch.arange(n_programs, dtype=torch.int32, device=device)
|
||||
x_ref = torch.sum(range).to(dtype_x)
|
||||
triton.testing.assert_allclose(x_ref, x_tri[0])
|
||||
triton.testing.assert_allclose(x_ref, last_sum)
|
||||
z_ref = torch_op(x_tri).to(dtype_x)
|
||||
# compare
|
||||
exact = op not in ['add']
|
||||
if exact:
|
||||
assert z_ref.item() == z_tri.item()
|
||||
else:
|
||||
triton.testing.assert_allclose(z_ref, z_tri)
|
||||
|
||||
|
||||
# ---------------
|
||||
|
@@ -89,6 +89,8 @@ float16 = dtype(ir.type.get_fp16)
|
||||
float32 = dtype(ir.type.get_fp32)
|
||||
float64 = dtype(ir.type.get_fp64)
|
||||
|
||||
pi32_t = pointer_dtype(int32)
|
||||
|
||||
|
||||
class block:
|
||||
@staticmethod
|
||||
@@ -464,6 +466,31 @@ def atomic_add(pointer, val, mask=None, builder=None):
|
||||
return frontend.atomic_add(pointer, val, mask, builder)
|
||||
|
||||
|
||||
@builtin
|
||||
def atomic_max(pointer, val, mask=None, builder=None):
|
||||
return frontend.atomic_max(pointer, val, mask, builder)
|
||||
|
||||
|
||||
@builtin
|
||||
def atomic_min(pointer, val, mask=None, builder=None):
|
||||
return frontend.atomic_min(pointer, val, mask, builder)
|
||||
|
||||
|
||||
@builtin
|
||||
def atomic_and(pointer, val, mask=None, builder=None):
|
||||
return frontend.atomic_and(pointer, val, mask, builder)
|
||||
|
||||
|
||||
@builtin
|
||||
def atomic_or(pointer, val, mask=None, builder=None):
|
||||
return frontend.atomic_or(pointer, val, mask, builder)
|
||||
|
||||
|
||||
@builtin
|
||||
def atomic_xor(pointer, val, mask=None, builder=None):
|
||||
return frontend.atomic_xor(pointer, val, mask, builder)
|
||||
|
||||
|
||||
# -----------------------
|
||||
# Conditioning
|
||||
# -----------------------
|
||||
|
@@ -76,7 +76,7 @@ def random(shape, dtype, device):
|
||||
if dtype in [torch.int8, torch.int16, torch.int32, torch.int64]:
|
||||
return torch.randint(1, 32, shape, dtype=dtype, device=device)
|
||||
if dtype in [torch.float16, torch.float32, torch.float64]:
|
||||
return torch.randn(shape, dtype=dtype, device=device)
|
||||
return torch.normal(0, 10, shape, dtype=dtype, device=device)
|
||||
raise RuntimeError(f'Unknown dtype {dtype}')
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user