[Triton-MLIR] Support FP8 (#864)

Co-authored-by: Superjomn <yanchunwei@outlook.com>
This commit is contained in:
Chenggang Zhao
2022-11-10 15:53:06 +08:00
committed by GitHub
parent 4946167241
commit 57fd1864a7
18 changed files with 571 additions and 160 deletions

View File

@@ -780,88 +780,88 @@ def test_store_bool():
assert (to_numpy(src).view('uint8') == to_numpy(dst).view('uint8')).all()
# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
# def test_f8_xf16_roundtrip(dtype):
# """Tests that converting an f8 to f16 and back to f8 doesn't change its value"""
# check_type_supported(dtype)
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_f8_xf16_roundtrip(dtype):
"""Tests that converting an f8 to f16 and back to f8 doesn't change its value"""
check_type_supported(dtype)
# @triton.jit
# def copy_kernel(input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
# offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
# mask = offsets < n_elements
# input = tl.load(input_ptr + offsets, mask=mask)
# output = input
# tl.store(output_ptr + offsets, output, mask=mask)
@triton.jit
def copy_kernel(input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
input = tl.load(input_ptr + offsets, mask=mask)
output = input
tl.store(output_ptr + offsets, output, mask=mask)
# f8_tensor = torch.tensor(range(-128, 128), dtype=torch.int8, device='cuda')
# f8 = triton.reinterpret(f8_tensor, tl.float8)
# n_elements = f8_tensor.numel()
# xf16 = torch.empty_like(f8_tensor, dtype=dtype)
# grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
# copy_kernel[grid](f8, xf16, n_elements, BLOCK_SIZE=1024)
f8_tensor = torch.tensor(range(-128, 128), dtype=torch.int8, device='cuda')
f8 = triton.reinterpret(f8_tensor, tl.float8)
n_elements = f8_tensor.numel()
xf16 = torch.empty_like(f8_tensor, dtype=dtype)
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
copy_kernel[grid](f8, xf16, n_elements, BLOCK_SIZE=1024)
# f8_output_tensor = torch.empty_like(xf16, dtype=torch.int8)
# f8_output = triton.reinterpret(f8_output_tensor, tl.float8)
# copy_kernel[grid](xf16, f8_output, n_elements, BLOCK_SIZE=1024)
f8_output_tensor = torch.empty_like(xf16, dtype=torch.int8)
f8_output = triton.reinterpret(f8_output_tensor, tl.float8)
copy_kernel[grid](xf16, f8_output, n_elements, BLOCK_SIZE=1024)
# assert torch.all(f8_tensor == f8_output_tensor)
assert torch.all(f8_tensor == f8_output_tensor)
# def test_f16_to_f8_rounding():
# """Takes all float16s, converts them to float8 and back to float16. Checks that the absolute
# error is the minimum over all float8.
# Or the same explanation a bit mathier:
# for all f16 |f16 - fromf8(tof8(f16))| == min over all f8 |f16 - fromf8(f8)|"""
# @triton.jit
# def copy_kernel(input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
# offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
# mask = offsets < n_elements
# input = tl.load(input_ptr + offsets, mask=mask)
# output = input
# tl.store(output_ptr + offsets, output, mask=mask)
def test_f16_to_f8_rounding():
"""Takes all float16s, converts them to float8 and back to float16. Checks that the absolute
error is the minimum over all float8.
Or the same explanation a bit mathier:
for all f16 |f16 - fromf8(tof8(f16))| == min over all f8 |f16 - fromf8(f8)|"""
@triton.jit
def copy_kernel(input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
input = tl.load(input_ptr + offsets, mask=mask)
output = input
tl.store(output_ptr + offsets, output, mask=mask)
# # torch.view with a dtype isn't supported in triton's torch yet so use numpy's view
# f16_input_np = (
# np.array(
# range(-int(2 ** (16 - 1)), int(2 ** (16 - 1))), dtype=np.int16,
# )
# .view(np.float16)
# )
# f16_input = torch.tensor(f16_input_np, dtype=torch.float16, device='cuda')
# n_elements = f16_input.numel()
# f8_output_tensor = torch.empty_like(f16_input, dtype=torch.int8)
# f8_output = triton.reinterpret(f8_output_tensor, tl.float8)
# grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
# copy_kernel[grid](f16_input, f8_output, n_elements, BLOCK_SIZE=1024)
# torch.view with a dtype isn't supported in triton's torch yet so use numpy's view
f16_input_np = (
np.array(
range(-int(2 ** (16 - 1)), int(2 ** (16 - 1))), dtype=np.int16,
)
.view(np.float16)
)
f16_input = torch.tensor(f16_input_np, dtype=torch.float16, device='cuda')
n_elements = f16_input.numel()
f8_output_tensor = torch.empty_like(f16_input, dtype=torch.int8)
f8_output = triton.reinterpret(f8_output_tensor, tl.float8)
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
copy_kernel[grid](f16_input, f8_output, n_elements, BLOCK_SIZE=1024)
# f16_output = torch.empty_like(f16_input, dtype=torch.float16)
# copy_kernel[grid](f8_output, f16_output, n_elements, BLOCK_SIZE=1024)
f16_output = torch.empty_like(f16_input, dtype=torch.float16)
copy_kernel[grid](f8_output, f16_output, n_elements, BLOCK_SIZE=1024)
# abs_error = torch.abs(f16_input - f16_output)
abs_error = torch.abs(f16_input - f16_output)
# all_f8_vals_tensor = torch.tensor(range(2 ** 8), dtype=torch.uint8, device='cuda')
# all_f8_vals = triton.reinterpret(all_f8_vals_tensor, tl.float8)
# all_f8_vals_in_f16 = torch.empty_like(all_f8_vals_tensor, dtype=torch.float16)
# copy_kernel[grid](all_f8_vals, all_f8_vals_in_f16, n_elements=256, BLOCK_SIZE=1024)
all_f8_vals_tensor = torch.tensor(range(2 ** 8), dtype=torch.uint8, device='cuda')
all_f8_vals = triton.reinterpret(all_f8_vals_tensor, tl.float8)
all_f8_vals_in_f16 = torch.empty_like(all_f8_vals_tensor, dtype=torch.float16)
copy_kernel[grid](all_f8_vals, all_f8_vals_in_f16, n_elements=256, BLOCK_SIZE=1024)
# all_finite_f8_vals_in_f16 = all_f8_vals_in_f16[
# torch.isfinite(all_f8_vals_in_f16)
# ]
all_finite_f8_vals_in_f16 = all_f8_vals_in_f16[
torch.isfinite(all_f8_vals_in_f16)
]
# min_error = torch.min(
# torch.abs(
# f16_input.reshape((-1, 1))
# - all_finite_f8_vals_in_f16.reshape((1, -1))
# ),
# dim=1,
# )[0]
# # 1.9375 is float8 max
# mismatch = torch.logical_and(
# abs_error != min_error, torch.logical_and(torch.isfinite(f16_input), torch.abs(f16_input) < 1.9375)
# )
# assert torch.all(
# torch.logical_not(mismatch)
# ), f"f16_input[mismatch]={f16_input[mismatch]} f16_output[mismatch]={f16_output[mismatch]} abs_error[mismatch]={abs_error[mismatch]} min_error[mismatch]={min_error[mismatch]}"
min_error = torch.min(
torch.abs(
f16_input.reshape((-1, 1))
- all_finite_f8_vals_in_f16.reshape((1, -1))
),
dim=1,
)[0]
# 1.9375 is float8 max
mismatch = torch.logical_and(
abs_error != min_error, torch.logical_and(torch.isfinite(f16_input), torch.abs(f16_input) < 1.9375)
)
assert torch.all(
torch.logical_not(mismatch)
), f"f16_input[mismatch]={f16_input[mismatch]} f16_output[mismatch]={f16_output[mismatch]} abs_error[mismatch]={abs_error[mismatch]} min_error[mismatch]={min_error[mismatch]}"
# # ---------------