fp8 <> bf16 conversion (#637)
Co-authored-by: Philippe Tillet <phil@openai.com>
This commit is contained in:
@@ -148,6 +148,8 @@ private:
|
|||||||
std::tuple<Value*, Value*, Value*, Value*> fp32x4_to_fp8x4(Value *in0, Value *in1, Value *in2, Value *in3);
|
std::tuple<Value*, Value*, Value*, Value*> fp32x4_to_fp8x4(Value *in0, Value *in1, Value *in2, Value *in3);
|
||||||
std::tuple<Value*, Value*, Value*, Value*> fp8x4_to_fp16x4(Value *in0, Value *in1, Value *in2, Value *in3);
|
std::tuple<Value*, Value*, Value*, Value*> fp8x4_to_fp16x4(Value *in0, Value *in1, Value *in2, Value *in3);
|
||||||
std::tuple<Value*, Value*, Value*, Value*> fp16x4_to_fp8x4(Value *in0, Value *in1, Value *in2, Value *in3);
|
std::tuple<Value*, Value*, Value*, Value*> fp16x4_to_fp8x4(Value *in0, Value *in1, Value *in2, Value *in3);
|
||||||
|
std::tuple<Value*, Value*, Value*, Value*> fp8x4_to_bf16x4(Value *in0, Value *in1, Value *in2, Value *in3);
|
||||||
|
std::tuple<Value*, Value*, Value*, Value*> bf16x4_to_fp8x4(Value *in0, Value *in1, Value *in2, Value *in3);
|
||||||
Value* bf16_to_fp32(Value *in0);
|
Value* bf16_to_fp32(Value *in0);
|
||||||
Value* fp32_to_bf16(Value *in0);
|
Value* fp32_to_bf16(Value *in0);
|
||||||
|
|
||||||
|
@@ -569,10 +569,10 @@ std::tuple<Value*, Value*, Value*, Value*> generator::fp8x4_to_fp16x4(Value *in0
|
|||||||
"prmt.b32 a1, 0, $2, 0x7060; \n\t" // If input is 0xdcba set a1 to 0xd0c0
|
"prmt.b32 a1, 0, $2, 0x7060; \n\t" // If input is 0xdcba set a1 to 0xd0c0
|
||||||
"lop3.b32 b0, a0, 0x7fff7fff, 0, 0xc0; \n\t" // b0 = a0 & 0x7fff7fff (strip sign)
|
"lop3.b32 b0, a0, 0x7fff7fff, 0, 0xc0; \n\t" // b0 = a0 & 0x7fff7fff (strip sign)
|
||||||
"lop3.b32 b1, a1, 0x7fff7fff, 0, 0xc0; \n\t" // b1 = a1 & 0x7fff7fff (strip sign)
|
"lop3.b32 b1, a1, 0x7fff7fff, 0, 0xc0; \n\t" // b1 = a1 & 0x7fff7fff (strip sign)
|
||||||
"shr.b32 b0, b0, 1; \n\t" // b0 <<= 1 (shift into fp16 poistion)
|
"shr.b32 b0, b0, 1; \n\t" // b0 >>= 1 (shift into fp16 poistion)
|
||||||
"shr.b32 b1, b1, 1; \n\t" // b1 <<= 1 (shift into fp16 position)
|
"shr.b32 b1, b1, 1; \n\t" // b1 >>= 1 (shift into fp16 position)
|
||||||
"lop3.b32 $0, b0, 0x80008000, a0, 0xf8; \n\t" // out0 = b0 | (0x80008000 | a0) (restore sign)
|
"lop3.b32 $0, b0, 0x80008000, a0, 0xf8; \n\t" // out0 = b0 | (0x80008000 & a0) (restore sign)
|
||||||
"lop3.b32 $1, b1, 0x80008000, a1, 0xf8; \n\t" // out1 = b1 | (0x80008000 | a1) (restore sign)
|
"lop3.b32 $1, b1, 0x80008000, a1, 0xf8; \n\t" // out1 = b1 | (0x80008000 & a1) (restore sign)
|
||||||
"}", "=r,=r,r", false);
|
"}", "=r,=r,r", false);
|
||||||
Value *packed_in = UndefValue::get(vec_ty(i8_ty, 4));
|
Value *packed_in = UndefValue::get(vec_ty(i8_ty, 4));
|
||||||
packed_in = insert_elt(packed_in, in0, (uint64_t)0);
|
packed_in = insert_elt(packed_in, in0, (uint64_t)0);
|
||||||
@@ -635,6 +635,110 @@ std::tuple<Value*, Value*, Value*, Value*> generator::fp16x4_to_fp8x4(Value *in0
|
|||||||
return std::make_tuple(ret0, ret1, ret2, ret3);
|
return std::make_tuple(ret0, ret1, ret2, ret3);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::tuple<Value*, Value*, Value*, Value*> generator::fp8x4_to_bf16x4(Value *in0, Value *in1, Value *in2, Value *in3) {
|
||||||
|
// current exp offset: 15
|
||||||
|
// Add 112 (127-15) to compensate the difference in exponent bias
|
||||||
|
// bf16 = (nosign >> (8-4) + 112 << 7) | sign;
|
||||||
|
// bf16 = (nosign >> 4 + 0x3800) | sign;
|
||||||
|
Type *ret_ty = StructType::get(*ctx_, {vec_ty(bf16_ty, 2), vec_ty(bf16_ty, 2)});
|
||||||
|
InlineAsm *ptx = InlineAsm::get(FunctionType::get(ret_ty, {i32_ty}, false),
|
||||||
|
"{"
|
||||||
|
".reg .b32 a<2>, sign<2>, nosign<2>, b<2>; \n\t"
|
||||||
|
"prmt.b32 a0, 0, $2, 0x5040; \n\t" // 0xdcba => 0xb0a0
|
||||||
|
"prmt.b32 a1, 0, $2, 0x7060; \n\t" // 0xdcba => 0xd0c0
|
||||||
|
"and.b32 sign0, a0, 0x80008000; \n\t"
|
||||||
|
"and.b32 sign1, a1, 0x80008000; \n\t"
|
||||||
|
"and.b32 nosign0, a0, 0x7fff7fff; \n\t"
|
||||||
|
"and.b32 nosign1, a1, 0x7fff7fff; \n\t"
|
||||||
|
"shr.b32 nosign0, nosign0, 4; \n\t"
|
||||||
|
"shr.b32 nosign1, nosign1, 4; \n\t"
|
||||||
|
"add.u32 nosign0, nosign0, 0x38003800; \n\t"
|
||||||
|
"add.u32 nosign1, nosign1, 0x38003800; \n\t"
|
||||||
|
"or.b32 $0, sign0, nosign0; \n\t"
|
||||||
|
"or.b32 $1, sign1, nosign1; \n\t"
|
||||||
|
"}", "=r,=r,r", false);
|
||||||
|
Value *packed_in = UndefValue::get(vec_ty(i8_ty, 4));
|
||||||
|
packed_in = insert_elt(packed_in, in0, (uint64_t)0);
|
||||||
|
packed_in = insert_elt(packed_in, in1, (uint64_t)1);
|
||||||
|
packed_in = insert_elt(packed_in, in2, (uint64_t)2);
|
||||||
|
packed_in = insert_elt(packed_in, in3, (uint64_t)3);
|
||||||
|
Value *in = bit_cast(packed_in, i32_ty);
|
||||||
|
Value *ret = call(ptx, {in});
|
||||||
|
Value *packed_ret0 = extract_val(ret, {0});
|
||||||
|
Value *packed_ret1 = extract_val(ret, {1});
|
||||||
|
Value *ret0 = extract_elt(packed_ret0, (uint64_t)0);
|
||||||
|
Value *ret1 = extract_elt(packed_ret0, (uint64_t)1);
|
||||||
|
Value *ret2 = extract_elt(packed_ret1, (uint64_t)0);
|
||||||
|
Value *ret3 = extract_elt(packed_ret1, (uint64_t)1);
|
||||||
|
return std::make_tuple(ret0, ret1, ret2, ret3);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::tuple<Value*, Value*, Value*, Value*> generator::bf16x4_to_fp8x4(Value *in0, Value *in1, Value *in2, Value *in3) {
|
||||||
|
/* Assuming fp8 exponent offset is 16. bf16 exponent offset is 127.
|
||||||
|
Max value in fp8: 0b01111111 (0x7f),
|
||||||
|
bf16: 3ff0
|
||||||
|
Min value in fp8: 0b00000000 (0x00)
|
||||||
|
bf16: 0x3c00
|
||||||
|
// @note: +0x8 is for "rounding to nearest zero"
|
||||||
|
fp8 = (nosign(bf16) - (112 << 7) + 0x8) << 4;
|
||||||
|
return fp8 | sign; // also permute bytes
|
||||||
|
*/
|
||||||
|
InlineAsm *ptx = InlineAsm::get(FunctionType::get({vec_ty(i8_ty, 4)}, {i32_ty, i32_ty}, false),
|
||||||
|
"{\n\t"
|
||||||
|
".reg .u32 sign, sign<2>, nosign, nosign<2>; \n\t"
|
||||||
|
".reg .u32 fp8_min, fp8_max, rn_, zero; \n\t"
|
||||||
|
"mov.u32 fp8_min, 0x38003800; \n\t"
|
||||||
|
"mov.u32 fp8_max, 0x3ff03ff0; \n\t"
|
||||||
|
"mov.u32 rn_, 0x80008; \n\t"
|
||||||
|
"mov.u32 zero, 0; \n\t"
|
||||||
|
"and.b32 sign0, $1, 0x80008000; \n\t"
|
||||||
|
"and.b32 sign1, $2, 0x80008000; \n\t"
|
||||||
|
"prmt.b32 sign, sign0, sign1, 0x7531; \n\t"
|
||||||
|
"and.b32 nosign0, $1, 0x7fff7fff; \n\t"
|
||||||
|
"and.b32 nosign1, $2, 0x7fff7fff; \n\t"
|
||||||
|
|
||||||
|
".reg .u32 nosign_0_<2>, nosign_1_<2>; \n\t" // nosign = clamp(nosign, min, max)
|
||||||
|
"and.b32 nosign_0_0, nosign0, 0xffff0000; \n\t"
|
||||||
|
"max.u32 nosign_0_0, nosign_0_0, 0x38000000; \n\t"
|
||||||
|
"min.u32 nosign_0_0, nosign_0_0, 0x3ff00000; \n\t"
|
||||||
|
"and.b32 nosign_0_1, nosign0, 0x0000ffff; \n\t"
|
||||||
|
"max.u32 nosign_0_1, nosign_0_1, 0x3800; \n\t"
|
||||||
|
"min.u32 nosign_0_1, nosign_0_1, 0x3ff0; \n\t"
|
||||||
|
"or.b32 nosign0, nosign_0_0, nosign_0_1; \n\t"
|
||||||
|
"and.b32 nosign_1_0, nosign1, 0xffff0000; \n\t"
|
||||||
|
"max.u32 nosign_1_0, nosign_1_0, 0x38000000; \n\t"
|
||||||
|
"min.u32 nosign_1_0, nosign_1_0, 0x3ff00000; \n\t"
|
||||||
|
"and.b32 nosign_1_1, nosign1, 0x0000ffff; \n\t"
|
||||||
|
"max.u32 nosign_1_1, nosign_1_1, 0x3800; \n\t"
|
||||||
|
"min.u32 nosign_1_1, nosign_1_1, 0x3ff0; \n\t"
|
||||||
|
"or.b32 nosign1, nosign_1_0, nosign_1_1; \n\t"
|
||||||
|
|
||||||
|
"add.u32 nosign0, nosign0, rn_; \n\t" // round to nearest zero
|
||||||
|
"add.u32 nosign1, nosign1, rn_; \n\t"
|
||||||
|
"sub.u32 nosign0, nosign0, 0x38003800; \n\t" // compensate offset
|
||||||
|
"sub.u32 nosign1, nosign1, 0x38003800; \n\t"
|
||||||
|
"shr.u32 nosign0, nosign0, 4; \n\t"
|
||||||
|
"shr.u32 nosign1, nosign1, 4; \n\t"
|
||||||
|
"prmt.b32 nosign, nosign0, nosign1, 0x6420; \n\t"
|
||||||
|
"or.b32 $0, nosign, sign; \n\t"
|
||||||
|
""
|
||||||
|
"}", "=r,r,r", false);
|
||||||
|
Value *packed_in0 = UndefValue::get(vec_ty(bf16_ty, 2));
|
||||||
|
Value *packed_in1 = UndefValue::get(vec_ty(bf16_ty, 2));
|
||||||
|
packed_in0 = insert_elt(packed_in0, in0, (int)0);
|
||||||
|
packed_in0 = insert_elt(packed_in0, in1, (int)1);
|
||||||
|
packed_in1 = insert_elt(packed_in1, in2, (int)0);
|
||||||
|
packed_in1 = insert_elt(packed_in1, in3, (int)1);
|
||||||
|
Value *in_arg0 = bit_cast(packed_in0, i32_ty);
|
||||||
|
Value *in_arg1 = bit_cast(packed_in1, i32_ty);
|
||||||
|
Value *ret = call(ptx, {in_arg0, in_arg1});
|
||||||
|
Value *ret0 = extract_elt(ret, (int)0);
|
||||||
|
Value *ret1 = extract_elt(ret, (int)1);
|
||||||
|
Value *ret2 = extract_elt(ret, (int)2);
|
||||||
|
Value *ret3 = extract_elt(ret, (int)3);
|
||||||
|
return std::make_tuple(ret0, ret1, ret2, ret3);
|
||||||
|
}
|
||||||
|
|
||||||
Value* generator::bf16_to_fp32(Value *in0){
|
Value* generator::bf16_to_fp32(Value *in0){
|
||||||
if (tgt_->as_nvidia()->sm() >= 80) {
|
if (tgt_->as_nvidia()->sm() >= 80) {
|
||||||
InlineAsm *ptx = InlineAsm::get(FunctionType::get(f32_ty, {bf16_ty}, false),
|
InlineAsm *ptx = InlineAsm::get(FunctionType::get(f32_ty, {bf16_ty}, false),
|
||||||
@@ -685,6 +789,11 @@ void generator::visit_cast_inst(ir::cast_inst* x) {
|
|||||||
return fp8x4_to_fp16x4(a, b, c, d);
|
return fp8x4_to_fp16x4(a, b, c, d);
|
||||||
if(op_sca_ty->is_fp8_ty() && ret_sca_ty->is_fp32_ty())
|
if(op_sca_ty->is_fp8_ty() && ret_sca_ty->is_fp32_ty())
|
||||||
return fp8x4_to_fp32x4(a, b, c, d);
|
return fp8x4_to_fp32x4(a, b, c, d);
|
||||||
|
// fp8 <> bf16
|
||||||
|
if(op_sca_ty->is_fp8_ty() && ret_sca_ty->is_bf16_ty())
|
||||||
|
return fp8x4_to_bf16x4(a, b, c, d);
|
||||||
|
if (op_sca_ty->is_bf16_ty() && ret_sca_ty->is_fp8_ty())
|
||||||
|
return bf16x4_to_fp8x4(a, b, c, d);
|
||||||
throw std::runtime_error("unsupported conversion");
|
throw std::runtime_error("unsupported conversion");
|
||||||
};
|
};
|
||||||
for(size_t i = 0; i < x_idxs.size(); i+=4){
|
for(size_t i = 0; i < x_idxs.size(); i+=4){
|
||||||
|
@@ -36,6 +36,9 @@ int membar::group_of(ir::value* v, std::vector<ir::value*> &async_write) {
|
|||||||
else{
|
else{
|
||||||
if(layouts_->has_tmp(v))
|
if(layouts_->has_tmp(v))
|
||||||
return async_write.size() - 1;
|
return async_write.size() - 1;
|
||||||
|
// // Ignore copy_to_shared. It won't modify async behavior.
|
||||||
|
// if(dynamic_cast<ir::copy_to_shared_inst*>(v))
|
||||||
|
// return 0;
|
||||||
auto it = std::find(async_write.begin(), async_write.end(), v);
|
auto it = std::find(async_write.begin(), async_write.end(), v);
|
||||||
return std::distance(async_write.begin(), it);
|
return std::distance(async_write.begin(), it);
|
||||||
}
|
}
|
||||||
|
@@ -719,8 +719,11 @@ def test_cast(dtype_x, dtype_z, bitcast, device='cuda'):
|
|||||||
assert to_numpy(z_tri) == z_ref
|
assert to_numpy(z_tri) == z_ref
|
||||||
|
|
||||||
|
|
||||||
def test_f8_f16_roundtrip():
|
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||||
|
def test_f8_xf16_roundtrip(dtype):
|
||||||
"""Tests that converting an f8 to f16 and back to f8 doesn't change its value"""
|
"""Tests that converting an f8 to f16 and back to f8 doesn't change its value"""
|
||||||
|
check_type_supported(dtype)
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def copy_kernel(input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
|
def copy_kernel(input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
|
||||||
offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||||
@@ -732,13 +735,13 @@ def test_f8_f16_roundtrip():
|
|||||||
f8_tensor = torch.tensor(range(-128, 128), dtype=torch.int8, device='cuda')
|
f8_tensor = torch.tensor(range(-128, 128), dtype=torch.int8, device='cuda')
|
||||||
f8 = triton.reinterpret(f8_tensor, tl.float8)
|
f8 = triton.reinterpret(f8_tensor, tl.float8)
|
||||||
n_elements = f8_tensor.numel()
|
n_elements = f8_tensor.numel()
|
||||||
f16 = torch.empty_like(f8_tensor, dtype=torch.float16)
|
xf16 = torch.empty_like(f8_tensor, dtype=dtype)
|
||||||
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
|
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
|
||||||
copy_kernel[grid](f8, f16, n_elements, BLOCK_SIZE=1024)
|
copy_kernel[grid](f8, xf16, n_elements, BLOCK_SIZE=1024)
|
||||||
|
|
||||||
f8_output_tensor = torch.empty_like(f16, dtype=torch.int8)
|
f8_output_tensor = torch.empty_like(xf16, dtype=torch.int8)
|
||||||
f8_output = triton.reinterpret(f8_output_tensor, tl.float8)
|
f8_output = triton.reinterpret(f8_output_tensor, tl.float8)
|
||||||
copy_kernel[grid](f16, f8_output, n_elements, BLOCK_SIZE=1024)
|
copy_kernel[grid](xf16, f8_output, n_elements, BLOCK_SIZE=1024)
|
||||||
|
|
||||||
assert torch.all(f8_tensor == f8_output_tensor)
|
assert torch.all(f8_tensor == f8_output_tensor)
|
||||||
|
|
||||||
@@ -746,7 +749,6 @@ def test_f8_f16_roundtrip():
|
|||||||
def test_f16_to_f8_rounding():
|
def test_f16_to_f8_rounding():
|
||||||
"""Takes all float16s, converts them to float8 and back to float16. Checks that the absolute
|
"""Takes all float16s, converts them to float8 and back to float16. Checks that the absolute
|
||||||
error is the minimum over all float8.
|
error is the minimum over all float8.
|
||||||
|
|
||||||
Or the same explanation a bit mathier:
|
Or the same explanation a bit mathier:
|
||||||
for all f16 |f16 - fromf8(tof8(f16))| == min over all f8 |f16 - fromf8(f8)|"""
|
for all f16 |f16 - fromf8(tof8(f16))| == min over all f8 |f16 - fromf8(f8)|"""
|
||||||
@triton.jit
|
@triton.jit
|
||||||
|
@@ -581,6 +581,13 @@ def cast(input: tl.tensor,
|
|||||||
return input
|
return input
|
||||||
src_sca_ty = src_ty.scalar
|
src_sca_ty = src_ty.scalar
|
||||||
dst_sca_ty = dst_ty.scalar
|
dst_sca_ty = dst_ty.scalar
|
||||||
|
# fp8 <=> bf16/fp16
|
||||||
|
if (src_sca_ty.is_bf16() or src_sca_ty.is_fp16()) and dst_sca_ty.is_fp8():
|
||||||
|
return tl.tensor(builder.create_fp_trunc(input.handle, dst_ty.to_ir(builder)),
|
||||||
|
dst_ty)
|
||||||
|
if src_sca_ty.is_fp8() and (dst_sca_ty.is_bf16() or dst_sca_ty.is_fp16()):
|
||||||
|
return tl.tensor(builder.create_fp_ext(input.handle, dst_ty.to_ir(builder)),
|
||||||
|
dst_ty)
|
||||||
# bf16 <=> (not fp32)
|
# bf16 <=> (not fp32)
|
||||||
if (src_sca_ty.is_bf16() and not dst_sca_ty.is_fp32()) or \
|
if (src_sca_ty.is_bf16() and not dst_sca_ty.is_fp32()) or \
|
||||||
(dst_sca_ty.is_bf16() and not src_sca_ty.is_fp32()):
|
(dst_sca_ty.is_bf16() and not src_sca_ty.is_fp32()):
|
||||||
|
Reference in New Issue
Block a user