[Triton-MLIR][BACKEND] Add elementwise ops and tests (#804)
Co-authored-by: Keren Zhou <kerenzhou@openai.com>
This commit is contained in:
@@ -1335,6 +1335,12 @@ void init_triton_translation(py::module &m) {
|
||||
py::bytes bytes(cubin);
|
||||
return bytes;
|
||||
});
|
||||
|
||||
m.def("add_external_libs",
|
||||
[](mlir::ModuleOp &op, const std::vector<std::string> &names,
|
||||
const std::vector<std::string> &paths) {
|
||||
::mlir::triton::addExternalLibs(op, names, paths);
|
||||
});
|
||||
}
|
||||
|
||||
void init_triton(py::module &m) {
|
||||
|
189
python/tests/test_elementwise.py
Normal file
189
python/tests/test_elementwise.py
Normal file
@@ -0,0 +1,189 @@
|
||||
import tempfile
|
||||
from inspect import Parameter, Signature
|
||||
|
||||
import _testcapi
|
||||
import pytest
|
||||
import torch
|
||||
from torch.testing import assert_close
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
torch_type = {
|
||||
"bool": torch.bool,
|
||||
"int32": torch.int32,
|
||||
"float32": torch.float32,
|
||||
"float64": torch.float64
|
||||
}
|
||||
|
||||
torch_ops = {
|
||||
"log": "log",
|
||||
"cos": "cos",
|
||||
"sin": "sin",
|
||||
"sqrt": "sqrt",
|
||||
"abs": "abs",
|
||||
"exp": "exp",
|
||||
"sigmoid": "sigmoid",
|
||||
"umulhi": None,
|
||||
"cdiv": None,
|
||||
"fdiv": "div",
|
||||
"minimum": "minimum",
|
||||
"maximum": "maximum",
|
||||
"where": "where",
|
||||
}
|
||||
|
||||
libdevice = '/usr/local/cuda/nvvm/libdevice/libdevice.10.bc'
|
||||
|
||||
|
||||
def get_tensor(shape, data_type, b_positive=False):
|
||||
x = None
|
||||
if data_type.startswith('int'):
|
||||
x = torch.randint(2**31 - 1, shape, dtype=torch_type[data_type], device='cuda')
|
||||
elif data_type.startswith('bool'):
|
||||
x = torch.randint(1, shape, dtype=torch_type[data_type], device='cuda')
|
||||
else:
|
||||
x = torch.randn(shape, dtype=torch_type[data_type], device='cuda')
|
||||
|
||||
if b_positive:
|
||||
x = torch.abs(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
@pytest.mark.parametrize('expr, output_type, input0_type',
|
||||
[('log', 'float32', 'float32'),
|
||||
('log', 'float64', 'float64'),
|
||||
('cos', 'float32', 'float32'),
|
||||
('cos', 'float64', 'float64'),
|
||||
('sin', 'float32', 'float32'),
|
||||
('sin', 'float64', 'float64'),
|
||||
('sqrt', 'float32', 'float32'),
|
||||
('sqrt', 'float64', 'float64'),
|
||||
('abs', 'float32', 'float32'),
|
||||
('exp', 'float32', 'float32'),
|
||||
('sigmoid', 'float32', 'float32'),
|
||||
])
|
||||
def test_single_input(expr, output_type, input0_type):
|
||||
src = f"""
|
||||
def kernel(X, Y, BLOCK: tl.constexpr):
|
||||
x = tl.load(X + tl.arange(0, BLOCK))
|
||||
y = tl.{expr}(x)
|
||||
tl.store(Y + tl.arange(0, BLOCK), y)
|
||||
"""
|
||||
fp = tempfile.NamedTemporaryFile(mode='w', suffix=".py")
|
||||
fp.write(src)
|
||||
fp.flush()
|
||||
|
||||
def kernel(X, Y, BLOCK: tl.constexpr):
|
||||
pass
|
||||
kernel.__code__ = _testcapi.code_newempty(fp.name, "kernel", 1)
|
||||
parameters = []
|
||||
parameters.append(Parameter("X", 1))
|
||||
parameters.append(Parameter("Y", 1))
|
||||
parameters.append(Parameter("BLOCK", 1))
|
||||
kernel.__signature__ = Signature(parameters=parameters)
|
||||
kernel = triton.jit(kernel)
|
||||
|
||||
shape = (128, )
|
||||
# limit the range of integers so that the sum does not overflow
|
||||
x = get_tensor(shape, input0_type, expr == 'log' or expr == 'sqrt')
|
||||
# triton result
|
||||
y = torch.zeros(shape, dtype=torch_type[output_type], device="cuda")
|
||||
kernel[(1,)](x, y, BLOCK=shape[0], extern_libs={"libdevice": libdevice})
|
||||
# reference result
|
||||
y_ref = getattr(torch, torch_ops[expr])(x)
|
||||
# compare
|
||||
assert_close(y, y_ref)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('expr, output_type, input0_type, input1_type',
|
||||
[('umulhi', 'int32', 'int32', 'int32'),
|
||||
('cdiv', 'int32', 'int32', 'int32'),
|
||||
('fdiv', 'float32', 'float32', 'float32'),
|
||||
('minimum', 'float32', 'float32', 'float32'),
|
||||
('maximum', 'float32', 'float32', 'float32'),
|
||||
])
|
||||
def test_two_input(expr, output_type, input0_type, input1_type):
|
||||
src = f"""
|
||||
def kernel(X0, X1, Y, BLOCK: tl.constexpr):
|
||||
x0 = tl.load(X0 + tl.arange(0, BLOCK))
|
||||
x1 = tl.load(X1 + tl.arange(0, BLOCK))
|
||||
y = tl.{expr}(x0, x1)
|
||||
tl.store(Y + tl.arange(0, BLOCK), y)
|
||||
"""
|
||||
fp = tempfile.NamedTemporaryFile(mode='w', suffix=".py")
|
||||
fp.write(src)
|
||||
fp.flush()
|
||||
|
||||
def kernel(X0, X1, Y, BLOCK: tl.constexpr):
|
||||
pass
|
||||
kernel.__code__ = _testcapi.code_newempty(fp.name, "kernel", 1)
|
||||
parameters = []
|
||||
parameters.append(Parameter("X0", 1))
|
||||
parameters.append(Parameter("X1", 1))
|
||||
parameters.append(Parameter("Y", 1))
|
||||
parameters.append(Parameter("BLOCK", 1))
|
||||
kernel.__signature__ = Signature(parameters=parameters)
|
||||
kernel = triton.jit(kernel)
|
||||
|
||||
shape = (128, )
|
||||
# limit the range of integers so that the sum does not overflow
|
||||
x0 = get_tensor(shape, input0_type)
|
||||
x1 = get_tensor(shape, input1_type)
|
||||
|
||||
# triton result
|
||||
y = torch.zeros(shape, dtype=torch_type[output_type], device="cuda")
|
||||
kernel[(1,)](x0, x1, y, BLOCK=shape[0], extern_libs={"libdevice": libdevice})
|
||||
# reference result
|
||||
|
||||
if expr == "cdiv":
|
||||
y_ref = (x0 + x1 - 1) // x1
|
||||
elif expr == "umulhi":
|
||||
y_ref = ((x0.to(torch.int64) * x1) >> 32).to(torch.int32)
|
||||
else:
|
||||
y_ref = getattr(torch, torch_ops[expr])(x0, x1)
|
||||
# compare
|
||||
assert_close(y, y_ref)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('expr, output_type, input0_type, input1_type, input2_type',
|
||||
[('where', "int32", "bool", "int32", "int32"), ])
|
||||
def test_three_input(expr, output_type, input0_type, input1_type, input2_type):
|
||||
src = f"""
|
||||
def kernel(X0, X1, X2, Y, BLOCK: tl.constexpr):
|
||||
x0 = tl.load(X0 + tl.arange(0, BLOCK))
|
||||
x1 = tl.load(X1 + tl.arange(0, BLOCK))
|
||||
x2 = tl.load(X2 + tl.arange(0, BLOCK))
|
||||
y = tl.{expr}(x0, x1, x2)
|
||||
tl.store(Y + tl.arange(0, BLOCK), y)
|
||||
"""
|
||||
fp = tempfile.NamedTemporaryFile(mode='w', suffix=".py")
|
||||
fp.write(src)
|
||||
fp.flush()
|
||||
|
||||
def kernel(X0, X1, X2, Y, BLOCK: tl.constexpr):
|
||||
pass
|
||||
kernel.__code__ = _testcapi.code_newempty(fp.name, "kernel", 1)
|
||||
parameters = []
|
||||
parameters.append(Parameter("X0", 1))
|
||||
parameters.append(Parameter("X1", 1))
|
||||
parameters.append(Parameter("X2", 1))
|
||||
parameters.append(Parameter("Y", 1))
|
||||
parameters.append(Parameter("BLOCK", 1))
|
||||
kernel.__signature__ = Signature(parameters=parameters)
|
||||
kernel = triton.jit(kernel)
|
||||
|
||||
shape = (128, )
|
||||
# limit the range of integers so that the sum does not overflow
|
||||
x0 = get_tensor(shape, input0_type)
|
||||
x1 = get_tensor(shape, input1_type)
|
||||
x2 = get_tensor(shape, input1_type)
|
||||
|
||||
# triton result
|
||||
y = torch.zeros(shape, dtype=torch_type[output_type], device="cuda")
|
||||
kernel[(1,)](x0, x1, x2, y, BLOCK=shape[0], extern_libs={"libdevice": libdevice})
|
||||
# reference result
|
||||
|
||||
y_ref = getattr(torch, torch_ops[expr])(x0, x1, x2)
|
||||
# compare
|
||||
assert_close(y, y_ref)
|
178
python/tests/test_ext_elemwise.py
Normal file
178
python/tests/test_ext_elemwise.py
Normal file
@@ -0,0 +1,178 @@
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch.testing import assert_close
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@pytest.mark.parametrize('num_warps, block_size, iter_size', [
|
||||
[4, 256, 1],
|
||||
[4, 1024, 256],
|
||||
])
|
||||
def test_sin_no_mask(num_warps, block_size, iter_size):
|
||||
@triton.jit
|
||||
def kernel(x_ptr,
|
||||
y_ptr,
|
||||
block_size,
|
||||
iter_size: tl.constexpr):
|
||||
pid = tl.program_id(axis=0)
|
||||
for i in range(0, block_size, iter_size):
|
||||
offset = pid * block_size + tl.arange(0, iter_size)
|
||||
x_ptrs = x_ptr + offset
|
||||
x = tl.load(x_ptrs)
|
||||
y = tl.libdevice.sin(x)
|
||||
y_ptrs = y_ptr + offset
|
||||
tl.store(y_ptrs, y)
|
||||
|
||||
x_ptr += iter_size
|
||||
y_ptr += iter_size
|
||||
|
||||
x = torch.randn((block_size,), device='cuda', dtype=torch.float32)
|
||||
y = torch.empty((block_size,), device=x.device, dtype=x.dtype)
|
||||
|
||||
grid = lambda EA: (x.shape.numel() // (block_size),)
|
||||
kernel[grid](x_ptr=x, y_ptr=y,
|
||||
block_size=x.shape[0], iter_size=iter_size, num_warps=num_warps)
|
||||
|
||||
golden_y = torch.sin(x)
|
||||
assert_close(y, golden_y, rtol=1e-7, atol=1e-7)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('num_warps, block_size, iter_size', [
|
||||
[4, 256, 1],
|
||||
[4, 1024, 256],
|
||||
])
|
||||
def test_fmin_no_mask(num_warps, block_size, iter_size):
|
||||
@triton.jit
|
||||
def kernel(x_ptr,
|
||||
y_ptr,
|
||||
z_ptr,
|
||||
block_size,
|
||||
iter_size: tl.constexpr):
|
||||
pid = tl.program_id(axis=0)
|
||||
for i in range(0, block_size, iter_size):
|
||||
offset = pid * block_size + tl.arange(0, iter_size)
|
||||
x_ptrs = x_ptr + offset
|
||||
y_ptrs = y_ptr + offset
|
||||
|
||||
x = tl.load(x_ptrs)
|
||||
y = tl.load(y_ptrs)
|
||||
z = tl.libdevice.min(x, y)
|
||||
z_ptrs = z_ptr + offset
|
||||
tl.store(z_ptrs, z)
|
||||
|
||||
x_ptr += iter_size
|
||||
y_ptr += iter_size
|
||||
z_ptr += iter_size
|
||||
|
||||
x = torch.randn((block_size,), device='cuda', dtype=torch.float32)
|
||||
y = torch.randn((block_size,), device='cuda', dtype=torch.float32)
|
||||
z = torch.empty((block_size,), device=x.device, dtype=x.dtype)
|
||||
|
||||
grid = lambda EA: (x.shape.numel() // (block_size),)
|
||||
kernel[grid](x_ptr=x, y_ptr=y, z_ptr=z,
|
||||
block_size=x.shape[0], iter_size=iter_size, num_warps=num_warps)
|
||||
|
||||
golden_z = torch.minimum(x, y)
|
||||
assert_close(z, golden_z, rtol=1e-7, atol=1e-7)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('num_warps, block_size, iter_size', [
|
||||
[4, 256, 1],
|
||||
[4, 1024, 256],
|
||||
])
|
||||
def test_fmad_rn_no_mask(num_warps, block_size, iter_size):
|
||||
@triton.jit
|
||||
def kernel(x_ptr,
|
||||
y_ptr,
|
||||
z_ptr,
|
||||
w_ptr,
|
||||
block_size,
|
||||
iter_size: tl.constexpr):
|
||||
pid = tl.program_id(axis=0)
|
||||
for i in range(0, block_size, iter_size):
|
||||
offset = pid * block_size + tl.arange(0, iter_size)
|
||||
x_ptrs = x_ptr + offset
|
||||
y_ptrs = y_ptr + offset
|
||||
z_ptrs = z_ptr + offset
|
||||
|
||||
x = tl.load(x_ptrs)
|
||||
y = tl.load(y_ptrs)
|
||||
z = tl.load(z_ptrs)
|
||||
|
||||
w = tl.libdevice.fma_rn(x, y, z)
|
||||
w_ptrs = w_ptr + offset
|
||||
tl.store(w_ptrs, w)
|
||||
|
||||
x_ptr += iter_size
|
||||
y_ptr += iter_size
|
||||
z_ptr += iter_size
|
||||
w_ptr += iter_size
|
||||
|
||||
x = torch.randn((block_size,), device='cuda', dtype=torch.float64)
|
||||
y = torch.randn((block_size,), device='cuda', dtype=torch.float64)
|
||||
z = torch.randn((block_size,), device='cuda', dtype=torch.float64)
|
||||
w = torch.empty((block_size,), device=x.device, dtype=x.dtype)
|
||||
|
||||
grid = lambda EA: (x.shape.numel() // (block_size),)
|
||||
kernel[grid](x_ptr=x, y_ptr=y, z_ptr=z, w_ptr=w,
|
||||
block_size=x.shape[0], iter_size=iter_size, num_warps=num_warps)
|
||||
|
||||
golden_w = x * y + z
|
||||
assert_close(w, golden_w, rtol=1e-7, atol=1e-7)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype_str, expr, lib_path",
|
||||
[('int32', 'libdevice.ffs', '/usr/local/cuda/nvvm/libdevice/libdevice.10.bc'),
|
||||
('int32', 'libdevice.ffs', '')])
|
||||
def test_libdevice(dtype_str, expr, lib_path):
|
||||
src = f"""
|
||||
def kernel(X, Y, BLOCK: tl.constexpr):
|
||||
x = tl.load(X + tl.arange(0, BLOCK))
|
||||
y = tl.{expr}(x)
|
||||
tl.store(Y + tl.arange(0, BLOCK), y)
|
||||
"""
|
||||
import tempfile
|
||||
from inspect import Parameter, Signature
|
||||
|
||||
import _testcapi
|
||||
|
||||
fp = tempfile.NamedTemporaryFile(mode='w', suffix=".py")
|
||||
fp.write(src)
|
||||
fp.flush()
|
||||
|
||||
def kernel(X, Y, BLOCK: tl.constexpr):
|
||||
pass
|
||||
kernel.__code__ = _testcapi.code_newempty(fp.name, "kernel", 1)
|
||||
parameters = []
|
||||
parameters.append(Parameter("X", 1))
|
||||
parameters.append(Parameter("Y", 1))
|
||||
parameters.append(Parameter("BLOCK", 1))
|
||||
kernel.__signature__ = Signature(parameters=parameters)
|
||||
kernel = triton.jit(kernel)
|
||||
|
||||
torch_type = {
|
||||
"int32": torch.int32,
|
||||
"float32": torch.float32,
|
||||
"float64": torch.float64
|
||||
}
|
||||
|
||||
shape = (128, )
|
||||
# limit the range of integers so that the sum does not overflow
|
||||
x = None
|
||||
if dtype_str == "int32":
|
||||
x = torch.randint(2**31 - 1, shape, dtype=torch_type[dtype_str], device="cuda")
|
||||
else:
|
||||
x = torch.randn(shape, dtype=torch_type[dtype_str], device="cuda")
|
||||
if expr == 'libdevice.ffs':
|
||||
y_ref = torch.zeros(shape, dtype=x.dtype, device="cuda")
|
||||
for i in range(shape[0]):
|
||||
y_ref[i] = (int(x[i]) & int(-x[i])).bit_length()
|
||||
|
||||
# triton result
|
||||
y = torch.zeros(shape, dtype=x.dtype, device="cuda")
|
||||
kernel[(1,)](x, y, BLOCK=shape[0], extern_libs={"libdevice": lib_path})
|
||||
# compare
|
||||
assert_close(y, y_ref)
|
@@ -36,6 +36,7 @@ def str_to_ty(name):
|
||||
"bf16": triton.language.bfloat16,
|
||||
"fp32": triton.language.float32,
|
||||
"fp64": triton.language.float64,
|
||||
"i1": triton.language.int1,
|
||||
"i8": triton.language.int8,
|
||||
"i16": triton.language.int16,
|
||||
"i32": triton.language.int32,
|
||||
@@ -45,7 +46,6 @@ def str_to_ty(name):
|
||||
"u32": triton.language.uint32,
|
||||
"u64": triton.language.uint64,
|
||||
"B": triton.language.int1,
|
||||
"i1": triton.language.int1,
|
||||
}
|
||||
return tys[name]
|
||||
|
||||
@@ -888,6 +888,13 @@ def optimize_tritongpu_ir(mod, num_stages):
|
||||
return mod
|
||||
|
||||
|
||||
def add_external_libs(mod, libs):
|
||||
for name, path in libs.items():
|
||||
if len(name) == 0 or len(path) == 0:
|
||||
return
|
||||
_triton.add_external_libs(mod, list(libs.keys()), list(libs.values()))
|
||||
|
||||
|
||||
def make_llvm_ir(mod):
|
||||
return _triton.translate_triton_gpu_to_llvmir(mod)
|
||||
|
||||
@@ -986,6 +993,8 @@ def _compile(fn, signature: str, device: int = -1, constants=dict(), specializat
|
||||
module = optimize_tritongpu_ir(module, num_stages)
|
||||
if output == "ttgir":
|
||||
return module.str()
|
||||
if extern_libs:
|
||||
add_external_libs(module, extern_libs)
|
||||
|
||||
# llvm-ir
|
||||
llvm_ir = make_llvm_ir(module)
|
||||
|
BIN
python/triton/language/libdevice.10.bc
Executable file
BIN
python/triton/language/libdevice.10.bc
Executable file
Binary file not shown.
@@ -226,7 +226,6 @@ def fdiv(input: tl.tensor,
|
||||
raise ValueError("both operands of fdiv must have floating poscalar type")
|
||||
input, other = binary_op_type_checking_impl(input, other, builder, False, False, False, True)
|
||||
ret = builder.create_fdiv(input.handle, other.handle)
|
||||
ret.set_fdiv_ieee_rounding(ieee_rounding)
|
||||
return tl.tensor(ret, input.type)
|
||||
|
||||
|
||||
@@ -1074,7 +1073,8 @@ def xor_sum(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor:
|
||||
|
||||
def umulhi(x: tl.tensor, y: tl.tensor, builder: ir.builder) -> tl.tensor:
|
||||
x, y = binary_op_type_checking_impl(x, y, builder)
|
||||
return tl.tensor(builder.create_umulhi(x.handle, y.handle), x.type)
|
||||
from . import libdevice
|
||||
return libdevice.mulhi(x, y, _builder=builder)
|
||||
|
||||
|
||||
def exp(x: tl.tensor, builder: ir.builder) -> tl.tensor:
|
||||
|
Reference in New Issue
Block a user