[FRONTEND] Fix a bool tensor storing problem (#746)

This commit is contained in:
Bin Bao
2022-10-10 15:11:50 -04:00
committed by GitHub
parent 5d4b26d380
commit 09cc2d454b
2 changed files with 23 additions and 2 deletions

View File

@@ -756,6 +756,25 @@ def test_cast(dtype_x, dtype_z, bitcast, device='cuda'):
assert to_numpy(z_tri) == z_ref
def test_store_bool():
"""Tests that boolean True is stored as 1"""
@triton.jit
def copy_kernel(input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
input = tl.load(input_ptr + offsets, mask=mask)
output = input
tl.store(output_ptr + offsets, output, mask=mask)
src = torch.tensor([True, False], dtype=torch.bool, device='cuda')
n_elements = src.numel()
dst = torch.empty_like(src)
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
copy_kernel[grid](src, dst, n_elements, BLOCK_SIZE=1024)
assert (to_numpy(src).view('uint8') == to_numpy(dst).view('uint8')).all()
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_f8_xf16_roundtrip(dtype):
"""Tests that converting an f8 to f16 and back to f8 doesn't change its value"""

View File

@@ -761,8 +761,10 @@ def store(ptr: tl.tensor,
elt_ty = ptr_ty.element_ty
# treat bool* as tl.int8*
if elt_ty == tl.int1:
elt_ty_ptr = tl.int8
ptr_ty = tl.pointer_type(elt_ty_ptr, ptr_ty.address_space)
# convert to bool first and then store as int8
val = cast(val, tl.int1, builder)
elt_ty = tl.int8
ptr_ty = tl.pointer_type(elt_ty, ptr_ty.address_space)
ptr = cast(ptr, ptr_ty, builder)
# eviction policy
eviction = _parse_eviction_policy(eviction_policy)