[OPTIMIZER] Not using MMA on FP32 when allowTF32 is false
This commit is contained in:
@@ -570,6 +570,12 @@ public:
|
|||||||
auto oldRetType = dotOp.getResult().getType().cast<RankedTensorType>();
|
auto oldRetType = dotOp.getResult().getType().cast<RankedTensorType>();
|
||||||
if (oldRetType.getEncoding().isa<triton::gpu::MmaEncodingAttr>())
|
if (oldRetType.getEncoding().isa<triton::gpu::MmaEncodingAttr>())
|
||||||
return failure();
|
return failure();
|
||||||
|
auto A = dotOp.getOperand(0).getType().cast<RankedTensorType>();
|
||||||
|
auto B = dotOp.getOperand(1).getType().cast<RankedTensorType>();
|
||||||
|
if (A.getElementType().isF32() && B.getElementType().isF32() &&
|
||||||
|
!dotOp.allowTF32())
|
||||||
|
return failure();
|
||||||
|
|
||||||
// get MMA encoding for the given number of warps
|
// get MMA encoding for the given number of warps
|
||||||
auto retShape = oldRetType.getShape();
|
auto retShape = oldRetType.getShape();
|
||||||
auto mod = op->getParentOfType<mlir::ModuleOp>();
|
auto mod = op->getParentOfType<mlir::ModuleOp>();
|
||||||
|
Reference in New Issue
Block a user