202 lines
6.5 KiB
Python
202 lines
6.5 KiB
Python
import tempfile
|
|
from inspect import Parameter, Signature
|
|
|
|
import _testcapi
|
|
import pytest
|
|
import torch
|
|
from torch.testing import assert_close
|
|
from tests.libdevice_testutil import system_libdevice_path
|
|
|
|
import triton
|
|
import triton.language as tl
|
|
|
|
torch_type = {
|
|
"bool": torch.bool,
|
|
"int32": torch.int32,
|
|
"float32": torch.float32,
|
|
"float64": torch.float64
|
|
}
|
|
|
|
torch_ops = {
|
|
"log": "log",
|
|
"cos": "cos",
|
|
"sin": "sin",
|
|
"sqrt": "sqrt",
|
|
"abs": "abs",
|
|
"exp": "exp",
|
|
"sigmoid": "sigmoid",
|
|
"umulhi": None,
|
|
"cdiv": None,
|
|
"fdiv": "div",
|
|
"minimum": "minimum",
|
|
"maximum": "maximum",
|
|
"where": "where",
|
|
}
|
|
|
|
|
|
def get_tensor(shape, data_type, b_positive=False):
|
|
x = None
|
|
if data_type.startswith('int'):
|
|
x = torch.randint(2**31 - 1, shape, dtype=torch_type[data_type], device='cuda')
|
|
elif data_type.startswith('bool'):
|
|
x = torch.randint(1, shape, dtype=torch_type[data_type], device='cuda')
|
|
else:
|
|
x = torch.randn(shape, dtype=torch_type[data_type], device='cuda')
|
|
|
|
if b_positive:
|
|
x = torch.abs(x)
|
|
|
|
return x
|
|
|
|
|
|
@pytest.mark.parametrize('expr, output_type, input0_type',
|
|
[('log', 'float32', 'float32'),
|
|
('log', 'float64', 'float64'),
|
|
('cos', 'float32', 'float32'),
|
|
('cos', 'float64', 'float64'),
|
|
('sin', 'float32', 'float32'),
|
|
('sin', 'float64', 'float64'),
|
|
('sqrt', 'float32', 'float32'),
|
|
('sqrt', 'float64', 'float64'),
|
|
('abs', 'float32', 'float32'),
|
|
('exp', 'float32', 'float32'),
|
|
('exp', 'float64', 'float64'),
|
|
('sigmoid', 'float32', 'float32'),
|
|
])
|
|
def test_single_input(expr, output_type, input0_type):
|
|
src = f"""
|
|
def kernel(X, Y, BLOCK: tl.constexpr):
|
|
x = tl.load(X + tl.arange(0, BLOCK))
|
|
y = tl.{expr}(x)
|
|
tl.store(Y + tl.arange(0, BLOCK), y)
|
|
"""
|
|
fp = tempfile.NamedTemporaryFile(mode='w', suffix=".py")
|
|
fp.write(src)
|
|
fp.flush()
|
|
|
|
def kernel(X, Y, BLOCK: tl.constexpr):
|
|
pass
|
|
kernel.__code__ = _testcapi.code_newempty(fp.name, "kernel", 1)
|
|
parameters = []
|
|
parameters.append(Parameter("X", 1))
|
|
parameters.append(Parameter("Y", 1))
|
|
parameters.append(Parameter("BLOCK", 1))
|
|
kernel.__signature__ = Signature(parameters=parameters)
|
|
kernel = triton.jit(kernel)
|
|
|
|
shape = (128, )
|
|
# limit the range of integers so that the sum does not overflow
|
|
x = get_tensor(shape, input0_type, expr == 'log' or expr == 'sqrt')
|
|
# triton result
|
|
y = torch.zeros(shape, dtype=torch_type[output_type], device="cuda")
|
|
kernel[(1,)](
|
|
x, y,
|
|
BLOCK=shape[0],
|
|
extern_libs={"libdevice": system_libdevice_path()},
|
|
)
|
|
# reference result
|
|
y_ref = getattr(torch, torch_ops[expr])(x)
|
|
# compare
|
|
assert_close(y, y_ref)
|
|
|
|
|
|
@pytest.mark.parametrize('expr, output_type, input0_type, input1_type',
|
|
[('umulhi', 'int32', 'int32', 'int32'),
|
|
('cdiv', 'int32', 'int32', 'int32'),
|
|
('fdiv', 'float32', 'float32', 'float32'),
|
|
('minimum', 'float32', 'float32', 'float32'),
|
|
('maximum', 'float32', 'float32', 'float32'),
|
|
])
|
|
def test_two_input(expr, output_type, input0_type, input1_type):
|
|
src = f"""
|
|
def kernel(X0, X1, Y, BLOCK: tl.constexpr):
|
|
x0 = tl.load(X0 + tl.arange(0, BLOCK))
|
|
x1 = tl.load(X1 + tl.arange(0, BLOCK))
|
|
y = tl.{expr}(x0, x1)
|
|
tl.store(Y + tl.arange(0, BLOCK), y)
|
|
"""
|
|
fp = tempfile.NamedTemporaryFile(mode='w', suffix=".py")
|
|
fp.write(src)
|
|
fp.flush()
|
|
|
|
def kernel(X0, X1, Y, BLOCK: tl.constexpr):
|
|
pass
|
|
kernel.__code__ = _testcapi.code_newempty(fp.name, "kernel", 1)
|
|
parameters = []
|
|
parameters.append(Parameter("X0", 1))
|
|
parameters.append(Parameter("X1", 1))
|
|
parameters.append(Parameter("Y", 1))
|
|
parameters.append(Parameter("BLOCK", 1))
|
|
kernel.__signature__ = Signature(parameters=parameters)
|
|
kernel = triton.jit(kernel)
|
|
|
|
shape = (128, )
|
|
# limit the range of integers so that the sum does not overflow
|
|
x0 = get_tensor(shape, input0_type)
|
|
x1 = get_tensor(shape, input1_type)
|
|
|
|
# triton result
|
|
y = torch.zeros(shape, dtype=torch_type[output_type], device="cuda")
|
|
kernel[(1,)](
|
|
x0, x1, y,
|
|
BLOCK=shape[0],
|
|
extern_libs={"libdevice": system_libdevice_path()},
|
|
)
|
|
# reference result
|
|
|
|
if expr == "cdiv":
|
|
y_ref = torch.div(x0 + x1 - 1, x1, rounding_mode='trunc')
|
|
elif expr == "umulhi":
|
|
y_ref = ((x0.to(torch.int64) * x1) >> 32).to(torch.int32)
|
|
else:
|
|
y_ref = getattr(torch, torch_ops[expr])(x0, x1)
|
|
# compare
|
|
assert_close(y, y_ref)
|
|
|
|
|
|
@pytest.mark.parametrize('expr, output_type, input0_type, input1_type, input2_type',
|
|
[('where', "int32", "bool", "int32", "int32"), ])
|
|
def test_three_input(expr, output_type, input0_type, input1_type, input2_type):
|
|
src = f"""
|
|
def kernel(X0, X1, X2, Y, BLOCK: tl.constexpr):
|
|
x0 = tl.load(X0 + tl.arange(0, BLOCK))
|
|
x1 = tl.load(X1 + tl.arange(0, BLOCK))
|
|
x2 = tl.load(X2 + tl.arange(0, BLOCK))
|
|
y = tl.{expr}(x0, x1, x2)
|
|
tl.store(Y + tl.arange(0, BLOCK), y)
|
|
"""
|
|
fp = tempfile.NamedTemporaryFile(mode='w', suffix=".py")
|
|
fp.write(src)
|
|
fp.flush()
|
|
|
|
def kernel(X0, X1, X2, Y, BLOCK: tl.constexpr):
|
|
pass
|
|
kernel.__code__ = _testcapi.code_newempty(fp.name, "kernel", 1)
|
|
parameters = []
|
|
parameters.append(Parameter("X0", 1))
|
|
parameters.append(Parameter("X1", 1))
|
|
parameters.append(Parameter("X2", 1))
|
|
parameters.append(Parameter("Y", 1))
|
|
parameters.append(Parameter("BLOCK", 1))
|
|
kernel.__signature__ = Signature(parameters=parameters)
|
|
kernel = triton.jit(kernel)
|
|
|
|
shape = (128, )
|
|
# limit the range of integers so that the sum does not overflow
|
|
x0 = get_tensor(shape, input0_type)
|
|
x1 = get_tensor(shape, input1_type)
|
|
x2 = get_tensor(shape, input1_type)
|
|
|
|
# triton result
|
|
y = torch.zeros(shape, dtype=torch_type[output_type], device="cuda")
|
|
kernel[(1,)](
|
|
x0, x1, x2, y,
|
|
BLOCK=shape[0],
|
|
extern_libs={"libdevice": system_libdevice_path()},
|
|
)
|
|
# reference result
|
|
|
|
y_ref = getattr(torch, torch_ops[expr])(x0, x1, x2)
|
|
# compare
|
|
assert_close(y, y_ref)
|