[IR] Preliminary support for BF16 (#129)
This PR adds a BF16 data-type, along with FP32 <-> BF16 conversion instructions in the LLVM codegen. Other kinds of ops on bfloat16 are not yet supported.
This commit is contained in:
committed by
Philippe Tillet
parent
9b4e2cae2d
commit
8cea583109
@@ -161,11 +161,10 @@ Type *generator::cvt(ir::type *ty) {
|
||||
switch(ty->get_type_id()){
|
||||
case ir::type::VoidTyID: return Type::getVoidTy(*ctx_);
|
||||
case ir::type::FP8TyID: return Type::getInt8Ty(*ctx_);
|
||||
case ir::type::HalfTyID: return Type::getHalfTy(*ctx_);
|
||||
case ir::type::FloatTyID: return Type::getFloatTy(*ctx_);
|
||||
case ir::type::DoubleTyID: return Type::getDoubleTy(*ctx_);
|
||||
case ir::type::X86_FP80TyID: return Type::getX86_FP80Ty(*ctx_);
|
||||
case ir::type::PPC_FP128TyID: return Type::getPPC_FP128Ty(*ctx_);
|
||||
case ir::type::FP16TyID: return Type::getHalfTy(*ctx_);
|
||||
case ir::type::BF16TyID: return Type::getInt16Ty(*ctx_);
|
||||
case ir::type::FP32TyID: return Type::getFloatTy(*ctx_);
|
||||
case ir::type::FP64TyID: return Type::getDoubleTy(*ctx_);
|
||||
case ir::type::LabelTyID: return Type::getLabelTy(*ctx_);
|
||||
case ir::type::MetadataTyID: return Type::getMetadataTy(*ctx_);
|
||||
case ir::type::TokenTyID: return Type::getTokenTy(*ctx_);
|
||||
@@ -428,57 +427,74 @@ std::tuple<Value*, Value*, Value*, Value*> generator::fp8x4_to_fp32x4(Value *in0
|
||||
|
||||
|
||||
std::tuple<Value*, Value*, Value*, Value*> generator::fp8x4_to_fp16x4(Value *in0, Value *in1, Value *in2, Value *in3){
|
||||
Type *ret_ty = StructType::get(*ctx_, {vec_ty(f16_ty, 2), vec_ty(f16_ty, 2)});
|
||||
InlineAsm *ptx = InlineAsm::get(FunctionType::get(ret_ty, {i32_ty}, false),
|
||||
"{"
|
||||
".reg .b32 a<2>, b<2>; \n\t"
|
||||
"prmt.b32 a0, 0, $2, 0x5140; \n\t"
|
||||
"prmt.b32 a1, 0, $2, 0x7362; \n\t"
|
||||
"lop3.b32 b0, a0, 0x7fff7fff, 0, 0xc0; \n\t" // strip sign
|
||||
"lop3.b32 b1, a1, 0x7fff7fff, 0, 0xc0; \n\t"
|
||||
"shr.b32 b0, b0, 1; \n\t" // shift into fp16 poistion
|
||||
"shr.b32 b1, b1, 1; \n\t"
|
||||
"lop3.b32 $0, b0, 0x80008000, a0, 0xf8; \n\t" // restore sign
|
||||
"lop3.b32 $1, b1, 0x80008000, a1, 0xf8; \n\t"
|
||||
"}", "=r,=r,r", false);
|
||||
Value *packed_in = UndefValue::get(vec_ty(i8_ty, 4));
|
||||
packed_in = insert_elt(packed_in, in0, (int)0);
|
||||
packed_in = insert_elt(packed_in, in1, (int)1);
|
||||
packed_in = insert_elt(packed_in, in2, (int)2);
|
||||
packed_in = insert_elt(packed_in, in3, (int)3);
|
||||
Value *in = bit_cast(packed_in, i32_ty);
|
||||
Value *ret = call(ptx, {in});
|
||||
Value *packed_ret0 = extract_val(ret, {0});
|
||||
Value *packed_ret1 = extract_val(ret, {1});
|
||||
Value *ret0 = extract_elt(packed_ret0, (int)0);
|
||||
Value *ret1 = extract_elt(packed_ret0, (int)1);
|
||||
Value *ret2 = extract_elt(packed_ret1, (int)0);
|
||||
Value *ret3 = extract_elt(packed_ret1, (int)1);
|
||||
return std::make_tuple(ret0, ret1, ret2, ret3);
|
||||
Type *ret_ty = StructType::get(*ctx_, {vec_ty(f16_ty, 2), vec_ty(f16_ty, 2)});
|
||||
InlineAsm *ptx = InlineAsm::get(FunctionType::get(ret_ty, {i32_ty}, false),
|
||||
"{"
|
||||
".reg .b32 a<2>, b<2>; \n\t"
|
||||
"prmt.b32 a0, 0, $2, 0x5140; \n\t"
|
||||
"prmt.b32 a1, 0, $2, 0x7362; \n\t"
|
||||
"lop3.b32 b0, a0, 0x7fff7fff, 0, 0xc0; \n\t" // strip sign
|
||||
"lop3.b32 b1, a1, 0x7fff7fff, 0, 0xc0; \n\t"
|
||||
"shr.b32 b0, b0, 1; \n\t" // shift into fp16 poistion
|
||||
"shr.b32 b1, b1, 1; \n\t"
|
||||
"lop3.b32 $0, b0, 0x80008000, a0, 0xf8; \n\t" // restore sign
|
||||
"lop3.b32 $1, b1, 0x80008000, a1, 0xf8; \n\t"
|
||||
"}", "=r,=r,r", false);
|
||||
Value *packed_in = UndefValue::get(vec_ty(i8_ty, 4));
|
||||
packed_in = insert_elt(packed_in, in0, (int)0);
|
||||
packed_in = insert_elt(packed_in, in1, (int)1);
|
||||
packed_in = insert_elt(packed_in, in2, (int)2);
|
||||
packed_in = insert_elt(packed_in, in3, (int)3);
|
||||
Value *in = bit_cast(packed_in, i32_ty);
|
||||
Value *ret = call(ptx, {in});
|
||||
Value *packed_ret0 = extract_val(ret, {0});
|
||||
Value *packed_ret1 = extract_val(ret, {1});
|
||||
Value *ret0 = extract_elt(packed_ret0, (int)0);
|
||||
Value *ret1 = extract_elt(packed_ret0, (int)1);
|
||||
Value *ret2 = extract_elt(packed_ret1, (int)0);
|
||||
Value *ret3 = extract_elt(packed_ret1, (int)1);
|
||||
return std::make_tuple(ret0, ret1, ret2, ret3);
|
||||
}
|
||||
|
||||
Value* generator::bf16_to_fp32(Value *in0){
|
||||
Value *ret = UndefValue::get(vec_ty(builder_->getInt16Ty(), 2));
|
||||
ret = insert_elt(ret, in0, (uint64_t)1);
|
||||
ret = insert_elt(ret, builder_->getInt16(0), (uint64_t)0);
|
||||
return bit_cast(ret, builder_->getFloatTy());
|
||||
}
|
||||
|
||||
Value* generator::fp32_to_bf16(Value *in0){
|
||||
if(tgt_->as_nvidia()->sm() >= 80){
|
||||
InlineAsm *ptx = InlineAsm::get(FunctionType::get(builder_->getInt16Ty(), {builder_->getFloatTy()}),
|
||||
"cvt.rn.bf16.f32 $0, $1;", "=h,r", false);
|
||||
return call(ptx, {in0});
|
||||
}
|
||||
return extract_elt(bit_cast(in0, vec_ty(builder_->getInt16Ty(), 2)), (uint64_t)1);
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Code Generation for `cast`
|
||||
*/
|
||||
void generator::visit_cast_inst(ir::cast_inst* x) {
|
||||
// <> FP8
|
||||
ir::value *op = x->get_operand(0);
|
||||
ir::type* ret_sca_ty = x->get_type()->get_scalar_ty();
|
||||
ir::type* op_sca_ty = op->get_type()->get_scalar_ty();
|
||||
auto x_idxs = idxs_.at(x);
|
||||
auto op_idxs = idxs_.at(op);
|
||||
|
||||
// <> FP8
|
||||
if(ret_sca_ty->is_fp8_ty() || op_sca_ty->is_fp8_ty()){
|
||||
// ensure that conversions can be vectorized
|
||||
int ld = layouts_->get(x)->get_order(0);
|
||||
int contiguous = layouts_->get(x)->to_scanline()->nts(ld);
|
||||
if(contiguous % 4 != 0)
|
||||
throw std::runtime_error("unsupported fp32 -> fp8 conversion");
|
||||
auto x_idxs = idxs_.at(x);
|
||||
auto op_idxs = idxs_.at(op);
|
||||
|
||||
// run the conversion
|
||||
auto cvt = [&](Value* a, Value* b, Value* c, Value* d){
|
||||
if(op_sca_ty->is_float_ty() && ret_sca_ty->is_fp8_ty())
|
||||
if(op_sca_ty->is_fp32_ty() && ret_sca_ty->is_fp8_ty())
|
||||
return fp32x4_to_fp8x4(a, b, c, d);
|
||||
if(op_sca_ty->is_fp8_ty() && ret_sca_ty->is_half_ty())
|
||||
if(op_sca_ty->is_fp8_ty() && ret_sca_ty->is_fp16_ty())
|
||||
return fp8x4_to_fp16x4(a, b, c, d);
|
||||
throw std::runtime_error("unsupported conversion");
|
||||
};
|
||||
@@ -494,6 +510,19 @@ void generator::visit_cast_inst(ir::cast_inst* x) {
|
||||
return;
|
||||
}
|
||||
|
||||
// <> BF16
|
||||
if(ret_sca_ty->is_bf16_ty() || op_sca_ty->is_bf16_ty()){
|
||||
// FP32 -> BF16
|
||||
if(op_sca_ty->is_fp32_ty())
|
||||
for(size_t i = 0; i < x_idxs.size(); i++)
|
||||
vals_[x][x_idxs[i + 0]] = fp32_to_bf16(vals_[op][op_idxs[i + 0]]);
|
||||
// BF16 -> FP32
|
||||
if(ret_sca_ty->is_fp32_ty())
|
||||
for(size_t i = 0; i < x_idxs.size(); i++)
|
||||
vals_[x][x_idxs[i + 0]] = bf16_to_fp32(vals_[op][op_idxs[i + 0]]);
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
Type *ty = cvt(x->get_type()->get_scalar_ty());
|
||||
auto cvt = [](ir::cast_op_t op){
|
||||
@@ -675,7 +704,6 @@ void generator::visit_load_inst(ir::load_inst* x){
|
||||
curr = extract_val(_ret, {ii});
|
||||
else
|
||||
curr = _ret;
|
||||
// std::cout << n_words << " " << vec << " " << width << " " << dtsize << " " << nbits << std::endl;
|
||||
rets.push_back(bit_cast(curr, vec_ty(ty, width / (dtsize*8))));
|
||||
}
|
||||
int tmp = (width / (dtsize * 8));
|
||||
@@ -694,6 +722,7 @@ void generator::visit_masked_load_inst(ir::masked_load_inst* x) {
|
||||
/**
|
||||
* \brief Code Generation for a (synchronous) `store`
|
||||
*/
|
||||
|
||||
void generator::visit_store_inst(ir::store_inst * x){
|
||||
ir::masked_store_inst *mx = dynamic_cast<ir::masked_store_inst*>(x);
|
||||
// operands
|
||||
@@ -740,6 +769,7 @@ void generator::visit_masked_store_inst(ir::masked_store_inst* x) {
|
||||
visit_store_inst(x);
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* \brief Code Generation for `reshape`
|
||||
*/
|
||||
@@ -901,7 +931,7 @@ void generator::visit_atomic_rmw_inst(ir::atomic_rmw_inst *atom) {
|
||||
int ld = ords_.at(ptr)[0];
|
||||
unsigned alignment = alignment_->get(ptr, ld);
|
||||
vec = std::min<int>(layouts_->get(ptr)->to_scanline()->nts(ld), alignment);
|
||||
vec = std::min(vec, val->get_type()->get_tile_element_ty()->is_half_ty() ? 2 : 1);
|
||||
vec = std::min(vec, val->get_type()->get_tile_element_ty()->is_fp16_ty() ? 2 : 1);
|
||||
}
|
||||
|
||||
for(int i = 0; i < idxs_.at(val).size(); i += vec){
|
||||
@@ -1105,10 +1135,10 @@ void generator::visit_mma884(ir::dot_inst* C, ir::value *A, ir::value *B, ir::va
|
||||
|
||||
ir::phi_node* phiA = dynamic_cast<ir::phi_node*>(A);
|
||||
ir::phi_node* phiB = dynamic_cast<ir::phi_node*>(B);
|
||||
|
||||
|
||||
// Cache lds value. If values are prefetched, create phi node
|
||||
// @param inc: incoming block (0 = header, 1 = loop)
|
||||
auto register_lds =
|
||||
auto register_lds =
|
||||
[&](decltype(has)& vals, int m, int K, int inc, Value* val0, Value *val1, bool is_prefetch) {
|
||||
if (K == 0 && is_prefetch) {
|
||||
ir::basic_block* inc_block = phiA->get_incoming_block(inc);
|
||||
@@ -1208,7 +1238,7 @@ void generator::visit_mma884(ir::dot_inst* C, ir::value *A, ir::value *B, ir::va
|
||||
load_a(m, 0, 0, true);
|
||||
for (unsigned n = 0; n < num_n/2; n += is_b_row?2:1)
|
||||
load_b(n, 0, 0, true);
|
||||
|
||||
|
||||
// update accumulators
|
||||
builder_->SetInsertPoint(curr_bb);
|
||||
for (unsigned K = 0; K < NK; K += 4) {
|
||||
@@ -1225,7 +1255,7 @@ void generator::visit_mma884(ir::dot_inst* C, ir::value *A, ir::value *B, ir::va
|
||||
call_mma(m, n, K);
|
||||
}
|
||||
}
|
||||
} else { // not prefetched
|
||||
} else { // not prefetched
|
||||
for(unsigned K = 0; K < NK; K += 4)
|
||||
for(unsigned m = 0; m < num_m/2; m++)
|
||||
for(unsigned n = 0; n < num_n/2; n++) {
|
||||
@@ -1356,7 +1386,7 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir::
|
||||
"{$0, $1, $2, $3}, "
|
||||
"{$4, $5, $6, $7}, "
|
||||
"{$8, $9}, "
|
||||
"{$10, $11, $12, $13};",
|
||||
"{$10, $11, $12, $13};",
|
||||
"=f,=f,=f,=f,r,r,r,r,r,r,0,1,2,3", true);
|
||||
|
||||
unsigned num_rep_0 = shapes[0] / layout->spt(0);
|
||||
@@ -1416,8 +1446,8 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir::
|
||||
int step_am = is_a_row ? m : m / (num_ptr_a)*(num_ptr_a);
|
||||
int step_ak = is_a_row ? K / (num_ptr_a*16)*(num_ptr_a*16) : K;
|
||||
InlineAsm *ld_a0_fn = InlineAsm::get(ld_x4_ty, "ldmatrix.sync.aligned.m8n8.x4" + a_trans + ".shared.b16 "
|
||||
"{$0, $1, $2, $3}, [$4 + " +
|
||||
std::to_string(2*step_am*16*layout->wpt(0)*stride_a_m + 2*step_ak*stride_a_k) + "];",
|
||||
"{$0, $1, $2, $3}, [$4 + " +
|
||||
std::to_string(2*step_am*16*layout->wpt(0)*stride_a_m + 2*step_ak*stride_a_k) + "];",
|
||||
"=r,=r,=r,=r,r", true);
|
||||
Value *haa = call(ld_x4_ty, ld_a0_fn, {ptra});
|
||||
if(K == 0 && inc == 1 && is_prefetch)
|
||||
@@ -1444,8 +1474,8 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir::
|
||||
int step_bn = is_b_row ? n / (num_ptr_b)*(num_ptr_b) : n;
|
||||
int step_bk = is_b_row ? K : K / (num_ptr_b*8)*(num_ptr_b*8);
|
||||
InlineAsm *ld_b_fn = InlineAsm::get(ld_x4_ty, "ldmatrix.sync.aligned.m8n8.x4" + b_trans + ".shared.b16 "
|
||||
"{$0, $1, $2, $3}, [$4 + " +
|
||||
std::to_string(2*step_bn*8*layout->wpt(1)*stride_b_n + 2*step_bk*stride_b_k) + "];",
|
||||
"{$0, $1, $2, $3}, [$4 + " +
|
||||
std::to_string(2*step_bn*8*layout->wpt(1)*stride_b_n + 2*step_bk*stride_b_k) + "];",
|
||||
"=r,=r,=r,=r,r", true);
|
||||
Value *hbb = call(ld_x4_ty, ld_b_fn, {ptrb});
|
||||
if(K == 0 && inc == 1 && is_prefetch)
|
||||
@@ -2058,7 +2088,7 @@ void generator::visit_prefetch_s_inst(ir::prefetch_s_inst *i) {
|
||||
} else {
|
||||
// If dot has been visitied, insert prefetched lds
|
||||
assert(inc == 1);
|
||||
assert(prefetch_latch_to_bb_.find(v) != prefetch_latch_to_bb_.end() &&
|
||||
assert(prefetch_latch_to_bb_.find(v) != prefetch_latch_to_bb_.end() &&
|
||||
"dot hasn't be visited");
|
||||
// sink lds & extract element
|
||||
// move lds & all uses to current location
|
||||
@@ -2081,7 +2111,7 @@ void generator::visit_prefetch_s_inst(ir::prefetch_s_inst *i) {
|
||||
assert(m_instr->getParent() == &*builder_->GetInsertBlock());
|
||||
builder_->SetInsertPoint(m_instr->getParent());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void generator::visit_async_wait_inst(ir::async_wait_inst* i) {
|
||||
@@ -2384,7 +2414,7 @@ void generator::visit_layout_shared(analysis::shared_layout* layout) {
|
||||
builder_->SetInsertPoint(parent);
|
||||
else if (const Instruction *first_non_phi = &*parent->getFirstNonPHI()) {
|
||||
builder_->SetInsertPoint(&*parent->getFirstNonPHI());
|
||||
} else
|
||||
} else
|
||||
builder_->SetInsertPoint(parent);
|
||||
|
||||
// create smem_idx
|
||||
@@ -2507,7 +2537,7 @@ void generator::finalize_shared_layout(analysis::shared_layout *shared) {
|
||||
Value *idx = smem_idx[shared];
|
||||
builder_->SetInsertPoint(bbs_.at(n_buffer->phi->get_parent())->getTerminator());
|
||||
Value *cond = icmp_eq(idx, i32(shared->get_num_stages()-1));
|
||||
PHINode *_ret = phi(i32_ty, 2);
|
||||
PHINode *_ret = phi(i32_ty, 2);
|
||||
Instruction *then_term = nullptr;
|
||||
Instruction *else_term = nullptr;
|
||||
Instruction *dummy = builder_->CreateRet(nullptr);
|
||||
@@ -2544,7 +2574,7 @@ void generator::finalize_shared_layout(analysis::shared_layout *shared) {
|
||||
if (auto curr_ptr = dyn_cast<PHINode>(shared_ptr_[shared])) {
|
||||
curr_ptr->addIncoming(shared_pre_ptr_[shared], header);
|
||||
curr_ptr->addIncoming(shared_next_ptr_[shared], loop);
|
||||
} else
|
||||
} else
|
||||
throw std::runtime_error("Should be PHINode");
|
||||
|
||||
BasicBlock *current = builder_->GetInsertBlock();
|
||||
|
Reference in New Issue
Block a user