[IR] Preliminary support for BF16 (#129)

This PR adds a BF16 data-type, along with FP32 <-> BF16 conversion instructions in the LLVM codegen. Other kinds of ops on bfloat16 are not yet supported.
This commit is contained in:
Philippe Tillet
2021-06-25 10:19:29 -04:00
committed by Philippe Tillet
parent 9b4e2cae2d
commit 8cea583109
18 changed files with 173 additions and 132 deletions

View File

@@ -134,6 +134,9 @@ public:
std::tuple<Value*, Value*, Value*, Value*> fp32x4_to_fp8x4(Value *in0, Value *in1, Value *in2, Value *in3);
std::tuple<Value*, Value*, Value*, Value*> fp8x4_to_fp16x4(Value *in0, Value *in1, Value *in2, Value *in3);
std::tuple<Value*, Value*, Value*, Value*> fp16x4_to_fp8x4(Value *in0, Value *in1, Value *in2, Value *in3);
Value* bf16_to_fp32(Value *in0);
Value* fp32_to_bf16(Value *in0);
void visit_cast_inst(ir::cast_inst*);
void visit_return_inst(ir::return_inst*);
void visit_cond_branch_inst(ir::cond_branch_inst*);

View File

@@ -25,7 +25,7 @@ public:
// non-numeric types
type void_ty, label_ty;
// floating point types
type fp8_ty, half_ty, float_ty, double_ty;
type fp8_ty, fp16_ty, bf16_ty, fp32_ty, fp64_ty;
// integer types
integer_type int1_ty, int8_ty, int16_ty, int32_ty, int64_ty, int128_ty;
// Pointer types

View File

@@ -28,23 +28,21 @@ protected:
public:
enum id_t {
// primitive types
VoidTyID = 0, ///< 0: type with no size
FP8TyID, ///< 1: 8-bit floating point type (3 bits mantissa)
HalfTyID, ///< 3: 16-bit floating point type
FloatTyID, ///< 4: 32-bit floating point type
DoubleTyID, ///< 5: 64-bit floating point type
X86_FP80TyID, ///< 6: 80-bit floating point type (X87)
FP128TyID, ///< 7: 128-bit floating point type (112-bit mantissa)
PPC_FP128TyID, ///< 8: 128-bit floating point type (two 64-bits, PowerPC)
LabelTyID, ///< 9: Labels
MetadataTyID, ///< 10: Metadata
TokenTyID, ///< 11: Token
VoidTyID = 0, ///< type with no size
FP8TyID, ///< 8-bit floating point type (3 bits mantissa)
FP16TyID, ///< 16-bit floating point type (10 bits mantissa)
BF16TyID, ///< 16-bit floating point type (7 bits mantissa)
FP32TyID, ///< 32-bit floating point type
FP64TyID, ///< 64-bit floating point type
LabelTyID, ///< Labels
MetadataTyID, ///< Metadata
TokenTyID, ///< Token
// derived types
IntegerTyID, ///< 12: Arbitrary bit width integers
FunctionTyID, ///< 13: Functions
PointerTyID, ///< 14: Pointers
StructTyID, ///< 15: Struct
BlockTyID, ///< 16: Block
IntegerTyID, ///< Arbitrary bit width integers
FunctionTyID, ///< Functions
PointerTyID, ///< Pointers
StructTyID, ///< Struct
BlockTyID, ///< Block
};
public:
@@ -74,9 +72,10 @@ public:
// primitive predicates
bool is_void_ty() const { return id_ == VoidTyID; }
bool is_fp8_ty() const { return id_ == FP8TyID; }
bool is_half_ty() const { return id_ == HalfTyID; }
bool is_float_ty() const { return id_ == FloatTyID; }
bool is_double_ty() const { return id_ == DoubleTyID; }
bool is_fp16_ty() const { return id_ == FP16TyID; }
bool is_bf16_ty() const { return id_ == BF16TyID; }
bool is_fp32_ty() const { return id_ == FP32TyID; }
bool is_fp64_ty() const { return id_ == FP64TyID; }
bool is_label_ty() const { return id_ == LabelTyID;}
bool is_metadata_ty() const { return id_ == MetadataTyID; }
bool is_token_ty() const { return id_ == TokenTyID; }
@@ -99,9 +98,10 @@ public:
static type *get_label_ty(context &ctx);
// half
static type *get_fp8_ty(context &ctx);
static type *get_half_ty(context &ctx);
static type *get_float_ty(context &ctx);
static type *get_double_ty(context &ctx);
static type *get_fp16_ty(context &ctx);
static type *get_bf16_ty(context &ctx);
static type *get_fp32_ty(context &ctx);
static type *get_fp64_ty(context &ctx);
// integer types
static integer_type *get_int1_ty(context &ctx);
static integer_type *get_int8_ty(context &ctx);
@@ -128,12 +128,9 @@ public:
switch(id_) {
case VoidTyID: return "void";
case FP8TyID: return "fp8";
case HalfTyID: return "f16";
case FloatTyID: return "f32";
case DoubleTyID: return "f64";
case X86_FP80TyID: return "f80";
case FP128TyID: return "f128";
case PPC_FP128TyID: return "ppcf128";
case FP16TyID: return "f16";
case FP32TyID: return "f32";
case FP64TyID: return "f64";
case LabelTyID: return "label";
case MetadataTyID: return "md";
case TokenTyID: return "tok";