[Frontend] Return a scalar if all input args are scalar (#816)

This commit is contained in:
Keren Zhou
2022-10-28 23:27:06 -07:00
committed by GitHub
parent 5ca1ed0101
commit 3ca667dfa8
2 changed files with 58 additions and 23 deletions

View File

@@ -1546,7 +1546,7 @@ def test_num_warps_pow2():
[('int32', 'libdevice.ffs', ''), [('int32', 'libdevice.ffs', ''),
('float32', 'libdevice.pow', '/usr/local/cuda/nvvm/libdevice/libdevice.10.bc'), ('float32', 'libdevice.pow', '/usr/local/cuda/nvvm/libdevice/libdevice.10.bc'),
('float64', 'libdevice.norm4d', '')]) ('float64', 'libdevice.norm4d', '')])
def test_libdevice(dtype_str, expr, lib_path): def test_libdevice_tensor(dtype_str, expr, lib_path):
@triton.jit @triton.jit
def kernel(X, Y, BLOCK: tl.constexpr): def kernel(X, Y, BLOCK: tl.constexpr):
@@ -1582,3 +1582,32 @@ def test_libdevice(dtype_str, expr, lib_path):
np.testing.assert_equal(y_ref, to_numpy(y_tri)) np.testing.assert_equal(y_ref, to_numpy(y_tri))
else: else:
np.testing.assert_allclose(y_ref, to_numpy(y_tri), rtol=0.01) np.testing.assert_allclose(y_ref, to_numpy(y_tri), rtol=0.01)
@pytest.mark.parametrize("dtype_str, expr, lib_path",
[('float32', 'libdevice.pow', '')])
def test_libdevice_scalar(dtype_str, expr, lib_path):
@triton.jit
def kernel(X, Y, BLOCK: tl.constexpr):
x = X
y = GENERATE_TEST_HERE
tl.store(Y + tl.arange(0, BLOCK), y)
shape = (128, )
rs = RandomState(17)
# limit the range of integers so that the sum does not overflow
x = numpy_random((1,), dtype_str=dtype_str, rs=rs)
y_ref = np.zeros(shape, dtype=x.dtype)
# numpy does not allow negative factors in power, so we use abs()
x = np.abs(x)
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.libdevice.pow(x, x)'})
y_ref[:] = np.power(x, x)
# triton result
x_tri = to_triton(x)[0].item()
y_tri = to_triton(numpy_random((shape[0],), dtype_str=dtype_str, rs=rs), device='cuda')
kernel[(1,)](x_tri, y_tri, BLOCK=shape[0], extern_libs={'libdevice': lib_path})
# compare
np.testing.assert_allclose(y_ref, to_numpy(y_tri), rtol=0.01)

View File

@@ -59,28 +59,34 @@ def elementwise(lib_name: str, lib_path: str, args: list, arg_type_symbol_dict:
:return: the return value of the function :return: the return value of the function
''' '''
dispatch_args = args.copy() dispatch_args = args.copy()
if len(args) == 1: all_scalar = True
dispatch_args[0] = core._to_tensor(dispatch_args[0], _builder) ret_shape = None
ret_shape = dispatch_args[0].shape for dispatch_arg in dispatch_args:
elif len(args) == 2: if dispatch_arg.type.is_block():
dispatch_args[0] = core._to_tensor(dispatch_args[0], _builder) all_scalar = False
dispatch_args[1] = core._to_tensor(dispatch_args[1], _builder) if not all_scalar:
dispatch_args[0], dispatch_args[1] = semantic.binary_op_type_checking_impl( if len(args) == 1:
dispatch_args[0], dispatch_args[1], _builder) dispatch_args[0] = core._to_tensor(dispatch_args[0], _builder)
ret_shape = dispatch_args[0].shape ret_shape = dispatch_args[0].shape
else: elif len(args) == 2:
for i in range(len(dispatch_args)): dispatch_args[0] = core._to_tensor(dispatch_args[0], _builder)
dispatch_args[i] = core._to_tensor(dispatch_args[i], _builder) dispatch_args[1] = core._to_tensor(dispatch_args[1], _builder)
broadcast_arg = dispatch_args[0] dispatch_args[0], dispatch_args[1] = semantic.binary_op_type_checking_impl(
# Get the broadcast shape over all the arguments dispatch_args[0], dispatch_args[1], _builder)
for i in range(len(dispatch_args)): ret_shape = dispatch_args[0].shape
_, broadcast_arg = semantic.binary_op_type_checking_impl( else:
dispatch_args[i], broadcast_arg, _builder) for i in range(len(dispatch_args)):
# Change the shape of each argument based on the broadcast shape dispatch_args[i] = core._to_tensor(dispatch_args[i], _builder)
for i in range(len(dispatch_args)): broadcast_arg = dispatch_args[0]
dispatch_args[i], _ = semantic.binary_op_type_checking_impl( # Get the broadcast shape over all the arguments
dispatch_args[i], broadcast_arg, _builder) for i in range(len(dispatch_args)):
ret_shape = broadcast_arg.shape _, broadcast_arg = semantic.binary_op_type_checking_impl(
dispatch_args[i], broadcast_arg, _builder)
# Change the shape of each argument based on the broadcast shape
for i in range(len(dispatch_args)):
dispatch_args[i], _ = semantic.binary_op_type_checking_impl(
dispatch_args[i], broadcast_arg, _builder)
ret_shape = broadcast_arg.shape
func = getattr(_builder, "create_extern_elementwise") func = getattr(_builder, "create_extern_elementwise")
return dispatch(func, lib_name, lib_path, dispatch_args, arg_type_symbol_dict, ret_shape, _builder) return dispatch(func, lib_name, lib_path, dispatch_args, arg_type_symbol_dict, ret_shape, _builder)