From 2824345065a5e26464c1ba4e62cf04f07b50ff3a Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Wed, 14 Jul 2021 17:16:48 -0700 Subject: [PATCH] [LANGUAGE] Added cos/sin (#132) --- include/triton/codegen/selection/generator.h | 2 ++ include/triton/ir/builder.h | 2 ++ include/triton/ir/dispatch.h | 2 ++ include/triton/ir/enums.h | 2 ++ include/triton/ir/instructions.h | 22 ++++++++++++++++++ include/triton/ir/visitor.h | 4 ++++ lib/codegen/selection/generator.cc | 24 ++++++++++++++++++++ lib/ir/builder.cc | 8 +++++++ lib/ir/dispatch.cc | 8 +++++++ lib/ir/instructions.cc | 21 +++++++++++++++++ python/src/triton.cc | 2 ++ python/test/test_language.py | 18 +++++++++++++-- python/triton/language.py | 22 ++++++++++++++++++ 13 files changed, 135 insertions(+), 2 deletions(-) diff --git a/include/triton/codegen/selection/generator.h b/include/triton/codegen/selection/generator.h index 58908ace1..8850b9f80 100644 --- a/include/triton/codegen/selection/generator.h +++ b/include/triton/codegen/selection/generator.h @@ -152,6 +152,8 @@ public: void visit_broadcast_inst(ir::broadcast_inst*); void visit_downcast_inst(ir::downcast_inst*); void visit_exp_inst(ir::exp_inst*); + void visit_cos_inst(ir::cos_inst*); + void visit_sin_inst(ir::sin_inst*); void visit_log_inst(ir::log_inst*); void visit_get_program_id_inst(ir::get_program_id_inst*); void visit_get_num_programs_inst(ir::get_num_programs_inst*); diff --git a/include/triton/ir/builder.h b/include/triton/ir/builder.h index cd4b1243d..a5aaf0ec9 100644 --- a/include/triton/ir/builder.h +++ b/include/triton/ir/builder.h @@ -145,6 +145,8 @@ public: value *create_atomic_exch(value *ptr, value *val); value *create_atomic_rmw(ir::atomic_rmw_op_t op, value *ptr, value *val, value *msk); value *create_exp(value* arg); + value *create_cos(value* arg); + value *create_sin(value* arg); value *create_log(value* arg); value *create_dot(value *A, value *B, value *C); value *create_trans(value *A, const std::vector &perm = {}); diff --git a/include/triton/ir/dispatch.h b/include/triton/ir/dispatch.h index b7fd85207..9b3aed44a 100644 --- a/include/triton/ir/dispatch.h +++ b/include/triton/ir/dispatch.h @@ -91,6 +91,8 @@ struct dispatch{ // math static ir::value *exp(ir::value *x, ir::builder *builder); static ir::value *log(ir::value *x, ir::builder *builder); + static ir::value *cos(ir::value *x, ir::builder *builder); + static ir::value *sin(ir::value *x, ir::builder *builder); static ir::value *sqrt(ir::value *x, ir::builder *builder); // internal (debug/optimization) diff --git a/include/triton/ir/enums.h b/include/triton/ir/enums.h index 499767026..dc677cc51 100644 --- a/include/triton/ir/enums.h +++ b/include/triton/ir/enums.h @@ -142,6 +142,8 @@ enum value_id_t: unsigned { INST_ATOMIC_RMW, // math INST_EXP, + INST_COS, + INST_SIN, INST_LOG, // array arithmetic INST_TRANS, diff --git a/include/triton/ir/instructions.h b/include/triton/ir/instructions.h index cf1f295dc..35759a5b0 100644 --- a/include/triton/ir/instructions.h +++ b/include/triton/ir/instructions.h @@ -657,6 +657,28 @@ public: static instruction* create(value *val, const std::string &name = "", instruction *next = nullptr); }; +class cos_inst: public builtin_inst { +private: + cos_inst(value *val, const std::string &name = "", instruction *next = nullptr); + std::string repr_impl() const { return "cos"; } + _TRITON_DEFINE_CLONE(cos_inst) + _TRITON_DEFINE_ACCEPT(cos_inst) + +public: + static instruction* create(value *val, const std::string &name = "", instruction *next = nullptr); +}; + +class sin_inst: public builtin_inst { +private: + sin_inst(value *val, const std::string &name = "", instruction *next = nullptr); + std::string repr_impl() const { return "sin"; } + _TRITON_DEFINE_CLONE(sin_inst) + _TRITON_DEFINE_ACCEPT(sin_inst) + +public: + static instruction* create(value *val, const std::string &name = "", instruction *next = nullptr); +}; + class log_inst: public builtin_inst { private: log_inst(value *val, const std::string &name = "", instruction *next = nullptr); diff --git a/include/triton/ir/visitor.h b/include/triton/ir/visitor.h index a42800714..9600cccb1 100644 --- a/include/triton/ir/visitor.h +++ b/include/triton/ir/visitor.h @@ -49,6 +49,8 @@ class broadcast_inst; class downcast_inst; class exp_inst; +class cos_inst; +class sin_inst; class log_inst; class get_program_id_inst; @@ -122,6 +124,8 @@ public: virtual void visit_masked_store_inst(masked_store_inst*) = 0; virtual void visit_exp_inst(exp_inst*) = 0; + virtual void visit_cos_inst(cos_inst*) = 0; + virtual void visit_sin_inst(sin_inst*) = 0; virtual void visit_log_inst(log_inst*) = 0; virtual void visit_reshape_inst(reshape_inst*) = 0; diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index d827928a2..a6d73e030 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -844,6 +844,30 @@ void generator::visit_exp_inst(ir::exp_inst* x){ } } +/** + * \brief Code Generation for `cos` + */ +void generator::visit_cos_inst(ir::cos_inst* x){ + std::vector tys = {f32_ty}; + FunctionType *fn_ty = FunctionType::get(f32_ty, tys, false); + InlineAsm *cos = InlineAsm::get(fn_ty, "cos.approx.f32 $0, $0;", "=f,0", false); + for(auto idx: idxs_.at(x)){ + vals_[x][idx] = call(cos, std::vector{vals_[x->get_operand(0)][idx]}); + } + } + +/** + * \brief Code Generation for `sin` + */ +void generator::visit_sin_inst(ir::sin_inst* x){ + std::vector tys = {f32_ty}; + FunctionType *fn_ty = FunctionType::get(f32_ty, tys, false); + InlineAsm *sin = InlineAsm::get(fn_ty, "sin.approx.f32 $0, $0;", "=f,0", false); + for(auto idx: idxs_.at(x)){ + vals_[x][idx] = call(sin, std::vector{vals_[x->get_operand(0)][idx]}); + } + } + /** * \brief Code Generation for `log` */ diff --git a/lib/ir/builder.cc b/lib/ir/builder.cc index 2bd745884..6173a95cc 100644 --- a/lib/ir/builder.cc +++ b/lib/ir/builder.cc @@ -337,6 +337,14 @@ value *builder::create_exp(value *arg){ return insert(exp_inst::create(arg)); } +value *builder::create_cos(value *arg){ + return insert(cos_inst::create(arg)); +} + +value *builder::create_sin(value *arg){ + return insert(sin_inst::create(arg)); +} + value *builder::create_log(value *arg){ return insert(log_inst::create(arg)); } diff --git a/lib/ir/dispatch.cc b/lib/ir/dispatch.cc index cd7223239..63fcb9723 100644 --- a/lib/ir/dispatch.cc +++ b/lib/ir/dispatch.cc @@ -688,6 +688,14 @@ ir::value *dispatch::log(ir::value *x, ir::builder *builder) { return builder->create_log(x); } +ir::value *dispatch::cos(ir::value *x, ir::builder *builder) { + return builder->create_cos(x); +} + +ir::value *dispatch::sin(ir::value *x, ir::builder *builder) { + return builder->create_sin(x); +} + ir::value *dispatch::sqrt(ir::value *x, ir::builder *builder) { return builder->create_sqrt(x); } diff --git a/lib/ir/instructions.cc b/lib/ir/instructions.cc index ebfb26c12..1159b22ee 100644 --- a/lib/ir/instructions.cc +++ b/lib/ir/instructions.cc @@ -781,6 +781,27 @@ instruction* exp_inst::create(value *val, const std::string& name, instruction * return new exp_inst(val, name, next); } +// cos +cos_inst::cos_inst(value *val, const std::string &name, instruction *next) + : builtin_inst(val->get_type(), INST_COS, 1, name, next) { + set_operand(0, val); +} + +instruction* cos_inst::create(value *val, const std::string& name, instruction *next) { + return new cos_inst(val, name, next); +} + +// sin +sin_inst::sin_inst(value *val, const std::string &name, instruction *next) + : builtin_inst(val->get_type(), INST_SIN, 1, name, next) { + set_operand(0, val); +} + +instruction* sin_inst::create(value *val, const std::string& name, instruction *next) { + return new sin_inst(val, name, next); +} + + // log log_inst::log_inst(value *val, const std::string &name, instruction *next) diff --git a/python/src/triton.cc b/python/src/triton.cc index 591e4bf79..6792d7c9c 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -156,6 +156,8 @@ void init_triton_frontend(py::module &&m) { // math m.def("exp", &ir::dispatch::exp, ret::reference); m.def("log", &ir::dispatch::log, ret::reference); + m.def("cos", &ir::dispatch::cos, ret::reference); + m.def("sin", &ir::dispatch::sin, ret::reference); m.def("sqrt", &ir::dispatch::sqrt, ret::reference); // internal (debugging only) m.def("multiple_of", &ir::dispatch::multiple_of, ret::reference); diff --git a/python/test/test_language.py b/python/test/test_language.py index 3757cdf22..30ff75bfb 100644 --- a/python/test/test_language.py +++ b/python/test/test_language.py @@ -35,7 +35,7 @@ def patch_kernel(template, to_replace): # generic test functions -def _test_unary(dtype_x, expr, device='cuda'): +def _test_unary(dtype_x, expr, torch_expr=None, device='cuda'): SIZE = 128 # define the kernel / launch-grid @triton.jit @@ -48,8 +48,9 @@ def _test_unary(dtype_x, expr, device='cuda'): kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': expr}) # inputs x = triton.testing.random(SIZE, dtype=cvt[dtype_x], device=device) + if 'log' in expr: x = torch.abs(x) + 0.01 # reference result - z_ref = eval(expr) + z_ref = eval(expr if torch_expr is None else torch_expr) # triton result z_tri = torch.empty_like(z_ref) kernel[(1, )](z_tri, x, SIZE=SIZE, num_warps=4) @@ -135,6 +136,19 @@ def test_compare_op(dtype_x, dtype_y, expr, device='cuda'): def test_unary_op(dtype_x, expr, device='cuda'): _test_unary(dtype_x, expr, device=device) +# ---------------- +# test math ops +# ---------------- +# @pytest.mark.paramterize("expr", [ +# 'exp', 'log', 'cos', 'sin' +# ]) + +@pytest.mark.parametrize("expr", [ + 'exp', 'log', 'cos', 'sin' +]) +def test_math_op(expr, device='cuda'): + _test_unary('float32', f'tl.{expr}(x)', f'torch.{expr}(x) ', device=device) + # ---------------- # test indexing diff --git a/python/triton/language.py b/python/triton/language.py index a8e36b9aa..cea27d9ff 100644 --- a/python/triton/language.py +++ b/python/triton/language.py @@ -546,6 +546,28 @@ def log(x, builder=None): return frontend.log(x, builder) +@builtin +def cos(x, builder=None): + """ + Computes the element-wise cosine of :code:`x` + + :param x: the input values + :type x: Block + """ + + return frontend.cos(x, builder) + +@builtin +def sin(x, builder=None): + """ + Computes the element-wise sine of :code:`x` + + :param x: the input values + :type x: Block + """ + + return frontend.sin(x, builder) + @builtin def sqrt(x, builder=None):