[FRONTEND] Support alternative install locations of system libdevice.10.bc (#951)
This commit is contained in:
committed by
GitHub
parent
e419781978
commit
9490252261
@@ -5,6 +5,7 @@ import _testcapi
|
||||
import pytest
|
||||
import torch
|
||||
from torch.testing import assert_close
|
||||
from tests.libdevice_testutil import system_libdevice_path
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
@@ -32,8 +33,6 @@ torch_ops = {
|
||||
"where": "where",
|
||||
}
|
||||
|
||||
libdevice = '/usr/local/cuda/nvvm/libdevice/libdevice.10.bc'
|
||||
|
||||
|
||||
def get_tensor(shape, data_type, b_positive=False):
|
||||
x = None
|
||||
@@ -90,7 +89,11 @@ def kernel(X, Y, BLOCK: tl.constexpr):
|
||||
x = get_tensor(shape, input0_type, expr == 'log' or expr == 'sqrt')
|
||||
# triton result
|
||||
y = torch.zeros(shape, dtype=torch_type[output_type], device="cuda")
|
||||
kernel[(1,)](x, y, BLOCK=shape[0], extern_libs={"libdevice": libdevice})
|
||||
kernel[(1,)](
|
||||
x, y,
|
||||
BLOCK=shape[0],
|
||||
extern_libs={"libdevice": system_libdevice_path()},
|
||||
)
|
||||
# reference result
|
||||
y_ref = getattr(torch, torch_ops[expr])(x)
|
||||
# compare
|
||||
@@ -134,7 +137,11 @@ def kernel(X0, X1, Y, BLOCK: tl.constexpr):
|
||||
|
||||
# triton result
|
||||
y = torch.zeros(shape, dtype=torch_type[output_type], device="cuda")
|
||||
kernel[(1,)](x0, x1, y, BLOCK=shape[0], extern_libs={"libdevice": libdevice})
|
||||
kernel[(1,)](
|
||||
x0, x1, y,
|
||||
BLOCK=shape[0],
|
||||
extern_libs={"libdevice": system_libdevice_path()},
|
||||
)
|
||||
# reference result
|
||||
|
||||
if expr == "cdiv":
|
||||
@@ -182,7 +189,11 @@ def kernel(X0, X1, X2, Y, BLOCK: tl.constexpr):
|
||||
|
||||
# triton result
|
||||
y = torch.zeros(shape, dtype=torch_type[output_type], device="cuda")
|
||||
kernel[(1,)](x0, x1, x2, y, BLOCK=shape[0], extern_libs={"libdevice": libdevice})
|
||||
kernel[(1,)](
|
||||
x0, x1, x2, y,
|
||||
BLOCK=shape[0],
|
||||
extern_libs={"libdevice": system_libdevice_path()},
|
||||
)
|
||||
# reference result
|
||||
|
||||
y_ref = getattr(torch, torch_ops[expr])(x0, x1, x2)
|
||||
|
Reference in New Issue
Block a user