fp8 <> bf16 conversion (#637)

Co-authored-by: Philippe Tillet <phil@openai.com>
This commit is contained in:
Da Yan
2022-08-31 05:20:12 +08:00
committed by GitHub
parent 210a296699
commit 437ced38c2
5 changed files with 133 additions and 10 deletions

View File

@@ -719,8 +719,11 @@ def test_cast(dtype_x, dtype_z, bitcast, device='cuda'):
assert to_numpy(z_tri) == z_ref
def test_f8_f16_roundtrip():
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_f8_xf16_roundtrip(dtype):
"""Tests that converting an f8 to f16 and back to f8 doesn't change its value"""
check_type_supported(dtype)
@triton.jit
def copy_kernel(input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
@@ -732,13 +735,13 @@ def test_f8_f16_roundtrip():
f8_tensor = torch.tensor(range(-128, 128), dtype=torch.int8, device='cuda')
f8 = triton.reinterpret(f8_tensor, tl.float8)
n_elements = f8_tensor.numel()
f16 = torch.empty_like(f8_tensor, dtype=torch.float16)
xf16 = torch.empty_like(f8_tensor, dtype=dtype)
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
copy_kernel[grid](f8, f16, n_elements, BLOCK_SIZE=1024)
copy_kernel[grid](f8, xf16, n_elements, BLOCK_SIZE=1024)
f8_output_tensor = torch.empty_like(f16, dtype=torch.int8)
f8_output_tensor = torch.empty_like(xf16, dtype=torch.int8)
f8_output = triton.reinterpret(f8_output_tensor, tl.float8)
copy_kernel[grid](f16, f8_output, n_elements, BLOCK_SIZE=1024)
copy_kernel[grid](xf16, f8_output, n_elements, BLOCK_SIZE=1024)
assert torch.all(f8_tensor == f8_output_tensor)
@@ -746,7 +749,6 @@ def test_f8_f16_roundtrip():
def test_f16_to_f8_rounding():
"""Takes all float16s, converts them to float8 and back to float16. Checks that the absolute
error is the minimum over all float8.
Or the same explanation a bit mathier:
for all f16 |f16 - fromf8(tof8(f16))| == min over all f8 |f16 - fromf8(f8)|"""
@triton.jit

View File

@@ -581,6 +581,13 @@ def cast(input: tl.tensor,
return input
src_sca_ty = src_ty.scalar
dst_sca_ty = dst_ty.scalar
# fp8 <=> bf16/fp16
if (src_sca_ty.is_bf16() or src_sca_ty.is_fp16()) and dst_sca_ty.is_fp8():
return tl.tensor(builder.create_fp_trunc(input.handle, dst_ty.to_ir(builder)),
dst_ty)
if src_sca_ty.is_fp8() and (dst_sca_ty.is_bf16() or dst_sca_ty.is_fp16()):
return tl.tensor(builder.create_fp_ext(input.handle, dst_ty.to_ir(builder)),
dst_ty)
# bf16 <=> (not fp32)
if (src_sca_ty.is_bf16() and not dst_sca_ty.is_fp32()) or \
(dst_sca_ty.is_bf16() and not src_sca_ty.is_fp32()):