[LANGUAGE] Added cos/sin (#132)

This commit is contained in:
Philippe Tillet
2021-07-14 17:16:48 -07:00
committed by Philippe Tillet
parent 3169e4355c
commit 2824345065
13 changed files with 135 additions and 2 deletions

View File

@@ -152,6 +152,8 @@ public:
void visit_broadcast_inst(ir::broadcast_inst*);
void visit_downcast_inst(ir::downcast_inst*);
void visit_exp_inst(ir::exp_inst*);
void visit_cos_inst(ir::cos_inst*);
void visit_sin_inst(ir::sin_inst*);
void visit_log_inst(ir::log_inst*);
void visit_get_program_id_inst(ir::get_program_id_inst*);
void visit_get_num_programs_inst(ir::get_num_programs_inst*);

View File

@@ -145,6 +145,8 @@ public:
value *create_atomic_exch(value *ptr, value *val);
value *create_atomic_rmw(ir::atomic_rmw_op_t op, value *ptr, value *val, value *msk);
value *create_exp(value* arg);
value *create_cos(value* arg);
value *create_sin(value* arg);
value *create_log(value* arg);
value *create_dot(value *A, value *B, value *C);
value *create_trans(value *A, const std::vector<int> &perm = {});

View File

@@ -91,6 +91,8 @@ struct dispatch{
// math
static ir::value *exp(ir::value *x, ir::builder *builder);
static ir::value *log(ir::value *x, ir::builder *builder);
static ir::value *cos(ir::value *x, ir::builder *builder);
static ir::value *sin(ir::value *x, ir::builder *builder);
static ir::value *sqrt(ir::value *x, ir::builder *builder);
// internal (debug/optimization)

View File

@@ -142,6 +142,8 @@ enum value_id_t: unsigned {
INST_ATOMIC_RMW,
// math
INST_EXP,
INST_COS,
INST_SIN,
INST_LOG,
// array arithmetic
INST_TRANS,

View File

@@ -657,6 +657,28 @@ public:
static instruction* create(value *val, const std::string &name = "", instruction *next = nullptr);
};
class cos_inst: public builtin_inst {
private:
cos_inst(value *val, const std::string &name = "", instruction *next = nullptr);
std::string repr_impl() const { return "cos"; }
_TRITON_DEFINE_CLONE(cos_inst)
_TRITON_DEFINE_ACCEPT(cos_inst)
public:
static instruction* create(value *val, const std::string &name = "", instruction *next = nullptr);
};
class sin_inst: public builtin_inst {
private:
sin_inst(value *val, const std::string &name = "", instruction *next = nullptr);
std::string repr_impl() const { return "sin"; }
_TRITON_DEFINE_CLONE(sin_inst)
_TRITON_DEFINE_ACCEPT(sin_inst)
public:
static instruction* create(value *val, const std::string &name = "", instruction *next = nullptr);
};
class log_inst: public builtin_inst {
private:
log_inst(value *val, const std::string &name = "", instruction *next = nullptr);

View File

@@ -49,6 +49,8 @@ class broadcast_inst;
class downcast_inst;
class exp_inst;
class cos_inst;
class sin_inst;
class log_inst;
class get_program_id_inst;
@@ -122,6 +124,8 @@ public:
virtual void visit_masked_store_inst(masked_store_inst*) = 0;
virtual void visit_exp_inst(exp_inst*) = 0;
virtual void visit_cos_inst(cos_inst*) = 0;
virtual void visit_sin_inst(sin_inst*) = 0;
virtual void visit_log_inst(log_inst*) = 0;
virtual void visit_reshape_inst(reshape_inst*) = 0;