From 3ca667dfa8df4b64bb47309bfbf7ffcad75bda51 Mon Sep 17 00:00:00 2001 From: Keren Zhou Date: Fri, 28 Oct 2022 23:27:06 -0700 Subject: [PATCH] [Frontend] Return a scalar if all input args are scalar (#816) --- python/test/unit/language/test_core.py | 31 +++++++++++++++- python/triton/language/extern.py | 50 ++++++++++++++------------ 2 files changed, 58 insertions(+), 23 deletions(-) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index d7d9130d5..1282f24d9 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -1546,7 +1546,7 @@ def test_num_warps_pow2(): [('int32', 'libdevice.ffs', ''), ('float32', 'libdevice.pow', '/usr/local/cuda/nvvm/libdevice/libdevice.10.bc'), ('float64', 'libdevice.norm4d', '')]) -def test_libdevice(dtype_str, expr, lib_path): +def test_libdevice_tensor(dtype_str, expr, lib_path): @triton.jit def kernel(X, Y, BLOCK: tl.constexpr): @@ -1582,3 +1582,32 @@ def test_libdevice(dtype_str, expr, lib_path): np.testing.assert_equal(y_ref, to_numpy(y_tri)) else: np.testing.assert_allclose(y_ref, to_numpy(y_tri), rtol=0.01) + + +@pytest.mark.parametrize("dtype_str, expr, lib_path", + [('float32', 'libdevice.pow', '')]) +def test_libdevice_scalar(dtype_str, expr, lib_path): + + @triton.jit + def kernel(X, Y, BLOCK: tl.constexpr): + x = X + y = GENERATE_TEST_HERE + tl.store(Y + tl.arange(0, BLOCK), y) + + shape = (128, ) + rs = RandomState(17) + # limit the range of integers so that the sum does not overflow + x = numpy_random((1,), dtype_str=dtype_str, rs=rs) + y_ref = np.zeros(shape, dtype=x.dtype) + + # numpy does not allow negative factors in power, so we use abs() + x = np.abs(x) + kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.libdevice.pow(x, x)'}) + y_ref[:] = np.power(x, x) + + # triton result + x_tri = to_triton(x)[0].item() + y_tri = to_triton(numpy_random((shape[0],), dtype_str=dtype_str, rs=rs), device='cuda') + kernel[(1,)](x_tri, y_tri, BLOCK=shape[0], extern_libs={'libdevice': lib_path}) + # compare + np.testing.assert_allclose(y_ref, to_numpy(y_tri), rtol=0.01) diff --git a/python/triton/language/extern.py b/python/triton/language/extern.py index a306a2e9a..1f3c9371c 100644 --- a/python/triton/language/extern.py +++ b/python/triton/language/extern.py @@ -59,28 +59,34 @@ def elementwise(lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: :return: the return value of the function ''' dispatch_args = args.copy() - if len(args) == 1: - dispatch_args[0] = core._to_tensor(dispatch_args[0], _builder) - ret_shape = dispatch_args[0].shape - elif len(args) == 2: - dispatch_args[0] = core._to_tensor(dispatch_args[0], _builder) - dispatch_args[1] = core._to_tensor(dispatch_args[1], _builder) - dispatch_args[0], dispatch_args[1] = semantic.binary_op_type_checking_impl( - dispatch_args[0], dispatch_args[1], _builder) - ret_shape = dispatch_args[0].shape - else: - for i in range(len(dispatch_args)): - dispatch_args[i] = core._to_tensor(dispatch_args[i], _builder) - broadcast_arg = dispatch_args[0] - # Get the broadcast shape over all the arguments - for i in range(len(dispatch_args)): - _, broadcast_arg = semantic.binary_op_type_checking_impl( - dispatch_args[i], broadcast_arg, _builder) - # Change the shape of each argument based on the broadcast shape - for i in range(len(dispatch_args)): - dispatch_args[i], _ = semantic.binary_op_type_checking_impl( - dispatch_args[i], broadcast_arg, _builder) - ret_shape = broadcast_arg.shape + all_scalar = True + ret_shape = None + for dispatch_arg in dispatch_args: + if dispatch_arg.type.is_block(): + all_scalar = False + if not all_scalar: + if len(args) == 1: + dispatch_args[0] = core._to_tensor(dispatch_args[0], _builder) + ret_shape = dispatch_args[0].shape + elif len(args) == 2: + dispatch_args[0] = core._to_tensor(dispatch_args[0], _builder) + dispatch_args[1] = core._to_tensor(dispatch_args[1], _builder) + dispatch_args[0], dispatch_args[1] = semantic.binary_op_type_checking_impl( + dispatch_args[0], dispatch_args[1], _builder) + ret_shape = dispatch_args[0].shape + else: + for i in range(len(dispatch_args)): + dispatch_args[i] = core._to_tensor(dispatch_args[i], _builder) + broadcast_arg = dispatch_args[0] + # Get the broadcast shape over all the arguments + for i in range(len(dispatch_args)): + _, broadcast_arg = semantic.binary_op_type_checking_impl( + dispatch_args[i], broadcast_arg, _builder) + # Change the shape of each argument based on the broadcast shape + for i in range(len(dispatch_args)): + dispatch_args[i], _ = semantic.binary_op_type_checking_impl( + dispatch_args[i], broadcast_arg, _builder) + ret_shape = broadcast_arg.shape func = getattr(_builder, "create_extern_elementwise") return dispatch(func, lib_name, lib_path, dispatch_args, arg_type_symbol_dict, ret_shape, _builder)