[LANGUAGE] Added cos/sin (#132)

This commit is contained in:
Philippe Tillet
2021-07-14 17:16:48 -07:00
committed by Philippe Tillet
parent 3169e4355c
commit 2824345065
13 changed files with 135 additions and 2 deletions

View File

@@ -35,7 +35,7 @@ def patch_kernel(template, to_replace):
# generic test functions
def _test_unary(dtype_x, expr, device='cuda'):
def _test_unary(dtype_x, expr, torch_expr=None, device='cuda'):
SIZE = 128
# define the kernel / launch-grid
@triton.jit
@@ -48,8 +48,9 @@ def _test_unary(dtype_x, expr, device='cuda'):
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': expr})
# inputs
x = triton.testing.random(SIZE, dtype=cvt[dtype_x], device=device)
if 'log' in expr: x = torch.abs(x) + 0.01
# reference result
z_ref = eval(expr)
z_ref = eval(expr if torch_expr is None else torch_expr)
# triton result
z_tri = torch.empty_like(z_ref)
kernel[(1, )](z_tri, x, SIZE=SIZE, num_warps=4)
@@ -135,6 +136,19 @@ def test_compare_op(dtype_x, dtype_y, expr, device='cuda'):
def test_unary_op(dtype_x, expr, device='cuda'):
_test_unary(dtype_x, expr, device=device)
# ----------------
# test math ops
# ----------------
# @pytest.mark.paramterize("expr", [
# 'exp', 'log', 'cos', 'sin'
# ])
@pytest.mark.parametrize("expr", [
'exp', 'log', 'cos', 'sin'
])
def test_math_op(expr, device='cuda'):
_test_unary('float32', f'tl.{expr}(x)', f'torch.{expr}(x) ', device=device)
# ----------------
# test indexing