[FRONTEND] Fixed inliner and got more tests to pass (#822)
This adds a `DialectInlinerInterface` to the Triton dialect. This, along with a few other minor semantic changes, fixes our tests on call instructions. Also added the option to provide use an "LLVM_SYSPATH" environment variable to link against locally build of LLVM; this was useful for debugging this issue.
This commit is contained in:
@@ -1177,20 +1177,20 @@ def test_permute(dtype_str, shape, perm, device='cuda'):
|
||||
# # ---------------
|
||||
|
||||
|
||||
# @pytest.mark.parametrize("start", [0, 1, 7, 16])
|
||||
# def test_arange(start, device='cuda'):
|
||||
# BLOCK = 128
|
||||
# z_tri = torch.empty(BLOCK, dtype=torch.int32, device=device)
|
||||
@pytest.mark.parametrize("start", [0, 1, 7, 16])
|
||||
def test_arange(start, device='cuda'):
|
||||
BLOCK = 128
|
||||
z_tri = torch.empty(BLOCK, dtype=torch.int32, device=device)
|
||||
|
||||
# @triton.jit
|
||||
# def _kernel(z, BLOCK: tl.constexpr,
|
||||
# START: tl.constexpr, END: tl.constexpr):
|
||||
# off = tl.arange(0, BLOCK)
|
||||
# val = tl.arange(START, END)
|
||||
# tl.store(z + off, val)
|
||||
# _kernel[(1,)](z_tri, START=start, END=start + BLOCK, BLOCK=BLOCK)
|
||||
# z_ref = torch.arange(start, BLOCK + start, dtype=torch.int32, device=device)
|
||||
# triton.testing.assert_almost_equal(z_tri, z_ref)
|
||||
@triton.jit
|
||||
def _kernel(z, BLOCK: tl.constexpr,
|
||||
START: tl.constexpr, END: tl.constexpr):
|
||||
off = tl.arange(0, BLOCK)
|
||||
val = tl.arange(START, END)
|
||||
tl.store(z + off, val)
|
||||
_kernel[(1,)](z_tri, START=start, END=start + BLOCK, BLOCK=BLOCK)
|
||||
z_ref = torch.arange(start, BLOCK + start, dtype=torch.int32, device=device)
|
||||
triton.testing.assert_almost_equal(z_tri, z_ref)
|
||||
|
||||
# # ---------------
|
||||
# # test load
|
||||
@@ -1248,47 +1248,47 @@ def test_permute(dtype_str, shape, perm, device='cuda'):
|
||||
# triton.testing.allclose(out, reference_out)
|
||||
|
||||
|
||||
# @pytest.mark.parametrize("cache", ["", ".ca", ".cg"])
|
||||
# def test_load_cache_modifier(cache):
|
||||
# src = torch.empty(128, device='cuda')
|
||||
# dst = torch.empty(128, device='cuda')
|
||||
@pytest.mark.parametrize("cache", ["", ".ca", ".cg"])
|
||||
def test_load_cache_modifier(cache):
|
||||
src = torch.empty(128, device='cuda')
|
||||
dst = torch.empty(128, device='cuda')
|
||||
|
||||
# @triton.jit
|
||||
# def _kernel(dst, src, CACHE: tl.constexpr):
|
||||
# offsets = tl.arange(0, 128)
|
||||
# x = tl.load(src + offsets, cache_modifier=CACHE)
|
||||
# tl.store(dst + offsets, x)
|
||||
@triton.jit
|
||||
def _kernel(dst, src, CACHE: tl.constexpr):
|
||||
offsets = tl.arange(0, 128)
|
||||
x = tl.load(src + offsets, cache_modifier=CACHE)
|
||||
tl.store(dst + offsets, x)
|
||||
|
||||
# pgm = _kernel[(1,)](dst, src, CACHE=cache)
|
||||
# ptx = pgm.asm['ptx']
|
||||
# if cache == '':
|
||||
# assert 'ld.global.ca' not in ptx
|
||||
# assert 'ld.global.cg' not in ptx
|
||||
# if cache == '.cg':
|
||||
# assert 'ld.global.cg' in ptx
|
||||
# assert 'ld.global.ca' not in ptx
|
||||
# if cache == '.ca':
|
||||
# assert 'ld.global.ca' in ptx
|
||||
# assert 'ld.global.cg' not in ptx
|
||||
pgm = _kernel[(1,)](dst, src, CACHE=cache)
|
||||
ptx = pgm.asm['ptx']
|
||||
if cache == '':
|
||||
assert 'ld.global.ca' not in ptx
|
||||
assert 'ld.global.cg' not in ptx
|
||||
if cache == '.cg':
|
||||
assert 'ld.global.cg' in ptx
|
||||
assert 'ld.global.ca' not in ptx
|
||||
if cache == '.ca':
|
||||
assert 'ld.global.ca' in ptx
|
||||
assert 'ld.global.cg' not in ptx
|
||||
|
||||
|
||||
# @pytest.mark.parametrize("N", [16, 10, 11, 1024])
|
||||
# def test_vectorization(N):
|
||||
# src = torch.empty(1024, device='cuda')
|
||||
# dst = torch.empty(1024, device='cuda')
|
||||
@pytest.mark.parametrize("N", [16, 10, 11, 1024])
|
||||
def test_vectorization(N):
|
||||
src = torch.empty(1024, device='cuda')
|
||||
dst = torch.empty(1024, device='cuda')
|
||||
|
||||
# @triton.jit
|
||||
# def _kernel(dst, src, N, BLOCK_SIZE: tl.constexpr):
|
||||
# offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
# x = tl.load(src + offsets, mask=offsets < N)
|
||||
# tl.store(dst + offsets, x, mask=offsets < N)
|
||||
# pgm = _kernel[(1,)](dst, src, N=N, BLOCK_SIZE=src.shape[0])
|
||||
# ptx = pgm.asm["ptx"]
|
||||
# if N % 16 == 0:
|
||||
# assert "ld.global.v4.b32" in ptx
|
||||
# else:
|
||||
# assert "ld.global.b32" in ptx
|
||||
# # triton.testing.assert_almost_equal(dst, src[:N])
|
||||
@triton.jit
|
||||
def _kernel(dst, src, N, BLOCK_SIZE: tl.constexpr):
|
||||
offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
x = tl.load(src + offsets, mask=offsets < N)
|
||||
tl.store(dst + offsets, x, mask=offsets < N)
|
||||
pgm = _kernel[(1,)](dst, src, N=N, BLOCK_SIZE=src.shape[0])
|
||||
ptx = pgm.asm["ptx"]
|
||||
if N % 16 == 0:
|
||||
assert "ld.global.v4.b32" in ptx
|
||||
else:
|
||||
assert "ld.global.b32" in ptx
|
||||
# triton.testing.assert_almost_equal(dst, src[:N])
|
||||
# # ---------------
|
||||
# # test store
|
||||
# # ---------------
|
||||
@@ -1335,145 +1335,149 @@ def test_permute(dtype_str, shape, perm, device='cuda'):
|
||||
# # ----------------
|
||||
|
||||
|
||||
# def test_noop(device='cuda'):
|
||||
# @triton.jit
|
||||
# def kernel(x):
|
||||
# pass
|
||||
# x = to_triton(numpy_random((1,), dtype_str='int32'), device=device)
|
||||
# kernel[(1, )](x)
|
||||
def test_noop(device='cuda'):
|
||||
@triton.jit
|
||||
def kernel(x):
|
||||
pass
|
||||
x = to_triton(numpy_random((1,), dtype_str='int32'), device=device)
|
||||
kernel[(1, )](x)
|
||||
|
||||
|
||||
# @pytest.mark.parametrize("value, value_type", [
|
||||
# (-1, 'i32'), (0, 'i32'), (-2**31, 'i32'), (2**31 - 1, 'i32'),
|
||||
# (2**31, 'u32'), (2**32 - 1, 'u32'), (2**32, 'i64'), (2**63 - 1, 'i64'),
|
||||
# (-2**63, 'i64'), (2**63, 'u64'), (2**64 - 1, 'u64')
|
||||
# ])
|
||||
# def test_value_specialization(value: int, value_type: str, device='cuda') -> None:
|
||||
# spec_type = None
|
||||
@pytest.mark.parametrize("value, value_type", [
|
||||
(-1, 'i32'), (0, 'i32'), (-2**31, 'i32'), (2**31 - 1, 'i32'),
|
||||
(2**31, 'u32'), (2**32 - 1, 'u32'), (2**32, 'i64'), (2**63 - 1, 'i64'),
|
||||
(-2**63, 'i64'), (2**63, 'u64'), (2**64 - 1, 'u64')
|
||||
])
|
||||
def test_value_specialization(value: int, value_type: str, device='cuda') -> None:
|
||||
spec_type = None
|
||||
|
||||
# def cache_hook(*args, **kwargs):
|
||||
# nonlocal spec_type
|
||||
# spec_type = kwargs["compile"]["signature"][0]
|
||||
# JITFunction.cache_hook = cache_hook
|
||||
def cache_hook(*args, **kwargs):
|
||||
nonlocal spec_type
|
||||
spec_type = kwargs["compile"]["signature"][0]
|
||||
JITFunction.cache_hook = cache_hook
|
||||
|
||||
# @triton.jit
|
||||
# def kernel(VALUE, X):
|
||||
# pass
|
||||
@triton.jit
|
||||
def kernel(VALUE, X):
|
||||
pass
|
||||
|
||||
# x = torch.tensor([3.14159], device='cuda')
|
||||
# pgm = kernel[(1, )](value, x)
|
||||
x = torch.tensor([3.14159], device='cuda')
|
||||
pgm = kernel[(1, )](value, x)
|
||||
|
||||
# JITFunction.cache_hook = None
|
||||
# assert spec_type == value_type
|
||||
JITFunction.cache_hook = None
|
||||
assert spec_type == value_type
|
||||
|
||||
# # --------------------
|
||||
# # value specialization
|
||||
# # --------------------
|
||||
|
||||
|
||||
# @pytest.mark.parametrize(
|
||||
# "value, overflow",
|
||||
# [(2**64 - 1, False), (2**64, True), (-2**63, False), (-2**63 - 1, True)]
|
||||
# )
|
||||
# def test_value_specialization_overflow(value: int, overflow: bool, device='cuda') -> None:
|
||||
@pytest.mark.parametrize(
|
||||
"value, overflow",
|
||||
[(2**64 - 1, False), (2**64, True), (-2**63, False), (-2**63 - 1, True)]
|
||||
)
|
||||
def test_value_specialization_overflow(value: int, overflow: bool, device='cuda') -> None:
|
||||
|
||||
# @triton.jit
|
||||
# def kernel(VALUE, X):
|
||||
# pass
|
||||
@triton.jit
|
||||
def kernel(VALUE, X):
|
||||
pass
|
||||
|
||||
# x = torch.tensor([3.14159], device='cuda')
|
||||
x = torch.tensor([3.14159], device='cuda')
|
||||
|
||||
# if overflow:
|
||||
# with pytest.raises(OverflowError):
|
||||
# kernel[(1, )](value, x)
|
||||
# else:
|
||||
# kernel[(1, )](value, x)
|
||||
if overflow:
|
||||
with pytest.raises(OverflowError):
|
||||
kernel[(1, )](value, x)
|
||||
else:
|
||||
kernel[(1, )](value, x)
|
||||
|
||||
|
||||
# # ----------------
|
||||
# # test constexpr
|
||||
# # ----------------
|
||||
|
||||
# @pytest.mark.parametrize("op", ['+', '-', '*', '/', '%', '<', '>'])
|
||||
# @pytest.mark.parametrize("is_lhs_constexpr", [False, True])
|
||||
# @pytest.mark.parametrize("is_rhs_constexpr", [True, False])
|
||||
# def test_bin_op_constexpr(op, is_lhs_constexpr, is_rhs_constexpr):
|
||||
@pytest.mark.parametrize("op", ['+', '-', '*', '/', '%', '<', '>'])
|
||||
@pytest.mark.parametrize("is_lhs_constexpr", [False, True])
|
||||
@pytest.mark.parametrize("is_rhs_constexpr", [True, False])
|
||||
def test_bin_op_constexpr(op, is_lhs_constexpr, is_rhs_constexpr):
|
||||
|
||||
# @triton.jit
|
||||
# def kernel(Z, X, Y):
|
||||
# x = tl.load(X)
|
||||
# y = tl.load(Y)
|
||||
# z = GENERATE_TEST_HERE
|
||||
# tl.store(Z, z)
|
||||
@triton.jit
|
||||
def kernel(Z, X, Y):
|
||||
x = tl.load(X)
|
||||
y = tl.load(Y)
|
||||
z = GENERATE_TEST_HERE
|
||||
tl.store(Z, z)
|
||||
|
||||
# x_str = "3.14" if is_lhs_constexpr else "x"
|
||||
# y_str = "4.13" if is_rhs_constexpr else "y"
|
||||
# kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f"{x_str} {op} {y_str}"})
|
||||
# x = numpy_random((1,), dtype_str="float32")
|
||||
# y = numpy_random((1,), dtype_str="float32")
|
||||
# z = np.array(eval(f"{x_str} {op} {y_str}"))
|
||||
# x_tri = to_triton(x)
|
||||
# y_tri = to_triton(y)
|
||||
# z_tri = to_triton(np.empty((1,), dtype=z.dtype))
|
||||
# kernel[(1,)](z_tri, x_tri, y_tri)
|
||||
# np.testing.assert_allclose(z, to_numpy(z_tri))
|
||||
x_str = "3.14" if is_lhs_constexpr else "x"
|
||||
y_str = "4.13" if is_rhs_constexpr else "y"
|
||||
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f"{x_str} {op} {y_str}"})
|
||||
x = numpy_random((1,), dtype_str="float32")
|
||||
y = numpy_random((1,), dtype_str="float32")
|
||||
z = np.array(eval(f"{x_str} {op} {y_str}"))
|
||||
x_tri = to_triton(x)
|
||||
y_tri = to_triton(y)
|
||||
z_tri = to_triton(np.empty((1,), dtype=z.dtype))
|
||||
kernel[(1,)](z_tri, x_tri, y_tri)
|
||||
np.testing.assert_allclose(z, to_numpy(z_tri))
|
||||
|
||||
|
||||
# def test_constexpr_shape():
|
||||
def test_constexpr_shape():
|
||||
|
||||
# @triton.jit
|
||||
# def kernel(X):
|
||||
# off = tl.arange(0, 128 + 128)
|
||||
# tl.store(X + off, off)
|
||||
@triton.jit
|
||||
def kernel(X):
|
||||
off = tl.arange(0, 128 + 128)
|
||||
tl.store(X + off, off)
|
||||
|
||||
# x_tri = to_triton(np.empty((256, ), dtype=np.int32))
|
||||
# kernel[(1,)](x_tri)
|
||||
# np.testing.assert_equal(to_numpy(x_tri), np.arange(0, 256))
|
||||
x_tri = to_triton(np.empty((256, ), dtype=np.int32))
|
||||
kernel[(1,)](x_tri)
|
||||
np.testing.assert_equal(to_numpy(x_tri), np.arange(0, 256))
|
||||
|
||||
|
||||
# def test_constexpr_scalar_shape():
|
||||
def test_constexpr_scalar_shape():
|
||||
|
||||
# @triton.jit
|
||||
# def kernel(X, s):
|
||||
# off = tl.arange(0, 256)
|
||||
# val = off % (256 // s)
|
||||
# tl.store(X + off, val)
|
||||
@triton.jit
|
||||
def kernel(X, s):
|
||||
off = tl.arange(0, 256)
|
||||
val = off % (256 // s)
|
||||
tl.store(X + off, val)
|
||||
|
||||
# x_tri = to_triton(np.empty((256, ), dtype=np.int32))
|
||||
# kernel[(1,)](x_tri, 32)
|
||||
# np.testing.assert_equal(to_numpy(x_tri), np.arange(0, 256) % 8)
|
||||
x_tri = to_triton(np.empty((256, ), dtype=np.int32))
|
||||
kernel[(1,)](x_tri, 32)
|
||||
np.testing.assert_equal(to_numpy(x_tri), np.arange(0, 256) % 8)
|
||||
|
||||
# # -------------
|
||||
# # test call
|
||||
# # -------------
|
||||
|
||||
|
||||
# @triton.jit
|
||||
# def val_multiplier(val, i):
|
||||
# return val * i
|
||||
@triton.jit
|
||||
def val_multiplier(val, i):
|
||||
return val * i
|
||||
|
||||
|
||||
# @triton.jit
|
||||
# def vecmul_kernel(ptr, n_elements, rep):
|
||||
# pid = tl.program_id(axis=0)
|
||||
# offsets = pid * 128 + tl.arange(0, 128)
|
||||
# mask = offsets < n_elements
|
||||
# vec = tl.load(ptr + offsets, mask=mask)
|
||||
# for i in range(1, rep):
|
||||
# vec = val_multiplier(vec, i)
|
||||
# tl.store(ptr + offsets, vec, mask=mask)
|
||||
@triton.jit
|
||||
def vecmul_kernel(ptr, n_elements, rep):
|
||||
pid = tl.program_id(axis=0)
|
||||
offsets = pid * 128 + tl.arange(0, 128)
|
||||
mask = offsets < n_elements
|
||||
vec = tl.load(ptr + offsets, mask=mask)
|
||||
for i in range(1, rep):
|
||||
vec = val_multiplier(vec, i)
|
||||
tl.store(ptr + offsets, vec, mask=mask)
|
||||
|
||||
|
||||
# def test_call():
|
||||
def test_call():
|
||||
|
||||
# @triton.jit
|
||||
# def kernel(ptr, n_elements, num1, num2):
|
||||
# vecmul_kernel(ptr, n_elements, num1)
|
||||
# vecmul_kernel(ptr, n_elements, num2)
|
||||
@triton.jit
|
||||
def kernel(ptr, n_elements, num1, num2):
|
||||
vecmul_kernel(ptr, n_elements, num1)
|
||||
vecmul_kernel(ptr, n_elements, num2)
|
||||
|
||||
# size = 1024
|
||||
# rand_val = numpy_random((size,), dtype_str="float32")
|
||||
# rand_val_tri = to_triton(rand_val, device='cuda')
|
||||
# kernel[(size // 128,)](rand_val_tri, size, 3, 5)
|
||||
size = 1024
|
||||
rand_val = numpy_random((size,), dtype_str="float32")
|
||||
rand_val_tri = to_triton(rand_val, device='cuda')
|
||||
kernel[(size // 128,)](rand_val_tri, size, 3, 5)
|
||||
|
||||
# ans = rand_val * 1 * 2 * 1 * 2 * 3 * 4
|
||||
# np.testing.assert_equal(to_numpy(rand_val_tri), ans)
|
||||
ans = rand_val * 1 * 2 * 1 * 2 * 3 * 4
|
||||
np.testing.assert_equal(to_numpy(rand_val_tri), ans)
|
||||
|
||||
# # -------------
|
||||
# # test if
|
||||
|
Reference in New Issue
Block a user