[IR] Preliminary support for BF16 (#129)

This PR adds a BF16 data-type, along with FP32 <-> BF16 conversion instructions in the LLVM codegen. Other kinds of ops on bfloat16 are not yet supported.
This commit is contained in:
Philippe Tillet
2021-06-25 10:19:29 -04:00
committed by Philippe Tillet
parent 9b4e2cae2d
commit 8cea583109
18 changed files with 173 additions and 132 deletions

View File

@@ -55,10 +55,10 @@ value *builder::get_int64(int64_t val)
{ return constant_int::get(type::get_int64_ty(ctx_), val);}
value *builder::get_float16(float val)
{ return constant_fp::get(type::get_half_ty(ctx_), val); }
{ return constant_fp::get(type::get_fp16_ty(ctx_), val); }
value *builder::get_float32(float val)
{ return constant_fp::get(type::get_float_ty(ctx_), val); }
{ return constant_fp::get(type::get_fp32_ty(ctx_), val); }
value *builder::get_range(int32_t _lo, int32_t _hi) {
constant_int* lo = static_cast<constant_int*>(get_int32(_lo));
@@ -85,13 +85,13 @@ type *builder::get_int64_ty()
{ return type::get_int64_ty(ctx_); }
type *builder::get_half_ty()
{ return type::get_half_ty(ctx_); }
{ return type::get_fp16_ty(ctx_); }
type *builder::get_float_ty()
{ return type::get_float_ty(ctx_); }
{ return type::get_fp32_ty(ctx_); }
type *builder::get_double_ty()
{ return type::get_double_ty(ctx_); }
{ return type::get_fp64_ty(ctx_); }
//===----------------------------------------------------------------------===//

View File

@@ -15,12 +15,12 @@ constant *constant::get_null_value(type *ty) {
switch (ty->get_scalar_ty()->get_type_id()) {
case type::IntegerTyID:
return constant_int::get(ty, 0);
case type::HalfTyID:
return constant_fp::get(type::get_half_ty(ctx), 0);
case type::FloatTyID:
return constant_fp::get(type::get_float_ty(ctx), 0);
case type::DoubleTyID:
return constant_fp::get(type::get_double_ty(ctx), 0);
case type::FP16TyID:
return constant_fp::get(type::get_fp16_ty(ctx), 0);
case type::FP32TyID:
return constant_fp::get(type::get_fp32_ty(ctx), 0);
case type::FP64TyID:
return constant_fp::get(type::get_fp64_ty(ctx), 0);
default:
throw std::runtime_error("Cannot create a null constant of that type!");
}

View File

@@ -14,9 +14,10 @@ context_impl::context_impl(context &ctx)
label_ty(ctx, type::LabelTyID),
// floating point
fp8_ty(ctx, type::FP8TyID),
half_ty(ctx, type::HalfTyID),
float_ty(ctx, type::FloatTyID),
double_ty(ctx, type::DoubleTyID),
fp16_ty(ctx, type::FP16TyID),
bf16_ty(ctx, type::BF16TyID),
fp32_ty(ctx, type::FP32TyID),
fp64_ty(ctx, type::FP64TyID),
// integers
int1_ty(ctx, 1),
int8_ty(ctx, 8),

View File

@@ -37,16 +37,16 @@ ir::type *computation_type(ir::type* a_ty, ir::type* b_ty){
context &ctx = a_ty->get_context();
// 1) if one operand is double, the other is implicitly
// converted to double
if(a_ty->is_double_ty() || b_ty->is_double_ty())
return type::get_double_ty(ctx);
if(a_ty->is_fp64_ty() || b_ty->is_fp64_ty())
return type::get_fp64_ty(ctx);
// 2) if one operand is float, the other is implicitly
// converted to float
if(a_ty->is_float_ty() || b_ty->is_float_ty())
return type::get_float_ty(ctx);
if(a_ty->is_fp32_ty() || b_ty->is_fp32_ty())
return type::get_fp32_ty(ctx);
// 3 ) if one operand is half, the other is implicitly
// converted to half
if(a_ty->is_half_ty() || b_ty->is_half_ty())
return type::get_half_ty(ctx);
if(a_ty->is_fp16_ty() || b_ty->is_fp16_ty())
return type::get_fp16_ty(ctx);
if(!a_ty->is_integer_ty() || !b_ty->is_integer_ty())
throw_unreachable("augment_types");
// 4 ) both operands are integer and undergo

View File

@@ -22,12 +22,10 @@ type *type::get_scalar_ty() const {
unsigned type::get_primitive_size_in_bits() const {
switch (id_) {
case FP8TyID: return 8;
case HalfTyID: return 16;
case FloatTyID: return 32;
case DoubleTyID: return 64;
case X86_FP80TyID: return 80;
case FP128TyID: return 128;
case PPC_FP128TyID: return 128;
case FP16TyID: return 16;
case BF16TyID: return 16;
case FP32TyID: return 32;
case FP64TyID: return 64;
case IntegerTyID: return ((integer_type*)(this))->get_bitwidth();
case BlockTyID: return ((block_type*)(this))->get_bitwidth();
default: return 0;
@@ -44,9 +42,10 @@ unsigned type::get_fp_mantissa_width() const {
id_t id = get_scalar_ty()->id_;
assert(is_floating_point_ty() && "Not a floating point type!");
if (id == FP8TyID) return 3;
if (id == HalfTyID) return 10;
if (id == FloatTyID) return 23;
if (id == DoubleTyID) return 53;
if (id == FP16TyID) return 10;
if (id == BF16TyID) return 7;
if (id == FP32TyID) return 23;
if (id == FP64TyID) return 53;
throw std::runtime_error("unreachable");
}
@@ -105,7 +104,7 @@ bool type::is_integer_ty(unsigned width) const
bool type::is_floating_point_ty() const
{ return is_fp8_ty() || is_half_ty() || is_float_ty() || is_double_ty(); }
{ return is_fp8_ty() || is_fp16_ty() || is_bf16_ty() || is_fp32_ty() || is_fp64_ty(); }
bool type::is_sized() const {
// primitive types are sized
@@ -124,9 +123,10 @@ type *type::get_void_ty(context &ctx) { return &ctx.p_impl->void_ty; }
type *type::get_label_ty(context &ctx) { return &ctx.p_impl->label_ty; }
// floating point
type *type::get_fp8_ty(context &ctx) { return &ctx.p_impl->fp8_ty; }
type *type::get_half_ty(context &ctx) { return &ctx.p_impl->half_ty; }
type *type::get_float_ty(context &ctx) { return &ctx.p_impl->float_ty; }
type *type::get_double_ty(context &ctx) { return &ctx.p_impl->double_ty; }
type *type::get_fp16_ty(context &ctx) { return &ctx.p_impl->fp16_ty; }
type *type::get_bf16_ty(context &ctx) { return &ctx.p_impl->bf16_ty; }
type *type::get_fp32_ty(context &ctx) { return &ctx.p_impl->fp32_ty; }
type *type::get_fp64_ty(context &ctx) { return &ctx.p_impl->fp64_ty; }
// integer types
integer_type *type::get_int1_ty(context &ctx) { return &ctx.p_impl->int1_ty; }
integer_type *type::get_int8_ty(context &ctx) { return &ctx.p_impl->int8_ty; }