[FRONTEND] fix bool conversion of floating types (#545)
This commit is contained in:
committed by
GitHub
parent
38573d1261
commit
7094657aa9
@@ -557,6 +557,7 @@ def test_atomic_cas():
|
||||
('float32', 'bfloat16', False),
|
||||
('bfloat16', 'float32', False),
|
||||
('float32', 'int32', True),
|
||||
('float32', 'int1', False),
|
||||
] + [
|
||||
(f'uint{x}', f'int{x}', True) for x in [8, 16, 32, 64]
|
||||
] + [
|
||||
@@ -565,6 +566,8 @@ def test_atomic_cas():
|
||||
def test_cast(dtype_x, dtype_z, bitcast, device='cuda'):
|
||||
# This is tricky because numpy doesn't have bfloat, and torch doesn't have uints.
|
||||
x0 = 43 if dtype_x in int_dtypes else 43.5
|
||||
if dtype_x in float_dtypes and dtype_z == 'int1':
|
||||
x0 = 0.5
|
||||
if dtype_x.startswith('bfloat'):
|
||||
x_tri = torch.tensor([x0], dtype=getattr(torch, dtype_x), device=device)
|
||||
else:
|
||||
@@ -578,11 +581,12 @@ def test_cast(dtype_x, dtype_z, bitcast, device='cuda'):
|
||||
z = x.to(Z.dtype.element_ty, bitcast=BITCAST)
|
||||
tl.store(Z, z)
|
||||
|
||||
dtype_z_np = dtype_z if dtype_z != 'int1' else 'bool_'
|
||||
# triton result
|
||||
if dtype_z.startswith('bfloat'):
|
||||
z_tri = torch.empty((1,), dtype=getattr(torch, dtype_z), device=device)
|
||||
else:
|
||||
z_tri = to_triton(np.empty((1, ), dtype=getattr(np, dtype_z)), device=device)
|
||||
z_tri = to_triton(np.empty((1, ), dtype=getattr(np, dtype_z_np)), device=device)
|
||||
kernel[(1, )](x_tri, z_tri, BITCAST=bitcast)
|
||||
# torch result
|
||||
if dtype_z.startswith('bfloat') or dtype_x.startswith('bfloat'):
|
||||
@@ -591,9 +595,9 @@ def test_cast(dtype_x, dtype_z, bitcast, device='cuda'):
|
||||
assert z_tri == z_ref
|
||||
else:
|
||||
if bitcast:
|
||||
z_ref = x.view(getattr(np, dtype_z))
|
||||
z_ref = x.view(getattr(np, dtype_z_np))
|
||||
else:
|
||||
z_ref = x.astype(getattr(np, dtype_z))
|
||||
z_ref = x.astype(getattr(np, dtype_z_np))
|
||||
assert to_numpy(z_tri) == z_ref
|
||||
|
||||
|
||||
|
@@ -565,7 +565,6 @@ def cast(input: tl.tensor,
|
||||
return input
|
||||
src_sca_ty = src_ty.scalar
|
||||
dst_sca_ty = dst_ty.scalar
|
||||
|
||||
# bf16 <=> (not fp32)
|
||||
if (src_sca_ty.is_bf16() and not dst_sca_ty.is_fp32()) or \
|
||||
(dst_sca_ty.is_bf16() and not src_sca_ty.is_fp32()):
|
||||
@@ -601,9 +600,7 @@ def cast(input: tl.tensor,
|
||||
if src_sca_ty.is_floating() and dst_sca_ty.is_int():
|
||||
# TODO: is this correct?
|
||||
if dst_sca_ty.is_bool():
|
||||
return tl.tensor(builder.create_fp_to_ui(input.handle,
|
||||
dst_ty.to_ir(builder)),
|
||||
dst_ty)
|
||||
return not_equal(input, tl._to_tensor(0, builder), builder)
|
||||
else:
|
||||
return tl.tensor(builder.create_fp_to_si(input.handle,
|
||||
dst_ty.to_ir(builder)),
|
||||
@@ -735,8 +732,8 @@ def store(ptr: tl.tensor,
|
||||
elt_ty = ptr_ty.element_ty
|
||||
# treat bool* as tl.int8*
|
||||
if elt_ty == tl.int1:
|
||||
elt_ty = tl.int8
|
||||
ptr_ty = tl.pointer_type(elt_ty, ptr_ty.address_space)
|
||||
elt_ty_ptr = tl.int8
|
||||
ptr_ty = tl.pointer_type(elt_ty_ptr, ptr_ty.address_space)
|
||||
ptr = cast(ptr, ptr_ty, builder)
|
||||
|
||||
# cast to target data-type
|
||||
|
Reference in New Issue
Block a user