[BACKEND] Add bf16 & tf32 mma supports (on A100) (#426)
This commit is contained in:
@@ -577,40 +577,41 @@ instruction* downcast_inst::create(value *arg, const std::string &name, instruct
|
||||
// matmul_inst classes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
dot_inst::dot_inst(value *A, value *B, value *C, TransT AT, TransT BT,
|
||||
dot_inst::dot_inst(value *A, value *B, value *C, TransT AT, TransT BT, bool allow_tf32,
|
||||
const std::string &name, instruction *next)
|
||||
: builtin_inst(C->get_type(), INST_DOT, 3, name, next) {
|
||||
set_operand(0, A);
|
||||
set_operand(1, B);
|
||||
set_operand(2, C);
|
||||
allow_tf32_ = allow_tf32;
|
||||
}
|
||||
|
||||
instruction *dot_inst::create(value *A, value *B, value *C,
|
||||
bool AT, bool BT,
|
||||
bool AT, bool BT, bool allow_tf32,
|
||||
const std::string &name, instruction *next) {
|
||||
TransT OPA = AT ? Trans : NoTrans;
|
||||
TransT OPB = BT ? Trans : NoTrans;
|
||||
return new dot_inst(A, B, C, OPA, OPB, name, next);
|
||||
return new dot_inst(A, B, C, OPA, OPB, allow_tf32, name, next);
|
||||
}
|
||||
|
||||
instruction *dot_inst::create_nn(value *A, value *B, value *C,
|
||||
instruction *dot_inst::create_nn(value *A, value *B, value *C, bool allow_tf32,
|
||||
const std::string &name, instruction *next) {
|
||||
return new dot_inst(A, B, C, NoTrans, NoTrans, name, next);
|
||||
return new dot_inst(A, B, C, NoTrans, NoTrans, allow_tf32, name, next);
|
||||
}
|
||||
|
||||
instruction *dot_inst::create_nt(value *A, value *B, value *C,
|
||||
instruction *dot_inst::create_nt(value *A, value *B, value *C, bool allow_tf32,
|
||||
const std::string &name, instruction *next) {
|
||||
return new dot_inst(A, B, C, NoTrans, Trans, name, next);
|
||||
return new dot_inst(A, B, C, NoTrans, Trans, allow_tf32, name, next);
|
||||
}
|
||||
|
||||
instruction *dot_inst::create_tn(value *A, value *B, value *C,
|
||||
instruction *dot_inst::create_tn(value *A, value *B, value *C, bool allow_tf32,
|
||||
const std::string &name, instruction *next) {
|
||||
return new dot_inst(A, B, C, Trans, NoTrans, name, next);
|
||||
return new dot_inst(A, B, C, Trans, NoTrans, allow_tf32, name, next);
|
||||
}
|
||||
|
||||
instruction *dot_inst::create_tt(value *A, value *B, value *C,
|
||||
instruction *dot_inst::create_tt(value *A, value *B, value *C, bool allow_tf32,
|
||||
const std::string &name, instruction *next) {
|
||||
return new dot_inst(A, B, C, Trans, Trans, name, next);
|
||||
return new dot_inst(A, B, C, Trans, Trans, allow_tf32, name, next);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
Reference in New Issue
Block a user