[LANG] Preliminary FP8 support (#96)
This commit is contained in:
committed by
Philippe Tillet
parent
4290be1ae8
commit
7355efa745
@@ -27,6 +27,7 @@ using namespace llvm;
|
||||
#define void_ty builder_->getVoidTy()
|
||||
#define f16_ty builder_->getHalfTy()
|
||||
#define f32_ty builder_->getFloatTy()
|
||||
#define i8_ty builder_->getInt8Ty()
|
||||
#define i32_ty builder_->getInt32Ty()
|
||||
#define vec_ty(type, num_el) VectorType::get(type, num_el, false)
|
||||
#define ptr_ty(...) PointerType::get(__VA_ARGS__)
|
||||
@@ -60,6 +61,7 @@ using namespace llvm;
|
||||
#define insert_elt(...) builder_->CreateInsertElement(__VA_ARGS__)
|
||||
#define intrinsic(...) builder_->CreateIntrinsic(__VA_ARGS__)
|
||||
#define load(...) builder_->CreateLoad(__VA_ARGS__)
|
||||
#define lshr(...) builder_->CreateLShr(__VA_ARGS__)
|
||||
#define max_num(...) builder_->CreateMaxNum(__VA_ARGS__)
|
||||
#define min_num(...) builder_->CreateMinNum(__VA_ARGS__)
|
||||
#define mul(...) builder_->CreateMul(__VA_ARGS__)
|
||||
@@ -69,6 +71,7 @@ using namespace llvm;
|
||||
#define select(...) builder_->CreateSelect(__VA_ARGS__)
|
||||
#define store(...) builder_->CreateStore(__VA_ARGS__)
|
||||
#define sub(...) builder_->CreateSub(__VA_ARGS__)
|
||||
#define shl(...) builder_->CreateShl(__VA_ARGS__)
|
||||
#define udiv(...) builder_->CreateUDiv(__VA_ARGS__)
|
||||
#define urem(...) builder_->CreateURem(__VA_ARGS__)
|
||||
#define splat(...) builder_->CreateVectorSplat(__VA_ARGS__)
|
||||
@@ -101,6 +104,7 @@ Type *generator::cvt(ir::type *ty) {
|
||||
// primitive types
|
||||
switch(ty->get_type_id()){
|
||||
case ir::type::VoidTyID: return Type::getVoidTy(*ctx_);
|
||||
case ir::type::FP8TyID: return Type::getInt8Ty(*ctx_);
|
||||
case ir::type::HalfTyID: return Type::getHalfTy(*ctx_);
|
||||
case ir::type::FloatTyID: return Type::getFloatTy(*ctx_);
|
||||
case ir::type::DoubleTyID: return Type::getDoubleTy(*ctx_);
|
||||
@@ -316,10 +320,108 @@ void generator::visit_fcmp_inst(ir::fcmp_inst* x) {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
std::tuple<Value*, Value*, Value*, Value*> generator::fp32x4_to_fp8x4(Value *in0, Value *in1, Value *in2, Value *in3){
|
||||
InlineAsm *ptx = InlineAsm::get(FunctionType::get(i32_ty, {f32_ty, f32_ty, f32_ty, f32_ty}, false),
|
||||
"{ \n\t"
|
||||
".reg .b32 b<4>; \n\t"
|
||||
"shl.b32 b0, $1, 4; \n\t" // shift into into upper byte
|
||||
"shl.b32 b1, $2, 4; \n\t"
|
||||
"shl.b32 b2, $3, 4; \n\t"
|
||||
"shl.b32 b3, $4, 4; \n\t"
|
||||
"lop3.b32 b0, b0, 0x80000000, $1, 0xb8; \n\t" // restore sign
|
||||
"lop3.b32 b1, b1, 0x80000000, $2, 0xb8; \n\t"
|
||||
"lop3.b32 b2, b2, 0x80000000, $3, 0xb8; \n\t"
|
||||
"lop3.b32 b3, b3, 0x80000000, $4, 0xb8; \n\t"
|
||||
"prmt.b32 b0, b0, b1, 0x6273; \n\t" // pack lower half b0, b1 (62 unused here)
|
||||
"prmt.b32 b2, b2, b3, 0x6273; \n\t" // pack lower half b2, b3 (62 unused here)
|
||||
"prmt.b32 $0, b0, b2, 0x5410; \n\t" // pack full b0, b1, b2, b3
|
||||
"}", "=r, r, r, r, r", false);
|
||||
Value *packed_ret = call(ptx, {in0, in1, in2, in3});
|
||||
Value* ret = bit_cast(packed_ret, vec_ty(i8_ty, 4));
|
||||
return std::make_tuple(extract_elt(ret, (int)0),
|
||||
extract_elt(ret, (int)1),
|
||||
extract_elt(ret, (int)2),
|
||||
extract_elt(ret, (int)3));
|
||||
}
|
||||
|
||||
std::tuple<Value*, Value*, Value*, Value*> generator::fp8x4_to_fp32x4(Value *in0, Value *in1, Value *in2, Value *in3){
|
||||
Value *ret0, *ret1, *ret2, *ret3;
|
||||
std::tie(ret0, ret1, ret2, ret3) = fp8x4_to_fp16x4(in0, in1, in2, in3);
|
||||
ret0 = cast(llvm::Instruction::FPExt, ret0, f32_ty);
|
||||
ret1 = cast(llvm::Instruction::FPExt, ret1, f32_ty);
|
||||
ret2 = cast(llvm::Instruction::FPExt, ret2, f32_ty);
|
||||
ret3 = cast(llvm::Instruction::FPExt, ret3, f32_ty);
|
||||
return std::make_tuple(ret0, ret1, ret2, ret3);
|
||||
}
|
||||
|
||||
|
||||
std::tuple<Value*, Value*, Value*, Value*> generator::fp8x4_to_fp16x4(Value *in0, Value *in1, Value *in2, Value *in3){
|
||||
Type *ret_ty = StructType::get(*ctx_, {vec_ty(f16_ty, 2), vec_ty(f16_ty, 2)});
|
||||
InlineAsm *ptx = InlineAsm::get(FunctionType::get(ret_ty, {i32_ty}, false),
|
||||
"{"
|
||||
".reg .b32 a<2>, b<2>; \n\t"
|
||||
"prmt.b32 a0, 0, $2, 0x5140; \n\t"
|
||||
"prmt.b32 a1, 0, $2, 0x7362; \n\t"
|
||||
"lop3.b32 b0, a0, 0x7fff7fff, 0, 0xc0; \n\t" // strip sign
|
||||
"lop3.b32 b1, a1, 0x7fff7fff, 0, 0xc0; \n\t"
|
||||
"shr.b32 b0, b0, 1; \n\t" // shift into fp16 poistion
|
||||
"shr.b32 b1, b1, 1; \n\t"
|
||||
"lop3.b32 $0, b0, 0x80008000, a0, 0xf8; \n\t" // restore sign
|
||||
"lop3.b32 $1, b1, 0x80008000, a1, 0xf8; \n\t"
|
||||
"}", "=r,=r,r", false);
|
||||
Value *packed_in = UndefValue::get(vec_ty(i8_ty, 4));
|
||||
packed_in = insert_elt(packed_in, in0, (int)0);
|
||||
packed_in = insert_elt(packed_in, in1, (int)1);
|
||||
packed_in = insert_elt(packed_in, in2, (int)2);
|
||||
packed_in = insert_elt(packed_in, in3, (int)3);
|
||||
Value *in = bit_cast(packed_in, i32_ty);
|
||||
Value *ret = call(ptx, {in});
|
||||
Value *packed_ret0 = extract_val(ret, {0});
|
||||
Value *packed_ret1 = extract_val(ret, {1});
|
||||
Value *ret0 = extract_elt(packed_ret0, (int)0);
|
||||
Value *ret1 = extract_elt(packed_ret0, (int)1);
|
||||
Value *ret2 = extract_elt(packed_ret1, (int)0);
|
||||
Value *ret3 = extract_elt(packed_ret1, (int)1);
|
||||
return std::make_tuple(ret0, ret1, ret2, ret3);
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* \brief Code Generation for `cast`
|
||||
*/
|
||||
void generator::visit_cast_inst(ir::cast_inst* x) {
|
||||
// <> FP8
|
||||
ir::value *op = x->get_operand(0);
|
||||
ir::type* ret_sca_ty = x->get_type()->get_scalar_ty();
|
||||
ir::type* op_sca_ty = op->get_type()->get_scalar_ty();
|
||||
if(ret_sca_ty->is_fp8_ty() || op_sca_ty->is_fp8_ty()){
|
||||
// ensure that conversions can be vectorized
|
||||
int ld = layouts_->get(x)->get_order(0);
|
||||
int contiguous = layouts_->get(x)->to_scanline()->nts(ld);
|
||||
if(contiguous % 4 != 0)
|
||||
throw std::runtime_error("unsupported fp32 -> fp8 conversion");
|
||||
auto x_idxs = idxs_.at(x);
|
||||
auto op_idxs = idxs_.at(op);
|
||||
// run the conversion
|
||||
auto cvt = [&](Value* a, Value* b, Value* c, Value* d){
|
||||
if(op_sca_ty->is_fp8_ty() && ret_sca_ty->is_half_ty())
|
||||
return fp8x4_to_fp16x4(a, b, c, d);
|
||||
throw std::runtime_error("unsupported conversion");
|
||||
};
|
||||
for(size_t i = 0; i < x_idxs.size(); i+=4){
|
||||
std::tie(vals_[x][x_idxs[i+0]],
|
||||
vals_[x][x_idxs[i+1]],
|
||||
vals_[x][x_idxs[i+2]],
|
||||
vals_[x][x_idxs[i+3]]) = cvt(vals_[op][op_idxs[i+0]],
|
||||
vals_[op][op_idxs[i+1]],
|
||||
vals_[op][op_idxs[i+2]],
|
||||
vals_[op][op_idxs[i+3]]);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
Type *ty = cvt(x->get_type()->get_scalar_ty());
|
||||
auto cvt = [](ir::cast_op_t op){
|
||||
using ll = llvm::Instruction::CastOps;
|
||||
|
@@ -12,9 +12,12 @@ namespace ir{
|
||||
context_impl::context_impl(context &ctx)
|
||||
: void_ty(ctx, type::VoidTyID),
|
||||
label_ty(ctx, type::LabelTyID),
|
||||
// floating point
|
||||
fp8_ty(ctx, type::FP8TyID),
|
||||
half_ty(ctx, type::HalfTyID),
|
||||
float_ty(ctx, type::FloatTyID),
|
||||
double_ty(ctx, type::DoubleTyID),
|
||||
// integers
|
||||
int1_ty(ctx, 1),
|
||||
int8_ty(ctx, 8),
|
||||
int16_ty(ctx, 16),
|
||||
|
@@ -21,6 +21,7 @@ type *type::get_scalar_ty() const {
|
||||
|
||||
unsigned type::get_primitive_size_in_bits() const {
|
||||
switch (id_) {
|
||||
case FP8TyID: return 8;
|
||||
case HalfTyID: return 16;
|
||||
case FloatTyID: return 32;
|
||||
case DoubleTyID: return 64;
|
||||
@@ -42,6 +43,7 @@ unsigned type::get_tile_bitwidth() const
|
||||
unsigned type::get_fp_mantissa_width() const {
|
||||
id_t id = get_scalar_ty()->id_;
|
||||
assert(is_floating_point_ty() && "Not a floating point type!");
|
||||
if (id == FP8TyID) return 3;
|
||||
if (id == HalfTyID) return 10;
|
||||
if (id == FloatTyID) return 23;
|
||||
if (id == DoubleTyID) return 53;
|
||||
@@ -103,7 +105,7 @@ bool type::is_integer_ty(unsigned width) const
|
||||
|
||||
|
||||
bool type::is_floating_point_ty() const
|
||||
{ return is_half_ty() || is_float_ty() || is_double_ty(); }
|
||||
{ return is_fp8_ty() || is_half_ty() || is_float_ty() || is_double_ty(); }
|
||||
|
||||
bool type::is_sized() const {
|
||||
// primitive types are sized
|
||||
@@ -120,7 +122,8 @@ bool type::is_sized() const {
|
||||
// primitive types
|
||||
type *type::get_void_ty(context &ctx) { return &ctx.p_impl->void_ty; }
|
||||
type *type::get_label_ty(context &ctx) { return &ctx.p_impl->label_ty; }
|
||||
// half
|
||||
// floating point
|
||||
type *type::get_fp8_ty(context &ctx) { return &ctx.p_impl->fp8_ty; }
|
||||
type *type::get_half_ty(context &ctx) { return &ctx.p_impl->half_ty; }
|
||||
type *type::get_float_ty(context &ctx) { return &ctx.p_impl->float_ty; }
|
||||
type *type::get_double_ty(context &ctx) { return &ctx.p_impl->double_ty; }
|
||||
|
Reference in New Issue
Block a user