[FRONTEND] Support alternative install locations of system libdevice.10.bc (#951)
This commit is contained in:
committed by
GitHub
parent
e419781978
commit
9490252261
18
python/tests/libdevice_testutil.py
Normal file
18
python/tests/libdevice_testutil.py
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
import os
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
_SYSTEM_LIBDEVICE_SEARCH_PATHS = [
|
||||||
|
'/usr/lib/cuda/nvvm/libdevice/libdevice.10.bc',
|
||||||
|
'/usr/local/cuda/nvvm/libdevice/libdevice.10.bc',
|
||||||
|
]
|
||||||
|
|
||||||
|
SYSTEM_LIBDEVICE_PATH: Optional[str] = None
|
||||||
|
for _p in _SYSTEM_LIBDEVICE_SEARCH_PATHS:
|
||||||
|
if os.path.exists(_p):
|
||||||
|
SYSTEM_LIBDEVICE_PATH = _p
|
||||||
|
|
||||||
|
def system_libdevice_path() -> str:
|
||||||
|
assert SYSTEM_LIBDEVICE_PATH is not None, \
|
||||||
|
"Could not find libdevice.10.bc path"
|
||||||
|
return SYSTEM_LIBDEVICE_PATH
|
||||||
|
|
@@ -12,6 +12,7 @@ import triton
|
|||||||
import triton._C.libtriton.triton as _triton
|
import triton._C.libtriton.triton as _triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
from triton.runtime.jit import JITFunction, TensorWrapper, reinterpret
|
from triton.runtime.jit import JITFunction, TensorWrapper, reinterpret
|
||||||
|
from tests.libdevice_testutil import system_libdevice_path
|
||||||
|
|
||||||
int_dtypes = ['int8', 'int16', 'int32', 'int64']
|
int_dtypes = ['int8', 'int16', 'int32', 'int64']
|
||||||
uint_dtypes = ['uint8', 'uint16', 'uint32', 'uint64']
|
uint_dtypes = ['uint8', 'uint16', 'uint32', 'uint64']
|
||||||
@@ -1552,7 +1553,7 @@ def test_num_warps_pow2():
|
|||||||
|
|
||||||
@pytest.mark.parametrize("dtype_str, expr, lib_path",
|
@pytest.mark.parametrize("dtype_str, expr, lib_path",
|
||||||
[('int32', 'libdevice.ffs', ''),
|
[('int32', 'libdevice.ffs', ''),
|
||||||
('float32', 'libdevice.pow', '/usr/local/cuda/nvvm/libdevice/libdevice.10.bc'),
|
('float32', 'libdevice.pow', system_libdevice_path()),
|
||||||
('float64', 'libdevice.norm4d', '')])
|
('float64', 'libdevice.norm4d', '')])
|
||||||
def test_libdevice_tensor(dtype_str, expr, lib_path):
|
def test_libdevice_tensor(dtype_str, expr, lib_path):
|
||||||
|
|
||||||
|
@@ -5,6 +5,7 @@ import _testcapi
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from torch.testing import assert_close
|
from torch.testing import assert_close
|
||||||
|
from tests.libdevice_testutil import system_libdevice_path
|
||||||
|
|
||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
@@ -32,8 +33,6 @@ torch_ops = {
|
|||||||
"where": "where",
|
"where": "where",
|
||||||
}
|
}
|
||||||
|
|
||||||
libdevice = '/usr/local/cuda/nvvm/libdevice/libdevice.10.bc'
|
|
||||||
|
|
||||||
|
|
||||||
def get_tensor(shape, data_type, b_positive=False):
|
def get_tensor(shape, data_type, b_positive=False):
|
||||||
x = None
|
x = None
|
||||||
@@ -90,7 +89,11 @@ def kernel(X, Y, BLOCK: tl.constexpr):
|
|||||||
x = get_tensor(shape, input0_type, expr == 'log' or expr == 'sqrt')
|
x = get_tensor(shape, input0_type, expr == 'log' or expr == 'sqrt')
|
||||||
# triton result
|
# triton result
|
||||||
y = torch.zeros(shape, dtype=torch_type[output_type], device="cuda")
|
y = torch.zeros(shape, dtype=torch_type[output_type], device="cuda")
|
||||||
kernel[(1,)](x, y, BLOCK=shape[0], extern_libs={"libdevice": libdevice})
|
kernel[(1,)](
|
||||||
|
x, y,
|
||||||
|
BLOCK=shape[0],
|
||||||
|
extern_libs={"libdevice": system_libdevice_path()},
|
||||||
|
)
|
||||||
# reference result
|
# reference result
|
||||||
y_ref = getattr(torch, torch_ops[expr])(x)
|
y_ref = getattr(torch, torch_ops[expr])(x)
|
||||||
# compare
|
# compare
|
||||||
@@ -134,7 +137,11 @@ def kernel(X0, X1, Y, BLOCK: tl.constexpr):
|
|||||||
|
|
||||||
# triton result
|
# triton result
|
||||||
y = torch.zeros(shape, dtype=torch_type[output_type], device="cuda")
|
y = torch.zeros(shape, dtype=torch_type[output_type], device="cuda")
|
||||||
kernel[(1,)](x0, x1, y, BLOCK=shape[0], extern_libs={"libdevice": libdevice})
|
kernel[(1,)](
|
||||||
|
x0, x1, y,
|
||||||
|
BLOCK=shape[0],
|
||||||
|
extern_libs={"libdevice": system_libdevice_path()},
|
||||||
|
)
|
||||||
# reference result
|
# reference result
|
||||||
|
|
||||||
if expr == "cdiv":
|
if expr == "cdiv":
|
||||||
@@ -182,7 +189,11 @@ def kernel(X0, X1, X2, Y, BLOCK: tl.constexpr):
|
|||||||
|
|
||||||
# triton result
|
# triton result
|
||||||
y = torch.zeros(shape, dtype=torch_type[output_type], device="cuda")
|
y = torch.zeros(shape, dtype=torch_type[output_type], device="cuda")
|
||||||
kernel[(1,)](x0, x1, x2, y, BLOCK=shape[0], extern_libs={"libdevice": libdevice})
|
kernel[(1,)](
|
||||||
|
x0, x1, x2, y,
|
||||||
|
BLOCK=shape[0],
|
||||||
|
extern_libs={"libdevice": system_libdevice_path()},
|
||||||
|
)
|
||||||
# reference result
|
# reference result
|
||||||
|
|
||||||
y_ref = getattr(torch, torch_ops[expr])(x0, x1, x2)
|
y_ref = getattr(torch, torch_ops[expr])(x0, x1, x2)
|
||||||
|
@@ -5,6 +5,7 @@ from torch.testing import assert_close
|
|||||||
|
|
||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
from tests.libdevice_testutil import system_libdevice_path
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('num_warps, block_size, iter_size', [
|
@pytest.mark.parametrize('num_warps, block_size, iter_size', [
|
||||||
@@ -125,7 +126,7 @@ def test_fmad_rn_no_mask(num_warps, block_size, iter_size):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("dtype_str, expr, lib_path",
|
@pytest.mark.parametrize("dtype_str, expr, lib_path",
|
||||||
[('int32', 'libdevice.ffs', '/usr/local/cuda/nvvm/libdevice/libdevice.10.bc'),
|
[('int32', 'libdevice.ffs', system_libdevice_path()),
|
||||||
('int32', 'libdevice.ffs', '')])
|
('int32', 'libdevice.ffs', '')])
|
||||||
def test_libdevice(dtype_str, expr, lib_path):
|
def test_libdevice(dtype_str, expr, lib_path):
|
||||||
src = f"""
|
src = f"""
|
||||||
|
Reference in New Issue
Block a user