[Frontend] Return a scalar if all input args are scalar (#816)
This commit is contained in:
@@ -1546,7 +1546,7 @@ def test_num_warps_pow2():
|
|||||||
[('int32', 'libdevice.ffs', ''),
|
[('int32', 'libdevice.ffs', ''),
|
||||||
('float32', 'libdevice.pow', '/usr/local/cuda/nvvm/libdevice/libdevice.10.bc'),
|
('float32', 'libdevice.pow', '/usr/local/cuda/nvvm/libdevice/libdevice.10.bc'),
|
||||||
('float64', 'libdevice.norm4d', '')])
|
('float64', 'libdevice.norm4d', '')])
|
||||||
def test_libdevice(dtype_str, expr, lib_path):
|
def test_libdevice_tensor(dtype_str, expr, lib_path):
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def kernel(X, Y, BLOCK: tl.constexpr):
|
def kernel(X, Y, BLOCK: tl.constexpr):
|
||||||
@@ -1582,3 +1582,32 @@ def test_libdevice(dtype_str, expr, lib_path):
|
|||||||
np.testing.assert_equal(y_ref, to_numpy(y_tri))
|
np.testing.assert_equal(y_ref, to_numpy(y_tri))
|
||||||
else:
|
else:
|
||||||
np.testing.assert_allclose(y_ref, to_numpy(y_tri), rtol=0.01)
|
np.testing.assert_allclose(y_ref, to_numpy(y_tri), rtol=0.01)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("dtype_str, expr, lib_path",
|
||||||
|
[('float32', 'libdevice.pow', '')])
|
||||||
|
def test_libdevice_scalar(dtype_str, expr, lib_path):
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def kernel(X, Y, BLOCK: tl.constexpr):
|
||||||
|
x = X
|
||||||
|
y = GENERATE_TEST_HERE
|
||||||
|
tl.store(Y + tl.arange(0, BLOCK), y)
|
||||||
|
|
||||||
|
shape = (128, )
|
||||||
|
rs = RandomState(17)
|
||||||
|
# limit the range of integers so that the sum does not overflow
|
||||||
|
x = numpy_random((1,), dtype_str=dtype_str, rs=rs)
|
||||||
|
y_ref = np.zeros(shape, dtype=x.dtype)
|
||||||
|
|
||||||
|
# numpy does not allow negative factors in power, so we use abs()
|
||||||
|
x = np.abs(x)
|
||||||
|
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.libdevice.pow(x, x)'})
|
||||||
|
y_ref[:] = np.power(x, x)
|
||||||
|
|
||||||
|
# triton result
|
||||||
|
x_tri = to_triton(x)[0].item()
|
||||||
|
y_tri = to_triton(numpy_random((shape[0],), dtype_str=dtype_str, rs=rs), device='cuda')
|
||||||
|
kernel[(1,)](x_tri, y_tri, BLOCK=shape[0], extern_libs={'libdevice': lib_path})
|
||||||
|
# compare
|
||||||
|
np.testing.assert_allclose(y_ref, to_numpy(y_tri), rtol=0.01)
|
||||||
|
@@ -59,6 +59,12 @@ def elementwise(lib_name: str, lib_path: str, args: list, arg_type_symbol_dict:
|
|||||||
:return: the return value of the function
|
:return: the return value of the function
|
||||||
'''
|
'''
|
||||||
dispatch_args = args.copy()
|
dispatch_args = args.copy()
|
||||||
|
all_scalar = True
|
||||||
|
ret_shape = None
|
||||||
|
for dispatch_arg in dispatch_args:
|
||||||
|
if dispatch_arg.type.is_block():
|
||||||
|
all_scalar = False
|
||||||
|
if not all_scalar:
|
||||||
if len(args) == 1:
|
if len(args) == 1:
|
||||||
dispatch_args[0] = core._to_tensor(dispatch_args[0], _builder)
|
dispatch_args[0] = core._to_tensor(dispatch_args[0], _builder)
|
||||||
ret_shape = dispatch_args[0].shape
|
ret_shape = dispatch_args[0].shape
|
||||||
|
Reference in New Issue
Block a user