[FRONTEND][BACKEND] Fixed various bugs (#819)

- Fixed bugs on layout conversions for int1 data (we should use int8
internally for int1 data to prevent llvm from using vec<i1> which has
different semantics)
- Fixed semantics of some casts to bool in the frontend
This commit is contained in:
Philippe Tillet
2022-10-28 23:34:14 -07:00
committed by GitHub
parent 82834d34f9
commit 7dfab26a39
5 changed files with 74 additions and 52 deletions

View File

@@ -411,41 +411,40 @@ def test_where(dtype):
assert (z == to_numpy(z_tri)).all()
# TODO: wrong result
# def test_where_broadcast():
# @triton.jit
# def where_kernel(cond_ptr, a_ptr, out_ptr, BLOCK_SIZE: tl.constexpr):
# xoffsets = tl.arange(0, BLOCK_SIZE)[:, None]
# yoffsets = tl.arange(0, BLOCK_SIZE)[None, :]
def test_where_broadcast():
@triton.jit
def where_kernel(cond_ptr, a_ptr, out_ptr, BLOCK_SIZE: tl.constexpr):
xoffsets = tl.arange(0, BLOCK_SIZE)[:, None]
yoffsets = tl.arange(0, BLOCK_SIZE)[None, :]
# mask = tl.load(cond_ptr + yoffsets)
# vals = tl.load(a_ptr + yoffsets + BLOCK_SIZE * xoffsets)
# res = tl.where(mask, vals, 0.)
# tl.store(out_ptr + yoffsets + BLOCK_SIZE * xoffsets, res)
mask = tl.load(cond_ptr + yoffsets)
vals = tl.load(a_ptr + yoffsets + BLOCK_SIZE * xoffsets)
res = tl.where(mask, vals, 0.)
tl.store(out_ptr + yoffsets + BLOCK_SIZE * xoffsets, res)
# @triton.jit
# def where_scalar_condition(a_ptr, out_ptr, BLOCK_SIZE: tl.constexpr):
# xoffsets = tl.arange(0, BLOCK_SIZE)[:, None]
# yoffsets = tl.arange(0, BLOCK_SIZE)[None, :]
# mask = 0
# vals = tl.load(a_ptr + yoffsets + BLOCK_SIZE * xoffsets)
# res = tl.where(mask, vals, 0.)
# tl.store(out_ptr + yoffsets + BLOCK_SIZE * xoffsets, res)
@triton.jit
def where_scalar_condition(a_ptr, out_ptr, BLOCK_SIZE: tl.constexpr):
xoffsets = tl.arange(0, BLOCK_SIZE)[:, None]
yoffsets = tl.arange(0, BLOCK_SIZE)[None, :]
mask = 0
vals = tl.load(a_ptr + yoffsets + BLOCK_SIZE * xoffsets)
res = tl.where(mask, vals, 0.)
tl.store(out_ptr + yoffsets + BLOCK_SIZE * xoffsets, res)
# SIZE = 32
# dtype = 'float32'
# rs = RandomState(17)
# x = numpy_random((SIZE, SIZE), dtype_str=dtype, rs=rs)
# mask = numpy_random(SIZE, 'bool', rs=rs)
# z = np.where(mask, x, 0)
# cond_tri = to_triton(mask, device="cuda")
# x_tri = to_triton(x, device='cuda', dst_type=dtype)
# z_tri = to_triton(np.empty((SIZE, SIZE), dtype=z.dtype), device='cuda', dst_type=dtype)
# where_kernel[(1,)](cond_tri, x_tri, z_tri, SIZE)
# assert (z == to_numpy(z_tri)).all()
# where_scalar_condition[(1,)](x_tri, z_tri, SIZE)
# z = np.where(0, x, 0)
# assert (z == to_numpy(z_tri)).all()
SIZE = 32
dtype = 'float32'
rs = RandomState(17)
x = numpy_random((SIZE, SIZE), dtype_str=dtype, rs=rs)
mask = numpy_random(SIZE, 'bool', rs=rs)
z = np.where(mask, x, 0)
cond_tri = to_triton(mask, device="cuda")
x_tri = to_triton(x, device='cuda', dst_type=dtype)
z_tri = to_triton(np.empty((SIZE, SIZE), dtype=z.dtype), device='cuda', dst_type=dtype)
where_kernel[(1,)](cond_tri, x_tri, z_tri, SIZE)
assert (z == to_numpy(z_tri)).all()
where_scalar_condition[(1,)](x_tri, z_tri, SIZE)
z = np.where(0, x, 0)
assert (z == to_numpy(z_tri)).all()
# # ---------------
# # test unary ops
@@ -719,7 +718,7 @@ def test_tuples():
# ('bfloat16', 'float32', False),
('float32', 'int32', True),
# TODO:
# ('float32', 'int1', False),
('float32', 'int1', False),
] + [
(f'uint{x}', f'int{x}', True) for x in [8, 16, 32, 64]
] + [