fp8 <> bf16 conversion (#637)

Co-authored-by: Philippe Tillet <phil@openai.com>
This commit is contained in:
Da Yan
2022-08-31 05:20:12 +08:00
committed by GitHub
parent 210a296699
commit 437ced38c2
5 changed files with 133 additions and 10 deletions

View File

@@ -148,6 +148,8 @@ private:
std::tuple<Value*, Value*, Value*, Value*> fp32x4_to_fp8x4(Value *in0, Value *in1, Value *in2, Value *in3); std::tuple<Value*, Value*, Value*, Value*> fp32x4_to_fp8x4(Value *in0, Value *in1, Value *in2, Value *in3);
std::tuple<Value*, Value*, Value*, Value*> fp8x4_to_fp16x4(Value *in0, Value *in1, Value *in2, Value *in3); std::tuple<Value*, Value*, Value*, Value*> fp8x4_to_fp16x4(Value *in0, Value *in1, Value *in2, Value *in3);
std::tuple<Value*, Value*, Value*, Value*> fp16x4_to_fp8x4(Value *in0, Value *in1, Value *in2, Value *in3); std::tuple<Value*, Value*, Value*, Value*> fp16x4_to_fp8x4(Value *in0, Value *in1, Value *in2, Value *in3);
std::tuple<Value*, Value*, Value*, Value*> fp8x4_to_bf16x4(Value *in0, Value *in1, Value *in2, Value *in3);
std::tuple<Value*, Value*, Value*, Value*> bf16x4_to_fp8x4(Value *in0, Value *in1, Value *in2, Value *in3);
Value* bf16_to_fp32(Value *in0); Value* bf16_to_fp32(Value *in0);
Value* fp32_to_bf16(Value *in0); Value* fp32_to_bf16(Value *in0);

View File

@@ -569,10 +569,10 @@ std::tuple<Value*, Value*, Value*, Value*> generator::fp8x4_to_fp16x4(Value *in0
"prmt.b32 a1, 0, $2, 0x7060; \n\t" // If input is 0xdcba set a1 to 0xd0c0 "prmt.b32 a1, 0, $2, 0x7060; \n\t" // If input is 0xdcba set a1 to 0xd0c0
"lop3.b32 b0, a0, 0x7fff7fff, 0, 0xc0; \n\t" // b0 = a0 & 0x7fff7fff (strip sign) "lop3.b32 b0, a0, 0x7fff7fff, 0, 0xc0; \n\t" // b0 = a0 & 0x7fff7fff (strip sign)
"lop3.b32 b1, a1, 0x7fff7fff, 0, 0xc0; \n\t" // b1 = a1 & 0x7fff7fff (strip sign) "lop3.b32 b1, a1, 0x7fff7fff, 0, 0xc0; \n\t" // b1 = a1 & 0x7fff7fff (strip sign)
"shr.b32 b0, b0, 1; \n\t" // b0 <<= 1 (shift into fp16 poistion) "shr.b32 b0, b0, 1; \n\t" // b0 >>= 1 (shift into fp16 poistion)
"shr.b32 b1, b1, 1; \n\t" // b1 <<= 1 (shift into fp16 position) "shr.b32 b1, b1, 1; \n\t" // b1 >>= 1 (shift into fp16 position)
"lop3.b32 $0, b0, 0x80008000, a0, 0xf8; \n\t" // out0 = b0 | (0x80008000 | a0) (restore sign) "lop3.b32 $0, b0, 0x80008000, a0, 0xf8; \n\t" // out0 = b0 | (0x80008000 & a0) (restore sign)
"lop3.b32 $1, b1, 0x80008000, a1, 0xf8; \n\t" // out1 = b1 | (0x80008000 | a1) (restore sign) "lop3.b32 $1, b1, 0x80008000, a1, 0xf8; \n\t" // out1 = b1 | (0x80008000 & a1) (restore sign)
"}", "=r,=r,r", false); "}", "=r,=r,r", false);
Value *packed_in = UndefValue::get(vec_ty(i8_ty, 4)); Value *packed_in = UndefValue::get(vec_ty(i8_ty, 4));
packed_in = insert_elt(packed_in, in0, (uint64_t)0); packed_in = insert_elt(packed_in, in0, (uint64_t)0);
@@ -635,6 +635,110 @@ std::tuple<Value*, Value*, Value*, Value*> generator::fp16x4_to_fp8x4(Value *in0
return std::make_tuple(ret0, ret1, ret2, ret3); return std::make_tuple(ret0, ret1, ret2, ret3);
} }
std::tuple<Value*, Value*, Value*, Value*> generator::fp8x4_to_bf16x4(Value *in0, Value *in1, Value *in2, Value *in3) {
// current exp offset: 15
// Add 112 (127-15) to compensate the difference in exponent bias
// bf16 = (nosign >> (8-4) + 112 << 7) | sign;
// bf16 = (nosign >> 4 + 0x3800) | sign;
Type *ret_ty = StructType::get(*ctx_, {vec_ty(bf16_ty, 2), vec_ty(bf16_ty, 2)});
InlineAsm *ptx = InlineAsm::get(FunctionType::get(ret_ty, {i32_ty}, false),
"{"
".reg .b32 a<2>, sign<2>, nosign<2>, b<2>; \n\t"
"prmt.b32 a0, 0, $2, 0x5040; \n\t" // 0xdcba => 0xb0a0
"prmt.b32 a1, 0, $2, 0x7060; \n\t" // 0xdcba => 0xd0c0
"and.b32 sign0, a0, 0x80008000; \n\t"
"and.b32 sign1, a1, 0x80008000; \n\t"
"and.b32 nosign0, a0, 0x7fff7fff; \n\t"
"and.b32 nosign1, a1, 0x7fff7fff; \n\t"
"shr.b32 nosign0, nosign0, 4; \n\t"
"shr.b32 nosign1, nosign1, 4; \n\t"
"add.u32 nosign0, nosign0, 0x38003800; \n\t"
"add.u32 nosign1, nosign1, 0x38003800; \n\t"
"or.b32 $0, sign0, nosign0; \n\t"
"or.b32 $1, sign1, nosign1; \n\t"
"}", "=r,=r,r", false);
Value *packed_in = UndefValue::get(vec_ty(i8_ty, 4));
packed_in = insert_elt(packed_in, in0, (uint64_t)0);
packed_in = insert_elt(packed_in, in1, (uint64_t)1);
packed_in = insert_elt(packed_in, in2, (uint64_t)2);
packed_in = insert_elt(packed_in, in3, (uint64_t)3);
Value *in = bit_cast(packed_in, i32_ty);
Value *ret = call(ptx, {in});
Value *packed_ret0 = extract_val(ret, {0});
Value *packed_ret1 = extract_val(ret, {1});
Value *ret0 = extract_elt(packed_ret0, (uint64_t)0);
Value *ret1 = extract_elt(packed_ret0, (uint64_t)1);
Value *ret2 = extract_elt(packed_ret1, (uint64_t)0);
Value *ret3 = extract_elt(packed_ret1, (uint64_t)1);
return std::make_tuple(ret0, ret1, ret2, ret3);
}
std::tuple<Value*, Value*, Value*, Value*> generator::bf16x4_to_fp8x4(Value *in0, Value *in1, Value *in2, Value *in3) {
/* Assuming fp8 exponent offset is 16. bf16 exponent offset is 127.
Max value in fp8: 0b01111111 (0x7f),
bf16: 3ff0
Min value in fp8: 0b00000000 (0x00)
bf16: 0x3c00
// @note: +0x8 is for "rounding to nearest zero"
fp8 = (nosign(bf16) - (112 << 7) + 0x8) << 4;
return fp8 | sign; // also permute bytes
*/
InlineAsm *ptx = InlineAsm::get(FunctionType::get({vec_ty(i8_ty, 4)}, {i32_ty, i32_ty}, false),
"{\n\t"
".reg .u32 sign, sign<2>, nosign, nosign<2>; \n\t"
".reg .u32 fp8_min, fp8_max, rn_, zero; \n\t"
"mov.u32 fp8_min, 0x38003800; \n\t"
"mov.u32 fp8_max, 0x3ff03ff0; \n\t"
"mov.u32 rn_, 0x80008; \n\t"
"mov.u32 zero, 0; \n\t"
"and.b32 sign0, $1, 0x80008000; \n\t"
"and.b32 sign1, $2, 0x80008000; \n\t"
"prmt.b32 sign, sign0, sign1, 0x7531; \n\t"
"and.b32 nosign0, $1, 0x7fff7fff; \n\t"
"and.b32 nosign1, $2, 0x7fff7fff; \n\t"
".reg .u32 nosign_0_<2>, nosign_1_<2>; \n\t" // nosign = clamp(nosign, min, max)
"and.b32 nosign_0_0, nosign0, 0xffff0000; \n\t"
"max.u32 nosign_0_0, nosign_0_0, 0x38000000; \n\t"
"min.u32 nosign_0_0, nosign_0_0, 0x3ff00000; \n\t"
"and.b32 nosign_0_1, nosign0, 0x0000ffff; \n\t"
"max.u32 nosign_0_1, nosign_0_1, 0x3800; \n\t"
"min.u32 nosign_0_1, nosign_0_1, 0x3ff0; \n\t"
"or.b32 nosign0, nosign_0_0, nosign_0_1; \n\t"
"and.b32 nosign_1_0, nosign1, 0xffff0000; \n\t"
"max.u32 nosign_1_0, nosign_1_0, 0x38000000; \n\t"
"min.u32 nosign_1_0, nosign_1_0, 0x3ff00000; \n\t"
"and.b32 nosign_1_1, nosign1, 0x0000ffff; \n\t"
"max.u32 nosign_1_1, nosign_1_1, 0x3800; \n\t"
"min.u32 nosign_1_1, nosign_1_1, 0x3ff0; \n\t"
"or.b32 nosign1, nosign_1_0, nosign_1_1; \n\t"
"add.u32 nosign0, nosign0, rn_; \n\t" // round to nearest zero
"add.u32 nosign1, nosign1, rn_; \n\t"
"sub.u32 nosign0, nosign0, 0x38003800; \n\t" // compensate offset
"sub.u32 nosign1, nosign1, 0x38003800; \n\t"
"shr.u32 nosign0, nosign0, 4; \n\t"
"shr.u32 nosign1, nosign1, 4; \n\t"
"prmt.b32 nosign, nosign0, nosign1, 0x6420; \n\t"
"or.b32 $0, nosign, sign; \n\t"
""
"}", "=r,r,r", false);
Value *packed_in0 = UndefValue::get(vec_ty(bf16_ty, 2));
Value *packed_in1 = UndefValue::get(vec_ty(bf16_ty, 2));
packed_in0 = insert_elt(packed_in0, in0, (int)0);
packed_in0 = insert_elt(packed_in0, in1, (int)1);
packed_in1 = insert_elt(packed_in1, in2, (int)0);
packed_in1 = insert_elt(packed_in1, in3, (int)1);
Value *in_arg0 = bit_cast(packed_in0, i32_ty);
Value *in_arg1 = bit_cast(packed_in1, i32_ty);
Value *ret = call(ptx, {in_arg0, in_arg1});
Value *ret0 = extract_elt(ret, (int)0);
Value *ret1 = extract_elt(ret, (int)1);
Value *ret2 = extract_elt(ret, (int)2);
Value *ret3 = extract_elt(ret, (int)3);
return std::make_tuple(ret0, ret1, ret2, ret3);
}
Value* generator::bf16_to_fp32(Value *in0){ Value* generator::bf16_to_fp32(Value *in0){
if (tgt_->as_nvidia()->sm() >= 80) { if (tgt_->as_nvidia()->sm() >= 80) {
InlineAsm *ptx = InlineAsm::get(FunctionType::get(f32_ty, {bf16_ty}, false), InlineAsm *ptx = InlineAsm::get(FunctionType::get(f32_ty, {bf16_ty}, false),
@@ -685,6 +789,11 @@ void generator::visit_cast_inst(ir::cast_inst* x) {
return fp8x4_to_fp16x4(a, b, c, d); return fp8x4_to_fp16x4(a, b, c, d);
if(op_sca_ty->is_fp8_ty() && ret_sca_ty->is_fp32_ty()) if(op_sca_ty->is_fp8_ty() && ret_sca_ty->is_fp32_ty())
return fp8x4_to_fp32x4(a, b, c, d); return fp8x4_to_fp32x4(a, b, c, d);
// fp8 <> bf16
if(op_sca_ty->is_fp8_ty() && ret_sca_ty->is_bf16_ty())
return fp8x4_to_bf16x4(a, b, c, d);
if (op_sca_ty->is_bf16_ty() && ret_sca_ty->is_fp8_ty())
return bf16x4_to_fp8x4(a, b, c, d);
throw std::runtime_error("unsupported conversion"); throw std::runtime_error("unsupported conversion");
}; };
for(size_t i = 0; i < x_idxs.size(); i+=4){ for(size_t i = 0; i < x_idxs.size(); i+=4){

View File

@@ -36,6 +36,9 @@ int membar::group_of(ir::value* v, std::vector<ir::value*> &async_write) {
else{ else{
if(layouts_->has_tmp(v)) if(layouts_->has_tmp(v))
return async_write.size() - 1; return async_write.size() - 1;
// // Ignore copy_to_shared. It won't modify async behavior.
// if(dynamic_cast<ir::copy_to_shared_inst*>(v))
// return 0;
auto it = std::find(async_write.begin(), async_write.end(), v); auto it = std::find(async_write.begin(), async_write.end(), v);
return std::distance(async_write.begin(), it); return std::distance(async_write.begin(), it);
} }

View File

@@ -719,8 +719,11 @@ def test_cast(dtype_x, dtype_z, bitcast, device='cuda'):
assert to_numpy(z_tri) == z_ref assert to_numpy(z_tri) == z_ref
def test_f8_f16_roundtrip(): @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_f8_xf16_roundtrip(dtype):
"""Tests that converting an f8 to f16 and back to f8 doesn't change its value""" """Tests that converting an f8 to f16 and back to f8 doesn't change its value"""
check_type_supported(dtype)
@triton.jit @triton.jit
def copy_kernel(input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr): def copy_kernel(input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
@@ -732,13 +735,13 @@ def test_f8_f16_roundtrip():
f8_tensor = torch.tensor(range(-128, 128), dtype=torch.int8, device='cuda') f8_tensor = torch.tensor(range(-128, 128), dtype=torch.int8, device='cuda')
f8 = triton.reinterpret(f8_tensor, tl.float8) f8 = triton.reinterpret(f8_tensor, tl.float8)
n_elements = f8_tensor.numel() n_elements = f8_tensor.numel()
f16 = torch.empty_like(f8_tensor, dtype=torch.float16) xf16 = torch.empty_like(f8_tensor, dtype=dtype)
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
copy_kernel[grid](f8, f16, n_elements, BLOCK_SIZE=1024) copy_kernel[grid](f8, xf16, n_elements, BLOCK_SIZE=1024)
f8_output_tensor = torch.empty_like(f16, dtype=torch.int8) f8_output_tensor = torch.empty_like(xf16, dtype=torch.int8)
f8_output = triton.reinterpret(f8_output_tensor, tl.float8) f8_output = triton.reinterpret(f8_output_tensor, tl.float8)
copy_kernel[grid](f16, f8_output, n_elements, BLOCK_SIZE=1024) copy_kernel[grid](xf16, f8_output, n_elements, BLOCK_SIZE=1024)
assert torch.all(f8_tensor == f8_output_tensor) assert torch.all(f8_tensor == f8_output_tensor)
@@ -746,7 +749,6 @@ def test_f8_f16_roundtrip():
def test_f16_to_f8_rounding(): def test_f16_to_f8_rounding():
"""Takes all float16s, converts them to float8 and back to float16. Checks that the absolute """Takes all float16s, converts them to float8 and back to float16. Checks that the absolute
error is the minimum over all float8. error is the minimum over all float8.
Or the same explanation a bit mathier: Or the same explanation a bit mathier:
for all f16 |f16 - fromf8(tof8(f16))| == min over all f8 |f16 - fromf8(f8)|""" for all f16 |f16 - fromf8(tof8(f16))| == min over all f8 |f16 - fromf8(f8)|"""
@triton.jit @triton.jit

View File

@@ -581,6 +581,13 @@ def cast(input: tl.tensor,
return input return input
src_sca_ty = src_ty.scalar src_sca_ty = src_ty.scalar
dst_sca_ty = dst_ty.scalar dst_sca_ty = dst_ty.scalar
# fp8 <=> bf16/fp16
if (src_sca_ty.is_bf16() or src_sca_ty.is_fp16()) and dst_sca_ty.is_fp8():
return tl.tensor(builder.create_fp_trunc(input.handle, dst_ty.to_ir(builder)),
dst_ty)
if src_sca_ty.is_fp8() and (dst_sca_ty.is_bf16() or dst_sca_ty.is_fp16()):
return tl.tensor(builder.create_fp_ext(input.handle, dst_ty.to_ir(builder)),
dst_ty)
# bf16 <=> (not fp32) # bf16 <=> (not fp32)
if (src_sca_ty.is_bf16() and not dst_sca_ty.is_fp32()) or \ if (src_sca_ty.is_bf16() and not dst_sca_ty.is_fp32()) or \
(dst_sca_ty.is_bf16() and not src_sca_ty.is_fp32()): (dst_sca_ty.is_bf16() and not src_sca_ty.is_fp32()):