[IR] Preliminary support for BF16 (#129)
This PR adds a BF16 data-type, along with FP32 <-> BF16 conversion instructions in the LLVM codegen. Other kinds of ops on bfloat16 are not yet supported.
This commit is contained in:
committed by
Philippe Tillet
parent
9b4e2cae2d
commit
8cea583109
@@ -22,12 +22,10 @@ type *type::get_scalar_ty() const {
|
||||
unsigned type::get_primitive_size_in_bits() const {
|
||||
switch (id_) {
|
||||
case FP8TyID: return 8;
|
||||
case HalfTyID: return 16;
|
||||
case FloatTyID: return 32;
|
||||
case DoubleTyID: return 64;
|
||||
case X86_FP80TyID: return 80;
|
||||
case FP128TyID: return 128;
|
||||
case PPC_FP128TyID: return 128;
|
||||
case FP16TyID: return 16;
|
||||
case BF16TyID: return 16;
|
||||
case FP32TyID: return 32;
|
||||
case FP64TyID: return 64;
|
||||
case IntegerTyID: return ((integer_type*)(this))->get_bitwidth();
|
||||
case BlockTyID: return ((block_type*)(this))->get_bitwidth();
|
||||
default: return 0;
|
||||
@@ -44,9 +42,10 @@ unsigned type::get_fp_mantissa_width() const {
|
||||
id_t id = get_scalar_ty()->id_;
|
||||
assert(is_floating_point_ty() && "Not a floating point type!");
|
||||
if (id == FP8TyID) return 3;
|
||||
if (id == HalfTyID) return 10;
|
||||
if (id == FloatTyID) return 23;
|
||||
if (id == DoubleTyID) return 53;
|
||||
if (id == FP16TyID) return 10;
|
||||
if (id == BF16TyID) return 7;
|
||||
if (id == FP32TyID) return 23;
|
||||
if (id == FP64TyID) return 53;
|
||||
throw std::runtime_error("unreachable");
|
||||
}
|
||||
|
||||
@@ -105,7 +104,7 @@ bool type::is_integer_ty(unsigned width) const
|
||||
|
||||
|
||||
bool type::is_floating_point_ty() const
|
||||
{ return is_fp8_ty() || is_half_ty() || is_float_ty() || is_double_ty(); }
|
||||
{ return is_fp8_ty() || is_fp16_ty() || is_bf16_ty() || is_fp32_ty() || is_fp64_ty(); }
|
||||
|
||||
bool type::is_sized() const {
|
||||
// primitive types are sized
|
||||
@@ -124,9 +123,10 @@ type *type::get_void_ty(context &ctx) { return &ctx.p_impl->void_ty; }
|
||||
type *type::get_label_ty(context &ctx) { return &ctx.p_impl->label_ty; }
|
||||
// floating point
|
||||
type *type::get_fp8_ty(context &ctx) { return &ctx.p_impl->fp8_ty; }
|
||||
type *type::get_half_ty(context &ctx) { return &ctx.p_impl->half_ty; }
|
||||
type *type::get_float_ty(context &ctx) { return &ctx.p_impl->float_ty; }
|
||||
type *type::get_double_ty(context &ctx) { return &ctx.p_impl->double_ty; }
|
||||
type *type::get_fp16_ty(context &ctx) { return &ctx.p_impl->fp16_ty; }
|
||||
type *type::get_bf16_ty(context &ctx) { return &ctx.p_impl->bf16_ty; }
|
||||
type *type::get_fp32_ty(context &ctx) { return &ctx.p_impl->fp32_ty; }
|
||||
type *type::get_fp64_ty(context &ctx) { return &ctx.p_impl->fp64_ty; }
|
||||
// integer types
|
||||
integer_type *type::get_int1_ty(context &ctx) { return &ctx.p_impl->int1_ty; }
|
||||
integer_type *type::get_int8_ty(context &ctx) { return &ctx.p_impl->int8_ty; }
|
||||
|
Reference in New Issue
Block a user