save changes
This commit is contained in:
@@ -2974,18 +2974,18 @@ void generator::visit_dot_inst(ir::dot_inst* dot) {
|
||||
unsigned NK = A_shapes[red_axis];
|
||||
bool is_outer = NK == 1;
|
||||
#ifdef USE_ROCM
|
||||
bool is_mma = layouts_->get(dot)->to_mma();
|
||||
#else
|
||||
bool is_mma = false;
|
||||
#else
|
||||
bool is_mma = layouts_->get(dot)->to_mma();
|
||||
#endif
|
||||
if(!is_outer && is_mma && tgt_->as_nvidia() && tgt_->as_nvidia()->sm() < 80)
|
||||
return visit_mma884(dot, A, B, D, NK);
|
||||
if(!is_outer && is_mma && tgt_->as_nvidia() && tgt_->as_nvidia()->sm() >= 80)
|
||||
return visit_mma16816(dot, A, B, D, NK); // rename it as visit_mma_v2()?
|
||||
if (dot->get_type()->get_scalar_ty()->is_fp32_ty() &&
|
||||
A->get_type()->get_scalar_ty()->is_fp32_ty())
|
||||
return visit_fmadot(dot, A, B, D, NK, c_ty, f_mul_add);
|
||||
throw std::runtime_error("dot has invalid operand type");
|
||||
// if (dot->get_type()->get_scalar_ty()->is_fp32_ty() &&
|
||||
// A->get_type()->get_scalar_ty()->is_fp32_ty())
|
||||
return visit_fmadot(dot, A, B, D, NK, c_ty, f_mul_add);
|
||||
// throw std::runtime_error("dot has invalid operand type");
|
||||
}
|
||||
|
||||
void generator::visit_trans_inst(ir::trans_inst* trans) {
|
||||
|
@@ -7,7 +7,7 @@ sudo apt install gdb -y
|
||||
|
||||
gdb -ex "set pagination off" \
|
||||
-ex "file python" \
|
||||
-ex 'run -m pytest --capture=tee-sys --verbose "python/test/unit/language/test_core.py::test_dot"' \
|
||||
-ex 'run -m pytest --capture=tee-sys --verbose "python/test/unit/language/test_core.py::test_dot[none-False-float16]"' \
|
||||
-ex "backtrace" \
|
||||
-ex "set confirm off" \
|
||||
-ex "q" \
|
||||
|
Reference in New Issue
Block a user