From 7094657aa993aa9c29d80a44446ff20956ec52e7 Mon Sep 17 00:00:00 2001 From: Natalia Gimelshein Date: Mon, 13 Jun 2022 15:52:37 -0700 Subject: [PATCH] [FRONTEND] fix bool conversion of floating types (#545) --- python/test/unit/language/test_core.py | 10 +++++++--- python/triton/language/semantic.py | 9 +++------ 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index b2b2cdeb1..c76cbbd95 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -557,6 +557,7 @@ def test_atomic_cas(): ('float32', 'bfloat16', False), ('bfloat16', 'float32', False), ('float32', 'int32', True), + ('float32', 'int1', False), ] + [ (f'uint{x}', f'int{x}', True) for x in [8, 16, 32, 64] ] + [ @@ -565,6 +566,8 @@ def test_atomic_cas(): def test_cast(dtype_x, dtype_z, bitcast, device='cuda'): # This is tricky because numpy doesn't have bfloat, and torch doesn't have uints. x0 = 43 if dtype_x in int_dtypes else 43.5 + if dtype_x in float_dtypes and dtype_z == 'int1': + x0 = 0.5 if dtype_x.startswith('bfloat'): x_tri = torch.tensor([x0], dtype=getattr(torch, dtype_x), device=device) else: @@ -578,11 +581,12 @@ def test_cast(dtype_x, dtype_z, bitcast, device='cuda'): z = x.to(Z.dtype.element_ty, bitcast=BITCAST) tl.store(Z, z) + dtype_z_np = dtype_z if dtype_z != 'int1' else 'bool_' # triton result if dtype_z.startswith('bfloat'): z_tri = torch.empty((1,), dtype=getattr(torch, dtype_z), device=device) else: - z_tri = to_triton(np.empty((1, ), dtype=getattr(np, dtype_z)), device=device) + z_tri = to_triton(np.empty((1, ), dtype=getattr(np, dtype_z_np)), device=device) kernel[(1, )](x_tri, z_tri, BITCAST=bitcast) # torch result if dtype_z.startswith('bfloat') or dtype_x.startswith('bfloat'): @@ -591,9 +595,9 @@ def test_cast(dtype_x, dtype_z, bitcast, device='cuda'): assert z_tri == z_ref else: if bitcast: - z_ref = x.view(getattr(np, dtype_z)) + z_ref = x.view(getattr(np, dtype_z_np)) else: - z_ref = x.astype(getattr(np, dtype_z)) + z_ref = x.astype(getattr(np, dtype_z_np)) assert to_numpy(z_tri) == z_ref diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index e1c8e6028..753944285 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -565,7 +565,6 @@ def cast(input: tl.tensor, return input src_sca_ty = src_ty.scalar dst_sca_ty = dst_ty.scalar - # bf16 <=> (not fp32) if (src_sca_ty.is_bf16() and not dst_sca_ty.is_fp32()) or \ (dst_sca_ty.is_bf16() and not src_sca_ty.is_fp32()): @@ -601,9 +600,7 @@ def cast(input: tl.tensor, if src_sca_ty.is_floating() and dst_sca_ty.is_int(): # TODO: is this correct? if dst_sca_ty.is_bool(): - return tl.tensor(builder.create_fp_to_ui(input.handle, - dst_ty.to_ir(builder)), - dst_ty) + return not_equal(input, tl._to_tensor(0, builder), builder) else: return tl.tensor(builder.create_fp_to_si(input.handle, dst_ty.to_ir(builder)), @@ -735,8 +732,8 @@ def store(ptr: tl.tensor, elt_ty = ptr_ty.element_ty # treat bool* as tl.int8* if elt_ty == tl.int1: - elt_ty = tl.int8 - ptr_ty = tl.pointer_type(elt_ty, ptr_ty.address_space) + elt_ty_ptr = tl.int8 + ptr_ty = tl.pointer_type(elt_ty_ptr, ptr_ty.address_space) ptr = cast(ptr, ptr_ty, builder) # cast to target data-type