[IR] Added IR and Codegen support for atomic_rmw (#120)

This commit is contained in:
Philippe Tillet
2021-05-25 18:31:48 -04:00
committed by Philippe Tillet
parent 59b0ac672a
commit 0274429429
17 changed files with 261 additions and 124 deletions

View File

@@ -154,7 +154,7 @@ public:
void visit_get_num_programs_inst(ir::get_num_programs_inst*);
void visit_atomic_cas_inst(ir::atomic_cas_inst*);
void visit_atomic_exch_inst(ir::atomic_exch_inst*);
void visit_atomic_add_inst(ir::atomic_add_inst*);
void visit_atomic_rmw_inst(ir::atomic_rmw_inst*);
void visit_mma884(ir::dot_inst*, ir::value *A, ir::value *B, ir::value *D, unsigned NK);
void visit_mma16816(ir::dot_inst*, ir::value *A, ir::value *B, ir::value *D, unsigned NK);
void visit_fmadot(ir::dot_inst*, ir::value *A, ir::value *B, ir::value *D, unsigned NK, Type *c_ty, Function *f_mul_add);

View File

@@ -137,7 +137,7 @@ public:
value *create_get_num_programs(unsigned axis);
value *create_atomic_cas(value *ptr, value *cmp, value *val);
value *create_atomic_exch(value *ptr, value *val);
value *create_atomic_add(value *ptr, value *val, value *msk);
value *create_atomic_rmw(ir::atomic_rmw_op_t op, value *ptr, value *val, value *msk);
value *create_exp(value* arg);
value *create_log(value* arg);
value *create_dot(value *A, value *B, value *C);

View File

@@ -71,6 +71,11 @@ struct dispatch{
static ir::value *atomic_cas(ir::value* ptr, ir::value *cmp, ir::value *val, ir::builder *builder);
static ir::value *atomic_xchg(ir::value* ptr, ir::value *val, ir::builder *builder);
static ir::value *atomic_add(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);
static ir::value *atomic_max(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);
static ir::value *atomic_min(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);
static ir::value *atomic_and(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);
static ir::value *atomic_or(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);
static ir::value *atomic_xor(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);
// linear algebra
static ir::value *dot(ir::value *lhs, ir::value *rhs, ir::builder *builder);

View File

@@ -28,6 +28,18 @@ enum binary_op_t: unsigned int{
Xor
};
enum class atomic_rmw_op_t: unsigned int{
And,
Or,
Xor,
Add,
Max,
Min,
UMax,
UMin,
FAdd,
};
enum cast_op_t: unsigned int {
Trunc,
ZExt,
@@ -127,7 +139,7 @@ enum value_id_t: unsigned {
// atomics
INST_ATOMIC_CAS,
INST_ATOMIC_EXCH,
INST_ATOMIC_ADD,
INST_ATOMIC_RMW,
// math
INST_EXP,
INST_LOG,

View File

@@ -608,17 +608,20 @@ public:
using io_inst::io_inst;
};
class atomic_add_inst: public atomic_inst {
class atomic_rmw_inst: public atomic_inst {
private:
atomic_add_inst(value *ptr, value *val, value *msk, const std::string &name = "", instruction *next = nullptr);
std::string repr_impl() const { return "atomic_add"; }
_TRITON_DEFINE_CLONE(atomic_add_inst)
_TRITON_DEFINE_ACCEPT(atomic_add_inst)
atomic_rmw_inst(atomic_rmw_op_t op, value *ptr, value *val, value *msk, const std::string &name = "", instruction *next = nullptr);
std::string repr_impl() const { return "atomic_rmw"; }
_TRITON_DEFINE_CLONE(atomic_rmw_inst)
_TRITON_DEFINE_ACCEPT(atomic_rmw_inst)
public:
static instruction* create(value *ptr, value *val, value *msk, const std::string &name = "", instruction *next = nullptr);
};
static instruction* create(atomic_rmw_op_t op, value *ptr, value *val, value *msk, const std::string &name = "", instruction *next = nullptr);
atomic_rmw_op_t get_op() { return op_; }
private:
atomic_rmw_op_t op_;
};
class atomic_cas_inst: public atomic_inst {
private:

View File

@@ -53,9 +53,10 @@ class log_inst;
class get_program_id_inst;
class get_num_programs_inst;
class atomic_inst;
class atomic_cas_inst;
class atomic_exch_inst;
class atomic_add_inst;
class atomic_rmw_inst;
class dot_inst;
class trans_inst;
class sqrt_inst;
@@ -132,7 +133,7 @@ public:
virtual void visit_get_num_programs_inst(get_num_programs_inst*) = 0;
virtual void visit_atomic_cas_inst(atomic_cas_inst*) = 0;
virtual void visit_atomic_exch_inst(atomic_exch_inst*) = 0;
virtual void visit_atomic_add_inst(atomic_add_inst*) = 0;
virtual void visit_atomic_rmw_inst(atomic_rmw_inst*) = 0;
virtual void visit_dot_inst(dot_inst*) = 0;
virtual void visit_trans_inst(trans_inst*) = 0;
virtual void visit_sqrt_inst(sqrt_inst*) = 0;