[LANGUAGE] Added cos/sin (#132)
This commit is contained in:
committed by
Philippe Tillet
parent
3169e4355c
commit
2824345065
@@ -156,6 +156,8 @@ void init_triton_frontend(py::module &&m) {
|
||||
// math
|
||||
m.def("exp", &ir::dispatch::exp, ret::reference);
|
||||
m.def("log", &ir::dispatch::log, ret::reference);
|
||||
m.def("cos", &ir::dispatch::cos, ret::reference);
|
||||
m.def("sin", &ir::dispatch::sin, ret::reference);
|
||||
m.def("sqrt", &ir::dispatch::sqrt, ret::reference);
|
||||
// internal (debugging only)
|
||||
m.def("multiple_of", &ir::dispatch::multiple_of, ret::reference);
|
||||
|
@@ -35,7 +35,7 @@ def patch_kernel(template, to_replace):
|
||||
|
||||
|
||||
# generic test functions
|
||||
def _test_unary(dtype_x, expr, device='cuda'):
|
||||
def _test_unary(dtype_x, expr, torch_expr=None, device='cuda'):
|
||||
SIZE = 128
|
||||
# define the kernel / launch-grid
|
||||
@triton.jit
|
||||
@@ -48,8 +48,9 @@ def _test_unary(dtype_x, expr, device='cuda'):
|
||||
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': expr})
|
||||
# inputs
|
||||
x = triton.testing.random(SIZE, dtype=cvt[dtype_x], device=device)
|
||||
if 'log' in expr: x = torch.abs(x) + 0.01
|
||||
# reference result
|
||||
z_ref = eval(expr)
|
||||
z_ref = eval(expr if torch_expr is None else torch_expr)
|
||||
# triton result
|
||||
z_tri = torch.empty_like(z_ref)
|
||||
kernel[(1, )](z_tri, x, SIZE=SIZE, num_warps=4)
|
||||
@@ -135,6 +136,19 @@ def test_compare_op(dtype_x, dtype_y, expr, device='cuda'):
|
||||
def test_unary_op(dtype_x, expr, device='cuda'):
|
||||
_test_unary(dtype_x, expr, device=device)
|
||||
|
||||
# ----------------
|
||||
# test math ops
|
||||
# ----------------
|
||||
# @pytest.mark.paramterize("expr", [
|
||||
# 'exp', 'log', 'cos', 'sin'
|
||||
# ])
|
||||
|
||||
@pytest.mark.parametrize("expr", [
|
||||
'exp', 'log', 'cos', 'sin'
|
||||
])
|
||||
def test_math_op(expr, device='cuda'):
|
||||
_test_unary('float32', f'tl.{expr}(x)', f'torch.{expr}(x) ', device=device)
|
||||
|
||||
|
||||
# ----------------
|
||||
# test indexing
|
||||
|
@@ -546,6 +546,28 @@ def log(x, builder=None):
|
||||
|
||||
return frontend.log(x, builder)
|
||||
|
||||
@builtin
|
||||
def cos(x, builder=None):
|
||||
"""
|
||||
Computes the element-wise cosine of :code:`x`
|
||||
|
||||
:param x: the input values
|
||||
:type x: Block
|
||||
"""
|
||||
|
||||
return frontend.cos(x, builder)
|
||||
|
||||
@builtin
|
||||
def sin(x, builder=None):
|
||||
"""
|
||||
Computes the element-wise sine of :code:`x`
|
||||
|
||||
:param x: the input values
|
||||
:type x: Block
|
||||
"""
|
||||
|
||||
return frontend.sin(x, builder)
|
||||
|
||||
|
||||
@builtin
|
||||
def sqrt(x, builder=None):
|
||||
|
Reference in New Issue
Block a user