[FRONTEND] Raise broadcast error (#555)

This commit is contained in:
Keren Zhou
2022-06-30 17:32:07 -07:00
committed by GitHub
parent f733327ba4
commit a74cce375f
2 changed files with 30 additions and 7 deletions

View File

@@ -385,6 +385,8 @@ def test_index1d(expr, dtype_str, device='cuda'):
rank_y = expr.count(',') + 1
shape_x = [32 for _ in range(rank_x)]
shape_z = [32 for _ in range(rank_y)]
shape_z_rank_mismatch = [32 for _ in range(rank_y + 1)]
shape_z_dim_mismatch = [64 for _ in range(rank_y)]
# Triton kernel
@triton.jit
@@ -395,12 +397,17 @@ def test_index1d(expr, dtype_str, device='cuda'):
z = GENERATE_TEST_HERE
tl.store(Z_PTR_EXPR, z)
to_replace = {
'X_PTR_EXPR': make_ptr_str('X', shape_x),
'Z_PTR_EXPR': make_ptr_str('Z', shape_z),
'GENERATE_TEST_HERE': expr,
}
kernel = patch_kernel(kernel, to_replace)
def generate_kernel(shape_x, shape_z):
to_replace = {
'X_PTR_EXPR': make_ptr_str('X', shape_x),
'Z_PTR_EXPR': make_ptr_str('Z', shape_z),
'GENERATE_TEST_HERE': expr,
}
return patch_kernel(kernel, to_replace)
kernel_match = generate_kernel(shape_x, shape_z)
kernel_dim_mismatch = generate_kernel(shape_x, shape_z_dim_mismatch)
kernel_rank_mismatch = generate_kernel(shape_x, shape_z_rank_mismatch)
# torch result
x = numpy_random(shape_x, dtype_str=dtype_str)
@@ -409,10 +416,21 @@ def test_index1d(expr, dtype_str, device='cuda'):
# triton result
z_tri = to_triton(np.empty_like(z_ref), device=device)
x_tri = to_triton(x)
kernel[(1, )](z_tri, x_tri, num_warps=1, SIZE=shape_x[0])
kernel_match[(1, )](z_tri, x_tri, num_warps=1, SIZE=shape_x[0])
# compare
assert (z_ref == to_numpy(z_tri)).all()
def catch_compilation_error(kernel):
try:
kernel[(1, )](z_tri, x_tri, num_warps=1, SIZE=shape_x[0])
except triton.code_gen.CompilationError as e:
np.testing.assert_(True)
except BaseException:
np.testing.assert_(False)
catch_compilation_error(kernel_dim_mismatch)
catch_compilation_error(kernel_rank_mismatch)
# ---------------
# test tuples