[IR] Preliminary support for BF16 (#129)
This PR adds a BF16 data-type, along with FP32 <-> BF16 conversion instructions in the LLVM codegen. Other kinds of ops on bfloat16 are not yet supported.
This commit is contained in:
committed by
Philippe Tillet
parent
9b4e2cae2d
commit
8cea583109
@@ -55,10 +55,10 @@ value *builder::get_int64(int64_t val)
|
||||
{ return constant_int::get(type::get_int64_ty(ctx_), val);}
|
||||
|
||||
value *builder::get_float16(float val)
|
||||
{ return constant_fp::get(type::get_half_ty(ctx_), val); }
|
||||
{ return constant_fp::get(type::get_fp16_ty(ctx_), val); }
|
||||
|
||||
value *builder::get_float32(float val)
|
||||
{ return constant_fp::get(type::get_float_ty(ctx_), val); }
|
||||
{ return constant_fp::get(type::get_fp32_ty(ctx_), val); }
|
||||
|
||||
value *builder::get_range(int32_t _lo, int32_t _hi) {
|
||||
constant_int* lo = static_cast<constant_int*>(get_int32(_lo));
|
||||
@@ -85,13 +85,13 @@ type *builder::get_int64_ty()
|
||||
{ return type::get_int64_ty(ctx_); }
|
||||
|
||||
type *builder::get_half_ty()
|
||||
{ return type::get_half_ty(ctx_); }
|
||||
{ return type::get_fp16_ty(ctx_); }
|
||||
|
||||
type *builder::get_float_ty()
|
||||
{ return type::get_float_ty(ctx_); }
|
||||
{ return type::get_fp32_ty(ctx_); }
|
||||
|
||||
type *builder::get_double_ty()
|
||||
{ return type::get_double_ty(ctx_); }
|
||||
{ return type::get_fp64_ty(ctx_); }
|
||||
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@@ -15,12 +15,12 @@ constant *constant::get_null_value(type *ty) {
|
||||
switch (ty->get_scalar_ty()->get_type_id()) {
|
||||
case type::IntegerTyID:
|
||||
return constant_int::get(ty, 0);
|
||||
case type::HalfTyID:
|
||||
return constant_fp::get(type::get_half_ty(ctx), 0);
|
||||
case type::FloatTyID:
|
||||
return constant_fp::get(type::get_float_ty(ctx), 0);
|
||||
case type::DoubleTyID:
|
||||
return constant_fp::get(type::get_double_ty(ctx), 0);
|
||||
case type::FP16TyID:
|
||||
return constant_fp::get(type::get_fp16_ty(ctx), 0);
|
||||
case type::FP32TyID:
|
||||
return constant_fp::get(type::get_fp32_ty(ctx), 0);
|
||||
case type::FP64TyID:
|
||||
return constant_fp::get(type::get_fp64_ty(ctx), 0);
|
||||
default:
|
||||
throw std::runtime_error("Cannot create a null constant of that type!");
|
||||
}
|
||||
|
@@ -14,9 +14,10 @@ context_impl::context_impl(context &ctx)
|
||||
label_ty(ctx, type::LabelTyID),
|
||||
// floating point
|
||||
fp8_ty(ctx, type::FP8TyID),
|
||||
half_ty(ctx, type::HalfTyID),
|
||||
float_ty(ctx, type::FloatTyID),
|
||||
double_ty(ctx, type::DoubleTyID),
|
||||
fp16_ty(ctx, type::FP16TyID),
|
||||
bf16_ty(ctx, type::BF16TyID),
|
||||
fp32_ty(ctx, type::FP32TyID),
|
||||
fp64_ty(ctx, type::FP64TyID),
|
||||
// integers
|
||||
int1_ty(ctx, 1),
|
||||
int8_ty(ctx, 8),
|
||||
|
@@ -37,16 +37,16 @@ ir::type *computation_type(ir::type* a_ty, ir::type* b_ty){
|
||||
context &ctx = a_ty->get_context();
|
||||
// 1) if one operand is double, the other is implicitly
|
||||
// converted to double
|
||||
if(a_ty->is_double_ty() || b_ty->is_double_ty())
|
||||
return type::get_double_ty(ctx);
|
||||
if(a_ty->is_fp64_ty() || b_ty->is_fp64_ty())
|
||||
return type::get_fp64_ty(ctx);
|
||||
// 2) if one operand is float, the other is implicitly
|
||||
// converted to float
|
||||
if(a_ty->is_float_ty() || b_ty->is_float_ty())
|
||||
return type::get_float_ty(ctx);
|
||||
if(a_ty->is_fp32_ty() || b_ty->is_fp32_ty())
|
||||
return type::get_fp32_ty(ctx);
|
||||
// 3 ) if one operand is half, the other is implicitly
|
||||
// converted to half
|
||||
if(a_ty->is_half_ty() || b_ty->is_half_ty())
|
||||
return type::get_half_ty(ctx);
|
||||
if(a_ty->is_fp16_ty() || b_ty->is_fp16_ty())
|
||||
return type::get_fp16_ty(ctx);
|
||||
if(!a_ty->is_integer_ty() || !b_ty->is_integer_ty())
|
||||
throw_unreachable("augment_types");
|
||||
// 4 ) both operands are integer and undergo
|
||||
|
@@ -22,12 +22,10 @@ type *type::get_scalar_ty() const {
|
||||
unsigned type::get_primitive_size_in_bits() const {
|
||||
switch (id_) {
|
||||
case FP8TyID: return 8;
|
||||
case HalfTyID: return 16;
|
||||
case FloatTyID: return 32;
|
||||
case DoubleTyID: return 64;
|
||||
case X86_FP80TyID: return 80;
|
||||
case FP128TyID: return 128;
|
||||
case PPC_FP128TyID: return 128;
|
||||
case FP16TyID: return 16;
|
||||
case BF16TyID: return 16;
|
||||
case FP32TyID: return 32;
|
||||
case FP64TyID: return 64;
|
||||
case IntegerTyID: return ((integer_type*)(this))->get_bitwidth();
|
||||
case BlockTyID: return ((block_type*)(this))->get_bitwidth();
|
||||
default: return 0;
|
||||
@@ -44,9 +42,10 @@ unsigned type::get_fp_mantissa_width() const {
|
||||
id_t id = get_scalar_ty()->id_;
|
||||
assert(is_floating_point_ty() && "Not a floating point type!");
|
||||
if (id == FP8TyID) return 3;
|
||||
if (id == HalfTyID) return 10;
|
||||
if (id == FloatTyID) return 23;
|
||||
if (id == DoubleTyID) return 53;
|
||||
if (id == FP16TyID) return 10;
|
||||
if (id == BF16TyID) return 7;
|
||||
if (id == FP32TyID) return 23;
|
||||
if (id == FP64TyID) return 53;
|
||||
throw std::runtime_error("unreachable");
|
||||
}
|
||||
|
||||
@@ -105,7 +104,7 @@ bool type::is_integer_ty(unsigned width) const
|
||||
|
||||
|
||||
bool type::is_floating_point_ty() const
|
||||
{ return is_fp8_ty() || is_half_ty() || is_float_ty() || is_double_ty(); }
|
||||
{ return is_fp8_ty() || is_fp16_ty() || is_bf16_ty() || is_fp32_ty() || is_fp64_ty(); }
|
||||
|
||||
bool type::is_sized() const {
|
||||
// primitive types are sized
|
||||
@@ -124,9 +123,10 @@ type *type::get_void_ty(context &ctx) { return &ctx.p_impl->void_ty; }
|
||||
type *type::get_label_ty(context &ctx) { return &ctx.p_impl->label_ty; }
|
||||
// floating point
|
||||
type *type::get_fp8_ty(context &ctx) { return &ctx.p_impl->fp8_ty; }
|
||||
type *type::get_half_ty(context &ctx) { return &ctx.p_impl->half_ty; }
|
||||
type *type::get_float_ty(context &ctx) { return &ctx.p_impl->float_ty; }
|
||||
type *type::get_double_ty(context &ctx) { return &ctx.p_impl->double_ty; }
|
||||
type *type::get_fp16_ty(context &ctx) { return &ctx.p_impl->fp16_ty; }
|
||||
type *type::get_bf16_ty(context &ctx) { return &ctx.p_impl->bf16_ty; }
|
||||
type *type::get_fp32_ty(context &ctx) { return &ctx.p_impl->fp32_ty; }
|
||||
type *type::get_fp64_ty(context &ctx) { return &ctx.p_impl->fp64_ty; }
|
||||
// integer types
|
||||
integer_type *type::get_int1_ty(context &ctx) { return &ctx.p_impl->int1_ty; }
|
||||
integer_type *type::get_int8_ty(context &ctx) { return &ctx.p_impl->int8_ty; }
|
||||
|
Reference in New Issue
Block a user