[FRONTEND] Improved constexpr handling (#493)

This commit is contained in:
Philippe Tillet
2022-04-12 00:02:54 -07:00
committed by GitHub
parent 14b0fd4cfb
commit 76bfac9f15
4 changed files with 70 additions and 96 deletions

View File

@@ -1032,28 +1032,44 @@ def test_value_specialization_overflow(value: int, overflow: bool, device='cuda'
kernel[(1, )](value, x)
else:
kernel[(1, )](value, x)
# -------------------------
# test dynamic parallelism
# -------------------------
@triton.jit
def mult(x, alpha):
tl.store(x + tl.program_id(0), alpha)
# ----------------
# test constexpr
# ----------------
@pytest.mark.parametrize("op", ['+', '-', '*', '/', '%', '<', '>'])
@pytest.mark.parametrize("is_lhs_constexpr", [False, True])
@pytest.mark.parametrize("is_rhs_constexpr", [True, False])
def test_bin_op_constexpr(op, is_lhs_constexpr, is_rhs_constexpr):
@triton.jit
def kernel(Z, X, Y):
x = tl.load(X)
y = tl.load(Y)
z = GENERATE_TEST_HERE
tl.store(Z, z)
x_str = "3.14" if is_lhs_constexpr else "x"
y_str = "4.13" if is_rhs_constexpr else "y"
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f"{x_str} {op} {y_str}"})
x = numpy_random((1,), dtype_str="float32")
y = numpy_random((1,), dtype_str="float32")
z = np.array(eval(f"{x_str} {op} {y_str}"))
x_tri = to_triton(x)
y_tri = to_triton(y)
z_tri = to_triton(np.empty((1,), dtype=z.dtype))
kernel[(1,)](z_tri, x_tri, y_tri)
np.testing.assert_allclose(z, to_numpy(z_tri))
@triton.jit
def stub(X, alpha, grid_0, grid_1, grid_2):
tl.launch(mult, [X, alpha], [grid_0, grid_1, grid_2])
def test_constexpr_shape():
@triton.jit
def kernel(X):
off = tl.arange(0, 128 + 128)
tl.store(X + off, off)
# def test_dyn_par(cond=True, device='cuda'):
# n_pids = 10
# # pids = torch.arange(n_pids, device=device)
# # alpha = 2.0
# # x_ref = pids * alpha
# x_tri = torch.full((10,), fill_value=-1., device=device)
# # cond = torch.tensor([cond], device=device)
# stub[(1,)](x_tri, 3.14, n_pids, 1, 1)
# print(x_tri)
# # triton.testing.assert_almost_equal(x_ref, x_tri)
x_tri = to_triton(np.empty((256, ), dtype=np.int32))
kernel[(1,)](x_tri)
np.testing.assert_equal(to_numpy(x_tri), np.arange(0, 256))