diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 6987d0c26..92c854f06 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -385,6 +385,8 @@ def test_index1d(expr, dtype_str, device='cuda'): rank_y = expr.count(',') + 1 shape_x = [32 for _ in range(rank_x)] shape_z = [32 for _ in range(rank_y)] + shape_z_rank_mismatch = [32 for _ in range(rank_y + 1)] + shape_z_dim_mismatch = [64 for _ in range(rank_y)] # Triton kernel @triton.jit @@ -395,12 +397,17 @@ def test_index1d(expr, dtype_str, device='cuda'): z = GENERATE_TEST_HERE tl.store(Z_PTR_EXPR, z) - to_replace = { - 'X_PTR_EXPR': make_ptr_str('X', shape_x), - 'Z_PTR_EXPR': make_ptr_str('Z', shape_z), - 'GENERATE_TEST_HERE': expr, - } - kernel = patch_kernel(kernel, to_replace) + def generate_kernel(shape_x, shape_z): + to_replace = { + 'X_PTR_EXPR': make_ptr_str('X', shape_x), + 'Z_PTR_EXPR': make_ptr_str('Z', shape_z), + 'GENERATE_TEST_HERE': expr, + } + return patch_kernel(kernel, to_replace) + + kernel_match = generate_kernel(shape_x, shape_z) + kernel_dim_mismatch = generate_kernel(shape_x, shape_z_dim_mismatch) + kernel_rank_mismatch = generate_kernel(shape_x, shape_z_rank_mismatch) # torch result x = numpy_random(shape_x, dtype_str=dtype_str) @@ -409,10 +416,21 @@ def test_index1d(expr, dtype_str, device='cuda'): # triton result z_tri = to_triton(np.empty_like(z_ref), device=device) x_tri = to_triton(x) - kernel[(1, )](z_tri, x_tri, num_warps=1, SIZE=shape_x[0]) + kernel_match[(1, )](z_tri, x_tri, num_warps=1, SIZE=shape_x[0]) # compare assert (z_ref == to_numpy(z_tri)).all() + def catch_compilation_error(kernel): + try: + kernel[(1, )](z_tri, x_tri, num_warps=1, SIZE=shape_x[0]) + except triton.code_gen.CompilationError as e: + np.testing.assert_(True) + except BaseException: + np.testing.assert_(False) + + catch_compilation_error(kernel_dim_mismatch) + catch_compilation_error(kernel_rank_mismatch) + # --------------- # test tuples diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index 2d137b904..9025319d6 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -482,6 +482,11 @@ def broadcast_impl_shape(input: tl.tensor, raise ValueError(f"Cannot broadcast, rank mismatch: {src_shape}, {shape}") if shape == src_shape: return input + for i in range(len(src_shape)): + if shape[i] != src_shape[i] and src_shape[i] != 1: + raise ValueError(f"Cannot broadcast, the expanded size of the tensor ({shape[i]})" + f" must match the existing size ({src_shape[1]}) at non-singleton dimension" + f" {i}: {src_shape}, {shape}") ret_ty = tl.block_type(input.type.scalar, shape) return tl.tensor(builder.create_broadcast(input.handle, shape), ret_ty)