[LANG] Add fp16 to fp8 conversion (#444)
This commit is contained in:
@@ -414,13 +414,13 @@ void generator::visit_fcmp_inst(ir::fcmp_inst* x) {
|
||||
|
||||
|
||||
std::tuple<Value*, Value*, Value*, Value*> generator::fp32x4_to_fp8x4(Value *in0, Value *in1, Value *in2, Value *in3){
|
||||
auto cvt = [this](Value *v){
|
||||
if(ConstantFP* ci = dyn_cast<ConstantFP>(v))
|
||||
if(ci->getValue().convertToFloat() == 0)
|
||||
return builder_->getInt8(0);
|
||||
throw std::runtime_error("unsupported cast");
|
||||
};
|
||||
return std::make_tuple(cvt(in0), cvt(in1), cvt(in2), cvt(in3));
|
||||
in0 = cast(llvm::Instruction::FPTrunc, in0, f16_ty);
|
||||
in1 = cast(llvm::Instruction::FPTrunc, in1, f16_ty);
|
||||
in2 = cast(llvm::Instruction::FPTrunc, in2, f16_ty);
|
||||
in3 = cast(llvm::Instruction::FPTrunc, in3, f16_ty);
|
||||
Value *ret0, *ret1, *ret2, *ret3;
|
||||
std::tie(ret0, ret1, ret2, ret3) = fp16x4_to_fp8x4(in0, in1, in2, in3);
|
||||
return std::make_tuple(ret0, ret1, ret2, ret3);
|
||||
}
|
||||
|
||||
std::tuple<Value*, Value*, Value*, Value*> generator::fp8x4_to_fp32x4(Value *in0, Value *in1, Value *in2, Value *in3){
|
||||
@@ -439,14 +439,14 @@ std::tuple<Value*, Value*, Value*, Value*> generator::fp8x4_to_fp16x4(Value *in0
|
||||
InlineAsm *ptx = InlineAsm::get(FunctionType::get(ret_ty, {i32_ty}, false),
|
||||
"{"
|
||||
".reg .b32 a<2>, b<2>; \n\t"
|
||||
"prmt.b32 a0, 0, $2, 0x5140; \n\t"
|
||||
"prmt.b32 a1, 0, $2, 0x7362; \n\t"
|
||||
"lop3.b32 b0, a0, 0x7fff7fff, 0, 0xc0; \n\t" // strip sign
|
||||
"lop3.b32 b1, a1, 0x7fff7fff, 0, 0xc0; \n\t"
|
||||
"shr.b32 b0, b0, 1; \n\t" // shift into fp16 poistion
|
||||
"shr.b32 b1, b1, 1; \n\t"
|
||||
"lop3.b32 $0, b0, 0x80008000, a0, 0xf8; \n\t" // restore sign
|
||||
"lop3.b32 $1, b1, 0x80008000, a1, 0xf8; \n\t"
|
||||
"prmt.b32 a0, 0, $2, 0x5040; \n\t" // If input is 0xdcba set a0 to 0xb0a0
|
||||
"prmt.b32 a1, 0, $2, 0x7060; \n\t" // If input is 0xdcba set a1 to 0xd0c0
|
||||
"lop3.b32 b0, a0, 0x7fff7fff, 0, 0xc0; \n\t" // b0 = a0 & 0x7fff7fff (strip sign)
|
||||
"lop3.b32 b1, a1, 0x7fff7fff, 0, 0xc0; \n\t" // b1 = a1 & 0x7fff7fff (strip sign)
|
||||
"shr.b32 b0, b0, 1; \n\t" // b0 <<= 1 (shift into fp16 poistion)
|
||||
"shr.b32 b1, b1, 1; \n\t" // b1 <<= 1 (shift into fp16 position)
|
||||
"lop3.b32 $0, b0, 0x80008000, a0, 0xf8; \n\t" // out0 = b0 | (0x80008000 | a0) (restore sign)
|
||||
"lop3.b32 $1, b1, 0x80008000, a1, 0xf8; \n\t" // out1 = b1 | (0x80008000 | a1) (restore sign)
|
||||
"}", "=r,=r,r", false);
|
||||
Value *packed_in = UndefValue::get(vec_ty(i8_ty, 4));
|
||||
packed_in = insert_elt(packed_in, in0, (uint64_t)0);
|
||||
@@ -464,6 +464,51 @@ std::tuple<Value*, Value*, Value*, Value*> generator::fp8x4_to_fp16x4(Value *in0
|
||||
return std::make_tuple(ret0, ret1, ret2, ret3);
|
||||
}
|
||||
|
||||
std::tuple<Value*, Value*, Value*, Value*> generator::fp16x4_to_fp8x4(Value *in0, Value *in1, Value *in2, Value *in3) {
|
||||
/* fp16 bit representation is seeeeemmmmmmmmmm (s=sign, e=exponent, m=mantissa)
|
||||
* fp8 bit representation is seeeemmm
|
||||
* The 4 fp8 exponent bits are the low order 4 exponent bits in fp16.
|
||||
* The 3 fp8 mantissa bits are the high order 3 mantissa bits in fp16.
|
||||
* Note that the low order exponent bits and high order mantissa bits in fp16 are contiguous.
|
||||
* We want to round to nearest fp8 value. To do that add 1 to 4th mantissa bit in fp16 (that's
|
||||
* one more than the number of mantissa bits in fp8).
|
||||
* fp8 = (fp16 & 0x8000) | (((f16 << 1) + 0x0080) & 0x7fff)
|
||||
*
|
||||
* We compute two fp16s in one uint32. The addition could cause bit flips from one fp16 to the
|
||||
* other. To avoid this we zero out the most significant exponent bit. If that bit is set then
|
||||
* the value isn't representable in float8 anyway so we assume it's never set (and give garbage
|
||||
* output if it is). If we were willing to assume the most significant exponent was never set
|
||||
* we could save the first two lop3.b32 instructions below.
|
||||
*/
|
||||
InlineAsm *ptx = InlineAsm::get(FunctionType::get({vec_ty(i8_ty, 4)}, {i32_ty, i32_ty}, false),
|
||||
"{"
|
||||
".reg .b32 a<2>, b<2>; \n\t"
|
||||
"shl.b32 a0, $1, 1; \n\t" // a0 = input0 << 1
|
||||
"shl.b32 a1, $2, 1; \n\t" // a1 = input1 << 1
|
||||
"lop3.b32 a0, a0, 0x7fff7fff, 0, 0xc0; \n\t" // a0 = (a0 & 0x7fff7fff)
|
||||
"lop3.b32 a1, a1, 0x7fff7fff, 0, 0xc0; \n\t" // a1 = (a1 & 0x7fff7fff)
|
||||
"add.u32 a0, a0, 0x00800080; \n\t" // a0 += 0x00800080
|
||||
"add.u32 a1, a1, 0x00800080; \n\t" // a1 += 0x00800080
|
||||
"lop3.b32 b0, $1, 0x80008000, a0, 0xea; \n\t" // b0 = (input0 & 0x80008000) | a0
|
||||
"lop3.b32 b1, $2, 0x80008000, a1, 0xea; \n\t" // b1 = (input1 & 0x80008000) | a1
|
||||
"prmt.b32 $0, b0, b1, 0x7531; \n\t" // If b0 = 0xabcd and b1=0x0123 sets output to 0xac02
|
||||
"}", "=r,r,r", false);
|
||||
Value *packed_in0 = UndefValue::get(vec_ty(f16_ty, 2));
|
||||
Value *packed_in1 = UndefValue::get(vec_ty(f16_ty, 2));
|
||||
packed_in0 = insert_elt(packed_in0, in0, (int)0);
|
||||
packed_in0 = insert_elt(packed_in0, in1, (int)1);
|
||||
packed_in1 = insert_elt(packed_in1, in2, (int)0);
|
||||
packed_in1 = insert_elt(packed_in1, in3, (int)1);
|
||||
Value *in_arg0 = bit_cast(packed_in0, i32_ty);
|
||||
Value *in_arg1 = bit_cast(packed_in1, i32_ty);
|
||||
Value *ret = call(ptx, {in_arg0, in_arg1});
|
||||
Value *ret0 = extract_elt(ret, (int)0);
|
||||
Value *ret1 = extract_elt(ret, (int)1);
|
||||
Value *ret2 = extract_elt(ret, (int)2);
|
||||
Value *ret3 = extract_elt(ret, (int)3);
|
||||
return std::make_tuple(ret0, ret1, ret2, ret3);
|
||||
}
|
||||
|
||||
Value* generator::bf16_to_fp32(Value *in0){
|
||||
if (tgt_->as_nvidia()->sm() >= 80) {
|
||||
InlineAsm *ptx = InlineAsm::get(FunctionType::get(f32_ty, {bf16_ty}, false),
|
||||
@@ -508,8 +553,12 @@ void generator::visit_cast_inst(ir::cast_inst* x) {
|
||||
auto cvt = [&](Value* a, Value* b, Value* c, Value* d){
|
||||
if(op_sca_ty->is_fp32_ty() && ret_sca_ty->is_fp8_ty())
|
||||
return fp32x4_to_fp8x4(a, b, c, d);
|
||||
if(op_sca_ty->is_fp16_ty() && ret_sca_ty->is_fp8_ty())
|
||||
return fp16x4_to_fp8x4(a, b, c, d);
|
||||
if(op_sca_ty->is_fp8_ty() && ret_sca_ty->is_fp16_ty())
|
||||
return fp8x4_to_fp16x4(a, b, c, d);
|
||||
if(op_sca_ty->is_fp8_ty() && ret_sca_ty->is_fp32_ty())
|
||||
return fp8x4_to_fp32x4(a, b, c, d);
|
||||
throw std::runtime_error("unsupported conversion");
|
||||
};
|
||||
for(size_t i = 0; i < x_idxs.size(); i+=4){
|
||||
|
Reference in New Issue
Block a user