[FRONTEND] fix broadcasting for where (#729)

Fixes #532, all 3 inputs to where have to be broadcast together.
This commit is contained in:
Natalia Gimelshein
2022-10-01 13:18:47 -07:00
committed by GitHub
parent b244db06da
commit f55960e773
2 changed files with 28 additions and 2 deletions

View File

@@ -410,9 +410,34 @@ def test_where(dtype):
assert (z == to_numpy(z_tri)).all()
def test_where_broadcast():
@triton.jit
def where_kernel(cond_ptr, a_ptr, out_ptr, BLOCK_SIZE: tl.constexpr):
xoffsets = tl.reshape(tl.arange(0, BLOCK_SIZE), [BLOCK_SIZE, 1])
yoffsets = tl.reshape(tl.arange(0, BLOCK_SIZE), [1, BLOCK_SIZE])
mask = tl.load(cond_ptr + yoffsets)
vals = tl.load(a_ptr + yoffsets + BLOCK_SIZE * xoffsets)
res = tl.where(mask, vals, 0.)
tl.store(out_ptr + yoffsets + BLOCK_SIZE * xoffsets, res)
SIZE = 32
dtype = 'float32'
rs = RandomState(17)
x = numpy_random((SIZE, SIZE), dtype_str=dtype, rs=rs)
mask = numpy_random(SIZE, 'bool', rs=rs)
z = np.where(mask, x, 0)
cond_tri = to_triton(mask, device="cuda")
x_tri = to_triton(x, device='cuda', dst_type=dtype)
z_tri = to_triton(np.empty((SIZE, SIZE), dtype=z.dtype), device='cuda', dst_type=dtype)
where_kernel[(1,)](cond_tri, x_tri, z_tri, SIZE)
assert (z == to_numpy(z_tri)).all()
# ---------------
# test unary ops
# ---------------
@pytest.mark.parametrize("dtype_x, expr", [
(dtype_x, ' -x') for dtype_x in dtypes_with_bfloat16
] + [

View File

@@ -973,8 +973,9 @@ def where(condition: tl.tensor,
builder: ir.builder) -> tl.tensor:
condition = cast(condition, tl.int1, builder)
if condition.type.is_block():
x = broadcast_impl_shape(x, condition.type.get_block_shapes(), builder)
y = broadcast_impl_shape(y, condition.type.get_block_shapes(), builder)
condition, x = broadcast_impl_value(condition, x, builder)
x, y = broadcast_impl_value(x, y, builder)
condition, x = broadcast_impl_value(condition, x, builder)
x, y = binary_op_type_checking_impl(x, y, builder, True, True)
ret_ty = x.type