[FRONTEND] properly broadcast scalar where condition (#736)

This commit is contained in:
Natalia Gimelshein
2022-10-04 12:44:03 -07:00
committed by GitHub
parent 2b0f877fad
commit d3c925db8a
2 changed files with 14 additions and 0 deletions

View File

@@ -421,6 +421,15 @@ def test_where_broadcast():
res = tl.where(mask, vals, 0.) res = tl.where(mask, vals, 0.)
tl.store(out_ptr + yoffsets + BLOCK_SIZE * xoffsets, res) tl.store(out_ptr + yoffsets + BLOCK_SIZE * xoffsets, res)
@triton.jit
def where_scalar_condition(a_ptr, out_ptr, BLOCK_SIZE: tl.constexpr):
xoffsets = tl.reshape(tl.arange(0, BLOCK_SIZE), [BLOCK_SIZE, 1])
yoffsets = tl.reshape(tl.arange(0, BLOCK_SIZE), [1, BLOCK_SIZE])
mask = 0
vals = tl.load(a_ptr + yoffsets + BLOCK_SIZE * xoffsets)
res = tl.where(mask, vals, 0.)
tl.store(out_ptr + yoffsets + BLOCK_SIZE * xoffsets, res)
SIZE = 32 SIZE = 32
dtype = 'float32' dtype = 'float32'
rs = RandomState(17) rs = RandomState(17)
@@ -432,6 +441,9 @@ def test_where_broadcast():
z_tri = to_triton(np.empty((SIZE, SIZE), dtype=z.dtype), device='cuda', dst_type=dtype) z_tri = to_triton(np.empty((SIZE, SIZE), dtype=z.dtype), device='cuda', dst_type=dtype)
where_kernel[(1,)](cond_tri, x_tri, z_tri, SIZE) where_kernel[(1,)](cond_tri, x_tri, z_tri, SIZE)
assert (z == to_numpy(z_tri)).all() assert (z == to_numpy(z_tri)).all()
where_scalar_condition[(1,)](x_tri, z_tri, SIZE)
z = np.where(0, x, 0)
assert (z == to_numpy(z_tri)).all()
# --------------- # ---------------
# test unary ops # test unary ops

View File

@@ -978,6 +978,8 @@ def where(condition: tl.tensor,
condition, x = broadcast_impl_value(condition, x, builder) condition, x = broadcast_impl_value(condition, x, builder)
x, y = binary_op_type_checking_impl(x, y, builder, True, True) x, y = binary_op_type_checking_impl(x, y, builder, True, True)
if not condition.type.is_block():
condition, _ = broadcast_impl_value(condition, x, builder)
ret_ty = x.type ret_ty = x.type
return tl.tensor(builder.create_select(condition.handle, x.handle, y.handle), ret_ty) return tl.tensor(builder.create_select(condition.handle, x.handle, y.handle), ret_ty)