[FRONTEND] Fix a bool tensor storing problem (#746)
This commit is contained in:
@@ -756,6 +756,25 @@ def test_cast(dtype_x, dtype_z, bitcast, device='cuda'):
|
||||
assert to_numpy(z_tri) == z_ref
|
||||
|
||||
|
||||
def test_store_bool():
|
||||
"""Tests that boolean True is stored as 1"""
|
||||
@triton.jit
|
||||
def copy_kernel(input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
|
||||
offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offsets < n_elements
|
||||
input = tl.load(input_ptr + offsets, mask=mask)
|
||||
output = input
|
||||
tl.store(output_ptr + offsets, output, mask=mask)
|
||||
|
||||
src = torch.tensor([True, False], dtype=torch.bool, device='cuda')
|
||||
n_elements = src.numel()
|
||||
dst = torch.empty_like(src)
|
||||
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
|
||||
copy_kernel[grid](src, dst, n_elements, BLOCK_SIZE=1024)
|
||||
|
||||
assert (to_numpy(src).view('uint8') == to_numpy(dst).view('uint8')).all()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
def test_f8_xf16_roundtrip(dtype):
|
||||
"""Tests that converting an f8 to f16 and back to f8 doesn't change its value"""
|
||||
|
@@ -761,8 +761,10 @@ def store(ptr: tl.tensor,
|
||||
elt_ty = ptr_ty.element_ty
|
||||
# treat bool* as tl.int8*
|
||||
if elt_ty == tl.int1:
|
||||
elt_ty_ptr = tl.int8
|
||||
ptr_ty = tl.pointer_type(elt_ty_ptr, ptr_ty.address_space)
|
||||
# convert to bool first and then store as int8
|
||||
val = cast(val, tl.int1, builder)
|
||||
elt_ty = tl.int8
|
||||
ptr_ty = tl.pointer_type(elt_ty, ptr_ty.address_space)
|
||||
ptr = cast(ptr, ptr_ty, builder)
|
||||
# eviction policy
|
||||
eviction = _parse_eviction_policy(eviction_policy)
|
||||
|
Reference in New Issue
Block a user