[FRONTEND] Raise broadcast error (#555)

This commit is contained in:
Keren Zhou
2022-06-30 17:32:07 -07:00
committed by GitHub
parent f733327ba4
commit a74cce375f
2 changed files with 30 additions and 7 deletions

View File

@@ -385,6 +385,8 @@ def test_index1d(expr, dtype_str, device='cuda'):
rank_y = expr.count(',') + 1
shape_x = [32 for _ in range(rank_x)]
shape_z = [32 for _ in range(rank_y)]
shape_z_rank_mismatch = [32 for _ in range(rank_y + 1)]
shape_z_dim_mismatch = [64 for _ in range(rank_y)]
# Triton kernel
@triton.jit
@@ -395,12 +397,17 @@ def test_index1d(expr, dtype_str, device='cuda'):
z = GENERATE_TEST_HERE
tl.store(Z_PTR_EXPR, z)
to_replace = {
'X_PTR_EXPR': make_ptr_str('X', shape_x),
'Z_PTR_EXPR': make_ptr_str('Z', shape_z),
'GENERATE_TEST_HERE': expr,
}
kernel = patch_kernel(kernel, to_replace)
def generate_kernel(shape_x, shape_z):
to_replace = {
'X_PTR_EXPR': make_ptr_str('X', shape_x),
'Z_PTR_EXPR': make_ptr_str('Z', shape_z),
'GENERATE_TEST_HERE': expr,
}
return patch_kernel(kernel, to_replace)
kernel_match = generate_kernel(shape_x, shape_z)
kernel_dim_mismatch = generate_kernel(shape_x, shape_z_dim_mismatch)
kernel_rank_mismatch = generate_kernel(shape_x, shape_z_rank_mismatch)
# torch result
x = numpy_random(shape_x, dtype_str=dtype_str)
@@ -409,10 +416,21 @@ def test_index1d(expr, dtype_str, device='cuda'):
# triton result
z_tri = to_triton(np.empty_like(z_ref), device=device)
x_tri = to_triton(x)
kernel[(1, )](z_tri, x_tri, num_warps=1, SIZE=shape_x[0])
kernel_match[(1, )](z_tri, x_tri, num_warps=1, SIZE=shape_x[0])
# compare
assert (z_ref == to_numpy(z_tri)).all()
def catch_compilation_error(kernel):
try:
kernel[(1, )](z_tri, x_tri, num_warps=1, SIZE=shape_x[0])
except triton.code_gen.CompilationError as e:
np.testing.assert_(True)
except BaseException:
np.testing.assert_(False)
catch_compilation_error(kernel_dim_mismatch)
catch_compilation_error(kernel_rank_mismatch)
# ---------------
# test tuples

View File

@@ -482,6 +482,11 @@ def broadcast_impl_shape(input: tl.tensor,
raise ValueError(f"Cannot broadcast, rank mismatch: {src_shape}, {shape}")
if shape == src_shape:
return input
for i in range(len(src_shape)):
if shape[i] != src_shape[i] and src_shape[i] != 1:
raise ValueError(f"Cannot broadcast, the expanded size of the tensor ({shape[i]})"
f" must match the existing size ({src_shape[1]}) at non-singleton dimension"
f" {i}: {src_shape}, {shape}")
ret_ty = tl.block_type(input.type.scalar, shape)
return tl.tensor(builder.create_broadcast(input.handle, shape), ret_ty)