Fix v100 fp32
This commit is contained in:
@@ -110,6 +110,19 @@ bool maybeAliasOp(Operation *op) {
|
||||
isa<tensor::InsertSliceOp>(op);
|
||||
}
|
||||
|
||||
bool supportMMA(triton::DotOp op, int version) {
|
||||
// Refer to mma section for the data type supported by Volta and Hopper
|
||||
// Tensor Core in
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-fragment-mma-884-f16
|
||||
auto aElemTy = op.a().getType().cast<RankedTensorType>().getElementType();
|
||||
auto bElemTy = op.b().getType().cast<RankedTensorType>().getElementType();
|
||||
return (aElemTy.isF16() && bElemTy.isF16()) ||
|
||||
(aElemTy.isBF16() && bElemTy.isBF16()) ||
|
||||
(aElemTy.isF32() && bElemTy.isF32() && op.allowTF32() &&
|
||||
version >= 2) ||
|
||||
(aElemTy.isInteger(8) && bElemTy.isInteger(8) && version >= 2);
|
||||
}
|
||||
|
||||
std::string getValueOperandName(Value value, AsmState &state) {
|
||||
std::string opName;
|
||||
llvm::raw_string_ostream ss(opName);
|
||||
|
Reference in New Issue
Block a user