[BACKEND] Added support for 1D conversion blocked -> slice (#831)
This commit is contained in:
@@ -493,61 +493,65 @@ def make_ptr_str(name, shape):
|
||||
|
||||
|
||||
# TODO: handle `%4 = triton_gpu.convert_layout %3 : (tensor<32xi32, #blocked0>) -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>``
|
||||
# @pytest.mark.parametrize("expr, dtype_str", [
|
||||
# (f'x[{s}]', d)
|
||||
# for s in ['None, :', ':, None', 'None, :, :', ':, :, None']
|
||||
# for d in ['int32', 'uint32', 'uint16']
|
||||
# ])
|
||||
# def test_index1d(expr, dtype_str, device='cuda'):
|
||||
# rank_x = expr.count(':')
|
||||
# rank_y = expr.count(',') + 1
|
||||
# shape_x = [32 for _ in range(rank_x)]
|
||||
# shape_z = [32 for _ in range(rank_y)]
|
||||
# shape_z_rank_mismatch = [32 for _ in range(rank_y + 1)]
|
||||
# shape_z_dim_mismatch = [64 for _ in range(rank_y)]
|
||||
@pytest.mark.parametrize("expr, dtype_str", [
|
||||
(f'x[{s}]', d)
|
||||
for s in ['None, :', ':, None',
|
||||
# TODO: 3D
|
||||
# 'None, :, :',
|
||||
# ':, :, None'
|
||||
]
|
||||
for d in ['int32', 'uint32', 'uint16']
|
||||
])
|
||||
def test_index1d(expr, dtype_str, device='cuda'):
|
||||
rank_x = expr.count(':')
|
||||
rank_y = expr.count(',') + 1
|
||||
shape_x = [32 for _ in range(rank_x)]
|
||||
shape_z = [32 for _ in range(rank_y)]
|
||||
shape_z_rank_mismatch = [32 for _ in range(rank_y + 1)]
|
||||
shape_z_dim_mismatch = [64 for _ in range(rank_y)]
|
||||
|
||||
# # Triton kernel
|
||||
# @triton.jit
|
||||
# def kernel(Z, X, SIZE: tl.constexpr):
|
||||
# m = tl.arange(0, SIZE)
|
||||
# n = tl.arange(0, SIZE)
|
||||
# x = tl.load(X_PTR_EXPR)
|
||||
# z = GENERATE_TEST_HERE
|
||||
# tl.store(Z_PTR_EXPR, z)
|
||||
# Triton kernel
|
||||
@triton.jit
|
||||
def kernel(Z, X, SIZE: tl.constexpr):
|
||||
m = tl.arange(0, SIZE)
|
||||
n = tl.arange(0, SIZE)
|
||||
x = tl.load(X_PTR_EXPR)
|
||||
z = GENERATE_TEST_HERE
|
||||
tl.store(Z_PTR_EXPR, z)
|
||||
|
||||
# def generate_kernel(shape_x, shape_z):
|
||||
# to_replace = {
|
||||
# 'X_PTR_EXPR': make_ptr_str('X', shape_x),
|
||||
# 'Z_PTR_EXPR': make_ptr_str('Z', shape_z),
|
||||
# 'GENERATE_TEST_HERE': expr,
|
||||
# }
|
||||
# return patch_kernel(kernel, to_replace)
|
||||
def generate_kernel(shape_x, shape_z):
|
||||
to_replace = {
|
||||
'X_PTR_EXPR': make_ptr_str('X', shape_x),
|
||||
'Z_PTR_EXPR': make_ptr_str('Z', shape_z),
|
||||
'GENERATE_TEST_HERE': expr,
|
||||
}
|
||||
return patch_kernel(kernel, to_replace)
|
||||
|
||||
# kernel_match = generate_kernel(shape_x, shape_z)
|
||||
# kernel_dim_mismatch = generate_kernel(shape_x, shape_z_dim_mismatch)
|
||||
# kernel_rank_mismatch = generate_kernel(shape_x, shape_z_rank_mismatch)
|
||||
kernel_match = generate_kernel(shape_x, shape_z)
|
||||
kernel_dim_mismatch = generate_kernel(shape_x, shape_z_dim_mismatch)
|
||||
kernel_rank_mismatch = generate_kernel(shape_x, shape_z_rank_mismatch)
|
||||
|
||||
# # torch result
|
||||
# x = numpy_random(shape_x, dtype_str=dtype_str)
|
||||
# y = np.zeros(shape_z, dtype=getattr(np, dtype_str))
|
||||
# z_ref = eval(expr) + y
|
||||
# # triton result
|
||||
# z_tri = to_triton(np.empty_like(z_ref), device=device)
|
||||
# x_tri = to_triton(x)
|
||||
# kernel_match[(1, )](z_tri, x_tri, num_warps=1, SIZE=shape_x[0])
|
||||
# # compare
|
||||
# assert (z_ref == to_numpy(z_tri)).all()
|
||||
# torch result
|
||||
x = numpy_random(shape_x, dtype_str=dtype_str)
|
||||
y = np.zeros(shape_z, dtype=getattr(np, dtype_str))
|
||||
z_ref = eval(expr) + y
|
||||
# triton result
|
||||
z_tri = to_triton(np.empty_like(z_ref), device=device)
|
||||
x_tri = to_triton(x)
|
||||
kernel_match[(1, )](z_tri, x_tri, num_warps=1, SIZE=shape_x[0])
|
||||
# compare
|
||||
assert (z_ref == to_numpy(z_tri)).all()
|
||||
|
||||
# def catch_compilation_error(kernel):
|
||||
# try:
|
||||
# kernel[(1, )](z_tri, x_tri, num_warps=1, SIZE=shape_x[0])
|
||||
# except triton.CompilationError as e:
|
||||
# np.testing.assert_(True)
|
||||
# except BaseException:
|
||||
# np.testing.assert_(False)
|
||||
def catch_compilation_error(kernel):
|
||||
try:
|
||||
kernel[(1, )](z_tri, x_tri, num_warps=1, SIZE=shape_x[0])
|
||||
except triton.CompilationError as e:
|
||||
np.testing.assert_(True)
|
||||
except BaseException:
|
||||
np.testing.assert_(False)
|
||||
|
||||
# catch_compilation_error(kernel_dim_mismatch)
|
||||
# catch_compilation_error(kernel_rank_mismatch)
|
||||
catch_compilation_error(kernel_dim_mismatch)
|
||||
catch_compilation_error(kernel_rank_mismatch)
|
||||
|
||||
|
||||
# # ---------------
|
||||
|
Reference in New Issue
Block a user