[BACKEND] Add bf16 & tf32 mma supports (on A100) (#426)

This commit is contained in:
daadaada
2022-01-12 02:20:31 +08:00
committed by GitHub
parent efdabe6073
commit 94a2e10fe5
17 changed files with 717 additions and 263 deletions

View File

@@ -19,6 +19,7 @@ void swizzle::run(ir::module &) {
continue;
ir::value* mma_dot_a = layout->hmma_dot_a();
ir::value* mma_dot_b = layout->hmma_dot_b();
if(!mma_dot_a && !mma_dot_b){
per_phase_[layout] = 1;
max_phase_[layout] = 1;
@@ -39,10 +40,10 @@ void swizzle::run(ir::module &) {
else
vec_[layout] = 2*layouts_->get(mma_dot_b)->to_mma()->rep(1);
}
else{
per_phase_[layout] = std::max<int>(128 / (in_layout->mts(ord[0])*in_layout->nts(ord[0])*dtsize), 1);
max_phase_[layout] = 8 / per_phase_[layout];
vec_[layout] = 8;
else {
per_phase_[layout] = std::max<int>(128 / (in_layout->mts(ord[0])*in_layout->nts(ord[0])*dtsize), 1);
max_phase_[layout] = layout->get_mma_strided() / per_phase_[layout];
vec_[layout] = layout->get_mma_vec();
}
}
}