[LANGUAGE] Added cos/sin (#132)
This commit is contained in:
committed by
Philippe Tillet
parent
3169e4355c
commit
2824345065
@@ -152,6 +152,8 @@ public:
|
||||
void visit_broadcast_inst(ir::broadcast_inst*);
|
||||
void visit_downcast_inst(ir::downcast_inst*);
|
||||
void visit_exp_inst(ir::exp_inst*);
|
||||
void visit_cos_inst(ir::cos_inst*);
|
||||
void visit_sin_inst(ir::sin_inst*);
|
||||
void visit_log_inst(ir::log_inst*);
|
||||
void visit_get_program_id_inst(ir::get_program_id_inst*);
|
||||
void visit_get_num_programs_inst(ir::get_num_programs_inst*);
|
||||
|
@@ -145,6 +145,8 @@ public:
|
||||
value *create_atomic_exch(value *ptr, value *val);
|
||||
value *create_atomic_rmw(ir::atomic_rmw_op_t op, value *ptr, value *val, value *msk);
|
||||
value *create_exp(value* arg);
|
||||
value *create_cos(value* arg);
|
||||
value *create_sin(value* arg);
|
||||
value *create_log(value* arg);
|
||||
value *create_dot(value *A, value *B, value *C);
|
||||
value *create_trans(value *A, const std::vector<int> &perm = {});
|
||||
|
@@ -91,6 +91,8 @@ struct dispatch{
|
||||
// math
|
||||
static ir::value *exp(ir::value *x, ir::builder *builder);
|
||||
static ir::value *log(ir::value *x, ir::builder *builder);
|
||||
static ir::value *cos(ir::value *x, ir::builder *builder);
|
||||
static ir::value *sin(ir::value *x, ir::builder *builder);
|
||||
static ir::value *sqrt(ir::value *x, ir::builder *builder);
|
||||
|
||||
// internal (debug/optimization)
|
||||
|
@@ -142,6 +142,8 @@ enum value_id_t: unsigned {
|
||||
INST_ATOMIC_RMW,
|
||||
// math
|
||||
INST_EXP,
|
||||
INST_COS,
|
||||
INST_SIN,
|
||||
INST_LOG,
|
||||
// array arithmetic
|
||||
INST_TRANS,
|
||||
|
@@ -657,6 +657,28 @@ public:
|
||||
static instruction* create(value *val, const std::string &name = "", instruction *next = nullptr);
|
||||
};
|
||||
|
||||
class cos_inst: public builtin_inst {
|
||||
private:
|
||||
cos_inst(value *val, const std::string &name = "", instruction *next = nullptr);
|
||||
std::string repr_impl() const { return "cos"; }
|
||||
_TRITON_DEFINE_CLONE(cos_inst)
|
||||
_TRITON_DEFINE_ACCEPT(cos_inst)
|
||||
|
||||
public:
|
||||
static instruction* create(value *val, const std::string &name = "", instruction *next = nullptr);
|
||||
};
|
||||
|
||||
class sin_inst: public builtin_inst {
|
||||
private:
|
||||
sin_inst(value *val, const std::string &name = "", instruction *next = nullptr);
|
||||
std::string repr_impl() const { return "sin"; }
|
||||
_TRITON_DEFINE_CLONE(sin_inst)
|
||||
_TRITON_DEFINE_ACCEPT(sin_inst)
|
||||
|
||||
public:
|
||||
static instruction* create(value *val, const std::string &name = "", instruction *next = nullptr);
|
||||
};
|
||||
|
||||
class log_inst: public builtin_inst {
|
||||
private:
|
||||
log_inst(value *val, const std::string &name = "", instruction *next = nullptr);
|
||||
|
@@ -49,6 +49,8 @@ class broadcast_inst;
|
||||
class downcast_inst;
|
||||
|
||||
class exp_inst;
|
||||
class cos_inst;
|
||||
class sin_inst;
|
||||
class log_inst;
|
||||
|
||||
class get_program_id_inst;
|
||||
@@ -122,6 +124,8 @@ public:
|
||||
virtual void visit_masked_store_inst(masked_store_inst*) = 0;
|
||||
|
||||
virtual void visit_exp_inst(exp_inst*) = 0;
|
||||
virtual void visit_cos_inst(cos_inst*) = 0;
|
||||
virtual void visit_sin_inst(sin_inst*) = 0;
|
||||
virtual void visit_log_inst(log_inst*) = 0;
|
||||
|
||||
virtual void visit_reshape_inst(reshape_inst*) = 0;
|
||||
|
@@ -844,6 +844,30 @@ void generator::visit_exp_inst(ir::exp_inst* x){
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Code Generation for `cos`
|
||||
*/
|
||||
void generator::visit_cos_inst(ir::cos_inst* x){
|
||||
std::vector<llvm::Type*> tys = {f32_ty};
|
||||
FunctionType *fn_ty = FunctionType::get(f32_ty, tys, false);
|
||||
InlineAsm *cos = InlineAsm::get(fn_ty, "cos.approx.f32 $0, $0;", "=f,0", false);
|
||||
for(auto idx: idxs_.at(x)){
|
||||
vals_[x][idx] = call(cos, std::vector<llvm::Value*>{vals_[x->get_operand(0)][idx]});
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Code Generation for `sin`
|
||||
*/
|
||||
void generator::visit_sin_inst(ir::sin_inst* x){
|
||||
std::vector<llvm::Type*> tys = {f32_ty};
|
||||
FunctionType *fn_ty = FunctionType::get(f32_ty, tys, false);
|
||||
InlineAsm *sin = InlineAsm::get(fn_ty, "sin.approx.f32 $0, $0;", "=f,0", false);
|
||||
for(auto idx: idxs_.at(x)){
|
||||
vals_[x][idx] = call(sin, std::vector<llvm::Value*>{vals_[x->get_operand(0)][idx]});
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Code Generation for `log`
|
||||
*/
|
||||
|
@@ -337,6 +337,14 @@ value *builder::create_exp(value *arg){
|
||||
return insert(exp_inst::create(arg));
|
||||
}
|
||||
|
||||
value *builder::create_cos(value *arg){
|
||||
return insert(cos_inst::create(arg));
|
||||
}
|
||||
|
||||
value *builder::create_sin(value *arg){
|
||||
return insert(sin_inst::create(arg));
|
||||
}
|
||||
|
||||
value *builder::create_log(value *arg){
|
||||
return insert(log_inst::create(arg));
|
||||
}
|
||||
|
@@ -688,6 +688,14 @@ ir::value *dispatch::log(ir::value *x, ir::builder *builder) {
|
||||
return builder->create_log(x);
|
||||
}
|
||||
|
||||
ir::value *dispatch::cos(ir::value *x, ir::builder *builder) {
|
||||
return builder->create_cos(x);
|
||||
}
|
||||
|
||||
ir::value *dispatch::sin(ir::value *x, ir::builder *builder) {
|
||||
return builder->create_sin(x);
|
||||
}
|
||||
|
||||
ir::value *dispatch::sqrt(ir::value *x, ir::builder *builder) {
|
||||
return builder->create_sqrt(x);
|
||||
}
|
||||
|
@@ -781,6 +781,27 @@ instruction* exp_inst::create(value *val, const std::string& name, instruction *
|
||||
return new exp_inst(val, name, next);
|
||||
}
|
||||
|
||||
// cos
|
||||
cos_inst::cos_inst(value *val, const std::string &name, instruction *next)
|
||||
: builtin_inst(val->get_type(), INST_COS, 1, name, next) {
|
||||
set_operand(0, val);
|
||||
}
|
||||
|
||||
instruction* cos_inst::create(value *val, const std::string& name, instruction *next) {
|
||||
return new cos_inst(val, name, next);
|
||||
}
|
||||
|
||||
// sin
|
||||
sin_inst::sin_inst(value *val, const std::string &name, instruction *next)
|
||||
: builtin_inst(val->get_type(), INST_SIN, 1, name, next) {
|
||||
set_operand(0, val);
|
||||
}
|
||||
|
||||
instruction* sin_inst::create(value *val, const std::string& name, instruction *next) {
|
||||
return new sin_inst(val, name, next);
|
||||
}
|
||||
|
||||
|
||||
// log
|
||||
|
||||
log_inst::log_inst(value *val, const std::string &name, instruction *next)
|
||||
|
@@ -156,6 +156,8 @@ void init_triton_frontend(py::module &&m) {
|
||||
// math
|
||||
m.def("exp", &ir::dispatch::exp, ret::reference);
|
||||
m.def("log", &ir::dispatch::log, ret::reference);
|
||||
m.def("cos", &ir::dispatch::cos, ret::reference);
|
||||
m.def("sin", &ir::dispatch::sin, ret::reference);
|
||||
m.def("sqrt", &ir::dispatch::sqrt, ret::reference);
|
||||
// internal (debugging only)
|
||||
m.def("multiple_of", &ir::dispatch::multiple_of, ret::reference);
|
||||
|
@@ -35,7 +35,7 @@ def patch_kernel(template, to_replace):
|
||||
|
||||
|
||||
# generic test functions
|
||||
def _test_unary(dtype_x, expr, device='cuda'):
|
||||
def _test_unary(dtype_x, expr, torch_expr=None, device='cuda'):
|
||||
SIZE = 128
|
||||
# define the kernel / launch-grid
|
||||
@triton.jit
|
||||
@@ -48,8 +48,9 @@ def _test_unary(dtype_x, expr, device='cuda'):
|
||||
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': expr})
|
||||
# inputs
|
||||
x = triton.testing.random(SIZE, dtype=cvt[dtype_x], device=device)
|
||||
if 'log' in expr: x = torch.abs(x) + 0.01
|
||||
# reference result
|
||||
z_ref = eval(expr)
|
||||
z_ref = eval(expr if torch_expr is None else torch_expr)
|
||||
# triton result
|
||||
z_tri = torch.empty_like(z_ref)
|
||||
kernel[(1, )](z_tri, x, SIZE=SIZE, num_warps=4)
|
||||
@@ -135,6 +136,19 @@ def test_compare_op(dtype_x, dtype_y, expr, device='cuda'):
|
||||
def test_unary_op(dtype_x, expr, device='cuda'):
|
||||
_test_unary(dtype_x, expr, device=device)
|
||||
|
||||
# ----------------
|
||||
# test math ops
|
||||
# ----------------
|
||||
# @pytest.mark.paramterize("expr", [
|
||||
# 'exp', 'log', 'cos', 'sin'
|
||||
# ])
|
||||
|
||||
@pytest.mark.parametrize("expr", [
|
||||
'exp', 'log', 'cos', 'sin'
|
||||
])
|
||||
def test_math_op(expr, device='cuda'):
|
||||
_test_unary('float32', f'tl.{expr}(x)', f'torch.{expr}(x) ', device=device)
|
||||
|
||||
|
||||
# ----------------
|
||||
# test indexing
|
||||
|
@@ -546,6 +546,28 @@ def log(x, builder=None):
|
||||
|
||||
return frontend.log(x, builder)
|
||||
|
||||
@builtin
|
||||
def cos(x, builder=None):
|
||||
"""
|
||||
Computes the element-wise cosine of :code:`x`
|
||||
|
||||
:param x: the input values
|
||||
:type x: Block
|
||||
"""
|
||||
|
||||
return frontend.cos(x, builder)
|
||||
|
||||
@builtin
|
||||
def sin(x, builder=None):
|
||||
"""
|
||||
Computes the element-wise sine of :code:`x`
|
||||
|
||||
:param x: the input values
|
||||
:type x: Block
|
||||
"""
|
||||
|
||||
return frontend.sin(x, builder)
|
||||
|
||||
|
||||
@builtin
|
||||
def sqrt(x, builder=None):
|
||||
|
Reference in New Issue
Block a user