[LANGUAGE] Added cos/sin (#132)
This commit is contained in:
committed by
Philippe Tillet
parent
3169e4355c
commit
2824345065
@@ -152,6 +152,8 @@ public:
|
||||
void visit_broadcast_inst(ir::broadcast_inst*);
|
||||
void visit_downcast_inst(ir::downcast_inst*);
|
||||
void visit_exp_inst(ir::exp_inst*);
|
||||
void visit_cos_inst(ir::cos_inst*);
|
||||
void visit_sin_inst(ir::sin_inst*);
|
||||
void visit_log_inst(ir::log_inst*);
|
||||
void visit_get_program_id_inst(ir::get_program_id_inst*);
|
||||
void visit_get_num_programs_inst(ir::get_num_programs_inst*);
|
||||
|
@@ -145,6 +145,8 @@ public:
|
||||
value *create_atomic_exch(value *ptr, value *val);
|
||||
value *create_atomic_rmw(ir::atomic_rmw_op_t op, value *ptr, value *val, value *msk);
|
||||
value *create_exp(value* arg);
|
||||
value *create_cos(value* arg);
|
||||
value *create_sin(value* arg);
|
||||
value *create_log(value* arg);
|
||||
value *create_dot(value *A, value *B, value *C);
|
||||
value *create_trans(value *A, const std::vector<int> &perm = {});
|
||||
|
@@ -91,6 +91,8 @@ struct dispatch{
|
||||
// math
|
||||
static ir::value *exp(ir::value *x, ir::builder *builder);
|
||||
static ir::value *log(ir::value *x, ir::builder *builder);
|
||||
static ir::value *cos(ir::value *x, ir::builder *builder);
|
||||
static ir::value *sin(ir::value *x, ir::builder *builder);
|
||||
static ir::value *sqrt(ir::value *x, ir::builder *builder);
|
||||
|
||||
// internal (debug/optimization)
|
||||
|
@@ -142,6 +142,8 @@ enum value_id_t: unsigned {
|
||||
INST_ATOMIC_RMW,
|
||||
// math
|
||||
INST_EXP,
|
||||
INST_COS,
|
||||
INST_SIN,
|
||||
INST_LOG,
|
||||
// array arithmetic
|
||||
INST_TRANS,
|
||||
|
@@ -657,6 +657,28 @@ public:
|
||||
static instruction* create(value *val, const std::string &name = "", instruction *next = nullptr);
|
||||
};
|
||||
|
||||
class cos_inst: public builtin_inst {
|
||||
private:
|
||||
cos_inst(value *val, const std::string &name = "", instruction *next = nullptr);
|
||||
std::string repr_impl() const { return "cos"; }
|
||||
_TRITON_DEFINE_CLONE(cos_inst)
|
||||
_TRITON_DEFINE_ACCEPT(cos_inst)
|
||||
|
||||
public:
|
||||
static instruction* create(value *val, const std::string &name = "", instruction *next = nullptr);
|
||||
};
|
||||
|
||||
class sin_inst: public builtin_inst {
|
||||
private:
|
||||
sin_inst(value *val, const std::string &name = "", instruction *next = nullptr);
|
||||
std::string repr_impl() const { return "sin"; }
|
||||
_TRITON_DEFINE_CLONE(sin_inst)
|
||||
_TRITON_DEFINE_ACCEPT(sin_inst)
|
||||
|
||||
public:
|
||||
static instruction* create(value *val, const std::string &name = "", instruction *next = nullptr);
|
||||
};
|
||||
|
||||
class log_inst: public builtin_inst {
|
||||
private:
|
||||
log_inst(value *val, const std::string &name = "", instruction *next = nullptr);
|
||||
|
@@ -49,6 +49,8 @@ class broadcast_inst;
|
||||
class downcast_inst;
|
||||
|
||||
class exp_inst;
|
||||
class cos_inst;
|
||||
class sin_inst;
|
||||
class log_inst;
|
||||
|
||||
class get_program_id_inst;
|
||||
@@ -122,6 +124,8 @@ public:
|
||||
virtual void visit_masked_store_inst(masked_store_inst*) = 0;
|
||||
|
||||
virtual void visit_exp_inst(exp_inst*) = 0;
|
||||
virtual void visit_cos_inst(cos_inst*) = 0;
|
||||
virtual void visit_sin_inst(sin_inst*) = 0;
|
||||
virtual void visit_log_inst(log_inst*) = 0;
|
||||
|
||||
virtual void visit_reshape_inst(reshape_inst*) = 0;
|
||||
|
Reference in New Issue
Block a user