[BACKEND] Add bf16 & tf32 mma supports (on A100) (#426)
This commit is contained in:
@@ -80,7 +80,7 @@ bool peephole::rewrite_dot(ir::instruction *value, ir::builder& builder){
|
||||
ir::value *a = dot->get_operand(0);
|
||||
ir::value *b = dot->get_operand(1);
|
||||
builder.set_insert_point(add);
|
||||
ir::value * new_dot = builder.insert(ir::dot_inst::create_nn(a, b, other, dot->get_name()));
|
||||
ir::value * new_dot = builder.insert(ir::dot_inst::create_nn(a, b, other, dot->allow_tf32(), dot->get_name()));
|
||||
add->replace_all_uses_with(new_dot);
|
||||
return true;
|
||||
}
|
||||
|
@@ -29,8 +29,13 @@ void prefetch::run(ir::module &mod) {
|
||||
std::vector<ir::dot_inst*> to_prefetch;
|
||||
ir::for_each_instruction(mod, [&](ir::instruction *i) {
|
||||
if (auto *dot = dynamic_cast<ir::dot_inst*>(i)) {
|
||||
// Now only do prefetching when dot is fp16
|
||||
if (dot->get_operand(0)->get_type()->get_scalar_ty()->get_type_id() != ir::type::FP16TyID)
|
||||
// Now only do prefetching when dot is using tensor cores
|
||||
if (!(dot->get_operand(0)->get_type()->get_scalar_ty()->is_fp16_ty() ||
|
||||
dot->get_operand(0)->get_type()->get_scalar_ty()->is_bf16_ty() ||
|
||||
(dot->get_operand(0)->get_type()->get_scalar_ty()->is_fp32_ty() && dot->allow_tf32()
|
||||
&& tgt_->as_nvidia() && tgt_->as_nvidia()->sm() >= 80)
|
||||
)
|
||||
)
|
||||
return;
|
||||
auto *a = dynamic_cast<ir::phi_node*>(dot->get_operand(0));
|
||||
auto *b = dynamic_cast<ir::phi_node*>(dot->get_operand(1));
|
||||
|
Reference in New Issue
Block a user