[BACKEND] Better bf16 support (#588)

This commit is contained in:
daadaada
2022-07-20 12:22:37 +08:00
committed by GitHub
parent 86cab58d89
commit 9b2bc88d11
6 changed files with 180 additions and 62 deletions

View File

@@ -86,7 +86,7 @@ Value* geper::operator()(Value *ptr, Value* off, const std::string& name){
// types
#define void_ty builder_->getVoidTy()
#define f16_ty builder_->getHalfTy()
#define bf16_ty builder_->getBFloatTy()
#define bf16_ty builder_->getInt16Ty()
#define f32_ty builder_->getFloatTy()
#define i1_ty builder_->getInt1Ty()
#define i8_ty builder_->getInt8Ty()
@@ -178,7 +178,7 @@ Type *generator::cvt(ir::type *ty) {
case ir::type::VoidTyID: return Type::getVoidTy(*ctx_);
case ir::type::FP8TyID: return Type::getInt8Ty(*ctx_);
case ir::type::FP16TyID: return Type::getHalfTy(*ctx_);
case ir::type::BF16TyID: return Type::getBFloatTy(*ctx_);
case ir::type::BF16TyID: return Type::getInt16Ty(*ctx_); // use int16 as storage type
case ir::type::FP32TyID: return Type::getFloatTy(*ctx_);
case ir::type::FP64TyID: return Type::getDoubleTy(*ctx_);
case ir::type::LabelTyID: return Type::getLabelTy(*ctx_);
@@ -378,8 +378,8 @@ void generator::visit_launch_inst(ir::launch_inst *launch) {
*/
void generator::visit_binary_operator(ir::binary_operator*x) {
using ll = llvm::Instruction::BinaryOps;
using tt = ir::binary_op_t;
auto cvt = [](ir::binary_op_t op){
using tt = ir::binary_op_t;
switch(op) {
case tt::Add: return ll::Add;
case tt::FAdd: return ll::FAdd;
@@ -406,20 +406,51 @@ void generator::visit_binary_operator(ir::binary_operator*x) {
for(indices_t idx: idxs_.at(x)){
Value *lhs = vals_[x->get_operand(0)][idx];
Value *rhs = vals_[x->get_operand(1)][idx];
auto op = cvt(x->get_op());
if(op == ll::Add)
vals_[x][idx] = add(lhs, rhs);
else if(op == ll::Mul)
vals_[x][idx] = mul(lhs, rhs);
else if(op == ll::FDiv && !x->get_fdiv_ieee_rounding() &&
x->get_type()->get_scalar_ty()->is_fp32_ty()){
InlineAsm *ptx = InlineAsm::get(FunctionType::get(f32_ty, {f32_ty, f32_ty}, false),
" div.full.f32 $0, $1, $2;", "=r,r,r", false);
vals_[x][idx] = builder_->CreateCall(ptx, {lhs, rhs});
// manually select bf16 bin op
if (x->get_operand(0)->get_type()->get_scalar_ty()->is_bf16_ty()) {
assert(x->get_operand(1)->get_type()->get_scalar_ty()->is_bf16_ty());
if (x->get_op() == tt::FAdd) { // a + b = a * 1.0 + b
InlineAsm *bf16_add_asm =
InlineAsm::get(FunctionType::get(bf16_ty, {bf16_ty, bf16_ty}, false),
"{ .reg .b16 c; \n\t"
" mov.b16 c, 0x3f80U; \n\t" // 1.0
" fma.rn.bf16 $0, $1, c, $2; } \n\t",
"=h,h,h", false);
vals_[x][idx] = builder_->CreateCall(bf16_add_asm, {lhs, rhs});
} else if (x->get_op() == tt::FSub) { // a - b = b * (-1.0) + a
InlineAsm *bf16_sub_asm =
InlineAsm::get(FunctionType::get(bf16_ty, {bf16_ty, bf16_ty}, false),
" { .reg .b16 c; \n\t"
" mov.b16 c, 0xbf80U; \n\t" // -1.0
" fma.rn.bf16 $0, $2, c, $1;} \n\t",
"=h,h,h", false);
vals_[x][idx] = builder_->CreateCall(bf16_sub_asm, {lhs, rhs});
} else if (x->get_op() == tt::FMul) { // a * b = a*b + 0
InlineAsm *bf16_mul_asm =
InlineAsm::get(FunctionType::get(bf16_ty, {bf16_ty, bf16_ty}, false),
" { .reg .b16 c; \n\t"
" mov.b16 c, 0x8000U; \n\t" // 0.0
" fma.rn.bf16 $0, $1, $2, c;} \n\t",
"=h,h,h", false);
vals_[x][idx] = builder_->CreateCall(bf16_mul_asm, {lhs, rhs});
} else
throw std::runtime_error("invalid bin op for bf16");
} else { // not bf16
auto op = cvt(x->get_op());
if(op == ll::Add)
vals_[x][idx] = add(lhs, rhs);
else if(op == ll::Mul)
vals_[x][idx] = mul(lhs, rhs);
else if(op == ll::FDiv && !x->get_fdiv_ieee_rounding() &&
x->get_type()->get_scalar_ty()->is_fp32_ty()){
InlineAsm *ptx = InlineAsm::get(FunctionType::get(f32_ty, {f32_ty, f32_ty}, false),
" div.full.f32 $0, $1, $2;", "=r,r,r", false);
vals_[x][idx] = builder_->CreateCall(ptx, {lhs, rhs});
}
else
vals_[x][idx] = bin_op(op, lhs, rhs);
}
else
vals_[x][idx] = bin_op(op, lhs, rhs);
}
}
}
@@ -970,8 +1001,6 @@ void generator::visit_store_inst(ir::store_inst * x){
has_l2_evict_policy = false;
auto idxs = idxs_.at(val_op);
Type *ty = cvt(val_op->get_type()->get_scalar_ty());
if (ty->isBFloatTy()) // llvm11-nvptx cannot select bf16 store
ty = f16_ty;
if(ty->isIntegerTy(1))
ty = builder_->getInt8Ty();
for(size_t i = 0; i < idxs.size(); i += vec){
@@ -2830,9 +2859,6 @@ void generator::visit_layout_convert(ir::value *out, ir::value *in){
// pointer to temporary shared memory
Type *ty = cvt(out->get_type()->get_scalar_ty());
if (ty->isBFloatTy()) // llvm11-nvptx cannot select bf16 store
ty = f16_ty;
// Orders
analysis::distributed_layout* in_layout = dynamic_cast<analysis::distributed_layout*>(layouts_->get(in));
analysis::distributed_layout* out_layout = dynamic_cast<analysis::distributed_layout*>(layouts_->get(out));
@@ -3229,8 +3255,22 @@ void generator::visit_constant_int(ir::constant_int *x){
void generator::visit_constant_fp(ir::constant_fp *x){
Type *ty = cvt(x->get_type()->get_scalar_ty());
for(indices_t idx: idxs_.at(x))
vals_[x][idx] = ConstantFP::get(ty, x->get_value());
for(indices_t idx: idxs_.at(x)) {
// manually select bf16 constant
if (x->get_type()->get_scalar_ty()->is_bf16_ty()) {
// highest 16 bits of fp32
float fp32_value = x->get_value();
uint16_t bf16_raw = (*reinterpret_cast<uint32_t*>(&fp32_value)
& 0xffff0000) >> 16;
std::stringstream const_str;
const_str << "0x" << std::hex << bf16_raw << "U"; // unsigned
InlineAsm *bf16_const = InlineAsm::get(FunctionType::get(bf16_ty, {}, false),
" mov.b16 $0, " + const_str.str() + ";",
"=h", false);
vals_[x][idx] = builder_->CreateCall(bf16_const, {});
} else
vals_[x][idx] = ConstantFP::get(ty, x->get_value());
}
}
void generator::visit_alloc_const(ir::alloc_const *alloc) {

View File

@@ -18,6 +18,8 @@ constant *constant::get_null_value(type *ty) {
return constant_int::get(ty, 0);
case type::FP16TyID:
return constant_fp::get(type::get_fp16_ty(ctx), 0);
case type::BF16TyID:
return constant_fp::get(type::get_bf16_ty(ctx), 0);
case type::FP32TyID:
return constant_fp::get(type::get_fp32_ty(ctx), 0);
case type::FP64TyID: