[BACKEND] bf16 flash-attention (#636)

This commit is contained in:
Da Yan
2022-08-27 11:40:55 +08:00
committed by GitHub
parent fe0c29b9ec
commit 210a296699

View File

@@ -2125,15 +2125,10 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir::
ldmatrix_ty = FunctionType::get(fp16x2_pack4_ty, std::vector<llvm::Type*>{smem_ptr_ty}, false); ldmatrix_ty = FunctionType::get(fp16x2_pack4_ty, std::vector<llvm::Type*>{smem_ptr_ty}, false);
phi_ty = fp16x2_ty; phi_ty = fp16x2_ty;
} else if (A_ir_ty->is_bf16_ty() && B_ir_ty->is_bf16_ty()) { } else if (A_ir_ty->is_bf16_ty() && B_ir_ty->is_bf16_ty()) {
// FIXME: We should use bf16 here. mma_ty = FunctionType::get(fp32_pack4_ty, std::vector<llvm::Type*>{bf16x2_ty, bf16x2_ty, bf16x2_ty, bf16x2_ty, bf16x2_ty, bf16x2_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty}, false);
mma_ty = FunctionType::get(fp32_pack4_ty, std::vector<llvm::Type*>{fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty}, false); smem_ptr_ty = ptr_ty(bf16_ty, 3);
smem_ptr_ty = ptr_ty(f16_ty, 3); ldmatrix_ty = FunctionType::get(bf16x2_pack4_ty, std::vector<llvm::Type*>{smem_ptr_ty}, false);
ldmatrix_ty = FunctionType::get(fp16x2_pack4_ty, std::vector<llvm::Type*>{smem_ptr_ty}, false); phi_ty = bf16x2_ty;
phi_ty = fp16x2_ty;
// mma_ty = FunctionType::get(fp32_pack4_ty, std::vector<llvm::Type*>{bf16x2_ty, bf16x2_ty, bf16x2_ty, bf16x2_ty, bf16x2_ty, bf16x2_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty}, false);
// smem_ptr_ty = ptr_ty(bf16_ty, 3);
// ldmatrix_ty = FunctionType::get(bf16x2_pack4_ty, std::vector<llvm::Type*>{smem_ptr_ty}, false);
// phi_ty = bf16x2_ty;
} else if (A_ir_ty->is_fp32_ty() && B_ir_ty->is_fp32_ty()) { } else if (A_ir_ty->is_fp32_ty() && B_ir_ty->is_fp32_ty()) {
mma_ty = FunctionType::get(fp32_pack4_ty, std::vector<llvm::Type*>{fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty}, false); mma_ty = FunctionType::get(fp32_pack4_ty, std::vector<llvm::Type*>{fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty}, false);
smem_ptr_ty = ptr_ty(fp32_ty, 3); smem_ptr_ty = ptr_ty(fp32_ty, 3);
@@ -2233,10 +2228,10 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir::
// std::cout << idxs_[A].size() << std::endl; // std::cout << idxs_[A].size() << std::endl;
// std::cout << (m+1)*ldm + k*2 + 3 << std::endl; // std::cout << (m+1)*ldm + k*2 + 3 << std::endl;
// int ldm = num_rep_k*4; // int ldm = num_rep_k*4;
Value* ha0 = UndefValue::get(fp16x2_ty); Value* ha0 = UndefValue::get(phi_ty); // e.g., fp16x2
Value* ha1 = UndefValue::get(fp16x2_ty); Value* ha1 = UndefValue::get(phi_ty);
Value* ha2 = UndefValue::get(fp16x2_ty); Value* ha2 = UndefValue::get(phi_ty);
Value* ha3 = UndefValue::get(fp16x2_ty); Value* ha3 = UndefValue::get(phi_ty);
ha0 = builder_->CreateInsertElement(ha0, vals_[A][idxs_[A][(m+0)*ldm + k*2 + 0]], i32(0)); ha0 = builder_->CreateInsertElement(ha0, vals_[A][idxs_[A][(m+0)*ldm + k*2 + 0]], i32(0));
ha0 = builder_->CreateInsertElement(ha0, vals_[A][idxs_[A][(m+0)*ldm + k*2 + 1]], i32(1)); ha0 = builder_->CreateInsertElement(ha0, vals_[A][idxs_[A][(m+0)*ldm + k*2 + 1]], i32(1));
ha1 = builder_->CreateInsertElement(ha1, vals_[A][idxs_[A][(m+1)*ldm + k*2 + 0]], i32(0)); ha1 = builder_->CreateInsertElement(ha1, vals_[A][idxs_[A][(m+1)*ldm + k*2 + 0]], i32(0));