diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index d6d059859..d4829e680 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -2125,15 +2125,10 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir:: ldmatrix_ty = FunctionType::get(fp16x2_pack4_ty, std::vector{smem_ptr_ty}, false); phi_ty = fp16x2_ty; } else if (A_ir_ty->is_bf16_ty() && B_ir_ty->is_bf16_ty()) { - // FIXME: We should use bf16 here. - mma_ty = FunctionType::get(fp32_pack4_ty, std::vector{fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty}, false); - smem_ptr_ty = ptr_ty(f16_ty, 3); - ldmatrix_ty = FunctionType::get(fp16x2_pack4_ty, std::vector{smem_ptr_ty}, false); - phi_ty = fp16x2_ty; - // mma_ty = FunctionType::get(fp32_pack4_ty, std::vector{bf16x2_ty, bf16x2_ty, bf16x2_ty, bf16x2_ty, bf16x2_ty, bf16x2_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty}, false); - // smem_ptr_ty = ptr_ty(bf16_ty, 3); - // ldmatrix_ty = FunctionType::get(bf16x2_pack4_ty, std::vector{smem_ptr_ty}, false); - // phi_ty = bf16x2_ty; + mma_ty = FunctionType::get(fp32_pack4_ty, std::vector{bf16x2_ty, bf16x2_ty, bf16x2_ty, bf16x2_ty, bf16x2_ty, bf16x2_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty}, false); + smem_ptr_ty = ptr_ty(bf16_ty, 3); + ldmatrix_ty = FunctionType::get(bf16x2_pack4_ty, std::vector{smem_ptr_ty}, false); + phi_ty = bf16x2_ty; } else if (A_ir_ty->is_fp32_ty() && B_ir_ty->is_fp32_ty()) { mma_ty = FunctionType::get(fp32_pack4_ty, std::vector{fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty}, false); smem_ptr_ty = ptr_ty(fp32_ty, 3); @@ -2233,10 +2228,10 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir:: // std::cout << idxs_[A].size() << std::endl; // std::cout << (m+1)*ldm + k*2 + 3 << std::endl; // int ldm = num_rep_k*4; - Value* ha0 = UndefValue::get(fp16x2_ty); - Value* ha1 = UndefValue::get(fp16x2_ty); - Value* ha2 = UndefValue::get(fp16x2_ty); - Value* ha3 = UndefValue::get(fp16x2_ty); + Value* ha0 = UndefValue::get(phi_ty); // e.g., fp16x2 + Value* ha1 = UndefValue::get(phi_ty); + Value* ha2 = UndefValue::get(phi_ty); + Value* ha3 = UndefValue::get(phi_ty); ha0 = builder_->CreateInsertElement(ha0, vals_[A][idxs_[A][(m+0)*ldm + k*2 + 0]], i32(0)); ha0 = builder_->CreateInsertElement(ha0, vals_[A][idxs_[A][(m+0)*ldm + k*2 + 1]], i32(1)); ha1 = builder_->CreateInsertElement(ha1, vals_[A][idxs_[A][(m+1)*ldm + k*2 + 0]], i32(0));