[BACKEND] Better bf16 support (#588)
This commit is contained in:
@@ -86,7 +86,7 @@ Value* geper::operator()(Value *ptr, Value* off, const std::string& name){
|
||||
// types
|
||||
#define void_ty builder_->getVoidTy()
|
||||
#define f16_ty builder_->getHalfTy()
|
||||
#define bf16_ty builder_->getBFloatTy()
|
||||
#define bf16_ty builder_->getInt16Ty()
|
||||
#define f32_ty builder_->getFloatTy()
|
||||
#define i1_ty builder_->getInt1Ty()
|
||||
#define i8_ty builder_->getInt8Ty()
|
||||
@@ -178,7 +178,7 @@ Type *generator::cvt(ir::type *ty) {
|
||||
case ir::type::VoidTyID: return Type::getVoidTy(*ctx_);
|
||||
case ir::type::FP8TyID: return Type::getInt8Ty(*ctx_);
|
||||
case ir::type::FP16TyID: return Type::getHalfTy(*ctx_);
|
||||
case ir::type::BF16TyID: return Type::getBFloatTy(*ctx_);
|
||||
case ir::type::BF16TyID: return Type::getInt16Ty(*ctx_); // use int16 as storage type
|
||||
case ir::type::FP32TyID: return Type::getFloatTy(*ctx_);
|
||||
case ir::type::FP64TyID: return Type::getDoubleTy(*ctx_);
|
||||
case ir::type::LabelTyID: return Type::getLabelTy(*ctx_);
|
||||
@@ -378,8 +378,8 @@ void generator::visit_launch_inst(ir::launch_inst *launch) {
|
||||
*/
|
||||
void generator::visit_binary_operator(ir::binary_operator*x) {
|
||||
using ll = llvm::Instruction::BinaryOps;
|
||||
using tt = ir::binary_op_t;
|
||||
auto cvt = [](ir::binary_op_t op){
|
||||
using tt = ir::binary_op_t;
|
||||
switch(op) {
|
||||
case tt::Add: return ll::Add;
|
||||
case tt::FAdd: return ll::FAdd;
|
||||
@@ -406,20 +406,51 @@ void generator::visit_binary_operator(ir::binary_operator*x) {
|
||||
for(indices_t idx: idxs_.at(x)){
|
||||
Value *lhs = vals_[x->get_operand(0)][idx];
|
||||
Value *rhs = vals_[x->get_operand(1)][idx];
|
||||
auto op = cvt(x->get_op());
|
||||
if(op == ll::Add)
|
||||
vals_[x][idx] = add(lhs, rhs);
|
||||
else if(op == ll::Mul)
|
||||
vals_[x][idx] = mul(lhs, rhs);
|
||||
else if(op == ll::FDiv && !x->get_fdiv_ieee_rounding() &&
|
||||
x->get_type()->get_scalar_ty()->is_fp32_ty()){
|
||||
InlineAsm *ptx = InlineAsm::get(FunctionType::get(f32_ty, {f32_ty, f32_ty}, false),
|
||||
" div.full.f32 $0, $1, $2;", "=r,r,r", false);
|
||||
vals_[x][idx] = builder_->CreateCall(ptx, {lhs, rhs});
|
||||
// manually select bf16 bin op
|
||||
if (x->get_operand(0)->get_type()->get_scalar_ty()->is_bf16_ty()) {
|
||||
assert(x->get_operand(1)->get_type()->get_scalar_ty()->is_bf16_ty());
|
||||
if (x->get_op() == tt::FAdd) { // a + b = a * 1.0 + b
|
||||
InlineAsm *bf16_add_asm =
|
||||
InlineAsm::get(FunctionType::get(bf16_ty, {bf16_ty, bf16_ty}, false),
|
||||
"{ .reg .b16 c; \n\t"
|
||||
" mov.b16 c, 0x3f80U; \n\t" // 1.0
|
||||
" fma.rn.bf16 $0, $1, c, $2; } \n\t",
|
||||
"=h,h,h", false);
|
||||
vals_[x][idx] = builder_->CreateCall(bf16_add_asm, {lhs, rhs});
|
||||
} else if (x->get_op() == tt::FSub) { // a - b = b * (-1.0) + a
|
||||
InlineAsm *bf16_sub_asm =
|
||||
InlineAsm::get(FunctionType::get(bf16_ty, {bf16_ty, bf16_ty}, false),
|
||||
" { .reg .b16 c; \n\t"
|
||||
" mov.b16 c, 0xbf80U; \n\t" // -1.0
|
||||
" fma.rn.bf16 $0, $2, c, $1;} \n\t",
|
||||
"=h,h,h", false);
|
||||
vals_[x][idx] = builder_->CreateCall(bf16_sub_asm, {lhs, rhs});
|
||||
} else if (x->get_op() == tt::FMul) { // a * b = a*b + 0
|
||||
InlineAsm *bf16_mul_asm =
|
||||
InlineAsm::get(FunctionType::get(bf16_ty, {bf16_ty, bf16_ty}, false),
|
||||
" { .reg .b16 c; \n\t"
|
||||
" mov.b16 c, 0x8000U; \n\t" // 0.0
|
||||
" fma.rn.bf16 $0, $1, $2, c;} \n\t",
|
||||
"=h,h,h", false);
|
||||
vals_[x][idx] = builder_->CreateCall(bf16_mul_asm, {lhs, rhs});
|
||||
} else
|
||||
throw std::runtime_error("invalid bin op for bf16");
|
||||
} else { // not bf16
|
||||
auto op = cvt(x->get_op());
|
||||
if(op == ll::Add)
|
||||
vals_[x][idx] = add(lhs, rhs);
|
||||
else if(op == ll::Mul)
|
||||
vals_[x][idx] = mul(lhs, rhs);
|
||||
else if(op == ll::FDiv && !x->get_fdiv_ieee_rounding() &&
|
||||
x->get_type()->get_scalar_ty()->is_fp32_ty()){
|
||||
InlineAsm *ptx = InlineAsm::get(FunctionType::get(f32_ty, {f32_ty, f32_ty}, false),
|
||||
" div.full.f32 $0, $1, $2;", "=r,r,r", false);
|
||||
vals_[x][idx] = builder_->CreateCall(ptx, {lhs, rhs});
|
||||
|
||||
}
|
||||
else
|
||||
vals_[x][idx] = bin_op(op, lhs, rhs);
|
||||
}
|
||||
else
|
||||
vals_[x][idx] = bin_op(op, lhs, rhs);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -970,8 +1001,6 @@ void generator::visit_store_inst(ir::store_inst * x){
|
||||
has_l2_evict_policy = false;
|
||||
auto idxs = idxs_.at(val_op);
|
||||
Type *ty = cvt(val_op->get_type()->get_scalar_ty());
|
||||
if (ty->isBFloatTy()) // llvm11-nvptx cannot select bf16 store
|
||||
ty = f16_ty;
|
||||
if(ty->isIntegerTy(1))
|
||||
ty = builder_->getInt8Ty();
|
||||
for(size_t i = 0; i < idxs.size(); i += vec){
|
||||
@@ -2830,9 +2859,6 @@ void generator::visit_layout_convert(ir::value *out, ir::value *in){
|
||||
// pointer to temporary shared memory
|
||||
Type *ty = cvt(out->get_type()->get_scalar_ty());
|
||||
|
||||
if (ty->isBFloatTy()) // llvm11-nvptx cannot select bf16 store
|
||||
ty = f16_ty;
|
||||
|
||||
// Orders
|
||||
analysis::distributed_layout* in_layout = dynamic_cast<analysis::distributed_layout*>(layouts_->get(in));
|
||||
analysis::distributed_layout* out_layout = dynamic_cast<analysis::distributed_layout*>(layouts_->get(out));
|
||||
@@ -3229,8 +3255,22 @@ void generator::visit_constant_int(ir::constant_int *x){
|
||||
|
||||
void generator::visit_constant_fp(ir::constant_fp *x){
|
||||
Type *ty = cvt(x->get_type()->get_scalar_ty());
|
||||
for(indices_t idx: idxs_.at(x))
|
||||
vals_[x][idx] = ConstantFP::get(ty, x->get_value());
|
||||
for(indices_t idx: idxs_.at(x)) {
|
||||
// manually select bf16 constant
|
||||
if (x->get_type()->get_scalar_ty()->is_bf16_ty()) {
|
||||
// highest 16 bits of fp32
|
||||
float fp32_value = x->get_value();
|
||||
uint16_t bf16_raw = (*reinterpret_cast<uint32_t*>(&fp32_value)
|
||||
& 0xffff0000) >> 16;
|
||||
std::stringstream const_str;
|
||||
const_str << "0x" << std::hex << bf16_raw << "U"; // unsigned
|
||||
InlineAsm *bf16_const = InlineAsm::get(FunctionType::get(bf16_ty, {}, false),
|
||||
" mov.b16 $0, " + const_str.str() + ";",
|
||||
"=h", false);
|
||||
vals_[x][idx] = builder_->CreateCall(bf16_const, {});
|
||||
} else
|
||||
vals_[x][idx] = ConstantFP::get(ty, x->get_value());
|
||||
}
|
||||
}
|
||||
|
||||
void generator::visit_alloc_const(ir::alloc_const *alloc) {
|
||||
|
@@ -18,6 +18,8 @@ constant *constant::get_null_value(type *ty) {
|
||||
return constant_int::get(ty, 0);
|
||||
case type::FP16TyID:
|
||||
return constant_fp::get(type::get_fp16_ty(ctx), 0);
|
||||
case type::BF16TyID:
|
||||
return constant_fp::get(type::get_bf16_ty(ctx), 0);
|
||||
case type::FP32TyID:
|
||||
return constant_fp::get(type::get_fp32_ty(ctx), 0);
|
||||
case type::FP64TyID:
|
||||
|
Reference in New Issue
Block a user