[Triton-MLIR][BACKEND] Add elementwise ops and tests (#804)

Co-authored-by: Keren Zhou <kerenzhou@openai.com>
This commit is contained in:
ben-zhang-609
2022-10-28 13:26:29 +08:00
committed by GitHub
parent 3b80801dff
commit 3685194456
9 changed files with 616 additions and 60 deletions

View File

@@ -0,0 +1,189 @@
import tempfile
from inspect import Parameter, Signature
import _testcapi
import pytest
import torch
from torch.testing import assert_close
import triton
import triton.language as tl
torch_type = {
"bool": torch.bool,
"int32": torch.int32,
"float32": torch.float32,
"float64": torch.float64
}
torch_ops = {
"log": "log",
"cos": "cos",
"sin": "sin",
"sqrt": "sqrt",
"abs": "abs",
"exp": "exp",
"sigmoid": "sigmoid",
"umulhi": None,
"cdiv": None,
"fdiv": "div",
"minimum": "minimum",
"maximum": "maximum",
"where": "where",
}
libdevice = '/usr/local/cuda/nvvm/libdevice/libdevice.10.bc'
def get_tensor(shape, data_type, b_positive=False):
x = None
if data_type.startswith('int'):
x = torch.randint(2**31 - 1, shape, dtype=torch_type[data_type], device='cuda')
elif data_type.startswith('bool'):
x = torch.randint(1, shape, dtype=torch_type[data_type], device='cuda')
else:
x = torch.randn(shape, dtype=torch_type[data_type], device='cuda')
if b_positive:
x = torch.abs(x)
return x
@pytest.mark.parametrize('expr, output_type, input0_type',
[('log', 'float32', 'float32'),
('log', 'float64', 'float64'),
('cos', 'float32', 'float32'),
('cos', 'float64', 'float64'),
('sin', 'float32', 'float32'),
('sin', 'float64', 'float64'),
('sqrt', 'float32', 'float32'),
('sqrt', 'float64', 'float64'),
('abs', 'float32', 'float32'),
('exp', 'float32', 'float32'),
('sigmoid', 'float32', 'float32'),
])
def test_single_input(expr, output_type, input0_type):
src = f"""
def kernel(X, Y, BLOCK: tl.constexpr):
x = tl.load(X + tl.arange(0, BLOCK))
y = tl.{expr}(x)
tl.store(Y + tl.arange(0, BLOCK), y)
"""
fp = tempfile.NamedTemporaryFile(mode='w', suffix=".py")
fp.write(src)
fp.flush()
def kernel(X, Y, BLOCK: tl.constexpr):
pass
kernel.__code__ = _testcapi.code_newempty(fp.name, "kernel", 1)
parameters = []
parameters.append(Parameter("X", 1))
parameters.append(Parameter("Y", 1))
parameters.append(Parameter("BLOCK", 1))
kernel.__signature__ = Signature(parameters=parameters)
kernel = triton.jit(kernel)
shape = (128, )
# limit the range of integers so that the sum does not overflow
x = get_tensor(shape, input0_type, expr == 'log' or expr == 'sqrt')
# triton result
y = torch.zeros(shape, dtype=torch_type[output_type], device="cuda")
kernel[(1,)](x, y, BLOCK=shape[0], extern_libs={"libdevice": libdevice})
# reference result
y_ref = getattr(torch, torch_ops[expr])(x)
# compare
assert_close(y, y_ref)
@pytest.mark.parametrize('expr, output_type, input0_type, input1_type',
[('umulhi', 'int32', 'int32', 'int32'),
('cdiv', 'int32', 'int32', 'int32'),
('fdiv', 'float32', 'float32', 'float32'),
('minimum', 'float32', 'float32', 'float32'),
('maximum', 'float32', 'float32', 'float32'),
])
def test_two_input(expr, output_type, input0_type, input1_type):
src = f"""
def kernel(X0, X1, Y, BLOCK: tl.constexpr):
x0 = tl.load(X0 + tl.arange(0, BLOCK))
x1 = tl.load(X1 + tl.arange(0, BLOCK))
y = tl.{expr}(x0, x1)
tl.store(Y + tl.arange(0, BLOCK), y)
"""
fp = tempfile.NamedTemporaryFile(mode='w', suffix=".py")
fp.write(src)
fp.flush()
def kernel(X0, X1, Y, BLOCK: tl.constexpr):
pass
kernel.__code__ = _testcapi.code_newempty(fp.name, "kernel", 1)
parameters = []
parameters.append(Parameter("X0", 1))
parameters.append(Parameter("X1", 1))
parameters.append(Parameter("Y", 1))
parameters.append(Parameter("BLOCK", 1))
kernel.__signature__ = Signature(parameters=parameters)
kernel = triton.jit(kernel)
shape = (128, )
# limit the range of integers so that the sum does not overflow
x0 = get_tensor(shape, input0_type)
x1 = get_tensor(shape, input1_type)
# triton result
y = torch.zeros(shape, dtype=torch_type[output_type], device="cuda")
kernel[(1,)](x0, x1, y, BLOCK=shape[0], extern_libs={"libdevice": libdevice})
# reference result
if expr == "cdiv":
y_ref = (x0 + x1 - 1) // x1
elif expr == "umulhi":
y_ref = ((x0.to(torch.int64) * x1) >> 32).to(torch.int32)
else:
y_ref = getattr(torch, torch_ops[expr])(x0, x1)
# compare
assert_close(y, y_ref)
@pytest.mark.parametrize('expr, output_type, input0_type, input1_type, input2_type',
[('where', "int32", "bool", "int32", "int32"), ])
def test_three_input(expr, output_type, input0_type, input1_type, input2_type):
src = f"""
def kernel(X0, X1, X2, Y, BLOCK: tl.constexpr):
x0 = tl.load(X0 + tl.arange(0, BLOCK))
x1 = tl.load(X1 + tl.arange(0, BLOCK))
x2 = tl.load(X2 + tl.arange(0, BLOCK))
y = tl.{expr}(x0, x1, x2)
tl.store(Y + tl.arange(0, BLOCK), y)
"""
fp = tempfile.NamedTemporaryFile(mode='w', suffix=".py")
fp.write(src)
fp.flush()
def kernel(X0, X1, X2, Y, BLOCK: tl.constexpr):
pass
kernel.__code__ = _testcapi.code_newempty(fp.name, "kernel", 1)
parameters = []
parameters.append(Parameter("X0", 1))
parameters.append(Parameter("X1", 1))
parameters.append(Parameter("X2", 1))
parameters.append(Parameter("Y", 1))
parameters.append(Parameter("BLOCK", 1))
kernel.__signature__ = Signature(parameters=parameters)
kernel = triton.jit(kernel)
shape = (128, )
# limit the range of integers so that the sum does not overflow
x0 = get_tensor(shape, input0_type)
x1 = get_tensor(shape, input1_type)
x2 = get_tensor(shape, input1_type)
# triton result
y = torch.zeros(shape, dtype=torch_type[output_type], device="cuda")
kernel[(1,)](x0, x1, x2, y, BLOCK=shape[0], extern_libs={"libdevice": libdevice})
# reference result
y_ref = getattr(torch, torch_ops[expr])(x0, x1, x2)
# compare
assert_close(y, y_ref)

View File

@@ -0,0 +1,178 @@
import pytest
import torch
from torch.testing import assert_close
import triton
import triton.language as tl
@pytest.mark.parametrize('num_warps, block_size, iter_size', [
[4, 256, 1],
[4, 1024, 256],
])
def test_sin_no_mask(num_warps, block_size, iter_size):
@triton.jit
def kernel(x_ptr,
y_ptr,
block_size,
iter_size: tl.constexpr):
pid = tl.program_id(axis=0)
for i in range(0, block_size, iter_size):
offset = pid * block_size + tl.arange(0, iter_size)
x_ptrs = x_ptr + offset
x = tl.load(x_ptrs)
y = tl.libdevice.sin(x)
y_ptrs = y_ptr + offset
tl.store(y_ptrs, y)
x_ptr += iter_size
y_ptr += iter_size
x = torch.randn((block_size,), device='cuda', dtype=torch.float32)
y = torch.empty((block_size,), device=x.device, dtype=x.dtype)
grid = lambda EA: (x.shape.numel() // (block_size),)
kernel[grid](x_ptr=x, y_ptr=y,
block_size=x.shape[0], iter_size=iter_size, num_warps=num_warps)
golden_y = torch.sin(x)
assert_close(y, golden_y, rtol=1e-7, atol=1e-7)
@pytest.mark.parametrize('num_warps, block_size, iter_size', [
[4, 256, 1],
[4, 1024, 256],
])
def test_fmin_no_mask(num_warps, block_size, iter_size):
@triton.jit
def kernel(x_ptr,
y_ptr,
z_ptr,
block_size,
iter_size: tl.constexpr):
pid = tl.program_id(axis=0)
for i in range(0, block_size, iter_size):
offset = pid * block_size + tl.arange(0, iter_size)
x_ptrs = x_ptr + offset
y_ptrs = y_ptr + offset
x = tl.load(x_ptrs)
y = tl.load(y_ptrs)
z = tl.libdevice.min(x, y)
z_ptrs = z_ptr + offset
tl.store(z_ptrs, z)
x_ptr += iter_size
y_ptr += iter_size
z_ptr += iter_size
x = torch.randn((block_size,), device='cuda', dtype=torch.float32)
y = torch.randn((block_size,), device='cuda', dtype=torch.float32)
z = torch.empty((block_size,), device=x.device, dtype=x.dtype)
grid = lambda EA: (x.shape.numel() // (block_size),)
kernel[grid](x_ptr=x, y_ptr=y, z_ptr=z,
block_size=x.shape[0], iter_size=iter_size, num_warps=num_warps)
golden_z = torch.minimum(x, y)
assert_close(z, golden_z, rtol=1e-7, atol=1e-7)
@pytest.mark.parametrize('num_warps, block_size, iter_size', [
[4, 256, 1],
[4, 1024, 256],
])
def test_fmad_rn_no_mask(num_warps, block_size, iter_size):
@triton.jit
def kernel(x_ptr,
y_ptr,
z_ptr,
w_ptr,
block_size,
iter_size: tl.constexpr):
pid = tl.program_id(axis=0)
for i in range(0, block_size, iter_size):
offset = pid * block_size + tl.arange(0, iter_size)
x_ptrs = x_ptr + offset
y_ptrs = y_ptr + offset
z_ptrs = z_ptr + offset
x = tl.load(x_ptrs)
y = tl.load(y_ptrs)
z = tl.load(z_ptrs)
w = tl.libdevice.fma_rn(x, y, z)
w_ptrs = w_ptr + offset
tl.store(w_ptrs, w)
x_ptr += iter_size
y_ptr += iter_size
z_ptr += iter_size
w_ptr += iter_size
x = torch.randn((block_size,), device='cuda', dtype=torch.float64)
y = torch.randn((block_size,), device='cuda', dtype=torch.float64)
z = torch.randn((block_size,), device='cuda', dtype=torch.float64)
w = torch.empty((block_size,), device=x.device, dtype=x.dtype)
grid = lambda EA: (x.shape.numel() // (block_size),)
kernel[grid](x_ptr=x, y_ptr=y, z_ptr=z, w_ptr=w,
block_size=x.shape[0], iter_size=iter_size, num_warps=num_warps)
golden_w = x * y + z
assert_close(w, golden_w, rtol=1e-7, atol=1e-7)
@pytest.mark.parametrize("dtype_str, expr, lib_path",
[('int32', 'libdevice.ffs', '/usr/local/cuda/nvvm/libdevice/libdevice.10.bc'),
('int32', 'libdevice.ffs', '')])
def test_libdevice(dtype_str, expr, lib_path):
src = f"""
def kernel(X, Y, BLOCK: tl.constexpr):
x = tl.load(X + tl.arange(0, BLOCK))
y = tl.{expr}(x)
tl.store(Y + tl.arange(0, BLOCK), y)
"""
import tempfile
from inspect import Parameter, Signature
import _testcapi
fp = tempfile.NamedTemporaryFile(mode='w', suffix=".py")
fp.write(src)
fp.flush()
def kernel(X, Y, BLOCK: tl.constexpr):
pass
kernel.__code__ = _testcapi.code_newempty(fp.name, "kernel", 1)
parameters = []
parameters.append(Parameter("X", 1))
parameters.append(Parameter("Y", 1))
parameters.append(Parameter("BLOCK", 1))
kernel.__signature__ = Signature(parameters=parameters)
kernel = triton.jit(kernel)
torch_type = {
"int32": torch.int32,
"float32": torch.float32,
"float64": torch.float64
}
shape = (128, )
# limit the range of integers so that the sum does not overflow
x = None
if dtype_str == "int32":
x = torch.randint(2**31 - 1, shape, dtype=torch_type[dtype_str], device="cuda")
else:
x = torch.randn(shape, dtype=torch_type[dtype_str], device="cuda")
if expr == 'libdevice.ffs':
y_ref = torch.zeros(shape, dtype=x.dtype, device="cuda")
for i in range(shape[0]):
y_ref[i] = (int(x[i]) & int(-x[i])).bit_length()
# triton result
y = torch.zeros(shape, dtype=x.dtype, device="cuda")
kernel[(1,)](x, y, BLOCK=shape[0], extern_libs={"libdevice": lib_path})
# compare
assert_close(y, y_ref)