[LANG] Add fp16 to fp8 conversion (#444)
This commit is contained in:
@@ -565,6 +565,90 @@ def test_cast(dtype_x, dtype_z, bitcast, device='cuda'):
|
||||
z_ref = x.astype(getattr(np, dtype_z))
|
||||
assert to_numpy(z_tri) == z_ref
|
||||
|
||||
|
||||
def test_f8_f16_roundtrip():
|
||||
"""Tests that converting an f8 to f16 and back to f8 doesn't change its value"""
|
||||
@triton.jit
|
||||
def copy_kernel(input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
|
||||
offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offsets < n_elements
|
||||
input = tl.load(input_ptr + offsets, mask=mask)
|
||||
output = input
|
||||
tl.store(output_ptr + offsets, output, mask=mask)
|
||||
|
||||
f8_tensor = torch.tensor(range(-128, 128), dtype=torch.int8, device='cuda')
|
||||
f8 = triton.reinterpret(f8_tensor, tl.float8)
|
||||
n_elements = f8_tensor.numel()
|
||||
f16 = torch.empty_like(f8_tensor, dtype=torch.float16)
|
||||
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
|
||||
copy_kernel[grid](f8, f16, n_elements, BLOCK_SIZE=1024)
|
||||
|
||||
f8_output_tensor = torch.empty_like(f16, dtype=torch.int8)
|
||||
f8_output = triton.reinterpret(f8_output_tensor, tl.float8)
|
||||
print(f16.dtype, f8_output.dtype)
|
||||
copy_kernel[grid](f16, f8_output, n_elements, BLOCK_SIZE=1024)
|
||||
|
||||
assert torch.all(f8_tensor == f8_output_tensor)
|
||||
|
||||
|
||||
def test_f16_to_f8_rounding():
|
||||
"""Takes all float16s, converts them to float8 and back to float16. Checks that the absolute
|
||||
error is the minimum over all float8.
|
||||
|
||||
Or the same explanation a bit mathier:
|
||||
for all f16 |f16 - fromf8(tof8(f16))| == min over all f8 |f16 - fromf8(f8)|"""
|
||||
@triton.jit
|
||||
def copy_kernel(input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
|
||||
offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offsets < n_elements
|
||||
input = tl.load(input_ptr + offsets, mask=mask)
|
||||
output = input
|
||||
tl.store(output_ptr + offsets, output, mask=mask)
|
||||
|
||||
# torch.view with a dtype isn't supported in triton's torch yet so use numpy's view
|
||||
f16_input_np = (
|
||||
np.array(
|
||||
range(-int(2 ** (16 - 1)), int(2 ** (16 - 1))), dtype=np.int16,
|
||||
)
|
||||
.view(np.float16)
|
||||
)
|
||||
f16_input = torch.tensor(f16_input_np, dtype=torch.float16, device='cuda')
|
||||
n_elements = f16_input.numel()
|
||||
f8_output_tensor = torch.empty_like(f16_input, dtype=torch.int8)
|
||||
f8_output = triton.reinterpret(f8_output_tensor, tl.float8)
|
||||
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
|
||||
copy_kernel[grid](f16_input, f8_output, n_elements, BLOCK_SIZE=1024)
|
||||
|
||||
f16_output = torch.empty_like(f16_input, dtype=torch.float16)
|
||||
copy_kernel[grid](f8_output, f16_output, n_elements, BLOCK_SIZE=1024)
|
||||
|
||||
abs_error = torch.abs(f16_input - f16_output)
|
||||
|
||||
all_f8_vals_tensor = torch.tensor(range(2 ** 8), dtype=torch.uint8, device='cuda')
|
||||
all_f8_vals = triton.reinterpret(all_f8_vals_tensor, tl.float8)
|
||||
all_f8_vals_in_f16 = torch.empty_like(all_f8_vals_tensor, dtype=torch.float16)
|
||||
copy_kernel[grid](all_f8_vals, all_f8_vals_in_f16, n_elements=256, BLOCK_SIZE=1024)
|
||||
|
||||
all_finite_f8_vals_in_f16 = all_f8_vals_in_f16[
|
||||
torch.isfinite(all_f8_vals_in_f16)
|
||||
]
|
||||
|
||||
min_error = torch.min(
|
||||
torch.abs(
|
||||
f16_input.reshape((-1, 1))
|
||||
- all_finite_f8_vals_in_f16.reshape((1, -1))
|
||||
),
|
||||
dim=1,
|
||||
)[0]
|
||||
# 1.9375 is float8 max
|
||||
mismatch = torch.logical_and(
|
||||
abs_error != min_error, torch.logical_and(torch.isfinite(f16_input), torch.abs(f16_input) < 1.9375)
|
||||
)
|
||||
assert torch.all(
|
||||
torch.logical_not(mismatch)
|
||||
), f"{f16_input[mismatch]=} {f16_output[mismatch]=} {abs_error[mismatch]=} {min_error[mismatch]=}"
|
||||
|
||||
|
||||
# ---------------
|
||||
# test reduce
|
||||
# ---------------
|
||||
|
Reference in New Issue
Block a user