[IR] Added IR and Codegen support for atomic_rmw (#120)
This commit is contained in:
committed by
Philippe Tillet
parent
59b0ac672a
commit
0274429429
@@ -154,7 +154,7 @@ public:
|
|||||||
void visit_get_num_programs_inst(ir::get_num_programs_inst*);
|
void visit_get_num_programs_inst(ir::get_num_programs_inst*);
|
||||||
void visit_atomic_cas_inst(ir::atomic_cas_inst*);
|
void visit_atomic_cas_inst(ir::atomic_cas_inst*);
|
||||||
void visit_atomic_exch_inst(ir::atomic_exch_inst*);
|
void visit_atomic_exch_inst(ir::atomic_exch_inst*);
|
||||||
void visit_atomic_add_inst(ir::atomic_add_inst*);
|
void visit_atomic_rmw_inst(ir::atomic_rmw_inst*);
|
||||||
void visit_mma884(ir::dot_inst*, ir::value *A, ir::value *B, ir::value *D, unsigned NK);
|
void visit_mma884(ir::dot_inst*, ir::value *A, ir::value *B, ir::value *D, unsigned NK);
|
||||||
void visit_mma16816(ir::dot_inst*, ir::value *A, ir::value *B, ir::value *D, unsigned NK);
|
void visit_mma16816(ir::dot_inst*, ir::value *A, ir::value *B, ir::value *D, unsigned NK);
|
||||||
void visit_fmadot(ir::dot_inst*, ir::value *A, ir::value *B, ir::value *D, unsigned NK, Type *c_ty, Function *f_mul_add);
|
void visit_fmadot(ir::dot_inst*, ir::value *A, ir::value *B, ir::value *D, unsigned NK, Type *c_ty, Function *f_mul_add);
|
||||||
|
@@ -137,7 +137,7 @@ public:
|
|||||||
value *create_get_num_programs(unsigned axis);
|
value *create_get_num_programs(unsigned axis);
|
||||||
value *create_atomic_cas(value *ptr, value *cmp, value *val);
|
value *create_atomic_cas(value *ptr, value *cmp, value *val);
|
||||||
value *create_atomic_exch(value *ptr, value *val);
|
value *create_atomic_exch(value *ptr, value *val);
|
||||||
value *create_atomic_add(value *ptr, value *val, value *msk);
|
value *create_atomic_rmw(ir::atomic_rmw_op_t op, value *ptr, value *val, value *msk);
|
||||||
value *create_exp(value* arg);
|
value *create_exp(value* arg);
|
||||||
value *create_log(value* arg);
|
value *create_log(value* arg);
|
||||||
value *create_dot(value *A, value *B, value *C);
|
value *create_dot(value *A, value *B, value *C);
|
||||||
|
@@ -71,6 +71,11 @@ struct dispatch{
|
|||||||
static ir::value *atomic_cas(ir::value* ptr, ir::value *cmp, ir::value *val, ir::builder *builder);
|
static ir::value *atomic_cas(ir::value* ptr, ir::value *cmp, ir::value *val, ir::builder *builder);
|
||||||
static ir::value *atomic_xchg(ir::value* ptr, ir::value *val, ir::builder *builder);
|
static ir::value *atomic_xchg(ir::value* ptr, ir::value *val, ir::builder *builder);
|
||||||
static ir::value *atomic_add(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);
|
static ir::value *atomic_add(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);
|
||||||
|
static ir::value *atomic_max(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);
|
||||||
|
static ir::value *atomic_min(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);
|
||||||
|
static ir::value *atomic_and(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);
|
||||||
|
static ir::value *atomic_or(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);
|
||||||
|
static ir::value *atomic_xor(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);
|
||||||
|
|
||||||
// linear algebra
|
// linear algebra
|
||||||
static ir::value *dot(ir::value *lhs, ir::value *rhs, ir::builder *builder);
|
static ir::value *dot(ir::value *lhs, ir::value *rhs, ir::builder *builder);
|
||||||
|
@@ -28,6 +28,18 @@ enum binary_op_t: unsigned int{
|
|||||||
Xor
|
Xor
|
||||||
};
|
};
|
||||||
|
|
||||||
|
enum class atomic_rmw_op_t: unsigned int{
|
||||||
|
And,
|
||||||
|
Or,
|
||||||
|
Xor,
|
||||||
|
Add,
|
||||||
|
Max,
|
||||||
|
Min,
|
||||||
|
UMax,
|
||||||
|
UMin,
|
||||||
|
FAdd,
|
||||||
|
};
|
||||||
|
|
||||||
enum cast_op_t: unsigned int {
|
enum cast_op_t: unsigned int {
|
||||||
Trunc,
|
Trunc,
|
||||||
ZExt,
|
ZExt,
|
||||||
@@ -127,7 +139,7 @@ enum value_id_t: unsigned {
|
|||||||
// atomics
|
// atomics
|
||||||
INST_ATOMIC_CAS,
|
INST_ATOMIC_CAS,
|
||||||
INST_ATOMIC_EXCH,
|
INST_ATOMIC_EXCH,
|
||||||
INST_ATOMIC_ADD,
|
INST_ATOMIC_RMW,
|
||||||
// math
|
// math
|
||||||
INST_EXP,
|
INST_EXP,
|
||||||
INST_LOG,
|
INST_LOG,
|
||||||
|
@@ -608,17 +608,20 @@ public:
|
|||||||
using io_inst::io_inst;
|
using io_inst::io_inst;
|
||||||
};
|
};
|
||||||
|
|
||||||
class atomic_add_inst: public atomic_inst {
|
class atomic_rmw_inst: public atomic_inst {
|
||||||
private:
|
private:
|
||||||
atomic_add_inst(value *ptr, value *val, value *msk, const std::string &name = "", instruction *next = nullptr);
|
atomic_rmw_inst(atomic_rmw_op_t op, value *ptr, value *val, value *msk, const std::string &name = "", instruction *next = nullptr);
|
||||||
std::string repr_impl() const { return "atomic_add"; }
|
std::string repr_impl() const { return "atomic_rmw"; }
|
||||||
_TRITON_DEFINE_CLONE(atomic_add_inst)
|
_TRITON_DEFINE_CLONE(atomic_rmw_inst)
|
||||||
_TRITON_DEFINE_ACCEPT(atomic_add_inst)
|
_TRITON_DEFINE_ACCEPT(atomic_rmw_inst)
|
||||||
|
|
||||||
public:
|
public:
|
||||||
static instruction* create(value *ptr, value *val, value *msk, const std::string &name = "", instruction *next = nullptr);
|
static instruction* create(atomic_rmw_op_t op, value *ptr, value *val, value *msk, const std::string &name = "", instruction *next = nullptr);
|
||||||
};
|
atomic_rmw_op_t get_op() { return op_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
atomic_rmw_op_t op_;
|
||||||
|
};
|
||||||
|
|
||||||
class atomic_cas_inst: public atomic_inst {
|
class atomic_cas_inst: public atomic_inst {
|
||||||
private:
|
private:
|
||||||
|
@@ -53,9 +53,10 @@ class log_inst;
|
|||||||
|
|
||||||
class get_program_id_inst;
|
class get_program_id_inst;
|
||||||
class get_num_programs_inst;
|
class get_num_programs_inst;
|
||||||
|
class atomic_inst;
|
||||||
class atomic_cas_inst;
|
class atomic_cas_inst;
|
||||||
class atomic_exch_inst;
|
class atomic_exch_inst;
|
||||||
class atomic_add_inst;
|
class atomic_rmw_inst;
|
||||||
class dot_inst;
|
class dot_inst;
|
||||||
class trans_inst;
|
class trans_inst;
|
||||||
class sqrt_inst;
|
class sqrt_inst;
|
||||||
@@ -132,7 +133,7 @@ public:
|
|||||||
virtual void visit_get_num_programs_inst(get_num_programs_inst*) = 0;
|
virtual void visit_get_num_programs_inst(get_num_programs_inst*) = 0;
|
||||||
virtual void visit_atomic_cas_inst(atomic_cas_inst*) = 0;
|
virtual void visit_atomic_cas_inst(atomic_cas_inst*) = 0;
|
||||||
virtual void visit_atomic_exch_inst(atomic_exch_inst*) = 0;
|
virtual void visit_atomic_exch_inst(atomic_exch_inst*) = 0;
|
||||||
virtual void visit_atomic_add_inst(atomic_add_inst*) = 0;
|
virtual void visit_atomic_rmw_inst(atomic_rmw_inst*) = 0;
|
||||||
virtual void visit_dot_inst(dot_inst*) = 0;
|
virtual void visit_dot_inst(dot_inst*) = 0;
|
||||||
virtual void visit_trans_inst(trans_inst*) = 0;
|
virtual void visit_trans_inst(trans_inst*) = 0;
|
||||||
virtual void visit_sqrt_inst(sqrt_inst*) = 0;
|
virtual void visit_sqrt_inst(sqrt_inst*) = 0;
|
||||||
|
@@ -403,6 +403,7 @@ void layouts::run(ir::module &mod) {
|
|||||||
for(const auto& x: values_)
|
for(const auto& x: values_)
|
||||||
create(x.first, x.second);
|
create(x.first, x.second);
|
||||||
|
|
||||||
|
|
||||||
// create temporaries
|
// create temporaries
|
||||||
size_t id = values_.size();
|
size_t id = values_.size();
|
||||||
ir::for_each_instruction(mod, [this, &id](ir::instruction* i) {
|
ir::for_each_instruction(mod, [this, &id](ir::instruction* i) {
|
||||||
|
@@ -813,28 +813,35 @@ void generator::visit_atomic_cas_inst(ir::atomic_cas_inst* cas) {
|
|||||||
Module *module = current->getModule();
|
Module *module = current->getModule();
|
||||||
Value *tid = tgt_->get_local_id(module, *builder_, 0);
|
Value *tid = tgt_->get_local_id(module, *builder_, 0);
|
||||||
Value *pred = icmp_eq(tid, i32(0));
|
Value *pred = icmp_eq(tid, i32(0));
|
||||||
BasicBlock *tid_0_bb = BasicBlock::Create(*ctx_, "tid_0", current->getParent());
|
// BasicBlock *tid_0_bb = BasicBlock::Create(*ctx_, "tid_0", current->getParent());
|
||||||
BasicBlock *tid_0_done_bb = BasicBlock::Create(*ctx_, "tid_0_done", current->getParent());
|
// BasicBlock *tid_0_done_bb = BasicBlock::Create(*ctx_, "tid_0_done", current->getParent());
|
||||||
add_barrier();
|
add_barrier();
|
||||||
tgt_->add_memfence(module, *builder_);
|
tgt_->add_memfence(module, *builder_);
|
||||||
cond_br(pred, tid_0_bb, tid_0_done_bb);
|
Value *atom_ptr;
|
||||||
builder_->SetInsertPoint(tid_0_bb);
|
atom_ptr = gep(shmem_, i32(alloc_->offset(layouts_->get(layouts_->tmp(cas)))), "");
|
||||||
|
atom_ptr = bit_cast(atom_ptr, ptr_ty(cvt(cas->get_type()->get_scalar_ty()), 3));
|
||||||
|
// cond_br(pred, tid_0_bb, tid_0_done_bb);
|
||||||
|
// builder_->SetInsertPoint(tid_0_bb);
|
||||||
Value *cas_ptr = vals_[cas->get_operand(0)][{}];
|
Value *cas_ptr = vals_[cas->get_operand(0)][{}];
|
||||||
Value *cas_cmp = vals_[cas->get_operand(1)][{}];
|
Value *cas_cmp = vals_[cas->get_operand(1)][{}];
|
||||||
Value *cas_val = vals_[cas->get_operand(2)][{}];
|
Value *cas_val = vals_[cas->get_operand(2)][{}];
|
||||||
Value *old = atomic_cmp_xchg(cas_ptr, cas_cmp, cas_val, AtomicOrdering::Monotonic, AtomicOrdering::Monotonic);
|
std::string asm_str = "@$1 atom.global.cas.b32 $0, [$2], $3, $4;";
|
||||||
old = extract_val(old, std::vector<unsigned>{0});
|
FunctionType *fn_ty = FunctionType::get(i32_ty, {pred->getType(), cas_ptr->getType(), cas_cmp->getType(), cas_val->getType()}, false);
|
||||||
Value *atom_ptr;
|
InlineAsm *iasm = InlineAsm::get(fn_ty, asm_str, "=r,b,l,r,r", true);
|
||||||
atom_ptr = gep(shmem_, i32(alloc_->offset(layouts_->get(layouts_->tmp(cas)))), "");
|
add_barrier();
|
||||||
atom_ptr = bit_cast(atom_ptr, ptr_ty(old->getType(), 3));
|
Value *old = call(iasm, {pred, cas_ptr, cas_cmp, cas_val});
|
||||||
store(old, atom_ptr);
|
add_barrier();
|
||||||
br(tid_0_done_bb);
|
|
||||||
builder_->SetInsertPoint(tid_0_done_bb);
|
std::string asm2_str = "@$0 st.shared.b32 [$1], $2;";
|
||||||
|
FunctionType *fn2_ty = FunctionType::get(void_ty, {pred->getType(), atom_ptr->getType(), old->getType()}, false);
|
||||||
|
InlineAsm *iasm2 = InlineAsm::get(fn2_ty, asm2_str, "b,r,r", true);
|
||||||
|
add_barrier();
|
||||||
|
call(iasm2, {pred, atom_ptr, old});
|
||||||
tgt_->add_memfence(module, *builder_);
|
tgt_->add_memfence(module, *builder_);
|
||||||
add_barrier();
|
add_barrier();
|
||||||
vals_[cas][{}] = load(atom_ptr);
|
vals_[cas][{}] = load(atom_ptr);
|
||||||
|
add_barrier();
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* \brief Code Generation for `atomic_exch`
|
* \brief Code Generation for `atomic_exch`
|
||||||
*/
|
*/
|
||||||
@@ -860,20 +867,19 @@ void generator::visit_atomic_exch_inst(ir::atomic_exch_inst* xchg) {
|
|||||||
/**
|
/**
|
||||||
* \brief Code Generation for `atomic_add`
|
* \brief Code Generation for `atomic_add`
|
||||||
*/
|
*/
|
||||||
//TODO: clean-up
|
void generator::visit_atomic_rmw_inst(ir::atomic_rmw_inst *atom) {
|
||||||
void generator::visit_atomic_add_inst(ir::atomic_add_inst* add) {
|
ir::value* ptr = atom->get_operand(0);
|
||||||
|
ir::value* val = atom->get_operand(1);
|
||||||
if(add->get_type()->is_block_ty()){
|
ir::value* msk = atom->get_operand(2);
|
||||||
ir::value* ptr = add->get_operand(0);
|
|
||||||
ir::value* val = add->get_operand(1);
|
|
||||||
ir::value* msk = add->get_operand(2);
|
|
||||||
|
|
||||||
// vector size
|
// vector size
|
||||||
int vec = 1;
|
int vec = 1;
|
||||||
|
if(atom->get_type()->is_block_ty()){
|
||||||
int ld = ords_.at(ptr)[0];
|
int ld = ords_.at(ptr)[0];
|
||||||
unsigned alignment = alignment_->get(ptr, ld);
|
unsigned alignment = alignment_->get(ptr, ld);
|
||||||
vec = std::min<int>(layouts_->get(ptr)->to_scanline()->nts(ld), alignment);
|
vec = std::min<int>(layouts_->get(ptr)->to_scanline()->nts(ld), alignment);
|
||||||
vec = std::min(vec, val->get_type()->get_tile_element_ty()->is_half_ty() ? 2 : 1);
|
vec = std::min(vec, val->get_type()->get_tile_element_ty()->is_half_ty() ? 2 : 1);
|
||||||
|
}
|
||||||
|
|
||||||
for(int i = 0; i < idxs_.at(val).size(); i += vec){
|
for(int i = 0; i < idxs_.at(val).size(); i += vec){
|
||||||
auto idx = idxs_[val][i];
|
auto idx = idxs_[val][i];
|
||||||
@@ -900,53 +906,47 @@ void generator::visit_atomic_add_inst(ir::atomic_add_inst* add) {
|
|||||||
// asm function type
|
// asm function type
|
||||||
FunctionType *fn_ty = FunctionType::get(ty, arg_ty, false);
|
FunctionType *fn_ty = FunctionType::get(ty, arg_ty, false);
|
||||||
// asm string
|
// asm string
|
||||||
std::string suffix = vec == 2 ? "x2" : "";
|
std::string s_nbits = std::to_string(nbits);
|
||||||
|
std::string name;
|
||||||
|
std::string s_ty;
|
||||||
|
using tt = ir::atomic_rmw_op_t;
|
||||||
|
switch(atom->get_op()){
|
||||||
|
case tt::Or: name = "or"; s_ty = "b"; break;
|
||||||
|
case tt::And: name = "and"; s_ty = "b"; break;
|
||||||
|
case tt::Xor: name = "xor", s_ty = "b"; break;
|
||||||
|
case tt::Add: name = "add" , s_ty = "s"; break;
|
||||||
|
case tt::Min: name = "min", s_ty = "s"; break;
|
||||||
|
case tt::Max: name = "max", s_ty = "s"; break;
|
||||||
|
case tt::UMin: name = "min", s_ty = "u"; break;
|
||||||
|
case tt::UMax: name = "max", s_ty = "u"; break;
|
||||||
|
case tt::FAdd: name = "add", s_ty = "f"; break;
|
||||||
|
}
|
||||||
|
std::string s_vec = vec == 2 ? "x2" : "";
|
||||||
std::string mod = nbits == 32 ? "" : ".noftz";
|
std::string mod = nbits == 32 ? "" : ".noftz";
|
||||||
std::string ty_str = add->get_type()->get_scalar_ty()->is_floating_point_ty() ? "f" : "u";
|
|
||||||
std::string asm_str = "@$1 atom.global.gpu.add" + mod + "." + ty_str + std::to_string(nbits) + suffix + " $0, [$2" + offset + "], $3;";
|
std::string asm_str = "@$1 atom.global.gpu." + name + mod + "." + s_ty + s_nbits + s_vec + " $0, [$2" + offset + "], $3;";
|
||||||
std::string ty_id = nbits == 32 ? ty_str : (vec == 1 ? "h" : "r");
|
std::string ty_id = nbits*vec == 32 ? "r" : "h";
|
||||||
std::string constraint = "=" + ty_id + ",b,l," + ty_id;
|
std::string constraint = "=" + ty_id + ",b,l," + ty_id;
|
||||||
// create inline asm
|
// create inline asm
|
||||||
InlineAsm *iasm = InlineAsm::get(fn_ty, asm_str, constraint, true);
|
InlineAsm *iasm = InlineAsm::get(fn_ty, asm_str, constraint, true);
|
||||||
// call asm
|
// call asm
|
||||||
vals_[add][idx] = call(iasm, (ArrayRef<Value*>{rmw_msk, rmw_ptr, rmw_val}));
|
if(atom->get_type()->is_block_ty())
|
||||||
}
|
vals_[atom][idx] = call(iasm, (ArrayRef<Value*>{rmw_msk, rmw_ptr, rmw_val}));
|
||||||
}
|
|
||||||
else{
|
else{
|
||||||
Value *rmw_ptr = vals_[add->get_operand(0)][{}];
|
Module *mod = builder_->GetInsertBlock()->getModule();
|
||||||
Value *rmw_val = vals_[add->get_operand(1)][{}];
|
tgt_->add_memfence(mod, *builder_);
|
||||||
Value *rmw_msk = vals_[add->get_operand(2)][{}];
|
|
||||||
Type* ty = rmw_val->getType();
|
|
||||||
size_t nbits = ty->getScalarSizeInBits();
|
|
||||||
std::vector<Type*> arg_ty = {rmw_msk->getType(), rmw_ptr->getType(), rmw_val->getType()};
|
|
||||||
FunctionType *fn_ty = FunctionType::get(ty, arg_ty, false);
|
|
||||||
std::string mod = nbits == 32 ? "" : ".noftz";
|
|
||||||
std::string ty_str = add->get_type()->get_scalar_ty()->is_floating_point_ty() ? "f" : "u";
|
|
||||||
std::string asm_str = "@$1 atom.global.gpu.add" + mod + "." + ty_str + std::to_string(nbits) + " $0, [$2], $3;";
|
|
||||||
std::string ty_id = nbits == 32 ? "r" : "h";
|
|
||||||
InlineAsm *iasm = InlineAsm::get(fn_ty, asm_str, "="+ty_id+",b,l,"+ty_id, true);
|
|
||||||
|
|
||||||
BasicBlock *current = builder_->GetInsertBlock();
|
|
||||||
Module *module = current->getModule();
|
|
||||||
|
|
||||||
Value *tid = tgt_->get_local_id(module, *builder_, 0);
|
|
||||||
Value *pred = icmp_eq(tid, i32(0));
|
|
||||||
BasicBlock *tid_0_bb = BasicBlock::Create(*ctx_, "tid_0", current->getParent());
|
|
||||||
BasicBlock *tid_0_done_bb = BasicBlock::Create(*ctx_, "tid_0_done", current->getParent());
|
|
||||||
tgt_->add_memfence(module, *builder_);
|
|
||||||
add_barrier();
|
add_barrier();
|
||||||
cond_br(pred, tid_0_bb, tid_0_done_bb);
|
Value *tid = tgt_->get_local_id(mod, *builder_, 0);
|
||||||
builder_->SetInsertPoint(tid_0_bb);
|
rmw_msk = builder_->CreateAnd(rmw_msk, icmp_eq(tid, i32(0)));
|
||||||
Value *old = call(iasm, (ArrayRef<Value*>{rmw_msk, rmw_ptr, rmw_val}));
|
Value *old = call(iasm, (ArrayRef<Value*>{rmw_msk, rmw_ptr, rmw_val}));
|
||||||
Value *atom_ptr;
|
Value *atom_ptr;
|
||||||
atom_ptr = gep(shmem_, i32(alloc_->offset(layouts_->get(layouts_->tmp(add)))), "");
|
atom_ptr = gep(shmem_, i32(alloc_->offset(layouts_->get(layouts_->tmp(atom)))), "");
|
||||||
atom_ptr = bit_cast(atom_ptr, ptr_ty(old->getType(), 3));
|
atom_ptr = bit_cast(atom_ptr, ptr_ty(old->getType(), 3));
|
||||||
store(old, atom_ptr);
|
store(old, atom_ptr);
|
||||||
br(tid_0_done_bb);
|
|
||||||
builder_->SetInsertPoint(tid_0_done_bb);
|
|
||||||
tgt_->add_memfence(module, *builder_);
|
|
||||||
add_barrier();
|
add_barrier();
|
||||||
vals_[add][{}] = load(atom_ptr);
|
vals_[atom][idx] = load(atom_ptr);
|
||||||
|
add_barrier();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -25,8 +25,8 @@ void dce::run(ir::module &mod) {
|
|||||||
case ir::INST_COND_BRANCH:
|
case ir::INST_COND_BRANCH:
|
||||||
case ir::INST_UNMASKED_STORE:
|
case ir::INST_UNMASKED_STORE:
|
||||||
case ir::INST_MASKED_STORE:
|
case ir::INST_MASKED_STORE:
|
||||||
case ir::INST_ATOMIC_ADD:
|
|
||||||
case ir::INST_ATOMIC_CAS:
|
case ir::INST_ATOMIC_CAS:
|
||||||
|
case ir::INST_ATOMIC_RMW:
|
||||||
case ir::INST_ATOMIC_EXCH:
|
case ir::INST_ATOMIC_EXCH:
|
||||||
case ir::INST_BARRIER: {
|
case ir::INST_BARRIER: {
|
||||||
work_list.push_back(i);
|
work_list.push_back(i);
|
||||||
|
@@ -323,8 +323,8 @@ value *builder::create_atomic_exch(value *ptr, value *val){
|
|||||||
return insert(atomic_exch_inst::create(ptr, val));
|
return insert(atomic_exch_inst::create(ptr, val));
|
||||||
}
|
}
|
||||||
|
|
||||||
value *builder::create_atomic_add(value *ptr, value *val, value *msk){
|
value *builder::create_atomic_rmw(ir::atomic_rmw_op_t op, value *ptr, value *val, value *msk){
|
||||||
return insert(atomic_add_inst::create(ptr, val, msk));
|
return insert(atomic_rmw_inst::create(op, ptr, val, msk));
|
||||||
}
|
}
|
||||||
|
|
||||||
value *builder::create_exp(value *arg){
|
value *builder::create_exp(value *arg){
|
||||||
|
@@ -538,7 +538,7 @@ ir::value *dispatch::atomic_xchg(ir::value* ptr, ir::value *val, ir::builder *bu
|
|||||||
return builder->create_atomic_exch(ptr, val);
|
return builder->create_atomic_exch(ptr, val);
|
||||||
}
|
}
|
||||||
|
|
||||||
ir::value *dispatch::atomic_add(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){
|
void atom_red_typechecking(ir::value*& ptr, ir::value *&val, ir::value *&mask, ir::builder *builder){
|
||||||
if(!ptr->get_type()->get_scalar_ty()->is_pointer_ty())
|
if(!ptr->get_type()->get_scalar_ty()->is_pointer_ty())
|
||||||
throw semantic_error("Pointer argument of store instruction is " + ptr->get_type()->repr());
|
throw semantic_error("Pointer argument of store instruction is " + ptr->get_type()->repr());
|
||||||
if(ptr->get_type()->is_block_ty()){
|
if(ptr->get_type()->is_block_ty()){
|
||||||
@@ -550,14 +550,69 @@ ir::value *dispatch::atomic_add(ir::value* ptr, ir::value *val, ir::value *mask,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
val = dispatch::cast(val, ptr->get_type()->get_scalar_ty()->get_pointer_element_ty(), builder);
|
val = dispatch::cast(val, ptr->get_type()->get_scalar_ty()->get_pointer_element_ty(), builder);
|
||||||
|
|
||||||
if(!mask){
|
if(!mask){
|
||||||
mask = builder->get_int1(true);
|
mask = builder->get_int1(true);
|
||||||
if(ptr->get_type()->is_block_ty())
|
if(ptr->get_type()->is_block_ty())
|
||||||
mask = builder->create_splat(mask, ptr->get_type()->get_block_shapes());
|
mask = builder->create_splat(mask, ptr->get_type()->get_block_shapes());
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return builder->create_atomic_add(ptr, val, mask);
|
ir::value *dispatch::atomic_max(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){
|
||||||
|
atom_red_typechecking(ptr, val, mask, builder);
|
||||||
|
ir::type* sca_ty = val->get_type()->get_scalar_ty();
|
||||||
|
// direct call to atomic_max for integers
|
||||||
|
if(sca_ty->is_integer_ty())
|
||||||
|
return builder->create_atomic_rmw(ir::atomic_rmw_op_t::Max, ptr, val, mask);
|
||||||
|
// for float
|
||||||
|
// return atomic_smax(i_ptr, i_val) if val >= 0
|
||||||
|
// return atomic_umin(i_ptr, i_val) if val < 0
|
||||||
|
ir::value* i_val = bitcast(val, builder->get_int32_ty(), builder);
|
||||||
|
ir::value* i_ptr = bitcast(ptr, pointer_type::get(builder->get_int32_ty(), 1), builder);
|
||||||
|
ir::value* pos = greater_equal(val, constant_fp::get(sca_ty, 0), builder);
|
||||||
|
ir::value* neg = less_than(val, constant_fp::get(sca_ty, 0), builder);
|
||||||
|
ir::value* pos_ret = builder->create_atomic_rmw(ir::atomic_rmw_op_t::Max, i_ptr, i_val, and_(mask, pos, builder));
|
||||||
|
ir::value* neg_ret = builder->create_atomic_rmw(ir::atomic_rmw_op_t::UMin, i_ptr, i_val, and_(mask, neg, builder));
|
||||||
|
return where(pos, pos_ret, neg_ret, builder);
|
||||||
|
}
|
||||||
|
|
||||||
|
ir::value *dispatch::atomic_min(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){
|
||||||
|
atom_red_typechecking(ptr, val, mask, builder);
|
||||||
|
ir::type* sca_ty = val->get_type()->get_scalar_ty();
|
||||||
|
// direct call to atomic_max for integers
|
||||||
|
if(sca_ty->is_integer_ty())
|
||||||
|
return builder->create_atomic_rmw(ir::atomic_rmw_op_t::Min, ptr, val, mask);
|
||||||
|
// for float
|
||||||
|
// return atomic_smin(i_ptr, i_val) if val >= 0
|
||||||
|
// return atomic_umax(i_ptr, i_val) if val < 0
|
||||||
|
ir::value* i_val = bitcast(val, builder->get_int32_ty(), builder);
|
||||||
|
ir::value* i_ptr = bitcast(ptr, pointer_type::get(builder->get_int32_ty(), 1), builder);
|
||||||
|
ir::value* pos = greater_equal(val, constant_fp::get(sca_ty, 0), builder);
|
||||||
|
ir::value* neg = less_than(val, constant_fp::get(sca_ty, 0), builder);
|
||||||
|
ir::value* pos_ret = builder->create_atomic_rmw(ir::atomic_rmw_op_t::Min, i_ptr, i_val, and_(mask, pos, builder));
|
||||||
|
ir::value* neg_ret = builder->create_atomic_rmw(ir::atomic_rmw_op_t::UMax, i_ptr, i_val, and_(mask, neg, builder));
|
||||||
|
return where(pos, pos_ret, neg_ret, builder);
|
||||||
|
}
|
||||||
|
|
||||||
|
ir::value *dispatch::atomic_add(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){
|
||||||
|
atom_red_typechecking(ptr, val, mask, builder);
|
||||||
|
ir::type* sca_ty = val->get_type()->get_scalar_ty();
|
||||||
|
auto op = sca_ty->is_floating_point_ty() ? ir::atomic_rmw_op_t::FAdd : ir::atomic_rmw_op_t::Add;
|
||||||
|
return builder->create_atomic_rmw(op, ptr, val, mask);
|
||||||
|
}
|
||||||
|
|
||||||
|
ir::value *dispatch::atomic_and(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){
|
||||||
|
atom_red_typechecking(ptr, val, mask, builder);
|
||||||
|
return builder->create_atomic_rmw(ir::atomic_rmw_op_t::And, ptr, val, mask);
|
||||||
|
}
|
||||||
|
|
||||||
|
ir::value *dispatch::atomic_or(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){
|
||||||
|
atom_red_typechecking(ptr, val, mask, builder);
|
||||||
|
return builder->create_atomic_rmw(ir::atomic_rmw_op_t::Or, ptr, val, mask);
|
||||||
|
}
|
||||||
|
|
||||||
|
ir::value *dispatch::atomic_xor(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){
|
||||||
|
atom_red_typechecking(ptr, val, mask, builder);
|
||||||
|
return builder->create_atomic_rmw(ir::atomic_rmw_op_t::Xor, ptr, val, mask);
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@@ -730,17 +730,20 @@ instruction* get_num_programs_inst::create(context &ctx, unsigned axis, const st
|
|||||||
return new get_num_programs_inst(type::get_int32_ty(ctx), axis, name, next);
|
return new get_num_programs_inst(type::get_int32_ty(ctx), axis, name, next);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// atomic_rmw
|
||||||
|
|
||||||
atomic_add_inst::atomic_add_inst(value *ptr, value *val, value *msk, const std::string &name, instruction *next)
|
atomic_rmw_inst::atomic_rmw_inst(atomic_rmw_op_t op, value *ptr, value *val, value *msk, const std::string &name, instruction *next)
|
||||||
: atomic_inst(ptr->get_type()->get_pointer_element_ty(), INST_ATOMIC_ADD, 3, name, next) {
|
: atomic_inst(ptr->get_type()->get_pointer_element_ty(), INST_ATOMIC_RMW, 3, name, next), op_(op) {
|
||||||
set_operand(0, ptr);
|
set_operand(0, ptr);
|
||||||
set_operand(1, val);
|
set_operand(1, val);
|
||||||
set_operand(2, msk);
|
set_operand(2, msk);
|
||||||
}
|
}
|
||||||
|
|
||||||
instruction* atomic_add_inst::create(value *ptr, value *val, value *msk, const std::string &name, instruction *next) {
|
instruction* atomic_rmw_inst::create(atomic_rmw_op_t op, value *ptr, value *val, value *msk, const std::string &name, instruction *next) {
|
||||||
return new atomic_add_inst(ptr, val, msk, name, next);
|
return new atomic_rmw_inst(op, ptr, val, msk, name, next);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// atomic cas
|
// atomic cas
|
||||||
|
|
||||||
atomic_cas_inst::atomic_cas_inst(value *ptr, value *cmp, value *val, const std::string &name, instruction *next)
|
atomic_cas_inst::atomic_cas_inst(value *ptr, value *cmp, value *val, const std::string &name, instruction *next)
|
||||||
|
@@ -137,6 +137,11 @@ void init_triton_frontend(py::module &&m) {
|
|||||||
m.def("atomic_cas", &ir::dispatch::atomic_cas, ret::reference);
|
m.def("atomic_cas", &ir::dispatch::atomic_cas, ret::reference);
|
||||||
m.def("atomic_xchg", &ir::dispatch::atomic_xchg, ret::reference);
|
m.def("atomic_xchg", &ir::dispatch::atomic_xchg, ret::reference);
|
||||||
m.def("atomic_add", &ir::dispatch::atomic_add, ret::reference);
|
m.def("atomic_add", &ir::dispatch::atomic_add, ret::reference);
|
||||||
|
m.def("atomic_max", &ir::dispatch::atomic_max, ret::reference);
|
||||||
|
m.def("atomic_min", &ir::dispatch::atomic_min, ret::reference);
|
||||||
|
m.def("atomic_and", &ir::dispatch::atomic_and, ret::reference);
|
||||||
|
m.def("atomic_or", &ir::dispatch::atomic_or, ret::reference);
|
||||||
|
m.def("atomic_xor", &ir::dispatch::atomic_xor, ret::reference);
|
||||||
// linear algebra
|
// linear algebra
|
||||||
m.def("dot", &ir::dispatch::dot, ret::reference);
|
m.def("dot", &ir::dispatch::dot, ret::reference);
|
||||||
// indexing
|
// indexing
|
||||||
|
@@ -4,6 +4,7 @@ import triton.language as tl
|
|||||||
import copy
|
import copy
|
||||||
import pytest
|
import pytest
|
||||||
import ast
|
import ast
|
||||||
|
import itertools
|
||||||
|
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
|
|
||||||
@@ -238,8 +239,13 @@ def test_tuples():
|
|||||||
# ---------------
|
# ---------------
|
||||||
# test atomics
|
# test atomics
|
||||||
# ---------------
|
# ---------------
|
||||||
@pytest.mark.parametrize("dtype_x", ['int32', 'float16', 'float32'])
|
@pytest.mark.parametrize("op, dtype_x, mode", itertools.chain.from_iterable([
|
||||||
def test_atomic_add(dtype_x, device='cuda'):
|
[('add', 'int32', mode), ('add', 'float16', mode), ('add', 'float32', mode), \
|
||||||
|
('max', 'int32', mode), ('max', 'float32', mode),\
|
||||||
|
('min', 'int32', mode), ('min', 'float32', mode),\
|
||||||
|
]
|
||||||
|
for mode in ['all_neg', 'all_pos', 'min_neg', 'max_pos']]))
|
||||||
|
def test_atomic_rmw(op, dtype_x, mode, device='cuda'):
|
||||||
dtype_x = cvt[dtype_x]
|
dtype_x = cvt[dtype_x]
|
||||||
n_programs = 37
|
n_programs = 37
|
||||||
|
|
||||||
@@ -247,20 +253,39 @@ def test_atomic_add(dtype_x, device='cuda'):
|
|||||||
@triton.jit
|
@triton.jit
|
||||||
def kernel(X, Z, **meta):
|
def kernel(X, Z, **meta):
|
||||||
pid = tl.program_id(0)
|
pid = tl.program_id(0)
|
||||||
old = tl.atomic_add(X, pid)
|
x = tl.load(X + pid)
|
||||||
tl.store(Z + pid, old)
|
old = GENERATE_TEST_HERE
|
||||||
|
|
||||||
|
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.atomic_{op}(Z, x)'})
|
||||||
|
torch_op = {'add': torch.sum, 'max': torch.max, 'min': torch.min}[op]
|
||||||
|
max_neutral = float('-inf') if dtype_x.is_floating_point else torch.iinfo(dtype_x).min
|
||||||
|
min_neutral = float('inf') if dtype_x.is_floating_point else torch.iinfo(dtype_x).max
|
||||||
|
neutral = {'add': 0, 'max': max_neutral, 'min': min_neutral}[op]
|
||||||
|
|
||||||
# triton result
|
# triton result
|
||||||
x_tri = torch.zeros((1, ), dtype=dtype_x, device=device)
|
x_tri = triton.testing.random((n_programs, ), dtype=dtype_x, device=device)
|
||||||
z_tri = torch.empty((n_programs, ), dtype=torch.int32, device=device)
|
if mode == 'all_neg':
|
||||||
|
x_tri = -torch.abs(x_tri)
|
||||||
|
if mode == 'all_pos':
|
||||||
|
x_tri = torch.abs(x_tri)
|
||||||
|
if mode == 'min_neg':
|
||||||
|
idx = torch.randint(n_programs, size=(1, )).item()
|
||||||
|
x_tri[idx] = -torch.max(torch.abs(x_tri)) - 1
|
||||||
|
if mode == 'max_pos':
|
||||||
|
idx = torch.randint(n_programs, size=(1, )).item()
|
||||||
|
x_tri[idx] = torch.max(torch.abs(x_tri)) + 1
|
||||||
|
|
||||||
|
z_tri = torch.empty([], dtype=dtype_x, device=device)
|
||||||
|
z_tri.fill_(neutral)
|
||||||
kernel[(n_programs, )](x_tri, z_tri)
|
kernel[(n_programs, )](x_tri, z_tri)
|
||||||
last_sum = torch.max(z_tri) + torch.argmax(z_tri)
|
|
||||||
last_sum = last_sum.to(dtype_x)
|
|
||||||
# torch result
|
# torch result
|
||||||
range = torch.arange(n_programs, dtype=torch.int32, device=device)
|
z_ref = torch_op(x_tri).to(dtype_x)
|
||||||
x_ref = torch.sum(range).to(dtype_x)
|
# compare
|
||||||
triton.testing.assert_allclose(x_ref, x_tri[0])
|
exact = op not in ['add']
|
||||||
triton.testing.assert_allclose(x_ref, last_sum)
|
if exact:
|
||||||
|
assert z_ref.item() == z_tri.item()
|
||||||
|
else:
|
||||||
|
triton.testing.assert_allclose(z_ref, z_tri)
|
||||||
|
|
||||||
|
|
||||||
# ---------------
|
# ---------------
|
||||||
|
@@ -89,6 +89,8 @@ float16 = dtype(ir.type.get_fp16)
|
|||||||
float32 = dtype(ir.type.get_fp32)
|
float32 = dtype(ir.type.get_fp32)
|
||||||
float64 = dtype(ir.type.get_fp64)
|
float64 = dtype(ir.type.get_fp64)
|
||||||
|
|
||||||
|
pi32_t = pointer_dtype(int32)
|
||||||
|
|
||||||
|
|
||||||
class block:
|
class block:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -464,6 +466,31 @@ def atomic_add(pointer, val, mask=None, builder=None):
|
|||||||
return frontend.atomic_add(pointer, val, mask, builder)
|
return frontend.atomic_add(pointer, val, mask, builder)
|
||||||
|
|
||||||
|
|
||||||
|
@builtin
|
||||||
|
def atomic_max(pointer, val, mask=None, builder=None):
|
||||||
|
return frontend.atomic_max(pointer, val, mask, builder)
|
||||||
|
|
||||||
|
|
||||||
|
@builtin
|
||||||
|
def atomic_min(pointer, val, mask=None, builder=None):
|
||||||
|
return frontend.atomic_min(pointer, val, mask, builder)
|
||||||
|
|
||||||
|
|
||||||
|
@builtin
|
||||||
|
def atomic_and(pointer, val, mask=None, builder=None):
|
||||||
|
return frontend.atomic_and(pointer, val, mask, builder)
|
||||||
|
|
||||||
|
|
||||||
|
@builtin
|
||||||
|
def atomic_or(pointer, val, mask=None, builder=None):
|
||||||
|
return frontend.atomic_or(pointer, val, mask, builder)
|
||||||
|
|
||||||
|
|
||||||
|
@builtin
|
||||||
|
def atomic_xor(pointer, val, mask=None, builder=None):
|
||||||
|
return frontend.atomic_xor(pointer, val, mask, builder)
|
||||||
|
|
||||||
|
|
||||||
# -----------------------
|
# -----------------------
|
||||||
# Conditioning
|
# Conditioning
|
||||||
# -----------------------
|
# -----------------------
|
||||||
|
@@ -76,7 +76,7 @@ def random(shape, dtype, device):
|
|||||||
if dtype in [torch.int8, torch.int16, torch.int32, torch.int64]:
|
if dtype in [torch.int8, torch.int16, torch.int32, torch.int64]:
|
||||||
return torch.randint(1, 32, shape, dtype=dtype, device=device)
|
return torch.randint(1, 32, shape, dtype=dtype, device=device)
|
||||||
if dtype in [torch.float16, torch.float32, torch.float64]:
|
if dtype in [torch.float16, torch.float32, torch.float64]:
|
||||||
return torch.randn(shape, dtype=dtype, device=device)
|
return torch.normal(0, 10, shape, dtype=dtype, device=device)
|
||||||
raise RuntimeError(f'Unknown dtype {dtype}')
|
raise RuntimeError(f'Unknown dtype {dtype}')
|
||||||
|
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user