[LANG] Preliminary FP8 support (#96)

This commit is contained in:
Philippe Tillet
2021-05-01 14:34:33 -04:00
committed by Philippe Tillet
parent 4290be1ae8
commit 7355efa745
10 changed files with 182 additions and 40 deletions

View File

@@ -102,6 +102,10 @@ public:
void visit_getelementptr_inst(ir::getelementptr_inst*);
void visit_icmp_inst(ir::icmp_inst*);
void visit_fcmp_inst(ir::fcmp_inst*);
std::tuple<Value*, Value*, Value*, Value*> fp8x4_to_fp32x4(Value *in0, Value *in1, Value *in2, Value *in3);
std::tuple<Value*, Value*, Value*, Value*> fp32x4_to_fp8x4(Value *in0, Value *in1, Value *in2, Value *in3);
std::tuple<Value*, Value*, Value*, Value*> fp8x4_to_fp16x4(Value *in0, Value *in1, Value *in2, Value *in3);
std::tuple<Value*, Value*, Value*, Value*> fp16x4_to_fp8x4(Value *in0, Value *in1, Value *in2, Value *in3);
void visit_cast_inst(ir::cast_inst*);
void visit_return_inst(ir::return_inst*);
void visit_cond_branch_inst(ir::cond_branch_inst*);

View File

@@ -22,9 +22,11 @@ public:
context_impl(context &ctx);
public:
// primitive types
type void_ty, label_ty, half_ty, float_ty, double_ty;
// derived types
// non-numeric types
type void_ty, label_ty;
// floating point types
type fp8_ty, half_ty, float_ty, double_ty;
// integer types
integer_type int1_ty, int8_ty, int16_ty, int32_ty, int64_ty, int128_ty;
// Pointer types
std::map<std::pair<type*, unsigned>, pointer_type*> ptr_tys;

View File

@@ -29,21 +29,22 @@ public:
enum id_t {
// primitive types
VoidTyID = 0, ///< 0: type with no size
HalfTyID, ///< 1: 16-bit floating point type
FloatTyID, ///< 2: 32-bit floating point type
DoubleTyID, ///< 3: 64-bit floating point type
X86_FP80TyID, ///< 4: 80-bit floating point type (X87)
FP128TyID, ///< 5: 128-bit floating point type (112-bit mantissa)
PPC_FP128TyID, ///< 6: 128-bit floating point type (two 64-bits, PowerPC)
LabelTyID, ///< 7: Labels
MetadataTyID, ///< 8: Metadata
TokenTyID, ///< 9: Token
FP8TyID, ///< 1: 8-bit floating point type (3 bits mantissa)
HalfTyID, ///< 3: 16-bit floating point type
FloatTyID, ///< 4: 32-bit floating point type
DoubleTyID, ///< 5: 64-bit floating point type
X86_FP80TyID, ///< 6: 80-bit floating point type (X87)
FP128TyID, ///< 7: 128-bit floating point type (112-bit mantissa)
PPC_FP128TyID, ///< 8: 128-bit floating point type (two 64-bits, PowerPC)
LabelTyID, ///< 9: Labels
MetadataTyID, ///< 10: Metadata
TokenTyID, ///< 11: Token
// derived types
IntegerTyID, ///< 10: Arbitrary bit width integers
FunctionTyID, ///< 11: Functions
PointerTyID, ///< 12: Pointers
StructTyID, ///< 13: Struct
BlockTyID, ///< 14: Tile
IntegerTyID, ///< 12: Arbitrary bit width integers
FunctionTyID, ///< 13: Functions
PointerTyID, ///< 14: Pointers
StructTyID, ///< 15: Struct
BlockTyID, ///< 16: Block
};
public:
@@ -72,6 +73,7 @@ public:
// primitive predicates
bool is_void_ty() const { return id_ == VoidTyID; }
bool is_fp8_ty() const { return id_ == FP8TyID; }
bool is_half_ty() const { return id_ == HalfTyID; }
bool is_float_ty() const { return id_ == FloatTyID; }
bool is_double_ty() const { return id_ == DoubleTyID; }
@@ -96,6 +98,7 @@ public:
static type *get_void_ty(context &ctx);
static type *get_label_ty(context &ctx);
// half
static type *get_fp8_ty(context &ctx);
static type *get_half_ty(context &ctx);
static type *get_float_ty(context &ctx);
static type *get_double_ty(context &ctx);
@@ -124,6 +127,7 @@ public:
std::string repr() const {
switch(id_) {
case VoidTyID: return "void";
case FP8TyID: return "fp8";
case HalfTyID: return "f16";
case FloatTyID: return "f32";
case DoubleTyID: return "f64";