[FRONTEND] fix bool conversion of floating types (#545)
This commit is contained in:
committed by
GitHub
parent
38573d1261
commit
7094657aa9
@@ -557,6 +557,7 @@ def test_atomic_cas():
|
|||||||
('float32', 'bfloat16', False),
|
('float32', 'bfloat16', False),
|
||||||
('bfloat16', 'float32', False),
|
('bfloat16', 'float32', False),
|
||||||
('float32', 'int32', True),
|
('float32', 'int32', True),
|
||||||
|
('float32', 'int1', False),
|
||||||
] + [
|
] + [
|
||||||
(f'uint{x}', f'int{x}', True) for x in [8, 16, 32, 64]
|
(f'uint{x}', f'int{x}', True) for x in [8, 16, 32, 64]
|
||||||
] + [
|
] + [
|
||||||
@@ -565,6 +566,8 @@ def test_atomic_cas():
|
|||||||
def test_cast(dtype_x, dtype_z, bitcast, device='cuda'):
|
def test_cast(dtype_x, dtype_z, bitcast, device='cuda'):
|
||||||
# This is tricky because numpy doesn't have bfloat, and torch doesn't have uints.
|
# This is tricky because numpy doesn't have bfloat, and torch doesn't have uints.
|
||||||
x0 = 43 if dtype_x in int_dtypes else 43.5
|
x0 = 43 if dtype_x in int_dtypes else 43.5
|
||||||
|
if dtype_x in float_dtypes and dtype_z == 'int1':
|
||||||
|
x0 = 0.5
|
||||||
if dtype_x.startswith('bfloat'):
|
if dtype_x.startswith('bfloat'):
|
||||||
x_tri = torch.tensor([x0], dtype=getattr(torch, dtype_x), device=device)
|
x_tri = torch.tensor([x0], dtype=getattr(torch, dtype_x), device=device)
|
||||||
else:
|
else:
|
||||||
@@ -578,11 +581,12 @@ def test_cast(dtype_x, dtype_z, bitcast, device='cuda'):
|
|||||||
z = x.to(Z.dtype.element_ty, bitcast=BITCAST)
|
z = x.to(Z.dtype.element_ty, bitcast=BITCAST)
|
||||||
tl.store(Z, z)
|
tl.store(Z, z)
|
||||||
|
|
||||||
|
dtype_z_np = dtype_z if dtype_z != 'int1' else 'bool_'
|
||||||
# triton result
|
# triton result
|
||||||
if dtype_z.startswith('bfloat'):
|
if dtype_z.startswith('bfloat'):
|
||||||
z_tri = torch.empty((1,), dtype=getattr(torch, dtype_z), device=device)
|
z_tri = torch.empty((1,), dtype=getattr(torch, dtype_z), device=device)
|
||||||
else:
|
else:
|
||||||
z_tri = to_triton(np.empty((1, ), dtype=getattr(np, dtype_z)), device=device)
|
z_tri = to_triton(np.empty((1, ), dtype=getattr(np, dtype_z_np)), device=device)
|
||||||
kernel[(1, )](x_tri, z_tri, BITCAST=bitcast)
|
kernel[(1, )](x_tri, z_tri, BITCAST=bitcast)
|
||||||
# torch result
|
# torch result
|
||||||
if dtype_z.startswith('bfloat') or dtype_x.startswith('bfloat'):
|
if dtype_z.startswith('bfloat') or dtype_x.startswith('bfloat'):
|
||||||
@@ -591,9 +595,9 @@ def test_cast(dtype_x, dtype_z, bitcast, device='cuda'):
|
|||||||
assert z_tri == z_ref
|
assert z_tri == z_ref
|
||||||
else:
|
else:
|
||||||
if bitcast:
|
if bitcast:
|
||||||
z_ref = x.view(getattr(np, dtype_z))
|
z_ref = x.view(getattr(np, dtype_z_np))
|
||||||
else:
|
else:
|
||||||
z_ref = x.astype(getattr(np, dtype_z))
|
z_ref = x.astype(getattr(np, dtype_z_np))
|
||||||
assert to_numpy(z_tri) == z_ref
|
assert to_numpy(z_tri) == z_ref
|
||||||
|
|
||||||
|
|
||||||
|
@@ -565,7 +565,6 @@ def cast(input: tl.tensor,
|
|||||||
return input
|
return input
|
||||||
src_sca_ty = src_ty.scalar
|
src_sca_ty = src_ty.scalar
|
||||||
dst_sca_ty = dst_ty.scalar
|
dst_sca_ty = dst_ty.scalar
|
||||||
|
|
||||||
# bf16 <=> (not fp32)
|
# bf16 <=> (not fp32)
|
||||||
if (src_sca_ty.is_bf16() and not dst_sca_ty.is_fp32()) or \
|
if (src_sca_ty.is_bf16() and not dst_sca_ty.is_fp32()) or \
|
||||||
(dst_sca_ty.is_bf16() and not src_sca_ty.is_fp32()):
|
(dst_sca_ty.is_bf16() and not src_sca_ty.is_fp32()):
|
||||||
@@ -601,9 +600,7 @@ def cast(input: tl.tensor,
|
|||||||
if src_sca_ty.is_floating() and dst_sca_ty.is_int():
|
if src_sca_ty.is_floating() and dst_sca_ty.is_int():
|
||||||
# TODO: is this correct?
|
# TODO: is this correct?
|
||||||
if dst_sca_ty.is_bool():
|
if dst_sca_ty.is_bool():
|
||||||
return tl.tensor(builder.create_fp_to_ui(input.handle,
|
return not_equal(input, tl._to_tensor(0, builder), builder)
|
||||||
dst_ty.to_ir(builder)),
|
|
||||||
dst_ty)
|
|
||||||
else:
|
else:
|
||||||
return tl.tensor(builder.create_fp_to_si(input.handle,
|
return tl.tensor(builder.create_fp_to_si(input.handle,
|
||||||
dst_ty.to_ir(builder)),
|
dst_ty.to_ir(builder)),
|
||||||
@@ -735,8 +732,8 @@ def store(ptr: tl.tensor,
|
|||||||
elt_ty = ptr_ty.element_ty
|
elt_ty = ptr_ty.element_ty
|
||||||
# treat bool* as tl.int8*
|
# treat bool* as tl.int8*
|
||||||
if elt_ty == tl.int1:
|
if elt_ty == tl.int1:
|
||||||
elt_ty = tl.int8
|
elt_ty_ptr = tl.int8
|
||||||
ptr_ty = tl.pointer_type(elt_ty, ptr_ty.address_space)
|
ptr_ty = tl.pointer_type(elt_ty_ptr, ptr_ty.address_space)
|
||||||
ptr = cast(ptr, ptr_ty, builder)
|
ptr = cast(ptr, ptr_ty, builder)
|
||||||
|
|
||||||
# cast to target data-type
|
# cast to target data-type
|
||||||
|
Reference in New Issue
Block a user