[FRONTEND] Fixed inliner and got more tests to pass (#822)

This adds a `DialectInlinerInterface` to the Triton dialect. This, along
with a few other minor semantic changes, fixes our tests on call
instructions. Also added the option to provide use an "LLVM_SYSPATH"
environment variable to link against locally build of LLVM; this was
useful for debugging this issue.
This commit is contained in:
Philippe Tillet
2022-10-30 14:10:02 -07:00
committed by GitHub
parent 71428194a1
commit e61dc75942
7 changed files with 192 additions and 155 deletions

View File

@@ -1177,20 +1177,20 @@ def test_permute(dtype_str, shape, perm, device='cuda'):
# # ---------------
# @pytest.mark.parametrize("start", [0, 1, 7, 16])
# def test_arange(start, device='cuda'):
# BLOCK = 128
# z_tri = torch.empty(BLOCK, dtype=torch.int32, device=device)
@pytest.mark.parametrize("start", [0, 1, 7, 16])
def test_arange(start, device='cuda'):
BLOCK = 128
z_tri = torch.empty(BLOCK, dtype=torch.int32, device=device)
# @triton.jit
# def _kernel(z, BLOCK: tl.constexpr,
# START: tl.constexpr, END: tl.constexpr):
# off = tl.arange(0, BLOCK)
# val = tl.arange(START, END)
# tl.store(z + off, val)
# _kernel[(1,)](z_tri, START=start, END=start + BLOCK, BLOCK=BLOCK)
# z_ref = torch.arange(start, BLOCK + start, dtype=torch.int32, device=device)
# triton.testing.assert_almost_equal(z_tri, z_ref)
@triton.jit
def _kernel(z, BLOCK: tl.constexpr,
START: tl.constexpr, END: tl.constexpr):
off = tl.arange(0, BLOCK)
val = tl.arange(START, END)
tl.store(z + off, val)
_kernel[(1,)](z_tri, START=start, END=start + BLOCK, BLOCK=BLOCK)
z_ref = torch.arange(start, BLOCK + start, dtype=torch.int32, device=device)
triton.testing.assert_almost_equal(z_tri, z_ref)
# # ---------------
# # test load
@@ -1248,47 +1248,47 @@ def test_permute(dtype_str, shape, perm, device='cuda'):
# triton.testing.allclose(out, reference_out)
# @pytest.mark.parametrize("cache", ["", ".ca", ".cg"])
# def test_load_cache_modifier(cache):
# src = torch.empty(128, device='cuda')
# dst = torch.empty(128, device='cuda')
@pytest.mark.parametrize("cache", ["", ".ca", ".cg"])
def test_load_cache_modifier(cache):
src = torch.empty(128, device='cuda')
dst = torch.empty(128, device='cuda')
# @triton.jit
# def _kernel(dst, src, CACHE: tl.constexpr):
# offsets = tl.arange(0, 128)
# x = tl.load(src + offsets, cache_modifier=CACHE)
# tl.store(dst + offsets, x)
@triton.jit
def _kernel(dst, src, CACHE: tl.constexpr):
offsets = tl.arange(0, 128)
x = tl.load(src + offsets, cache_modifier=CACHE)
tl.store(dst + offsets, x)
# pgm = _kernel[(1,)](dst, src, CACHE=cache)
# ptx = pgm.asm['ptx']
# if cache == '':
# assert 'ld.global.ca' not in ptx
# assert 'ld.global.cg' not in ptx
# if cache == '.cg':
# assert 'ld.global.cg' in ptx
# assert 'ld.global.ca' not in ptx
# if cache == '.ca':
# assert 'ld.global.ca' in ptx
# assert 'ld.global.cg' not in ptx
pgm = _kernel[(1,)](dst, src, CACHE=cache)
ptx = pgm.asm['ptx']
if cache == '':
assert 'ld.global.ca' not in ptx
assert 'ld.global.cg' not in ptx
if cache == '.cg':
assert 'ld.global.cg' in ptx
assert 'ld.global.ca' not in ptx
if cache == '.ca':
assert 'ld.global.ca' in ptx
assert 'ld.global.cg' not in ptx
# @pytest.mark.parametrize("N", [16, 10, 11, 1024])
# def test_vectorization(N):
# src = torch.empty(1024, device='cuda')
# dst = torch.empty(1024, device='cuda')
@pytest.mark.parametrize("N", [16, 10, 11, 1024])
def test_vectorization(N):
src = torch.empty(1024, device='cuda')
dst = torch.empty(1024, device='cuda')
# @triton.jit
# def _kernel(dst, src, N, BLOCK_SIZE: tl.constexpr):
# offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
# x = tl.load(src + offsets, mask=offsets < N)
# tl.store(dst + offsets, x, mask=offsets < N)
# pgm = _kernel[(1,)](dst, src, N=N, BLOCK_SIZE=src.shape[0])
# ptx = pgm.asm["ptx"]
# if N % 16 == 0:
# assert "ld.global.v4.b32" in ptx
# else:
# assert "ld.global.b32" in ptx
# # triton.testing.assert_almost_equal(dst, src[:N])
@triton.jit
def _kernel(dst, src, N, BLOCK_SIZE: tl.constexpr):
offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
x = tl.load(src + offsets, mask=offsets < N)
tl.store(dst + offsets, x, mask=offsets < N)
pgm = _kernel[(1,)](dst, src, N=N, BLOCK_SIZE=src.shape[0])
ptx = pgm.asm["ptx"]
if N % 16 == 0:
assert "ld.global.v4.b32" in ptx
else:
assert "ld.global.b32" in ptx
# triton.testing.assert_almost_equal(dst, src[:N])
# # ---------------
# # test store
# # ---------------
@@ -1335,145 +1335,149 @@ def test_permute(dtype_str, shape, perm, device='cuda'):
# # ----------------
# def test_noop(device='cuda'):
# @triton.jit
# def kernel(x):
# pass
# x = to_triton(numpy_random((1,), dtype_str='int32'), device=device)
# kernel[(1, )](x)
def test_noop(device='cuda'):
@triton.jit
def kernel(x):
pass
x = to_triton(numpy_random((1,), dtype_str='int32'), device=device)
kernel[(1, )](x)
# @pytest.mark.parametrize("value, value_type", [
# (-1, 'i32'), (0, 'i32'), (-2**31, 'i32'), (2**31 - 1, 'i32'),
# (2**31, 'u32'), (2**32 - 1, 'u32'), (2**32, 'i64'), (2**63 - 1, 'i64'),
# (-2**63, 'i64'), (2**63, 'u64'), (2**64 - 1, 'u64')
# ])
# def test_value_specialization(value: int, value_type: str, device='cuda') -> None:
# spec_type = None
@pytest.mark.parametrize("value, value_type", [
(-1, 'i32'), (0, 'i32'), (-2**31, 'i32'), (2**31 - 1, 'i32'),
(2**31, 'u32'), (2**32 - 1, 'u32'), (2**32, 'i64'), (2**63 - 1, 'i64'),
(-2**63, 'i64'), (2**63, 'u64'), (2**64 - 1, 'u64')
])
def test_value_specialization(value: int, value_type: str, device='cuda') -> None:
spec_type = None
# def cache_hook(*args, **kwargs):
# nonlocal spec_type
# spec_type = kwargs["compile"]["signature"][0]
# JITFunction.cache_hook = cache_hook
def cache_hook(*args, **kwargs):
nonlocal spec_type
spec_type = kwargs["compile"]["signature"][0]
JITFunction.cache_hook = cache_hook
# @triton.jit
# def kernel(VALUE, X):
# pass
@triton.jit
def kernel(VALUE, X):
pass
# x = torch.tensor([3.14159], device='cuda')
# pgm = kernel[(1, )](value, x)
x = torch.tensor([3.14159], device='cuda')
pgm = kernel[(1, )](value, x)
# JITFunction.cache_hook = None
# assert spec_type == value_type
JITFunction.cache_hook = None
assert spec_type == value_type
# # --------------------
# # value specialization
# # --------------------
# @pytest.mark.parametrize(
# "value, overflow",
# [(2**64 - 1, False), (2**64, True), (-2**63, False), (-2**63 - 1, True)]
# )
# def test_value_specialization_overflow(value: int, overflow: bool, device='cuda') -> None:
@pytest.mark.parametrize(
"value, overflow",
[(2**64 - 1, False), (2**64, True), (-2**63, False), (-2**63 - 1, True)]
)
def test_value_specialization_overflow(value: int, overflow: bool, device='cuda') -> None:
# @triton.jit
# def kernel(VALUE, X):
# pass
@triton.jit
def kernel(VALUE, X):
pass
# x = torch.tensor([3.14159], device='cuda')
x = torch.tensor([3.14159], device='cuda')
# if overflow:
# with pytest.raises(OverflowError):
# kernel[(1, )](value, x)
# else:
# kernel[(1, )](value, x)
if overflow:
with pytest.raises(OverflowError):
kernel[(1, )](value, x)
else:
kernel[(1, )](value, x)
# # ----------------
# # test constexpr
# # ----------------
# @pytest.mark.parametrize("op", ['+', '-', '*', '/', '%', '<', '>'])
# @pytest.mark.parametrize("is_lhs_constexpr", [False, True])
# @pytest.mark.parametrize("is_rhs_constexpr", [True, False])
# def test_bin_op_constexpr(op, is_lhs_constexpr, is_rhs_constexpr):
@pytest.mark.parametrize("op", ['+', '-', '*', '/', '%', '<', '>'])
@pytest.mark.parametrize("is_lhs_constexpr", [False, True])
@pytest.mark.parametrize("is_rhs_constexpr", [True, False])
def test_bin_op_constexpr(op, is_lhs_constexpr, is_rhs_constexpr):
# @triton.jit
# def kernel(Z, X, Y):
# x = tl.load(X)
# y = tl.load(Y)
# z = GENERATE_TEST_HERE
# tl.store(Z, z)
@triton.jit
def kernel(Z, X, Y):
x = tl.load(X)
y = tl.load(Y)
z = GENERATE_TEST_HERE
tl.store(Z, z)
# x_str = "3.14" if is_lhs_constexpr else "x"
# y_str = "4.13" if is_rhs_constexpr else "y"
# kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f"{x_str} {op} {y_str}"})
# x = numpy_random((1,), dtype_str="float32")
# y = numpy_random((1,), dtype_str="float32")
# z = np.array(eval(f"{x_str} {op} {y_str}"))
# x_tri = to_triton(x)
# y_tri = to_triton(y)
# z_tri = to_triton(np.empty((1,), dtype=z.dtype))
# kernel[(1,)](z_tri, x_tri, y_tri)
# np.testing.assert_allclose(z, to_numpy(z_tri))
x_str = "3.14" if is_lhs_constexpr else "x"
y_str = "4.13" if is_rhs_constexpr else "y"
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f"{x_str} {op} {y_str}"})
x = numpy_random((1,), dtype_str="float32")
y = numpy_random((1,), dtype_str="float32")
z = np.array(eval(f"{x_str} {op} {y_str}"))
x_tri = to_triton(x)
y_tri = to_triton(y)
z_tri = to_triton(np.empty((1,), dtype=z.dtype))
kernel[(1,)](z_tri, x_tri, y_tri)
np.testing.assert_allclose(z, to_numpy(z_tri))
# def test_constexpr_shape():
def test_constexpr_shape():
# @triton.jit
# def kernel(X):
# off = tl.arange(0, 128 + 128)
# tl.store(X + off, off)
@triton.jit
def kernel(X):
off = tl.arange(0, 128 + 128)
tl.store(X + off, off)
# x_tri = to_triton(np.empty((256, ), dtype=np.int32))
# kernel[(1,)](x_tri)
# np.testing.assert_equal(to_numpy(x_tri), np.arange(0, 256))
x_tri = to_triton(np.empty((256, ), dtype=np.int32))
kernel[(1,)](x_tri)
np.testing.assert_equal(to_numpy(x_tri), np.arange(0, 256))
# def test_constexpr_scalar_shape():
def test_constexpr_scalar_shape():
# @triton.jit
# def kernel(X, s):
# off = tl.arange(0, 256)
# val = off % (256 // s)
# tl.store(X + off, val)
@triton.jit
def kernel(X, s):
off = tl.arange(0, 256)
val = off % (256 // s)
tl.store(X + off, val)
# x_tri = to_triton(np.empty((256, ), dtype=np.int32))
# kernel[(1,)](x_tri, 32)
# np.testing.assert_equal(to_numpy(x_tri), np.arange(0, 256) % 8)
x_tri = to_triton(np.empty((256, ), dtype=np.int32))
kernel[(1,)](x_tri, 32)
np.testing.assert_equal(to_numpy(x_tri), np.arange(0, 256) % 8)
# # -------------
# # test call
# # -------------
# @triton.jit
# def val_multiplier(val, i):
# return val * i
@triton.jit
def val_multiplier(val, i):
return val * i
# @triton.jit
# def vecmul_kernel(ptr, n_elements, rep):
# pid = tl.program_id(axis=0)
# offsets = pid * 128 + tl.arange(0, 128)
# mask = offsets < n_elements
# vec = tl.load(ptr + offsets, mask=mask)
# for i in range(1, rep):
# vec = val_multiplier(vec, i)
# tl.store(ptr + offsets, vec, mask=mask)
@triton.jit
def vecmul_kernel(ptr, n_elements, rep):
pid = tl.program_id(axis=0)
offsets = pid * 128 + tl.arange(0, 128)
mask = offsets < n_elements
vec = tl.load(ptr + offsets, mask=mask)
for i in range(1, rep):
vec = val_multiplier(vec, i)
tl.store(ptr + offsets, vec, mask=mask)
# def test_call():
def test_call():
# @triton.jit
# def kernel(ptr, n_elements, num1, num2):
# vecmul_kernel(ptr, n_elements, num1)
# vecmul_kernel(ptr, n_elements, num2)
@triton.jit
def kernel(ptr, n_elements, num1, num2):
vecmul_kernel(ptr, n_elements, num1)
vecmul_kernel(ptr, n_elements, num2)
# size = 1024
# rand_val = numpy_random((size,), dtype_str="float32")
# rand_val_tri = to_triton(rand_val, device='cuda')
# kernel[(size // 128,)](rand_val_tri, size, 3, 5)
size = 1024
rand_val = numpy_random((size,), dtype_str="float32")
rand_val_tri = to_triton(rand_val, device='cuda')
kernel[(size // 128,)](rand_val_tri, size, 3, 5)
# ans = rand_val * 1 * 2 * 1 * 2 * 3 * 4
# np.testing.assert_equal(to_numpy(rand_val_tri), ans)
ans = rand_val * 1 * 2 * 1 * 2 * 3 * 4
np.testing.assert_equal(to_numpy(rand_val_tri), ans)
# # -------------
# # test if