[BACKEND] Add bf16 & tf32 mma supports (on A100) (#426)

This commit is contained in:
daadaada
2022-01-12 02:20:31 +08:00
committed by GitHub
parent efdabe6073
commit 94a2e10fe5
17 changed files with 717 additions and 263 deletions

View File

@@ -577,40 +577,41 @@ instruction* downcast_inst::create(value *arg, const std::string &name, instruct
// matmul_inst classes
//===----------------------------------------------------------------------===//
dot_inst::dot_inst(value *A, value *B, value *C, TransT AT, TransT BT,
dot_inst::dot_inst(value *A, value *B, value *C, TransT AT, TransT BT, bool allow_tf32,
const std::string &name, instruction *next)
: builtin_inst(C->get_type(), INST_DOT, 3, name, next) {
set_operand(0, A);
set_operand(1, B);
set_operand(2, C);
allow_tf32_ = allow_tf32;
}
instruction *dot_inst::create(value *A, value *B, value *C,
bool AT, bool BT,
bool AT, bool BT, bool allow_tf32,
const std::string &name, instruction *next) {
TransT OPA = AT ? Trans : NoTrans;
TransT OPB = BT ? Trans : NoTrans;
return new dot_inst(A, B, C, OPA, OPB, name, next);
return new dot_inst(A, B, C, OPA, OPB, allow_tf32, name, next);
}
instruction *dot_inst::create_nn(value *A, value *B, value *C,
instruction *dot_inst::create_nn(value *A, value *B, value *C, bool allow_tf32,
const std::string &name, instruction *next) {
return new dot_inst(A, B, C, NoTrans, NoTrans, name, next);
return new dot_inst(A, B, C, NoTrans, NoTrans, allow_tf32, name, next);
}
instruction *dot_inst::create_nt(value *A, value *B, value *C,
instruction *dot_inst::create_nt(value *A, value *B, value *C, bool allow_tf32,
const std::string &name, instruction *next) {
return new dot_inst(A, B, C, NoTrans, Trans, name, next);
return new dot_inst(A, B, C, NoTrans, Trans, allow_tf32, name, next);
}
instruction *dot_inst::create_tn(value *A, value *B, value *C,
instruction *dot_inst::create_tn(value *A, value *B, value *C, bool allow_tf32,
const std::string &name, instruction *next) {
return new dot_inst(A, B, C, Trans, NoTrans, name, next);
return new dot_inst(A, B, C, Trans, NoTrans, allow_tf32, name, next);
}
instruction *dot_inst::create_tt(value *A, value *B, value *C,
instruction *dot_inst::create_tt(value *A, value *B, value *C, bool allow_tf32,
const std::string &name, instruction *next) {
return new dot_inst(A, B, C, Trans, Trans, name, next);
return new dot_inst(A, B, C, Trans, Trans, allow_tf32, name, next);
}
//===----------------------------------------------------------------------===//