From 9490252261aae4e9fa1ab2ff8aede48769057894 Mon Sep 17 00:00:00 2001 From: Crutcher Dunnavant Date: Mon, 5 Dec 2022 19:41:44 -0800 Subject: [PATCH] [FRONTEND] Support alternative install locations of system libdevice.10.bc (#951) --- python/tests/libdevice_testutil.py | 18 ++++++++++++++++++ python/tests/test_core.py | 3 ++- python/tests/test_elementwise.py | 21 ++++++++++++++++----- python/tests/test_ext_elemwise.py | 3 ++- 4 files changed, 38 insertions(+), 7 deletions(-) create mode 100644 python/tests/libdevice_testutil.py diff --git a/python/tests/libdevice_testutil.py b/python/tests/libdevice_testutil.py new file mode 100644 index 000000000..03bc08660 --- /dev/null +++ b/python/tests/libdevice_testutil.py @@ -0,0 +1,18 @@ +import os +from typing import Optional + +_SYSTEM_LIBDEVICE_SEARCH_PATHS = [ + '/usr/lib/cuda/nvvm/libdevice/libdevice.10.bc', + '/usr/local/cuda/nvvm/libdevice/libdevice.10.bc', +] + +SYSTEM_LIBDEVICE_PATH: Optional[str] = None +for _p in _SYSTEM_LIBDEVICE_SEARCH_PATHS: + if os.path.exists(_p): + SYSTEM_LIBDEVICE_PATH = _p + +def system_libdevice_path() -> str: + assert SYSTEM_LIBDEVICE_PATH is not None, \ + "Could not find libdevice.10.bc path" + return SYSTEM_LIBDEVICE_PATH + diff --git a/python/tests/test_core.py b/python/tests/test_core.py index cf26d44e0..a2175469c 100644 --- a/python/tests/test_core.py +++ b/python/tests/test_core.py @@ -12,6 +12,7 @@ import triton import triton._C.libtriton.triton as _triton import triton.language as tl from triton.runtime.jit import JITFunction, TensorWrapper, reinterpret +from tests.libdevice_testutil import system_libdevice_path int_dtypes = ['int8', 'int16', 'int32', 'int64'] uint_dtypes = ['uint8', 'uint16', 'uint32', 'uint64'] @@ -1552,7 +1553,7 @@ def test_num_warps_pow2(): @pytest.mark.parametrize("dtype_str, expr, lib_path", [('int32', 'libdevice.ffs', ''), - ('float32', 'libdevice.pow', '/usr/local/cuda/nvvm/libdevice/libdevice.10.bc'), + ('float32', 'libdevice.pow', system_libdevice_path()), ('float64', 'libdevice.norm4d', '')]) def test_libdevice_tensor(dtype_str, expr, lib_path): diff --git a/python/tests/test_elementwise.py b/python/tests/test_elementwise.py index 2ff4acae9..f88702296 100644 --- a/python/tests/test_elementwise.py +++ b/python/tests/test_elementwise.py @@ -5,6 +5,7 @@ import _testcapi import pytest import torch from torch.testing import assert_close +from tests.libdevice_testutil import system_libdevice_path import triton import triton.language as tl @@ -32,8 +33,6 @@ torch_ops = { "where": "where", } -libdevice = '/usr/local/cuda/nvvm/libdevice/libdevice.10.bc' - def get_tensor(shape, data_type, b_positive=False): x = None @@ -90,7 +89,11 @@ def kernel(X, Y, BLOCK: tl.constexpr): x = get_tensor(shape, input0_type, expr == 'log' or expr == 'sqrt') # triton result y = torch.zeros(shape, dtype=torch_type[output_type], device="cuda") - kernel[(1,)](x, y, BLOCK=shape[0], extern_libs={"libdevice": libdevice}) + kernel[(1,)]( + x, y, + BLOCK=shape[0], + extern_libs={"libdevice": system_libdevice_path()}, + ) # reference result y_ref = getattr(torch, torch_ops[expr])(x) # compare @@ -134,7 +137,11 @@ def kernel(X0, X1, Y, BLOCK: tl.constexpr): # triton result y = torch.zeros(shape, dtype=torch_type[output_type], device="cuda") - kernel[(1,)](x0, x1, y, BLOCK=shape[0], extern_libs={"libdevice": libdevice}) + kernel[(1,)]( + x0, x1, y, + BLOCK=shape[0], + extern_libs={"libdevice": system_libdevice_path()}, + ) # reference result if expr == "cdiv": @@ -182,7 +189,11 @@ def kernel(X0, X1, X2, Y, BLOCK: tl.constexpr): # triton result y = torch.zeros(shape, dtype=torch_type[output_type], device="cuda") - kernel[(1,)](x0, x1, x2, y, BLOCK=shape[0], extern_libs={"libdevice": libdevice}) + kernel[(1,)]( + x0, x1, x2, y, + BLOCK=shape[0], + extern_libs={"libdevice": system_libdevice_path()}, + ) # reference result y_ref = getattr(torch, torch_ops[expr])(x0, x1, x2) diff --git a/python/tests/test_ext_elemwise.py b/python/tests/test_ext_elemwise.py index 9e44db65e..cef11978b 100644 --- a/python/tests/test_ext_elemwise.py +++ b/python/tests/test_ext_elemwise.py @@ -5,6 +5,7 @@ from torch.testing import assert_close import triton import triton.language as tl +from tests.libdevice_testutil import system_libdevice_path @pytest.mark.parametrize('num_warps, block_size, iter_size', [ @@ -125,7 +126,7 @@ def test_fmad_rn_no_mask(num_warps, block_size, iter_size): @pytest.mark.parametrize("dtype_str, expr, lib_path", - [('int32', 'libdevice.ffs', '/usr/local/cuda/nvvm/libdevice/libdevice.10.bc'), + [('int32', 'libdevice.ffs', system_libdevice_path()), ('int32', 'libdevice.ffs', '')]) def test_libdevice(dtype_str, expr, lib_path): src = f"""