[IR] Preliminary support for BF16 (#129)

This PR adds a BF16 data-type, along with FP32 <-> BF16 conversion instructions in the LLVM codegen. Other kinds of ops on bfloat16 are not yet supported.
This commit is contained in:
Philippe Tillet
2021-06-25 10:19:29 -04:00
committed by Philippe Tillet
parent 9b4e2cae2d
commit 8cea583109
18 changed files with 173 additions and 132 deletions

View File

@@ -22,12 +22,10 @@ type *type::get_scalar_ty() const {
unsigned type::get_primitive_size_in_bits() const {
switch (id_) {
case FP8TyID: return 8;
case HalfTyID: return 16;
case FloatTyID: return 32;
case DoubleTyID: return 64;
case X86_FP80TyID: return 80;
case FP128TyID: return 128;
case PPC_FP128TyID: return 128;
case FP16TyID: return 16;
case BF16TyID: return 16;
case FP32TyID: return 32;
case FP64TyID: return 64;
case IntegerTyID: return ((integer_type*)(this))->get_bitwidth();
case BlockTyID: return ((block_type*)(this))->get_bitwidth();
default: return 0;
@@ -44,9 +42,10 @@ unsigned type::get_fp_mantissa_width() const {
id_t id = get_scalar_ty()->id_;
assert(is_floating_point_ty() && "Not a floating point type!");
if (id == FP8TyID) return 3;
if (id == HalfTyID) return 10;
if (id == FloatTyID) return 23;
if (id == DoubleTyID) return 53;
if (id == FP16TyID) return 10;
if (id == BF16TyID) return 7;
if (id == FP32TyID) return 23;
if (id == FP64TyID) return 53;
throw std::runtime_error("unreachable");
}
@@ -105,7 +104,7 @@ bool type::is_integer_ty(unsigned width) const
bool type::is_floating_point_ty() const
{ return is_fp8_ty() || is_half_ty() || is_float_ty() || is_double_ty(); }
{ return is_fp8_ty() || is_fp16_ty() || is_bf16_ty() || is_fp32_ty() || is_fp64_ty(); }
bool type::is_sized() const {
// primitive types are sized
@@ -124,9 +123,10 @@ type *type::get_void_ty(context &ctx) { return &ctx.p_impl->void_ty; }
type *type::get_label_ty(context &ctx) { return &ctx.p_impl->label_ty; }
// floating point
type *type::get_fp8_ty(context &ctx) { return &ctx.p_impl->fp8_ty; }
type *type::get_half_ty(context &ctx) { return &ctx.p_impl->half_ty; }
type *type::get_float_ty(context &ctx) { return &ctx.p_impl->float_ty; }
type *type::get_double_ty(context &ctx) { return &ctx.p_impl->double_ty; }
type *type::get_fp16_ty(context &ctx) { return &ctx.p_impl->fp16_ty; }
type *type::get_bf16_ty(context &ctx) { return &ctx.p_impl->bf16_ty; }
type *type::get_fp32_ty(context &ctx) { return &ctx.p_impl->fp32_ty; }
type *type::get_fp64_ty(context &ctx) { return &ctx.p_impl->fp64_ty; }
// integer types
integer_type *type::get_int1_ty(context &ctx) { return &ctx.p_impl->int1_ty; }
integer_type *type::get_int8_ty(context &ctx) { return &ctx.p_impl->int8_ty; }