[FRONTEND] fix bool conversion of floating types (#545)

This commit is contained in:
Natalia Gimelshein
2022-06-13 15:52:37 -07:00
committed by GitHub
parent 38573d1261
commit 7094657aa9
2 changed files with 10 additions and 9 deletions

View File

@@ -557,6 +557,7 @@ def test_atomic_cas():
('float32', 'bfloat16', False),
('bfloat16', 'float32', False),
('float32', 'int32', True),
('float32', 'int1', False),
] + [
(f'uint{x}', f'int{x}', True) for x in [8, 16, 32, 64]
] + [
@@ -565,6 +566,8 @@ def test_atomic_cas():
def test_cast(dtype_x, dtype_z, bitcast, device='cuda'):
# This is tricky because numpy doesn't have bfloat, and torch doesn't have uints.
x0 = 43 if dtype_x in int_dtypes else 43.5
if dtype_x in float_dtypes and dtype_z == 'int1':
x0 = 0.5
if dtype_x.startswith('bfloat'):
x_tri = torch.tensor([x0], dtype=getattr(torch, dtype_x), device=device)
else:
@@ -578,11 +581,12 @@ def test_cast(dtype_x, dtype_z, bitcast, device='cuda'):
z = x.to(Z.dtype.element_ty, bitcast=BITCAST)
tl.store(Z, z)
dtype_z_np = dtype_z if dtype_z != 'int1' else 'bool_'
# triton result
if dtype_z.startswith('bfloat'):
z_tri = torch.empty((1,), dtype=getattr(torch, dtype_z), device=device)
else:
z_tri = to_triton(np.empty((1, ), dtype=getattr(np, dtype_z)), device=device)
z_tri = to_triton(np.empty((1, ), dtype=getattr(np, dtype_z_np)), device=device)
kernel[(1, )](x_tri, z_tri, BITCAST=bitcast)
# torch result
if dtype_z.startswith('bfloat') or dtype_x.startswith('bfloat'):
@@ -591,9 +595,9 @@ def test_cast(dtype_x, dtype_z, bitcast, device='cuda'):
assert z_tri == z_ref
else:
if bitcast:
z_ref = x.view(getattr(np, dtype_z))
z_ref = x.view(getattr(np, dtype_z_np))
else:
z_ref = x.astype(getattr(np, dtype_z))
z_ref = x.astype(getattr(np, dtype_z_np))
assert to_numpy(z_tri) == z_ref

View File

@@ -565,7 +565,6 @@ def cast(input: tl.tensor,
return input
src_sca_ty = src_ty.scalar
dst_sca_ty = dst_ty.scalar
# bf16 <=> (not fp32)
if (src_sca_ty.is_bf16() and not dst_sca_ty.is_fp32()) or \
(dst_sca_ty.is_bf16() and not src_sca_ty.is_fp32()):
@@ -601,9 +600,7 @@ def cast(input: tl.tensor,
if src_sca_ty.is_floating() and dst_sca_ty.is_int():
# TODO: is this correct?
if dst_sca_ty.is_bool():
return tl.tensor(builder.create_fp_to_ui(input.handle,
dst_ty.to_ir(builder)),
dst_ty)
return not_equal(input, tl._to_tensor(0, builder), builder)
else:
return tl.tensor(builder.create_fp_to_si(input.handle,
dst_ty.to_ir(builder)),
@@ -735,8 +732,8 @@ def store(ptr: tl.tensor,
elt_ty = ptr_ty.element_ty
# treat bool* as tl.int8*
if elt_ty == tl.int1:
elt_ty = tl.int8
ptr_ty = tl.pointer_type(elt_ty, ptr_ty.address_space)
elt_ty_ptr = tl.int8
ptr_ty = tl.pointer_type(elt_ty_ptr, ptr_ty.address_space)
ptr = cast(ptr, ptr_ty, builder)
# cast to target data-type