[FRONTEND] Improved constexpr handling (#493)
This commit is contained in:
@@ -1032,28 +1032,44 @@ def test_value_specialization_overflow(value: int, overflow: bool, device='cuda'
|
||||
kernel[(1, )](value, x)
|
||||
else:
|
||||
kernel[(1, )](value, x)
|
||||
# -------------------------
|
||||
# test dynamic parallelism
|
||||
# -------------------------
|
||||
|
||||
|
||||
@triton.jit
|
||||
def mult(x, alpha):
|
||||
tl.store(x + tl.program_id(0), alpha)
|
||||
# ----------------
|
||||
# test constexpr
|
||||
# ----------------
|
||||
|
||||
@pytest.mark.parametrize("op", ['+', '-', '*', '/', '%', '<', '>'])
|
||||
@pytest.mark.parametrize("is_lhs_constexpr", [False, True])
|
||||
@pytest.mark.parametrize("is_rhs_constexpr", [True, False])
|
||||
def test_bin_op_constexpr(op, is_lhs_constexpr, is_rhs_constexpr):
|
||||
|
||||
@triton.jit
|
||||
def kernel(Z, X, Y):
|
||||
x = tl.load(X)
|
||||
y = tl.load(Y)
|
||||
z = GENERATE_TEST_HERE
|
||||
tl.store(Z, z)
|
||||
|
||||
x_str = "3.14" if is_lhs_constexpr else "x"
|
||||
y_str = "4.13" if is_rhs_constexpr else "y"
|
||||
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f"{x_str} {op} {y_str}"})
|
||||
x = numpy_random((1,), dtype_str="float32")
|
||||
y = numpy_random((1,), dtype_str="float32")
|
||||
z = np.array(eval(f"{x_str} {op} {y_str}"))
|
||||
x_tri = to_triton(x)
|
||||
y_tri = to_triton(y)
|
||||
z_tri = to_triton(np.empty((1,), dtype=z.dtype))
|
||||
kernel[(1,)](z_tri, x_tri, y_tri)
|
||||
np.testing.assert_allclose(z, to_numpy(z_tri))
|
||||
|
||||
|
||||
@triton.jit
|
||||
def stub(X, alpha, grid_0, grid_1, grid_2):
|
||||
tl.launch(mult, [X, alpha], [grid_0, grid_1, grid_2])
|
||||
def test_constexpr_shape():
|
||||
|
||||
@triton.jit
|
||||
def kernel(X):
|
||||
off = tl.arange(0, 128 + 128)
|
||||
tl.store(X + off, off)
|
||||
|
||||
# def test_dyn_par(cond=True, device='cuda'):
|
||||
# n_pids = 10
|
||||
# # pids = torch.arange(n_pids, device=device)
|
||||
# # alpha = 2.0
|
||||
# # x_ref = pids * alpha
|
||||
# x_tri = torch.full((10,), fill_value=-1., device=device)
|
||||
# # cond = torch.tensor([cond], device=device)
|
||||
# stub[(1,)](x_tri, 3.14, n_pids, 1, 1)
|
||||
# print(x_tri)
|
||||
# # triton.testing.assert_almost_equal(x_ref, x_tri)
|
||||
x_tri = to_triton(np.empty((256, ), dtype=np.int32))
|
||||
kernel[(1,)](x_tri)
|
||||
np.testing.assert_equal(to_numpy(x_tri), np.arange(0, 256))
|
||||
|
Reference in New Issue
Block a user