[OPTIMIZER] Not using MMA on FP32 when allowTF32 is false

This commit is contained in:
Phil Tillet
2022-11-04 23:16:28 -07:00
parent b39cc56f93
commit d767919bc1

View File

@@ -570,6 +570,12 @@ public:
auto oldRetType = dotOp.getResult().getType().cast<RankedTensorType>();
if (oldRetType.getEncoding().isa<triton::gpu::MmaEncodingAttr>())
return failure();
auto A = dotOp.getOperand(0).getType().cast<RankedTensorType>();
auto B = dotOp.getOperand(1).getType().cast<RankedTensorType>();
if (A.getElementType().isF32() && B.getElementType().isF32() &&
!dotOp.allowTF32())
return failure();
// get MMA encoding for the given number of warps
auto retShape = oldRetType.getShape();
auto mod = op->getParentOfType<mlir::ModuleOp>();