[FRONTEND] Fix a bool tensor storing problem (#746)
This commit is contained in:
@@ -756,6 +756,25 @@ def test_cast(dtype_x, dtype_z, bitcast, device='cuda'):
|
|||||||
assert to_numpy(z_tri) == z_ref
|
assert to_numpy(z_tri) == z_ref
|
||||||
|
|
||||||
|
|
||||||
|
def test_store_bool():
|
||||||
|
"""Tests that boolean True is stored as 1"""
|
||||||
|
@triton.jit
|
||||||
|
def copy_kernel(input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
|
||||||
|
offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||||
|
mask = offsets < n_elements
|
||||||
|
input = tl.load(input_ptr + offsets, mask=mask)
|
||||||
|
output = input
|
||||||
|
tl.store(output_ptr + offsets, output, mask=mask)
|
||||||
|
|
||||||
|
src = torch.tensor([True, False], dtype=torch.bool, device='cuda')
|
||||||
|
n_elements = src.numel()
|
||||||
|
dst = torch.empty_like(src)
|
||||||
|
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
|
||||||
|
copy_kernel[grid](src, dst, n_elements, BLOCK_SIZE=1024)
|
||||||
|
|
||||||
|
assert (to_numpy(src).view('uint8') == to_numpy(dst).view('uint8')).all()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||||
def test_f8_xf16_roundtrip(dtype):
|
def test_f8_xf16_roundtrip(dtype):
|
||||||
"""Tests that converting an f8 to f16 and back to f8 doesn't change its value"""
|
"""Tests that converting an f8 to f16 and back to f8 doesn't change its value"""
|
||||||
|
@@ -761,8 +761,10 @@ def store(ptr: tl.tensor,
|
|||||||
elt_ty = ptr_ty.element_ty
|
elt_ty = ptr_ty.element_ty
|
||||||
# treat bool* as tl.int8*
|
# treat bool* as tl.int8*
|
||||||
if elt_ty == tl.int1:
|
if elt_ty == tl.int1:
|
||||||
elt_ty_ptr = tl.int8
|
# convert to bool first and then store as int8
|
||||||
ptr_ty = tl.pointer_type(elt_ty_ptr, ptr_ty.address_space)
|
val = cast(val, tl.int1, builder)
|
||||||
|
elt_ty = tl.int8
|
||||||
|
ptr_ty = tl.pointer_type(elt_ty, ptr_ty.address_space)
|
||||||
ptr = cast(ptr, ptr_ty, builder)
|
ptr = cast(ptr, ptr_ty, builder)
|
||||||
# eviction policy
|
# eviction policy
|
||||||
eviction = _parse_eviction_policy(eviction_policy)
|
eviction = _parse_eviction_policy(eviction_policy)
|
||||||
|
Reference in New Issue
Block a user