[FRONTEND][BACKEND] Fixed various bugs (#819)
- Fixed bugs on layout conversions for int1 data (we should use int8 internally for int1 data to prevent llvm from using vec<i1> which has different semantics) - Fixed semantics of some casts to bool in the frontend
This commit is contained in:
@@ -411,41 +411,40 @@ def test_where(dtype):
|
||||
assert (z == to_numpy(z_tri)).all()
|
||||
|
||||
|
||||
# TODO: wrong result
|
||||
# def test_where_broadcast():
|
||||
# @triton.jit
|
||||
# def where_kernel(cond_ptr, a_ptr, out_ptr, BLOCK_SIZE: tl.constexpr):
|
||||
# xoffsets = tl.arange(0, BLOCK_SIZE)[:, None]
|
||||
# yoffsets = tl.arange(0, BLOCK_SIZE)[None, :]
|
||||
def test_where_broadcast():
|
||||
@triton.jit
|
||||
def where_kernel(cond_ptr, a_ptr, out_ptr, BLOCK_SIZE: tl.constexpr):
|
||||
xoffsets = tl.arange(0, BLOCK_SIZE)[:, None]
|
||||
yoffsets = tl.arange(0, BLOCK_SIZE)[None, :]
|
||||
|
||||
# mask = tl.load(cond_ptr + yoffsets)
|
||||
# vals = tl.load(a_ptr + yoffsets + BLOCK_SIZE * xoffsets)
|
||||
# res = tl.where(mask, vals, 0.)
|
||||
# tl.store(out_ptr + yoffsets + BLOCK_SIZE * xoffsets, res)
|
||||
mask = tl.load(cond_ptr + yoffsets)
|
||||
vals = tl.load(a_ptr + yoffsets + BLOCK_SIZE * xoffsets)
|
||||
res = tl.where(mask, vals, 0.)
|
||||
tl.store(out_ptr + yoffsets + BLOCK_SIZE * xoffsets, res)
|
||||
|
||||
# @triton.jit
|
||||
# def where_scalar_condition(a_ptr, out_ptr, BLOCK_SIZE: tl.constexpr):
|
||||
# xoffsets = tl.arange(0, BLOCK_SIZE)[:, None]
|
||||
# yoffsets = tl.arange(0, BLOCK_SIZE)[None, :]
|
||||
# mask = 0
|
||||
# vals = tl.load(a_ptr + yoffsets + BLOCK_SIZE * xoffsets)
|
||||
# res = tl.where(mask, vals, 0.)
|
||||
# tl.store(out_ptr + yoffsets + BLOCK_SIZE * xoffsets, res)
|
||||
@triton.jit
|
||||
def where_scalar_condition(a_ptr, out_ptr, BLOCK_SIZE: tl.constexpr):
|
||||
xoffsets = tl.arange(0, BLOCK_SIZE)[:, None]
|
||||
yoffsets = tl.arange(0, BLOCK_SIZE)[None, :]
|
||||
mask = 0
|
||||
vals = tl.load(a_ptr + yoffsets + BLOCK_SIZE * xoffsets)
|
||||
res = tl.where(mask, vals, 0.)
|
||||
tl.store(out_ptr + yoffsets + BLOCK_SIZE * xoffsets, res)
|
||||
|
||||
# SIZE = 32
|
||||
# dtype = 'float32'
|
||||
# rs = RandomState(17)
|
||||
# x = numpy_random((SIZE, SIZE), dtype_str=dtype, rs=rs)
|
||||
# mask = numpy_random(SIZE, 'bool', rs=rs)
|
||||
# z = np.where(mask, x, 0)
|
||||
# cond_tri = to_triton(mask, device="cuda")
|
||||
# x_tri = to_triton(x, device='cuda', dst_type=dtype)
|
||||
# z_tri = to_triton(np.empty((SIZE, SIZE), dtype=z.dtype), device='cuda', dst_type=dtype)
|
||||
# where_kernel[(1,)](cond_tri, x_tri, z_tri, SIZE)
|
||||
# assert (z == to_numpy(z_tri)).all()
|
||||
# where_scalar_condition[(1,)](x_tri, z_tri, SIZE)
|
||||
# z = np.where(0, x, 0)
|
||||
# assert (z == to_numpy(z_tri)).all()
|
||||
SIZE = 32
|
||||
dtype = 'float32'
|
||||
rs = RandomState(17)
|
||||
x = numpy_random((SIZE, SIZE), dtype_str=dtype, rs=rs)
|
||||
mask = numpy_random(SIZE, 'bool', rs=rs)
|
||||
z = np.where(mask, x, 0)
|
||||
cond_tri = to_triton(mask, device="cuda")
|
||||
x_tri = to_triton(x, device='cuda', dst_type=dtype)
|
||||
z_tri = to_triton(np.empty((SIZE, SIZE), dtype=z.dtype), device='cuda', dst_type=dtype)
|
||||
where_kernel[(1,)](cond_tri, x_tri, z_tri, SIZE)
|
||||
assert (z == to_numpy(z_tri)).all()
|
||||
where_scalar_condition[(1,)](x_tri, z_tri, SIZE)
|
||||
z = np.where(0, x, 0)
|
||||
assert (z == to_numpy(z_tri)).all()
|
||||
|
||||
# # ---------------
|
||||
# # test unary ops
|
||||
@@ -719,7 +718,7 @@ def test_tuples():
|
||||
# ('bfloat16', 'float32', False),
|
||||
('float32', 'int32', True),
|
||||
# TODO:
|
||||
# ('float32', 'int1', False),
|
||||
('float32', 'int1', False),
|
||||
] + [
|
||||
(f'uint{x}', f'int{x}', True) for x in [8, 16, 32, 64]
|
||||
] + [
|
||||
|
Reference in New Issue
Block a user