From 137bb67fad31568463b1be85ee12706a35650855 Mon Sep 17 00:00:00 2001 From: TC <93944281+tomconerlyanth@users.noreply.github.com> Date: Wed, 2 Feb 2022 20:42:09 -0800 Subject: [PATCH] [LANG] Add fp16 to fp8 conversion (#444) --- lib/codegen/selection/generator.cc | 79 +++++++++++++++++++----- python/test/unit/language/test_core.py | 84 ++++++++++++++++++++++++++ 2 files changed, 148 insertions(+), 15 deletions(-) diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index 26b6b342a..d2ebce1c6 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -414,13 +414,13 @@ void generator::visit_fcmp_inst(ir::fcmp_inst* x) { std::tuple generator::fp32x4_to_fp8x4(Value *in0, Value *in1, Value *in2, Value *in3){ - auto cvt = [this](Value *v){ - if(ConstantFP* ci = dyn_cast(v)) - if(ci->getValue().convertToFloat() == 0) - return builder_->getInt8(0); - throw std::runtime_error("unsupported cast"); - }; - return std::make_tuple(cvt(in0), cvt(in1), cvt(in2), cvt(in3)); + in0 = cast(llvm::Instruction::FPTrunc, in0, f16_ty); + in1 = cast(llvm::Instruction::FPTrunc, in1, f16_ty); + in2 = cast(llvm::Instruction::FPTrunc, in2, f16_ty); + in3 = cast(llvm::Instruction::FPTrunc, in3, f16_ty); + Value *ret0, *ret1, *ret2, *ret3; + std::tie(ret0, ret1, ret2, ret3) = fp16x4_to_fp8x4(in0, in1, in2, in3); + return std::make_tuple(ret0, ret1, ret2, ret3); } std::tuple generator::fp8x4_to_fp32x4(Value *in0, Value *in1, Value *in2, Value *in3){ @@ -439,14 +439,14 @@ std::tuple generator::fp8x4_to_fp16x4(Value *in0 InlineAsm *ptx = InlineAsm::get(FunctionType::get(ret_ty, {i32_ty}, false), "{" ".reg .b32 a<2>, b<2>; \n\t" - "prmt.b32 a0, 0, $2, 0x5140; \n\t" - "prmt.b32 a1, 0, $2, 0x7362; \n\t" - "lop3.b32 b0, a0, 0x7fff7fff, 0, 0xc0; \n\t" // strip sign - "lop3.b32 b1, a1, 0x7fff7fff, 0, 0xc0; \n\t" - "shr.b32 b0, b0, 1; \n\t" // shift into fp16 poistion - "shr.b32 b1, b1, 1; \n\t" - "lop3.b32 $0, b0, 0x80008000, a0, 0xf8; \n\t" // restore sign - "lop3.b32 $1, b1, 0x80008000, a1, 0xf8; \n\t" + "prmt.b32 a0, 0, $2, 0x5040; \n\t" // If input is 0xdcba set a0 to 0xb0a0 + "prmt.b32 a1, 0, $2, 0x7060; \n\t" // If input is 0xdcba set a1 to 0xd0c0 + "lop3.b32 b0, a0, 0x7fff7fff, 0, 0xc0; \n\t" // b0 = a0 & 0x7fff7fff (strip sign) + "lop3.b32 b1, a1, 0x7fff7fff, 0, 0xc0; \n\t" // b1 = a1 & 0x7fff7fff (strip sign) + "shr.b32 b0, b0, 1; \n\t" // b0 <<= 1 (shift into fp16 poistion) + "shr.b32 b1, b1, 1; \n\t" // b1 <<= 1 (shift into fp16 position) + "lop3.b32 $0, b0, 0x80008000, a0, 0xf8; \n\t" // out0 = b0 | (0x80008000 | a0) (restore sign) + "lop3.b32 $1, b1, 0x80008000, a1, 0xf8; \n\t" // out1 = b1 | (0x80008000 | a1) (restore sign) "}", "=r,=r,r", false); Value *packed_in = UndefValue::get(vec_ty(i8_ty, 4)); packed_in = insert_elt(packed_in, in0, (uint64_t)0); @@ -464,6 +464,51 @@ std::tuple generator::fp8x4_to_fp16x4(Value *in0 return std::make_tuple(ret0, ret1, ret2, ret3); } +std::tuple generator::fp16x4_to_fp8x4(Value *in0, Value *in1, Value *in2, Value *in3) { + /* fp16 bit representation is seeeeemmmmmmmmmm (s=sign, e=exponent, m=mantissa) + * fp8 bit representation is seeeemmm + * The 4 fp8 exponent bits are the low order 4 exponent bits in fp16. + * The 3 fp8 mantissa bits are the high order 3 mantissa bits in fp16. + * Note that the low order exponent bits and high order mantissa bits in fp16 are contiguous. + * We want to round to nearest fp8 value. To do that add 1 to 4th mantissa bit in fp16 (that's + * one more than the number of mantissa bits in fp8). + * fp8 = (fp16 & 0x8000) | (((f16 << 1) + 0x0080) & 0x7fff) + * + * We compute two fp16s in one uint32. The addition could cause bit flips from one fp16 to the + * other. To avoid this we zero out the most significant exponent bit. If that bit is set then + * the value isn't representable in float8 anyway so we assume it's never set (and give garbage + * output if it is). If we were willing to assume the most significant exponent was never set + * we could save the first two lop3.b32 instructions below. + */ + InlineAsm *ptx = InlineAsm::get(FunctionType::get({vec_ty(i8_ty, 4)}, {i32_ty, i32_ty}, false), + "{" + ".reg .b32 a<2>, b<2>; \n\t" + "shl.b32 a0, $1, 1; \n\t" // a0 = input0 << 1 + "shl.b32 a1, $2, 1; \n\t" // a1 = input1 << 1 + "lop3.b32 a0, a0, 0x7fff7fff, 0, 0xc0; \n\t" // a0 = (a0 & 0x7fff7fff) + "lop3.b32 a1, a1, 0x7fff7fff, 0, 0xc0; \n\t" // a1 = (a1 & 0x7fff7fff) + "add.u32 a0, a0, 0x00800080; \n\t" // a0 += 0x00800080 + "add.u32 a1, a1, 0x00800080; \n\t" // a1 += 0x00800080 + "lop3.b32 b0, $1, 0x80008000, a0, 0xea; \n\t" // b0 = (input0 & 0x80008000) | a0 + "lop3.b32 b1, $2, 0x80008000, a1, 0xea; \n\t" // b1 = (input1 & 0x80008000) | a1 + "prmt.b32 $0, b0, b1, 0x7531; \n\t" // If b0 = 0xabcd and b1=0x0123 sets output to 0xac02 + "}", "=r,r,r", false); + Value *packed_in0 = UndefValue::get(vec_ty(f16_ty, 2)); + Value *packed_in1 = UndefValue::get(vec_ty(f16_ty, 2)); + packed_in0 = insert_elt(packed_in0, in0, (int)0); + packed_in0 = insert_elt(packed_in0, in1, (int)1); + packed_in1 = insert_elt(packed_in1, in2, (int)0); + packed_in1 = insert_elt(packed_in1, in3, (int)1); + Value *in_arg0 = bit_cast(packed_in0, i32_ty); + Value *in_arg1 = bit_cast(packed_in1, i32_ty); + Value *ret = call(ptx, {in_arg0, in_arg1}); + Value *ret0 = extract_elt(ret, (int)0); + Value *ret1 = extract_elt(ret, (int)1); + Value *ret2 = extract_elt(ret, (int)2); + Value *ret3 = extract_elt(ret, (int)3); + return std::make_tuple(ret0, ret1, ret2, ret3); +} + Value* generator::bf16_to_fp32(Value *in0){ if (tgt_->as_nvidia()->sm() >= 80) { InlineAsm *ptx = InlineAsm::get(FunctionType::get(f32_ty, {bf16_ty}, false), @@ -508,8 +553,12 @@ void generator::visit_cast_inst(ir::cast_inst* x) { auto cvt = [&](Value* a, Value* b, Value* c, Value* d){ if(op_sca_ty->is_fp32_ty() && ret_sca_ty->is_fp8_ty()) return fp32x4_to_fp8x4(a, b, c, d); + if(op_sca_ty->is_fp16_ty() && ret_sca_ty->is_fp8_ty()) + return fp16x4_to_fp8x4(a, b, c, d); if(op_sca_ty->is_fp8_ty() && ret_sca_ty->is_fp16_ty()) return fp8x4_to_fp16x4(a, b, c, d); + if(op_sca_ty->is_fp8_ty() && ret_sca_ty->is_fp32_ty()) + return fp8x4_to_fp32x4(a, b, c, d); throw std::runtime_error("unsupported conversion"); }; for(size_t i = 0; i < x_idxs.size(); i+=4){ diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 3e35700f8..a49b47585 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -565,6 +565,90 @@ def test_cast(dtype_x, dtype_z, bitcast, device='cuda'): z_ref = x.astype(getattr(np, dtype_z)) assert to_numpy(z_tri) == z_ref + +def test_f8_f16_roundtrip(): + """Tests that converting an f8 to f16 and back to f8 doesn't change its value""" + @triton.jit + def copy_kernel(input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr): + offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + input = tl.load(input_ptr + offsets, mask=mask) + output = input + tl.store(output_ptr + offsets, output, mask=mask) + + f8_tensor = torch.tensor(range(-128, 128), dtype=torch.int8, device='cuda') + f8 = triton.reinterpret(f8_tensor, tl.float8) + n_elements = f8_tensor.numel() + f16 = torch.empty_like(f8_tensor, dtype=torch.float16) + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) + copy_kernel[grid](f8, f16, n_elements, BLOCK_SIZE=1024) + + f8_output_tensor = torch.empty_like(f16, dtype=torch.int8) + f8_output = triton.reinterpret(f8_output_tensor, tl.float8) + print(f16.dtype, f8_output.dtype) + copy_kernel[grid](f16, f8_output, n_elements, BLOCK_SIZE=1024) + + assert torch.all(f8_tensor == f8_output_tensor) + + +def test_f16_to_f8_rounding(): + """Takes all float16s, converts them to float8 and back to float16. Checks that the absolute + error is the minimum over all float8. + + Or the same explanation a bit mathier: + for all f16 |f16 - fromf8(tof8(f16))| == min over all f8 |f16 - fromf8(f8)|""" + @triton.jit + def copy_kernel(input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr): + offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + input = tl.load(input_ptr + offsets, mask=mask) + output = input + tl.store(output_ptr + offsets, output, mask=mask) + + # torch.view with a dtype isn't supported in triton's torch yet so use numpy's view + f16_input_np = ( + np.array( + range(-int(2 ** (16 - 1)), int(2 ** (16 - 1))), dtype=np.int16, + ) + .view(np.float16) + ) + f16_input = torch.tensor(f16_input_np, dtype=torch.float16, device='cuda') + n_elements = f16_input.numel() + f8_output_tensor = torch.empty_like(f16_input, dtype=torch.int8) + f8_output = triton.reinterpret(f8_output_tensor, tl.float8) + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) + copy_kernel[grid](f16_input, f8_output, n_elements, BLOCK_SIZE=1024) + + f16_output = torch.empty_like(f16_input, dtype=torch.float16) + copy_kernel[grid](f8_output, f16_output, n_elements, BLOCK_SIZE=1024) + + abs_error = torch.abs(f16_input - f16_output) + + all_f8_vals_tensor = torch.tensor(range(2 ** 8), dtype=torch.uint8, device='cuda') + all_f8_vals = triton.reinterpret(all_f8_vals_tensor, tl.float8) + all_f8_vals_in_f16 = torch.empty_like(all_f8_vals_tensor, dtype=torch.float16) + copy_kernel[grid](all_f8_vals, all_f8_vals_in_f16, n_elements=256, BLOCK_SIZE=1024) + + all_finite_f8_vals_in_f16 = all_f8_vals_in_f16[ + torch.isfinite(all_f8_vals_in_f16) + ] + + min_error = torch.min( + torch.abs( + f16_input.reshape((-1, 1)) + - all_finite_f8_vals_in_f16.reshape((1, -1)) + ), + dim=1, + )[0] + # 1.9375 is float8 max + mismatch = torch.logical_and( + abs_error != min_error, torch.logical_and(torch.isfinite(f16_input), torch.abs(f16_input) < 1.9375) + ) + assert torch.all( + torch.logical_not(mismatch) + ), f"{f16_input[mismatch]=} {f16_output[mismatch]=} {abs_error[mismatch]=} {min_error[mismatch]=}" + + # --------------- # test reduce # ---------------