|
|
|
@@ -569,10 +569,10 @@ std::tuple<Value*, Value*, Value*, Value*> generator::fp8x4_to_fp16x4(Value *in0
|
|
|
|
|
"prmt.b32 a1, 0, $2, 0x7060; \n\t" // If input is 0xdcba set a1 to 0xd0c0
|
|
|
|
|
"lop3.b32 b0, a0, 0x7fff7fff, 0, 0xc0; \n\t" // b0 = a0 & 0x7fff7fff (strip sign)
|
|
|
|
|
"lop3.b32 b1, a1, 0x7fff7fff, 0, 0xc0; \n\t" // b1 = a1 & 0x7fff7fff (strip sign)
|
|
|
|
|
"shr.b32 b0, b0, 1; \n\t" // b0 <<= 1 (shift into fp16 poistion)
|
|
|
|
|
"shr.b32 b1, b1, 1; \n\t" // b1 <<= 1 (shift into fp16 position)
|
|
|
|
|
"lop3.b32 $0, b0, 0x80008000, a0, 0xf8; \n\t" // out0 = b0 | (0x80008000 | a0) (restore sign)
|
|
|
|
|
"lop3.b32 $1, b1, 0x80008000, a1, 0xf8; \n\t" // out1 = b1 | (0x80008000 | a1) (restore sign)
|
|
|
|
|
"shr.b32 b0, b0, 1; \n\t" // b0 >>= 1 (shift into fp16 poistion)
|
|
|
|
|
"shr.b32 b1, b1, 1; \n\t" // b1 >>= 1 (shift into fp16 position)
|
|
|
|
|
"lop3.b32 $0, b0, 0x80008000, a0, 0xf8; \n\t" // out0 = b0 | (0x80008000 & a0) (restore sign)
|
|
|
|
|
"lop3.b32 $1, b1, 0x80008000, a1, 0xf8; \n\t" // out1 = b1 | (0x80008000 & a1) (restore sign)
|
|
|
|
|
"}", "=r,=r,r", false);
|
|
|
|
|
Value *packed_in = UndefValue::get(vec_ty(i8_ty, 4));
|
|
|
|
|
packed_in = insert_elt(packed_in, in0, (uint64_t)0);
|
|
|
|
@@ -635,6 +635,110 @@ std::tuple<Value*, Value*, Value*, Value*> generator::fp16x4_to_fp8x4(Value *in0
|
|
|
|
|
return std::make_tuple(ret0, ret1, ret2, ret3);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::tuple<Value*, Value*, Value*, Value*> generator::fp8x4_to_bf16x4(Value *in0, Value *in1, Value *in2, Value *in3) {
|
|
|
|
|
// current exp offset: 15
|
|
|
|
|
// Add 112 (127-15) to compensate the difference in exponent bias
|
|
|
|
|
// bf16 = (nosign >> (8-4) + 112 << 7) | sign;
|
|
|
|
|
// bf16 = (nosign >> 4 + 0x3800) | sign;
|
|
|
|
|
Type *ret_ty = StructType::get(*ctx_, {vec_ty(bf16_ty, 2), vec_ty(bf16_ty, 2)});
|
|
|
|
|
InlineAsm *ptx = InlineAsm::get(FunctionType::get(ret_ty, {i32_ty}, false),
|
|
|
|
|
"{"
|
|
|
|
|
".reg .b32 a<2>, sign<2>, nosign<2>, b<2>; \n\t"
|
|
|
|
|
"prmt.b32 a0, 0, $2, 0x5040; \n\t" // 0xdcba => 0xb0a0
|
|
|
|
|
"prmt.b32 a1, 0, $2, 0x7060; \n\t" // 0xdcba => 0xd0c0
|
|
|
|
|
"and.b32 sign0, a0, 0x80008000; \n\t"
|
|
|
|
|
"and.b32 sign1, a1, 0x80008000; \n\t"
|
|
|
|
|
"and.b32 nosign0, a0, 0x7fff7fff; \n\t"
|
|
|
|
|
"and.b32 nosign1, a1, 0x7fff7fff; \n\t"
|
|
|
|
|
"shr.b32 nosign0, nosign0, 4; \n\t"
|
|
|
|
|
"shr.b32 nosign1, nosign1, 4; \n\t"
|
|
|
|
|
"add.u32 nosign0, nosign0, 0x38003800; \n\t"
|
|
|
|
|
"add.u32 nosign1, nosign1, 0x38003800; \n\t"
|
|
|
|
|
"or.b32 $0, sign0, nosign0; \n\t"
|
|
|
|
|
"or.b32 $1, sign1, nosign1; \n\t"
|
|
|
|
|
"}", "=r,=r,r", false);
|
|
|
|
|
Value *packed_in = UndefValue::get(vec_ty(i8_ty, 4));
|
|
|
|
|
packed_in = insert_elt(packed_in, in0, (uint64_t)0);
|
|
|
|
|
packed_in = insert_elt(packed_in, in1, (uint64_t)1);
|
|
|
|
|
packed_in = insert_elt(packed_in, in2, (uint64_t)2);
|
|
|
|
|
packed_in = insert_elt(packed_in, in3, (uint64_t)3);
|
|
|
|
|
Value *in = bit_cast(packed_in, i32_ty);
|
|
|
|
|
Value *ret = call(ptx, {in});
|
|
|
|
|
Value *packed_ret0 = extract_val(ret, {0});
|
|
|
|
|
Value *packed_ret1 = extract_val(ret, {1});
|
|
|
|
|
Value *ret0 = extract_elt(packed_ret0, (uint64_t)0);
|
|
|
|
|
Value *ret1 = extract_elt(packed_ret0, (uint64_t)1);
|
|
|
|
|
Value *ret2 = extract_elt(packed_ret1, (uint64_t)0);
|
|
|
|
|
Value *ret3 = extract_elt(packed_ret1, (uint64_t)1);
|
|
|
|
|
return std::make_tuple(ret0, ret1, ret2, ret3);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::tuple<Value*, Value*, Value*, Value*> generator::bf16x4_to_fp8x4(Value *in0, Value *in1, Value *in2, Value *in3) {
|
|
|
|
|
/* Assuming fp8 exponent offset is 16. bf16 exponent offset is 127.
|
|
|
|
|
Max value in fp8: 0b01111111 (0x7f),
|
|
|
|
|
bf16: 3ff0
|
|
|
|
|
Min value in fp8: 0b00000000 (0x00)
|
|
|
|
|
bf16: 0x3c00
|
|
|
|
|
// @note: +0x8 is for "rounding to nearest zero"
|
|
|
|
|
fp8 = (nosign(bf16) - (112 << 7) + 0x8) << 4;
|
|
|
|
|
return fp8 | sign; // also permute bytes
|
|
|
|
|
*/
|
|
|
|
|
InlineAsm *ptx = InlineAsm::get(FunctionType::get({vec_ty(i8_ty, 4)}, {i32_ty, i32_ty}, false),
|
|
|
|
|
"{\n\t"
|
|
|
|
|
".reg .u32 sign, sign<2>, nosign, nosign<2>; \n\t"
|
|
|
|
|
".reg .u32 fp8_min, fp8_max, rn_, zero; \n\t"
|
|
|
|
|
"mov.u32 fp8_min, 0x38003800; \n\t"
|
|
|
|
|
"mov.u32 fp8_max, 0x3ff03ff0; \n\t"
|
|
|
|
|
"mov.u32 rn_, 0x80008; \n\t"
|
|
|
|
|
"mov.u32 zero, 0; \n\t"
|
|
|
|
|
"and.b32 sign0, $1, 0x80008000; \n\t"
|
|
|
|
|
"and.b32 sign1, $2, 0x80008000; \n\t"
|
|
|
|
|
"prmt.b32 sign, sign0, sign1, 0x7531; \n\t"
|
|
|
|
|
"and.b32 nosign0, $1, 0x7fff7fff; \n\t"
|
|
|
|
|
"and.b32 nosign1, $2, 0x7fff7fff; \n\t"
|
|
|
|
|
|
|
|
|
|
".reg .u32 nosign_0_<2>, nosign_1_<2>; \n\t" // nosign = clamp(nosign, min, max)
|
|
|
|
|
"and.b32 nosign_0_0, nosign0, 0xffff0000; \n\t"
|
|
|
|
|
"max.u32 nosign_0_0, nosign_0_0, 0x38000000; \n\t"
|
|
|
|
|
"min.u32 nosign_0_0, nosign_0_0, 0x3ff00000; \n\t"
|
|
|
|
|
"and.b32 nosign_0_1, nosign0, 0x0000ffff; \n\t"
|
|
|
|
|
"max.u32 nosign_0_1, nosign_0_1, 0x3800; \n\t"
|
|
|
|
|
"min.u32 nosign_0_1, nosign_0_1, 0x3ff0; \n\t"
|
|
|
|
|
"or.b32 nosign0, nosign_0_0, nosign_0_1; \n\t"
|
|
|
|
|
"and.b32 nosign_1_0, nosign1, 0xffff0000; \n\t"
|
|
|
|
|
"max.u32 nosign_1_0, nosign_1_0, 0x38000000; \n\t"
|
|
|
|
|
"min.u32 nosign_1_0, nosign_1_0, 0x3ff00000; \n\t"
|
|
|
|
|
"and.b32 nosign_1_1, nosign1, 0x0000ffff; \n\t"
|
|
|
|
|
"max.u32 nosign_1_1, nosign_1_1, 0x3800; \n\t"
|
|
|
|
|
"min.u32 nosign_1_1, nosign_1_1, 0x3ff0; \n\t"
|
|
|
|
|
"or.b32 nosign1, nosign_1_0, nosign_1_1; \n\t"
|
|
|
|
|
|
|
|
|
|
"add.u32 nosign0, nosign0, rn_; \n\t" // round to nearest zero
|
|
|
|
|
"add.u32 nosign1, nosign1, rn_; \n\t"
|
|
|
|
|
"sub.u32 nosign0, nosign0, 0x38003800; \n\t" // compensate offset
|
|
|
|
|
"sub.u32 nosign1, nosign1, 0x38003800; \n\t"
|
|
|
|
|
"shr.u32 nosign0, nosign0, 4; \n\t"
|
|
|
|
|
"shr.u32 nosign1, nosign1, 4; \n\t"
|
|
|
|
|
"prmt.b32 nosign, nosign0, nosign1, 0x6420; \n\t"
|
|
|
|
|
"or.b32 $0, nosign, sign; \n\t"
|
|
|
|
|
""
|
|
|
|
|
"}", "=r,r,r", false);
|
|
|
|
|
Value *packed_in0 = UndefValue::get(vec_ty(bf16_ty, 2));
|
|
|
|
|
Value *packed_in1 = UndefValue::get(vec_ty(bf16_ty, 2));
|
|
|
|
|
packed_in0 = insert_elt(packed_in0, in0, (int)0);
|
|
|
|
|
packed_in0 = insert_elt(packed_in0, in1, (int)1);
|
|
|
|
|
packed_in1 = insert_elt(packed_in1, in2, (int)0);
|
|
|
|
|
packed_in1 = insert_elt(packed_in1, in3, (int)1);
|
|
|
|
|
Value *in_arg0 = bit_cast(packed_in0, i32_ty);
|
|
|
|
|
Value *in_arg1 = bit_cast(packed_in1, i32_ty);
|
|
|
|
|
Value *ret = call(ptx, {in_arg0, in_arg1});
|
|
|
|
|
Value *ret0 = extract_elt(ret, (int)0);
|
|
|
|
|
Value *ret1 = extract_elt(ret, (int)1);
|
|
|
|
|
Value *ret2 = extract_elt(ret, (int)2);
|
|
|
|
|
Value *ret3 = extract_elt(ret, (int)3);
|
|
|
|
|
return std::make_tuple(ret0, ret1, ret2, ret3);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Value* generator::bf16_to_fp32(Value *in0){
|
|
|
|
|
if (tgt_->as_nvidia()->sm() >= 80) {
|
|
|
|
|
InlineAsm *ptx = InlineAsm::get(FunctionType::get(f32_ty, {bf16_ty}, false),
|
|
|
|
@@ -685,6 +789,11 @@ void generator::visit_cast_inst(ir::cast_inst* x) {
|
|
|
|
|
return fp8x4_to_fp16x4(a, b, c, d);
|
|
|
|
|
if(op_sca_ty->is_fp8_ty() && ret_sca_ty->is_fp32_ty())
|
|
|
|
|
return fp8x4_to_fp32x4(a, b, c, d);
|
|
|
|
|
// fp8 <> bf16
|
|
|
|
|
if(op_sca_ty->is_fp8_ty() && ret_sca_ty->is_bf16_ty())
|
|
|
|
|
return fp8x4_to_bf16x4(a, b, c, d);
|
|
|
|
|
if (op_sca_ty->is_bf16_ty() && ret_sca_ty->is_fp8_ty())
|
|
|
|
|
return bf16x4_to_fp8x4(a, b, c, d);
|
|
|
|
|
throw std::runtime_error("unsupported conversion");
|
|
|
|
|
};
|
|
|
|
|
for(size_t i = 0; i < x_idxs.size(); i+=4){
|
|
|
|
|