[BACKEND] Added support for 1D conversion blocked -> slice (#831)

This commit is contained in:
Philippe Tillet
2022-11-01 13:19:58 -07:00
committed by GitHub
parent c9d84237e8
commit 12d60cb4a3
5 changed files with 103 additions and 78 deletions

View File

@@ -493,61 +493,65 @@ def make_ptr_str(name, shape):
# TODO: handle `%4 = triton_gpu.convert_layout %3 : (tensor<32xi32, #blocked0>) -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>``
# @pytest.mark.parametrize("expr, dtype_str", [
# (f'x[{s}]', d)
# for s in ['None, :', ':, None', 'None, :, :', ':, :, None']
# for d in ['int32', 'uint32', 'uint16']
# ])
# def test_index1d(expr, dtype_str, device='cuda'):
# rank_x = expr.count(':')
# rank_y = expr.count(',') + 1
# shape_x = [32 for _ in range(rank_x)]
# shape_z = [32 for _ in range(rank_y)]
# shape_z_rank_mismatch = [32 for _ in range(rank_y + 1)]
# shape_z_dim_mismatch = [64 for _ in range(rank_y)]
@pytest.mark.parametrize("expr, dtype_str", [
(f'x[{s}]', d)
for s in ['None, :', ':, None',
# TODO: 3D
# 'None, :, :',
# ':, :, None'
]
for d in ['int32', 'uint32', 'uint16']
])
def test_index1d(expr, dtype_str, device='cuda'):
rank_x = expr.count(':')
rank_y = expr.count(',') + 1
shape_x = [32 for _ in range(rank_x)]
shape_z = [32 for _ in range(rank_y)]
shape_z_rank_mismatch = [32 for _ in range(rank_y + 1)]
shape_z_dim_mismatch = [64 for _ in range(rank_y)]
# # Triton kernel
# @triton.jit
# def kernel(Z, X, SIZE: tl.constexpr):
# m = tl.arange(0, SIZE)
# n = tl.arange(0, SIZE)
# x = tl.load(X_PTR_EXPR)
# z = GENERATE_TEST_HERE
# tl.store(Z_PTR_EXPR, z)
# Triton kernel
@triton.jit
def kernel(Z, X, SIZE: tl.constexpr):
m = tl.arange(0, SIZE)
n = tl.arange(0, SIZE)
x = tl.load(X_PTR_EXPR)
z = GENERATE_TEST_HERE
tl.store(Z_PTR_EXPR, z)
# def generate_kernel(shape_x, shape_z):
# to_replace = {
# 'X_PTR_EXPR': make_ptr_str('X', shape_x),
# 'Z_PTR_EXPR': make_ptr_str('Z', shape_z),
# 'GENERATE_TEST_HERE': expr,
# }
# return patch_kernel(kernel, to_replace)
def generate_kernel(shape_x, shape_z):
to_replace = {
'X_PTR_EXPR': make_ptr_str('X', shape_x),
'Z_PTR_EXPR': make_ptr_str('Z', shape_z),
'GENERATE_TEST_HERE': expr,
}
return patch_kernel(kernel, to_replace)
# kernel_match = generate_kernel(shape_x, shape_z)
# kernel_dim_mismatch = generate_kernel(shape_x, shape_z_dim_mismatch)
# kernel_rank_mismatch = generate_kernel(shape_x, shape_z_rank_mismatch)
kernel_match = generate_kernel(shape_x, shape_z)
kernel_dim_mismatch = generate_kernel(shape_x, shape_z_dim_mismatch)
kernel_rank_mismatch = generate_kernel(shape_x, shape_z_rank_mismatch)
# # torch result
# x = numpy_random(shape_x, dtype_str=dtype_str)
# y = np.zeros(shape_z, dtype=getattr(np, dtype_str))
# z_ref = eval(expr) + y
# # triton result
# z_tri = to_triton(np.empty_like(z_ref), device=device)
# x_tri = to_triton(x)
# kernel_match[(1, )](z_tri, x_tri, num_warps=1, SIZE=shape_x[0])
# # compare
# assert (z_ref == to_numpy(z_tri)).all()
# torch result
x = numpy_random(shape_x, dtype_str=dtype_str)
y = np.zeros(shape_z, dtype=getattr(np, dtype_str))
z_ref = eval(expr) + y
# triton result
z_tri = to_triton(np.empty_like(z_ref), device=device)
x_tri = to_triton(x)
kernel_match[(1, )](z_tri, x_tri, num_warps=1, SIZE=shape_x[0])
# compare
assert (z_ref == to_numpy(z_tri)).all()
# def catch_compilation_error(kernel):
# try:
# kernel[(1, )](z_tri, x_tri, num_warps=1, SIZE=shape_x[0])
# except triton.CompilationError as e:
# np.testing.assert_(True)
# except BaseException:
# np.testing.assert_(False)
def catch_compilation_error(kernel):
try:
kernel[(1, )](z_tri, x_tri, num_warps=1, SIZE=shape_x[0])
except triton.CompilationError as e:
np.testing.assert_(True)
except BaseException:
np.testing.assert_(False)
# catch_compilation_error(kernel_dim_mismatch)
# catch_compilation_error(kernel_rank_mismatch)
catch_compilation_error(kernel_dim_mismatch)
catch_compilation_error(kernel_rank_mismatch)
# # ---------------