[FRONTEND] Raise broadcast error (#555)
This commit is contained in:
@@ -385,6 +385,8 @@ def test_index1d(expr, dtype_str, device='cuda'):
|
||||
rank_y = expr.count(',') + 1
|
||||
shape_x = [32 for _ in range(rank_x)]
|
||||
shape_z = [32 for _ in range(rank_y)]
|
||||
shape_z_rank_mismatch = [32 for _ in range(rank_y + 1)]
|
||||
shape_z_dim_mismatch = [64 for _ in range(rank_y)]
|
||||
|
||||
# Triton kernel
|
||||
@triton.jit
|
||||
@@ -395,12 +397,17 @@ def test_index1d(expr, dtype_str, device='cuda'):
|
||||
z = GENERATE_TEST_HERE
|
||||
tl.store(Z_PTR_EXPR, z)
|
||||
|
||||
to_replace = {
|
||||
'X_PTR_EXPR': make_ptr_str('X', shape_x),
|
||||
'Z_PTR_EXPR': make_ptr_str('Z', shape_z),
|
||||
'GENERATE_TEST_HERE': expr,
|
||||
}
|
||||
kernel = patch_kernel(kernel, to_replace)
|
||||
def generate_kernel(shape_x, shape_z):
|
||||
to_replace = {
|
||||
'X_PTR_EXPR': make_ptr_str('X', shape_x),
|
||||
'Z_PTR_EXPR': make_ptr_str('Z', shape_z),
|
||||
'GENERATE_TEST_HERE': expr,
|
||||
}
|
||||
return patch_kernel(kernel, to_replace)
|
||||
|
||||
kernel_match = generate_kernel(shape_x, shape_z)
|
||||
kernel_dim_mismatch = generate_kernel(shape_x, shape_z_dim_mismatch)
|
||||
kernel_rank_mismatch = generate_kernel(shape_x, shape_z_rank_mismatch)
|
||||
|
||||
# torch result
|
||||
x = numpy_random(shape_x, dtype_str=dtype_str)
|
||||
@@ -409,10 +416,21 @@ def test_index1d(expr, dtype_str, device='cuda'):
|
||||
# triton result
|
||||
z_tri = to_triton(np.empty_like(z_ref), device=device)
|
||||
x_tri = to_triton(x)
|
||||
kernel[(1, )](z_tri, x_tri, num_warps=1, SIZE=shape_x[0])
|
||||
kernel_match[(1, )](z_tri, x_tri, num_warps=1, SIZE=shape_x[0])
|
||||
# compare
|
||||
assert (z_ref == to_numpy(z_tri)).all()
|
||||
|
||||
def catch_compilation_error(kernel):
|
||||
try:
|
||||
kernel[(1, )](z_tri, x_tri, num_warps=1, SIZE=shape_x[0])
|
||||
except triton.code_gen.CompilationError as e:
|
||||
np.testing.assert_(True)
|
||||
except BaseException:
|
||||
np.testing.assert_(False)
|
||||
|
||||
catch_compilation_error(kernel_dim_mismatch)
|
||||
catch_compilation_error(kernel_rank_mismatch)
|
||||
|
||||
|
||||
# ---------------
|
||||
# test tuples
|
||||
|
@@ -482,6 +482,11 @@ def broadcast_impl_shape(input: tl.tensor,
|
||||
raise ValueError(f"Cannot broadcast, rank mismatch: {src_shape}, {shape}")
|
||||
if shape == src_shape:
|
||||
return input
|
||||
for i in range(len(src_shape)):
|
||||
if shape[i] != src_shape[i] and src_shape[i] != 1:
|
||||
raise ValueError(f"Cannot broadcast, the expanded size of the tensor ({shape[i]})"
|
||||
f" must match the existing size ({src_shape[1]}) at non-singleton dimension"
|
||||
f" {i}: {src_shape}, {shape}")
|
||||
ret_ty = tl.block_type(input.type.scalar, shape)
|
||||
return tl.tensor(builder.create_broadcast(input.handle, shape), ret_ty)
|
||||
|
||||
|
Reference in New Issue
Block a user