From 09cc2d454b442301e88d1df153214732bd8714d8 Mon Sep 17 00:00:00 2001 From: Bin Bao Date: Mon, 10 Oct 2022 15:11:50 -0400 Subject: [PATCH] [FRONTEND] Fix a bool tensor storing problem (#746) --- python/test/unit/language/test_core.py | 19 +++++++++++++++++++ python/triton/language/semantic.py | 6 ++++-- 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 9cdd7885c..82a69cd42 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -756,6 +756,25 @@ def test_cast(dtype_x, dtype_z, bitcast, device='cuda'): assert to_numpy(z_tri) == z_ref +def test_store_bool(): + """Tests that boolean True is stored as 1""" + @triton.jit + def copy_kernel(input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr): + offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + input = tl.load(input_ptr + offsets, mask=mask) + output = input + tl.store(output_ptr + offsets, output, mask=mask) + + src = torch.tensor([True, False], dtype=torch.bool, device='cuda') + n_elements = src.numel() + dst = torch.empty_like(src) + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) + copy_kernel[grid](src, dst, n_elements, BLOCK_SIZE=1024) + + assert (to_numpy(src).view('uint8') == to_numpy(dst).view('uint8')).all() + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) def test_f8_xf16_roundtrip(dtype): """Tests that converting an f8 to f16 and back to f8 doesn't change its value""" diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index caf7e1bc1..11306851c 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -761,8 +761,10 @@ def store(ptr: tl.tensor, elt_ty = ptr_ty.element_ty # treat bool* as tl.int8* if elt_ty == tl.int1: - elt_ty_ptr = tl.int8 - ptr_ty = tl.pointer_type(elt_ty_ptr, ptr_ty.address_space) + # convert to bool first and then store as int8 + val = cast(val, tl.int1, builder) + elt_ty = tl.int8 + ptr_ty = tl.pointer_type(elt_ty, ptr_ty.address_space) ptr = cast(ptr, ptr_ty, builder) # eviction policy eviction = _parse_eviction_policy(eviction_policy)