[IR] Added IR and Codegen support for atomic_rmw (#120)
This commit is contained in:
committed by
Philippe Tillet
parent
59b0ac672a
commit
0274429429
@@ -154,7 +154,7 @@ public:
|
||||
void visit_get_num_programs_inst(ir::get_num_programs_inst*);
|
||||
void visit_atomic_cas_inst(ir::atomic_cas_inst*);
|
||||
void visit_atomic_exch_inst(ir::atomic_exch_inst*);
|
||||
void visit_atomic_add_inst(ir::atomic_add_inst*);
|
||||
void visit_atomic_rmw_inst(ir::atomic_rmw_inst*);
|
||||
void visit_mma884(ir::dot_inst*, ir::value *A, ir::value *B, ir::value *D, unsigned NK);
|
||||
void visit_mma16816(ir::dot_inst*, ir::value *A, ir::value *B, ir::value *D, unsigned NK);
|
||||
void visit_fmadot(ir::dot_inst*, ir::value *A, ir::value *B, ir::value *D, unsigned NK, Type *c_ty, Function *f_mul_add);
|
||||
|
@@ -137,7 +137,7 @@ public:
|
||||
value *create_get_num_programs(unsigned axis);
|
||||
value *create_atomic_cas(value *ptr, value *cmp, value *val);
|
||||
value *create_atomic_exch(value *ptr, value *val);
|
||||
value *create_atomic_add(value *ptr, value *val, value *msk);
|
||||
value *create_atomic_rmw(ir::atomic_rmw_op_t op, value *ptr, value *val, value *msk);
|
||||
value *create_exp(value* arg);
|
||||
value *create_log(value* arg);
|
||||
value *create_dot(value *A, value *B, value *C);
|
||||
|
@@ -71,6 +71,11 @@ struct dispatch{
|
||||
static ir::value *atomic_cas(ir::value* ptr, ir::value *cmp, ir::value *val, ir::builder *builder);
|
||||
static ir::value *atomic_xchg(ir::value* ptr, ir::value *val, ir::builder *builder);
|
||||
static ir::value *atomic_add(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);
|
||||
static ir::value *atomic_max(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);
|
||||
static ir::value *atomic_min(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);
|
||||
static ir::value *atomic_and(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);
|
||||
static ir::value *atomic_or(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);
|
||||
static ir::value *atomic_xor(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);
|
||||
|
||||
// linear algebra
|
||||
static ir::value *dot(ir::value *lhs, ir::value *rhs, ir::builder *builder);
|
||||
|
@@ -28,6 +28,18 @@ enum binary_op_t: unsigned int{
|
||||
Xor
|
||||
};
|
||||
|
||||
enum class atomic_rmw_op_t: unsigned int{
|
||||
And,
|
||||
Or,
|
||||
Xor,
|
||||
Add,
|
||||
Max,
|
||||
Min,
|
||||
UMax,
|
||||
UMin,
|
||||
FAdd,
|
||||
};
|
||||
|
||||
enum cast_op_t: unsigned int {
|
||||
Trunc,
|
||||
ZExt,
|
||||
@@ -127,7 +139,7 @@ enum value_id_t: unsigned {
|
||||
// atomics
|
||||
INST_ATOMIC_CAS,
|
||||
INST_ATOMIC_EXCH,
|
||||
INST_ATOMIC_ADD,
|
||||
INST_ATOMIC_RMW,
|
||||
// math
|
||||
INST_EXP,
|
||||
INST_LOG,
|
||||
|
@@ -608,17 +608,20 @@ public:
|
||||
using io_inst::io_inst;
|
||||
};
|
||||
|
||||
class atomic_add_inst: public atomic_inst {
|
||||
class atomic_rmw_inst: public atomic_inst {
|
||||
private:
|
||||
atomic_add_inst(value *ptr, value *val, value *msk, const std::string &name = "", instruction *next = nullptr);
|
||||
std::string repr_impl() const { return "atomic_add"; }
|
||||
_TRITON_DEFINE_CLONE(atomic_add_inst)
|
||||
_TRITON_DEFINE_ACCEPT(atomic_add_inst)
|
||||
atomic_rmw_inst(atomic_rmw_op_t op, value *ptr, value *val, value *msk, const std::string &name = "", instruction *next = nullptr);
|
||||
std::string repr_impl() const { return "atomic_rmw"; }
|
||||
_TRITON_DEFINE_CLONE(atomic_rmw_inst)
|
||||
_TRITON_DEFINE_ACCEPT(atomic_rmw_inst)
|
||||
|
||||
public:
|
||||
static instruction* create(value *ptr, value *val, value *msk, const std::string &name = "", instruction *next = nullptr);
|
||||
};
|
||||
static instruction* create(atomic_rmw_op_t op, value *ptr, value *val, value *msk, const std::string &name = "", instruction *next = nullptr);
|
||||
atomic_rmw_op_t get_op() { return op_; }
|
||||
|
||||
private:
|
||||
atomic_rmw_op_t op_;
|
||||
};
|
||||
|
||||
class atomic_cas_inst: public atomic_inst {
|
||||
private:
|
||||
|
@@ -53,9 +53,10 @@ class log_inst;
|
||||
|
||||
class get_program_id_inst;
|
||||
class get_num_programs_inst;
|
||||
class atomic_inst;
|
||||
class atomic_cas_inst;
|
||||
class atomic_exch_inst;
|
||||
class atomic_add_inst;
|
||||
class atomic_rmw_inst;
|
||||
class dot_inst;
|
||||
class trans_inst;
|
||||
class sqrt_inst;
|
||||
@@ -132,7 +133,7 @@ public:
|
||||
virtual void visit_get_num_programs_inst(get_num_programs_inst*) = 0;
|
||||
virtual void visit_atomic_cas_inst(atomic_cas_inst*) = 0;
|
||||
virtual void visit_atomic_exch_inst(atomic_exch_inst*) = 0;
|
||||
virtual void visit_atomic_add_inst(atomic_add_inst*) = 0;
|
||||
virtual void visit_atomic_rmw_inst(atomic_rmw_inst*) = 0;
|
||||
virtual void visit_dot_inst(dot_inst*) = 0;
|
||||
virtual void visit_trans_inst(trans_inst*) = 0;
|
||||
virtual void visit_sqrt_inst(sqrt_inst*) = 0;
|
||||
|
Reference in New Issue
Block a user