[Triton-MLIR] Support FP8 (#864)
Co-authored-by: Superjomn <yanchunwei@outlook.com>
This commit is contained in:
@@ -493,10 +493,6 @@ void init_triton_ir(py::module &&m) {
|
||||
[](mlir::OpBuilder &self) -> mlir::Type {
|
||||
return self.getType<mlir::triton::Float8Type>();
|
||||
})
|
||||
.def("get_bf8_ty",
|
||||
[](mlir::OpBuilder &self) -> mlir::Type {
|
||||
return self.getType<mlir::triton::BFloat8Type>();
|
||||
})
|
||||
.def(
|
||||
"get_half_ty",
|
||||
[](mlir::OpBuilder &self) -> mlir::Type { return self.getF16Type(); })
|
||||
@@ -616,14 +612,20 @@ void init_triton_ir(py::module &&m) {
|
||||
})
|
||||
|
||||
// Cast instructions
|
||||
// Conversions for custom FP types (FP8)
|
||||
.def("create_fp_to_fp",
|
||||
[](mlir::OpBuilder &self, mlir::Value &src,
|
||||
mlir::Type &dstType) -> mlir::Value {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return self.create<mlir::triton::FpToFpOp>(loc, dstType, src);
|
||||
})
|
||||
// Conversions for standard LLVM builtin types
|
||||
.def("create_bitcast",
|
||||
[](mlir::OpBuilder &self, mlir::Value &src,
|
||||
mlir::Type &dstType) -> mlir::Value {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return self.create<mlir::triton::BitcastOp>(loc, dstType, src);
|
||||
})
|
||||
// .def("create_cast", &ir::builder::create_cast)
|
||||
// .def("create_ptr_to_int", &ir::builder::create_ptr_to_int)
|
||||
.def("create_si_to_fp",
|
||||
[](mlir::OpBuilder &self, mlir::Value &src,
|
||||
mlir::Type &dstType) -> mlir::Value {
|
||||
@@ -697,7 +699,6 @@ void init_triton_ir(py::module &&m) {
|
||||
return self.create<mlir::arith::IndexCastOp>(loc, input,
|
||||
self.getI32Type());
|
||||
})
|
||||
|
||||
.def("create_fmul",
|
||||
[](mlir::OpBuilder &self, mlir::Value &lhs,
|
||||
mlir::Value &rhs) -> mlir::Value {
|
||||
|
@@ -780,88 +780,88 @@ def test_store_bool():
|
||||
assert (to_numpy(src).view('uint8') == to_numpy(dst).view('uint8')).all()
|
||||
|
||||
|
||||
# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
# def test_f8_xf16_roundtrip(dtype):
|
||||
# """Tests that converting an f8 to f16 and back to f8 doesn't change its value"""
|
||||
# check_type_supported(dtype)
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
def test_f8_xf16_roundtrip(dtype):
|
||||
"""Tests that converting an f8 to f16 and back to f8 doesn't change its value"""
|
||||
check_type_supported(dtype)
|
||||
|
||||
# @triton.jit
|
||||
# def copy_kernel(input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
|
||||
# offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
# mask = offsets < n_elements
|
||||
# input = tl.load(input_ptr + offsets, mask=mask)
|
||||
# output = input
|
||||
# tl.store(output_ptr + offsets, output, mask=mask)
|
||||
@triton.jit
|
||||
def copy_kernel(input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
|
||||
offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offsets < n_elements
|
||||
input = tl.load(input_ptr + offsets, mask=mask)
|
||||
output = input
|
||||
tl.store(output_ptr + offsets, output, mask=mask)
|
||||
|
||||
# f8_tensor = torch.tensor(range(-128, 128), dtype=torch.int8, device='cuda')
|
||||
# f8 = triton.reinterpret(f8_tensor, tl.float8)
|
||||
# n_elements = f8_tensor.numel()
|
||||
# xf16 = torch.empty_like(f8_tensor, dtype=dtype)
|
||||
# grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
|
||||
# copy_kernel[grid](f8, xf16, n_elements, BLOCK_SIZE=1024)
|
||||
f8_tensor = torch.tensor(range(-128, 128), dtype=torch.int8, device='cuda')
|
||||
f8 = triton.reinterpret(f8_tensor, tl.float8)
|
||||
n_elements = f8_tensor.numel()
|
||||
xf16 = torch.empty_like(f8_tensor, dtype=dtype)
|
||||
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
|
||||
copy_kernel[grid](f8, xf16, n_elements, BLOCK_SIZE=1024)
|
||||
|
||||
# f8_output_tensor = torch.empty_like(xf16, dtype=torch.int8)
|
||||
# f8_output = triton.reinterpret(f8_output_tensor, tl.float8)
|
||||
# copy_kernel[grid](xf16, f8_output, n_elements, BLOCK_SIZE=1024)
|
||||
f8_output_tensor = torch.empty_like(xf16, dtype=torch.int8)
|
||||
f8_output = triton.reinterpret(f8_output_tensor, tl.float8)
|
||||
copy_kernel[grid](xf16, f8_output, n_elements, BLOCK_SIZE=1024)
|
||||
|
||||
# assert torch.all(f8_tensor == f8_output_tensor)
|
||||
assert torch.all(f8_tensor == f8_output_tensor)
|
||||
|
||||
|
||||
# def test_f16_to_f8_rounding():
|
||||
# """Takes all float16s, converts them to float8 and back to float16. Checks that the absolute
|
||||
# error is the minimum over all float8.
|
||||
# Or the same explanation a bit mathier:
|
||||
# for all f16 |f16 - fromf8(tof8(f16))| == min over all f8 |f16 - fromf8(f8)|"""
|
||||
# @triton.jit
|
||||
# def copy_kernel(input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
|
||||
# offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
# mask = offsets < n_elements
|
||||
# input = tl.load(input_ptr + offsets, mask=mask)
|
||||
# output = input
|
||||
# tl.store(output_ptr + offsets, output, mask=mask)
|
||||
def test_f16_to_f8_rounding():
|
||||
"""Takes all float16s, converts them to float8 and back to float16. Checks that the absolute
|
||||
error is the minimum over all float8.
|
||||
Or the same explanation a bit mathier:
|
||||
for all f16 |f16 - fromf8(tof8(f16))| == min over all f8 |f16 - fromf8(f8)|"""
|
||||
@triton.jit
|
||||
def copy_kernel(input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
|
||||
offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offsets < n_elements
|
||||
input = tl.load(input_ptr + offsets, mask=mask)
|
||||
output = input
|
||||
tl.store(output_ptr + offsets, output, mask=mask)
|
||||
|
||||
# # torch.view with a dtype isn't supported in triton's torch yet so use numpy's view
|
||||
# f16_input_np = (
|
||||
# np.array(
|
||||
# range(-int(2 ** (16 - 1)), int(2 ** (16 - 1))), dtype=np.int16,
|
||||
# )
|
||||
# .view(np.float16)
|
||||
# )
|
||||
# f16_input = torch.tensor(f16_input_np, dtype=torch.float16, device='cuda')
|
||||
# n_elements = f16_input.numel()
|
||||
# f8_output_tensor = torch.empty_like(f16_input, dtype=torch.int8)
|
||||
# f8_output = triton.reinterpret(f8_output_tensor, tl.float8)
|
||||
# grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
|
||||
# copy_kernel[grid](f16_input, f8_output, n_elements, BLOCK_SIZE=1024)
|
||||
# torch.view with a dtype isn't supported in triton's torch yet so use numpy's view
|
||||
f16_input_np = (
|
||||
np.array(
|
||||
range(-int(2 ** (16 - 1)), int(2 ** (16 - 1))), dtype=np.int16,
|
||||
)
|
||||
.view(np.float16)
|
||||
)
|
||||
f16_input = torch.tensor(f16_input_np, dtype=torch.float16, device='cuda')
|
||||
n_elements = f16_input.numel()
|
||||
f8_output_tensor = torch.empty_like(f16_input, dtype=torch.int8)
|
||||
f8_output = triton.reinterpret(f8_output_tensor, tl.float8)
|
||||
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
|
||||
copy_kernel[grid](f16_input, f8_output, n_elements, BLOCK_SIZE=1024)
|
||||
|
||||
# f16_output = torch.empty_like(f16_input, dtype=torch.float16)
|
||||
# copy_kernel[grid](f8_output, f16_output, n_elements, BLOCK_SIZE=1024)
|
||||
f16_output = torch.empty_like(f16_input, dtype=torch.float16)
|
||||
copy_kernel[grid](f8_output, f16_output, n_elements, BLOCK_SIZE=1024)
|
||||
|
||||
# abs_error = torch.abs(f16_input - f16_output)
|
||||
abs_error = torch.abs(f16_input - f16_output)
|
||||
|
||||
# all_f8_vals_tensor = torch.tensor(range(2 ** 8), dtype=torch.uint8, device='cuda')
|
||||
# all_f8_vals = triton.reinterpret(all_f8_vals_tensor, tl.float8)
|
||||
# all_f8_vals_in_f16 = torch.empty_like(all_f8_vals_tensor, dtype=torch.float16)
|
||||
# copy_kernel[grid](all_f8_vals, all_f8_vals_in_f16, n_elements=256, BLOCK_SIZE=1024)
|
||||
all_f8_vals_tensor = torch.tensor(range(2 ** 8), dtype=torch.uint8, device='cuda')
|
||||
all_f8_vals = triton.reinterpret(all_f8_vals_tensor, tl.float8)
|
||||
all_f8_vals_in_f16 = torch.empty_like(all_f8_vals_tensor, dtype=torch.float16)
|
||||
copy_kernel[grid](all_f8_vals, all_f8_vals_in_f16, n_elements=256, BLOCK_SIZE=1024)
|
||||
|
||||
# all_finite_f8_vals_in_f16 = all_f8_vals_in_f16[
|
||||
# torch.isfinite(all_f8_vals_in_f16)
|
||||
# ]
|
||||
all_finite_f8_vals_in_f16 = all_f8_vals_in_f16[
|
||||
torch.isfinite(all_f8_vals_in_f16)
|
||||
]
|
||||
|
||||
# min_error = torch.min(
|
||||
# torch.abs(
|
||||
# f16_input.reshape((-1, 1))
|
||||
# - all_finite_f8_vals_in_f16.reshape((1, -1))
|
||||
# ),
|
||||
# dim=1,
|
||||
# )[0]
|
||||
# # 1.9375 is float8 max
|
||||
# mismatch = torch.logical_and(
|
||||
# abs_error != min_error, torch.logical_and(torch.isfinite(f16_input), torch.abs(f16_input) < 1.9375)
|
||||
# )
|
||||
# assert torch.all(
|
||||
# torch.logical_not(mismatch)
|
||||
# ), f"f16_input[mismatch]={f16_input[mismatch]} f16_output[mismatch]={f16_output[mismatch]} abs_error[mismatch]={abs_error[mismatch]} min_error[mismatch]={min_error[mismatch]}"
|
||||
min_error = torch.min(
|
||||
torch.abs(
|
||||
f16_input.reshape((-1, 1))
|
||||
- all_finite_f8_vals_in_f16.reshape((1, -1))
|
||||
),
|
||||
dim=1,
|
||||
)[0]
|
||||
# 1.9375 is float8 max
|
||||
mismatch = torch.logical_and(
|
||||
abs_error != min_error, torch.logical_and(torch.isfinite(f16_input), torch.abs(f16_input) < 1.9375)
|
||||
)
|
||||
assert torch.all(
|
||||
torch.logical_not(mismatch)
|
||||
), f"f16_input[mismatch]={f16_input[mismatch]} f16_output[mismatch]={f16_output[mismatch]} abs_error[mismatch]={abs_error[mismatch]} min_error[mismatch]={min_error[mismatch]}"
|
||||
|
||||
|
||||
# # ---------------
|
||||
|
@@ -48,6 +48,8 @@ class dtype:
|
||||
SINT_TYPES = ['int1', 'int8', 'int16', 'int32', 'int64']
|
||||
UINT_TYPES = ['uint8', 'uint16', 'uint32', 'uint64']
|
||||
FP_TYPES = ['fp8', 'fp16', 'bf16', 'fp32', 'fp64']
|
||||
CUSTOMIZED_FP_TYPES = ['fp8']
|
||||
STANDARD_FP_TYPES = ['fp16', 'bf16', 'fp32', 'fp64']
|
||||
OTHER_TYPES = ['void']
|
||||
|
||||
class SIGNEDNESS(Enum):
|
||||
@@ -129,6 +131,12 @@ class dtype:
|
||||
def is_floating(self):
|
||||
return self.name in dtype.FP_TYPES
|
||||
|
||||
def is_customized_floating(self):
|
||||
return self.name in dtype.CUSTOMIZED_FP_TYPES
|
||||
|
||||
def is_standard_floating(self):
|
||||
return self.name in dtype.STANDARD_FP_TYPES
|
||||
|
||||
def is_int_signed(self):
|
||||
return self.name in dtype.SINT_TYPES
|
||||
|
||||
|
@@ -613,39 +613,45 @@ def cast(input: tl.tensor,
|
||||
dst_ty = tl.block_type(dst_ty, input.type.get_block_shapes())
|
||||
if src_ty == dst_ty:
|
||||
return input
|
||||
|
||||
src_sca_ty = src_ty.scalar
|
||||
dst_sca_ty = dst_ty.scalar
|
||||
# fp8 <=> bf16/fp16
|
||||
if (src_sca_ty.is_bf16() or src_sca_ty.is_fp16()) and dst_sca_ty.is_fp8():
|
||||
return tl.tensor(builder.create_fp_trunc(input.handle, dst_ty.to_ir(builder)),
|
||||
|
||||
# Casting with customized floating types involved: fp8 <=> bf16, fp16, fp32, fp64
|
||||
if (src_sca_ty.is_customized_floating() and dst_sca_ty.is_floating()) or \
|
||||
(src_sca_ty.is_floating() and dst_sca_ty.is_customized_floating()):
|
||||
return tl.tensor(builder.create_fp_to_fp(input.handle, dst_ty.to_ir(builder)),
|
||||
dst_ty)
|
||||
if src_sca_ty.is_fp8() and (dst_sca_ty.is_bf16() or dst_sca_ty.is_fp16()):
|
||||
return tl.tensor(builder.create_fp_ext(input.handle, dst_ty.to_ir(builder)),
|
||||
dst_ty)
|
||||
# bf16 <=> (not fp32)
|
||||
if (src_sca_ty.is_bf16() and not dst_sca_ty.is_fp32()) or \
|
||||
(dst_sca_ty.is_bf16() and not src_sca_ty.is_fp32()):
|
||||
|
||||
# Casting types of the same bit width: fp16 <=> bf16
|
||||
if (src_sca_ty.is_fp16() and dst_sca_ty.is_bf16()) or \
|
||||
(src_sca_ty.is_bf16() and dst_sca_ty.is_fp16()):
|
||||
return cast(cast(input, tl.float32, builder), dst_sca_ty, builder)
|
||||
|
||||
# FP Truncation
|
||||
# Standard floating types' casting: truncation
|
||||
# fp64 => fp32, fp16, bf16
|
||||
# fp32 => fp16, bf16
|
||||
truncate_fp = src_sca_ty.is_floating() and \
|
||||
dst_sca_ty.is_floating() and \
|
||||
src_sca_ty.fp_mantissa_width > dst_sca_ty.fp_mantissa_width
|
||||
src_sca_ty.primitive_bitwidth > dst_sca_ty.primitive_bitwidth
|
||||
if truncate_fp:
|
||||
return tl.tensor(builder.create_fp_trunc(input.handle,
|
||||
dst_ty.to_ir(builder)),
|
||||
dst_ty)
|
||||
|
||||
# FP Extension
|
||||
# Standard floating types' casting: extension
|
||||
# fp32 => fp64
|
||||
# fp16 => fp32, fp64
|
||||
# bf16 => fp32, fp64
|
||||
ext_fp = src_sca_ty.is_floating() and \
|
||||
dst_sca_ty.is_floating() and \
|
||||
src_sca_ty.fp_mantissa_width < dst_sca_ty.fp_mantissa_width
|
||||
src_sca_ty.primitive_bitwidth < dst_sca_ty.primitive_bitwidth
|
||||
if ext_fp:
|
||||
return tl.tensor(builder.create_fp_ext(input.handle,
|
||||
dst_ty.to_ir(builder)),
|
||||
dst_ty)
|
||||
|
||||
# Int cast
|
||||
# Casting between integer types
|
||||
if src_sca_ty.is_int() and dst_sca_ty.is_int() and \
|
||||
(src_sca_ty.int_bitwidth != dst_sca_ty.int_bitwidth or src_sca_ty.int_signedness != dst_sca_ty.int_signedness):
|
||||
sign_extend = src_sca_ty.is_int_signed() and not src_sca_ty.is_bool()
|
||||
@@ -658,8 +664,8 @@ def cast(input: tl.tensor,
|
||||
dst_ty.to_ir(builder), sign_extend),
|
||||
dst_ty)
|
||||
|
||||
# Float to Int
|
||||
if src_sca_ty.is_floating() and dst_sca_ty.is_int():
|
||||
# Casting standard floating types to integer types
|
||||
if src_sca_ty.is_standard_floating() and dst_sca_ty.is_int():
|
||||
if dst_sca_ty.is_bool():
|
||||
ty = input.dtype.to_ir(builder)
|
||||
_0 = tl.tensor(builder.get_null_value(ty), input.dtype)
|
||||
@@ -673,8 +679,8 @@ def cast(input: tl.tensor,
|
||||
dst_ty.to_ir(builder)),
|
||||
dst_ty)
|
||||
|
||||
# int => float
|
||||
if src_sca_ty.is_int() and dst_sca_ty.is_floating():
|
||||
# Casting integer types to standard floating types
|
||||
if src_sca_ty.is_int() and dst_sca_ty.is_standard_floating():
|
||||
if src_sca_ty.is_bool() or not src_sca_ty.is_int_signed():
|
||||
return tl.tensor(builder.create_ui_to_fp(input.handle,
|
||||
dst_ty.to_ir(builder)),
|
||||
@@ -684,7 +690,7 @@ def cast(input: tl.tensor,
|
||||
dst_ty.to_ir(builder)),
|
||||
dst_ty)
|
||||
|
||||
# ptr => int
|
||||
# Casting pointer types to integer types
|
||||
if src_sca_ty.is_ptr() and dst_sca_ty.is_int():
|
||||
bitwidth = dst_sca_ty.int_bitwidth
|
||||
if bitwidth == 64:
|
||||
@@ -695,19 +701,14 @@ def cast(input: tl.tensor,
|
||||
tl.tensor(builder.get_int64(0), tl.int64),
|
||||
builder)
|
||||
|
||||
if not src_sca_ty.is_ptr() and dst_sca_ty.is_ptr():
|
||||
# Casting integer types to pointer types
|
||||
if src_sca_ty.is_int() and dst_sca_ty.is_ptr():
|
||||
return tl.tensor(builder.create_int_to_ptr(input.handle, dst_ty.to_ir(builder)), dst_ty)
|
||||
# Ptr . Ptr
|
||||
|
||||
# Casting pointer types to pointer types
|
||||
if src_sca_ty.is_ptr() and dst_sca_ty.is_ptr():
|
||||
return tl.tensor(builder.create_bitcast(input.handle, dst_ty.to_ir(builder)), dst_ty)
|
||||
# * . Bool
|
||||
if dst_sca_ty.is_bool():
|
||||
if src_sca_ty.is_ptr():
|
||||
input = cast(input, tl.int64, builder)
|
||||
other = builder.get_int64(0)
|
||||
if src_ty.is_bool():
|
||||
other = builder.create_splat(other, src_ty.get_block_shapes())
|
||||
return tl.tensor(builder.create_icmpNE(input.handle, other), dst_ty)
|
||||
|
||||
assert False, f'cannot cast {input} to {dst_ty}'
|
||||
|
||||
# ===----------------------------------------------------------------------===//
|
||||
|
@@ -176,6 +176,9 @@ class JITFunction(KernelInterface):
|
||||
triton.language.uint32: 'u32',
|
||||
triton.language.uint64: 'u64',
|
||||
triton.language.float8: 'fp8',
|
||||
triton.language.float16: 'fp16',
|
||||
triton.language.bfloat16: 'bf16',
|
||||
triton.language.float32: 'fp32',
|
||||
}[key]
|
||||
return f'*{ty}'
|
||||
if key is None:
|
||||
|
Reference in New Issue
Block a user