[Triton-MLIR]Add ptx vprintf support (#825)

Not know how to write unit test for this feature.

Co-authored-by: Yan Chunwei <yanchunwei@outlook.com>
This commit is contained in:
ben-zhang-609
2022-11-02 16:39:09 +08:00
committed by GitHub
parent 12d60cb4a3
commit 5feb6e24f9
9 changed files with 386 additions and 62 deletions

View File

@@ -144,7 +144,7 @@ def _test_unary(dtype_x, expr, numpy_expr=None, device='cuda'):
# triton result
x_tri = to_triton(x, device=device, dst_type=dtype_x)
z_tri = to_triton(np.empty_like(z_ref), device=device, dst_type=dtype_x)
kernel[(1, )](z_tri, x_tri, SIZE=SIZE, num_warps=4)
kernel[(1, )](z_tri, x_tri, SIZE=SIZE, num_warps=4, extern_libs={"libdevice": "/usr/local/cuda/nvvm/libdevice/libdevice.10.bc"})
# compare
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
@@ -463,17 +463,12 @@ def test_unary_op(dtype_x, expr, device='cuda'):
# # test math ops
# # ----------------
# TODO: Math module
# # @pytest.mark.parametrize("expr", [
# # 'exp', 'log', 'cos', 'sin'
# # ])
# @pytest.mark.parametrize("expr", [
# 'exp', 'log', 'cos', 'sin'
# ])
# def test_math_op(expr, device='cuda'):
# _test_unary('float32', f'tl.{expr}(x)', f'np.{expr}(x) ', device=device)
@pytest.mark.parametrize("expr", [
'exp', 'log', 'cos', 'sin'
])
def test_math_op(expr, device='cuda'):
_test_unary('float32', f'tl.{expr}(x)', f'np.{expr}(x) ', device=device)
# # ----------------
@@ -1545,43 +1540,43 @@ def test_num_warps_pow2():
# # -------------
# @pytest.mark.parametrize("dtype_str, expr, lib_path",
# [('int32', 'libdevice.ffs', ''),
# ('float32', 'libdevice.pow', '/usr/local/cuda/nvvm/libdevice/libdevice.10.bc'),
# ('float64', 'libdevice.norm4d', '')])
# def test_libdevice(dtype_str, expr, lib_path):
@pytest.mark.parametrize("dtype_str, expr, lib_path",
[('int32', 'libdevice.ffs', ''),
('float32', 'libdevice.pow', '/usr/local/cuda/nvvm/libdevice/libdevice.10.bc'),
('float64', 'libdevice.norm4d', '')])
def test_libdevice(dtype_str, expr, lib_path):
# @triton.jit
# def kernel(X, Y, BLOCK: tl.constexpr):
# x = tl.load(X + tl.arange(0, BLOCK))
# y = GENERATE_TEST_HERE
# tl.store(Y + tl.arange(0, BLOCK), y)
@triton.jit
def kernel(X, Y, BLOCK: tl.constexpr):
x = tl.load(X + tl.arange(0, BLOCK))
y = GENERATE_TEST_HERE
tl.store(Y + tl.arange(0, BLOCK), y)
# shape = (128, )
# rs = RandomState(17)
# # limit the range of integers so that the sum does not overflow
# x = numpy_random(shape, dtype_str=dtype_str, rs=rs)
shape = (128, )
rs = RandomState(17)
# limit the range of integers so that the sum does not overflow
x = numpy_random(shape, dtype_str=dtype_str, rs=rs)
# if expr == 'libdevice.ffs':
# kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.libdevice.ffs(x)'})
# y_ref = np.zeros(shape, dtype=x.dtype)
# for i in range(shape[0]):
# y_ref[i] = (int(x[i]) & int(-x[i])).bit_length()
# elif expr == 'libdevice.pow':
# # numpy does not allow negative factors in power, so we use abs()
# x = np.abs(x)
# kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.libdevice.pow(x, x)'})
# y_ref = np.power(x, x)
# elif expr == 'libdevice.norm4d':
# kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.libdevice.norm4d(x, x, x, x)'})
# y_ref = np.sqrt(4 * np.power(x, 2))
if expr == 'libdevice.ffs':
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.libdevice.ffs(x)'})
y_ref = np.zeros(shape, dtype=x.dtype)
for i in range(shape[0]):
y_ref[i] = (int(x[i]) & int(-x[i])).bit_length()
elif expr == 'libdevice.pow':
# numpy does not allow negative factors in power, so we use abs()
x = np.abs(x)
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.libdevice.pow(x, x)'})
y_ref = np.power(x, x)
elif expr == 'libdevice.norm4d':
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.libdevice.norm4d(x, x, x, x)'})
y_ref = np.sqrt(4 * np.power(x, 2))
# x_tri = to_triton(x)
# # triton result
# y_tri = to_triton(numpy_random((shape[0],), dtype_str=dtype_str, rs=rs), device='cuda')
# kernel[(1,)](x_tri, y_tri, BLOCK=shape[0], extern_libs={'libdevice': lib_path})
# # compare
# if expr == 'libdevice.ffs':
# np.testing.assert_equal(y_ref, to_numpy(y_tri))
# else:
# np.testing.assert_allclose(y_ref, to_numpy(y_tri), rtol=0.01)
x_tri = to_triton(x)
# triton result
y_tri = to_triton(numpy_random((shape[0],), dtype_str=dtype_str, rs=rs), device='cuda')
kernel[(1,)](x_tri, y_tri, BLOCK=shape[0], extern_libs={'libdevice': lib_path})
# compare
if expr == 'libdevice.ffs':
np.testing.assert_equal(y_ref, to_numpy(y_tri))
else:
np.testing.assert_allclose(y_ref, to_numpy(y_tri), rtol=0.01)