[FRONTEND] Fix a bool tensor storing problem (#746)

This commit is contained in:
Bin Bao
2022-10-10 15:11:50 -04:00
committed by GitHub
parent 5d4b26d380
commit 09cc2d454b
2 changed files with 23 additions and 2 deletions

View File

@@ -756,6 +756,25 @@ def test_cast(dtype_x, dtype_z, bitcast, device='cuda'):
assert to_numpy(z_tri) == z_ref assert to_numpy(z_tri) == z_ref
def test_store_bool():
"""Tests that boolean True is stored as 1"""
@triton.jit
def copy_kernel(input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
input = tl.load(input_ptr + offsets, mask=mask)
output = input
tl.store(output_ptr + offsets, output, mask=mask)
src = torch.tensor([True, False], dtype=torch.bool, device='cuda')
n_elements = src.numel()
dst = torch.empty_like(src)
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
copy_kernel[grid](src, dst, n_elements, BLOCK_SIZE=1024)
assert (to_numpy(src).view('uint8') == to_numpy(dst).view('uint8')).all()
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_f8_xf16_roundtrip(dtype): def test_f8_xf16_roundtrip(dtype):
"""Tests that converting an f8 to f16 and back to f8 doesn't change its value""" """Tests that converting an f8 to f16 and back to f8 doesn't change its value"""

View File

@@ -761,8 +761,10 @@ def store(ptr: tl.tensor,
elt_ty = ptr_ty.element_ty elt_ty = ptr_ty.element_ty
# treat bool* as tl.int8* # treat bool* as tl.int8*
if elt_ty == tl.int1: if elt_ty == tl.int1:
elt_ty_ptr = tl.int8 # convert to bool first and then store as int8
ptr_ty = tl.pointer_type(elt_ty_ptr, ptr_ty.address_space) val = cast(val, tl.int1, builder)
elt_ty = tl.int8
ptr_ty = tl.pointer_type(elt_ty, ptr_ty.address_space)
ptr = cast(ptr, ptr_ty, builder) ptr = cast(ptr, ptr_ty, builder)
# eviction policy # eviction policy
eviction = _parse_eviction_policy(eviction_policy) eviction = _parse_eviction_policy(eviction_policy)