Merge branch 'master' into IFU_11_1_2022

This commit is contained in:
Michael Melesse
2022-11-01 17:29:10 +00:00
4 changed files with 73 additions and 24 deletions

View File

@@ -1593,9 +1593,9 @@ def test_num_warps_pow2():
[('int32', 'libdevice.ffs', ''),
('float32', 'libdevice.pow', '/usr/local/cuda/nvvm/libdevice/libdevice.10.bc'),
('float64', 'libdevice.norm4d', '')])
def test_libdevice(dtype_str, expr, lib_path):
def test_libdevice_tensor(dtype_str, expr, lib_path):
if torch.version.hip is not None:
pytest.skip(f"test_libdevice currently has segfaults on ROCM")
pytest.skip(f"test_libdevice_tensor currently has segfaults on ROCM")
@triton.jit
def kernel(X, Y, BLOCK: tl.constexpr):
x = tl.load(X + tl.arange(0, BLOCK))
@@ -1630,3 +1630,32 @@ def test_libdevice(dtype_str, expr, lib_path):
np.testing.assert_equal(y_ref, to_numpy(y_tri))
else:
np.testing.assert_allclose(y_ref, to_numpy(y_tri), rtol=0.01)
@pytest.mark.parametrize("dtype_str, expr, lib_path",
[('float32', 'libdevice.pow', '')])
def test_libdevice_scalar(dtype_str, expr, lib_path):
@triton.jit
def kernel(X, Y, BLOCK: tl.constexpr):
x = X
y = GENERATE_TEST_HERE
tl.store(Y + tl.arange(0, BLOCK), y)
shape = (128, )
rs = RandomState(17)
# limit the range of integers so that the sum does not overflow
x = numpy_random((1,), dtype_str=dtype_str, rs=rs)
y_ref = np.zeros(shape, dtype=x.dtype)
# numpy does not allow negative factors in power, so we use abs()
x = np.abs(x)
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.libdevice.pow(x, x)'})
y_ref[:] = np.power(x, x)
# triton result
x_tri = to_triton(x)[0].item()
y_tri = to_triton(numpy_random((shape[0],), dtype_str=dtype_str, rs=rs), device='cuda')
kernel[(1,)](x_tri, y_tri, BLOCK=shape[0], extern_libs={'libdevice': lib_path})
# compare
np.testing.assert_allclose(y_ref, to_numpy(y_tri), rtol=0.01)