From d767919bc139d3a7ed2128e1b05b71cfdfe54063 Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Fri, 4 Nov 2022 23:16:28 -0700 Subject: [PATCH] [OPTIMIZER] Not using MMA on FP32 when allowTF32 is false --- lib/Dialect/TritonGPU/Transforms/Combine.cpp | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/lib/Dialect/TritonGPU/Transforms/Combine.cpp b/lib/Dialect/TritonGPU/Transforms/Combine.cpp index 456ce1200..06139cecd 100644 --- a/lib/Dialect/TritonGPU/Transforms/Combine.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Combine.cpp @@ -570,6 +570,12 @@ public: auto oldRetType = dotOp.getResult().getType().cast(); if (oldRetType.getEncoding().isa()) return failure(); + auto A = dotOp.getOperand(0).getType().cast(); + auto B = dotOp.getOperand(1).getType().cast(); + if (A.getElementType().isF32() && B.getElementType().isF32() && + !dotOp.allowTF32()) + return failure(); + // get MMA encoding for the given number of warps auto retShape = oldRetType.getShape(); auto mod = op->getParentOfType();