[Triton-MLIR]Add ptx vprintf support (#825)
Not know how to write unit test for this feature. Co-authored-by: Yan Chunwei <yanchunwei@outlook.com>
This commit is contained in:
56
python/tests/printf_helper.py
Normal file
56
python/tests/printf_helper.py
Normal file
@@ -0,0 +1,56 @@
|
||||
import torch
|
||||
from torch.testing import assert_close
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
torch_type = {
|
||||
"bool": torch.bool,
|
||||
'int8': torch.int8,
|
||||
'uint8': torch.uint8,
|
||||
'int16': torch.int16,
|
||||
"int32": torch.int32,
|
||||
'int64': torch.long,
|
||||
'float16': torch.float16,
|
||||
'bfloat16': torch.bfloat16,
|
||||
"float32": torch.float32,
|
||||
"float64": torch.float64
|
||||
}
|
||||
|
||||
|
||||
def get_tensor(shape, data_type, b_positive=False):
|
||||
x = None
|
||||
if data_type.startswith('int'):
|
||||
x = torch.arange(0, shape[0], dtype=torch_type[data_type], device='cuda')
|
||||
else:
|
||||
x = torch.arange(0, shape[0], dtype=torch_type[data_type], device='cuda')
|
||||
|
||||
return x
|
||||
|
||||
# @pytest.mark.parametrize('data_type',
|
||||
# [("int8"),
|
||||
# ('int16'),
|
||||
# ('int32'),
|
||||
# ("int64"),
|
||||
# ('float16'),
|
||||
# ("float32"),
|
||||
# ("float64")])
|
||||
|
||||
|
||||
def printf(data_type):
|
||||
@triton.jit
|
||||
def kernel(X, Y, BLOCK: tl.constexpr):
|
||||
x = tl.load(X + tl.arange(0, BLOCK))
|
||||
tl.printf("", x)
|
||||
tl.store(Y + tl.arange(0, BLOCK), x)
|
||||
|
||||
shape = (128, )
|
||||
# limit the range of integers so that the sum does not overflow
|
||||
x = get_tensor(shape, data_type)
|
||||
y = torch.zeros(shape, dtype=x.dtype, device="cuda")
|
||||
kernel[(1,)](x, y, BLOCK=shape[0])
|
||||
assert_close(y, x)
|
||||
|
||||
|
||||
printf("float16")
|
||||
printf("int8")
|
@@ -144,7 +144,7 @@ def _test_unary(dtype_x, expr, numpy_expr=None, device='cuda'):
|
||||
# triton result
|
||||
x_tri = to_triton(x, device=device, dst_type=dtype_x)
|
||||
z_tri = to_triton(np.empty_like(z_ref), device=device, dst_type=dtype_x)
|
||||
kernel[(1, )](z_tri, x_tri, SIZE=SIZE, num_warps=4)
|
||||
kernel[(1, )](z_tri, x_tri, SIZE=SIZE, num_warps=4, extern_libs={"libdevice": "/usr/local/cuda/nvvm/libdevice/libdevice.10.bc"})
|
||||
# compare
|
||||
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
|
||||
|
||||
@@ -463,17 +463,12 @@ def test_unary_op(dtype_x, expr, device='cuda'):
|
||||
# # test math ops
|
||||
# # ----------------
|
||||
|
||||
# TODO: Math module
|
||||
# # @pytest.mark.parametrize("expr", [
|
||||
# # 'exp', 'log', 'cos', 'sin'
|
||||
# # ])
|
||||
|
||||
|
||||
# @pytest.mark.parametrize("expr", [
|
||||
# 'exp', 'log', 'cos', 'sin'
|
||||
# ])
|
||||
# def test_math_op(expr, device='cuda'):
|
||||
# _test_unary('float32', f'tl.{expr}(x)', f'np.{expr}(x) ', device=device)
|
||||
@pytest.mark.parametrize("expr", [
|
||||
'exp', 'log', 'cos', 'sin'
|
||||
])
|
||||
def test_math_op(expr, device='cuda'):
|
||||
_test_unary('float32', f'tl.{expr}(x)', f'np.{expr}(x) ', device=device)
|
||||
|
||||
|
||||
# # ----------------
|
||||
@@ -1545,43 +1540,43 @@ def test_num_warps_pow2():
|
||||
# # -------------
|
||||
|
||||
|
||||
# @pytest.mark.parametrize("dtype_str, expr, lib_path",
|
||||
# [('int32', 'libdevice.ffs', ''),
|
||||
# ('float32', 'libdevice.pow', '/usr/local/cuda/nvvm/libdevice/libdevice.10.bc'),
|
||||
# ('float64', 'libdevice.norm4d', '')])
|
||||
# def test_libdevice(dtype_str, expr, lib_path):
|
||||
@pytest.mark.parametrize("dtype_str, expr, lib_path",
|
||||
[('int32', 'libdevice.ffs', ''),
|
||||
('float32', 'libdevice.pow', '/usr/local/cuda/nvvm/libdevice/libdevice.10.bc'),
|
||||
('float64', 'libdevice.norm4d', '')])
|
||||
def test_libdevice(dtype_str, expr, lib_path):
|
||||
|
||||
# @triton.jit
|
||||
# def kernel(X, Y, BLOCK: tl.constexpr):
|
||||
# x = tl.load(X + tl.arange(0, BLOCK))
|
||||
# y = GENERATE_TEST_HERE
|
||||
# tl.store(Y + tl.arange(0, BLOCK), y)
|
||||
@triton.jit
|
||||
def kernel(X, Y, BLOCK: tl.constexpr):
|
||||
x = tl.load(X + tl.arange(0, BLOCK))
|
||||
y = GENERATE_TEST_HERE
|
||||
tl.store(Y + tl.arange(0, BLOCK), y)
|
||||
|
||||
# shape = (128, )
|
||||
# rs = RandomState(17)
|
||||
# # limit the range of integers so that the sum does not overflow
|
||||
# x = numpy_random(shape, dtype_str=dtype_str, rs=rs)
|
||||
shape = (128, )
|
||||
rs = RandomState(17)
|
||||
# limit the range of integers so that the sum does not overflow
|
||||
x = numpy_random(shape, dtype_str=dtype_str, rs=rs)
|
||||
|
||||
# if expr == 'libdevice.ffs':
|
||||
# kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.libdevice.ffs(x)'})
|
||||
# y_ref = np.zeros(shape, dtype=x.dtype)
|
||||
# for i in range(shape[0]):
|
||||
# y_ref[i] = (int(x[i]) & int(-x[i])).bit_length()
|
||||
# elif expr == 'libdevice.pow':
|
||||
# # numpy does not allow negative factors in power, so we use abs()
|
||||
# x = np.abs(x)
|
||||
# kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.libdevice.pow(x, x)'})
|
||||
# y_ref = np.power(x, x)
|
||||
# elif expr == 'libdevice.norm4d':
|
||||
# kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.libdevice.norm4d(x, x, x, x)'})
|
||||
# y_ref = np.sqrt(4 * np.power(x, 2))
|
||||
if expr == 'libdevice.ffs':
|
||||
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.libdevice.ffs(x)'})
|
||||
y_ref = np.zeros(shape, dtype=x.dtype)
|
||||
for i in range(shape[0]):
|
||||
y_ref[i] = (int(x[i]) & int(-x[i])).bit_length()
|
||||
elif expr == 'libdevice.pow':
|
||||
# numpy does not allow negative factors in power, so we use abs()
|
||||
x = np.abs(x)
|
||||
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.libdevice.pow(x, x)'})
|
||||
y_ref = np.power(x, x)
|
||||
elif expr == 'libdevice.norm4d':
|
||||
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.libdevice.norm4d(x, x, x, x)'})
|
||||
y_ref = np.sqrt(4 * np.power(x, 2))
|
||||
|
||||
# x_tri = to_triton(x)
|
||||
# # triton result
|
||||
# y_tri = to_triton(numpy_random((shape[0],), dtype_str=dtype_str, rs=rs), device='cuda')
|
||||
# kernel[(1,)](x_tri, y_tri, BLOCK=shape[0], extern_libs={'libdevice': lib_path})
|
||||
# # compare
|
||||
# if expr == 'libdevice.ffs':
|
||||
# np.testing.assert_equal(y_ref, to_numpy(y_tri))
|
||||
# else:
|
||||
# np.testing.assert_allclose(y_ref, to_numpy(y_tri), rtol=0.01)
|
||||
x_tri = to_triton(x)
|
||||
# triton result
|
||||
y_tri = to_triton(numpy_random((shape[0],), dtype_str=dtype_str, rs=rs), device='cuda')
|
||||
kernel[(1,)](x_tri, y_tri, BLOCK=shape[0], extern_libs={'libdevice': lib_path})
|
||||
# compare
|
||||
if expr == 'libdevice.ffs':
|
||||
np.testing.assert_equal(y_ref, to_numpy(y_tri))
|
||||
else:
|
||||
np.testing.assert_allclose(y_ref, to_numpy(y_tri), rtol=0.01)
|
||||
|
21
python/tests/test_printf.py
Normal file
21
python/tests/test_printf.py
Normal file
@@ -0,0 +1,21 @@
|
||||
import os
|
||||
import subprocess
|
||||
|
||||
dir_path = os.path.dirname(os.path.realpath(__file__))
|
||||
printf_path = os.path.join(dir_path, "printf_helper.py")
|
||||
|
||||
|
||||
def test_printf():
|
||||
proc = subprocess.Popen(["python", printf_path], stdout=subprocess.PIPE, shell=False)
|
||||
(outs, err) = proc.communicate()
|
||||
outs = outs.split()
|
||||
new_lines = set()
|
||||
for line in outs:
|
||||
try:
|
||||
value = int(float(line))
|
||||
new_lines.add(value)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
for i in range(128):
|
||||
assert i in new_lines
|
||||
assert len(new_lines) == 128
|
Reference in New Issue
Block a user