fix 6/7 dot tests

This commit is contained in:
Michael Melesse
2022-11-01 14:18:06 +00:00
parent 4f3e2d6ed7
commit 4fb9d4904e
3 changed files with 15 additions and 7 deletions

View File

@@ -2973,19 +2973,20 @@ void generator::visit_dot_inst(ir::dot_inst* dot) {
size_t red_axis = 1;
unsigned NK = A_shapes[red_axis];
bool is_outer = NK == 1;
#ifdef USE_ROCM
bool is_mma = false;
return visit_fmadot(dot, A, B, D, NK, c_ty, f_mul_add);
#else
bool is_mma = layouts_->get(dot)->to_mma();
#endif
if(!is_outer && is_mma && tgt_->as_nvidia() && tgt_->as_nvidia()->sm() < 80)
return visit_mma884(dot, A, B, D, NK);
if(!is_outer && is_mma && tgt_->as_nvidia() && tgt_->as_nvidia()->sm() >= 80)
return visit_mma16816(dot, A, B, D, NK); // rename it as visit_mma_v2()?
// if (dot->get_type()->get_scalar_ty()->is_fp32_ty() &&
// A->get_type()->get_scalar_ty()->is_fp32_ty())
return visit_fmadot(dot, A, B, D, NK, c_ty, f_mul_add);
// throw std::runtime_error("dot has invalid operand type");
if (dot->get_type()->get_scalar_ty()->is_fp32_ty() &&
A->get_type()->get_scalar_ty()->is_fp32_ty())
return visit_fmadot(dot, A, B, D, NK, c_ty, f_mul_add);
throw std::runtime_error("dot has invalid operand type");
#endif
}
void generator::visit_trans_inst(ir::trans_inst* trans) {