[Triton-MLIR][Backend] Port FMADot conversion for DotOp (#844)

Co-authored-by: ben-zhang-609 <benzh609@gmail.com>
This commit is contained in:
Yan Chunwei
2022-11-09 12:57:50 +08:00
committed by GitHub
parent de5b84c476
commit 0c87360657
4 changed files with 240 additions and 5 deletions

View File

@@ -576,6 +576,14 @@ public:
auto oldRetType = dotOp.getResult().getType().cast<RankedTensorType>();
if (oldRetType.getEncoding().isa<triton::gpu::MmaEncodingAttr>())
return failure();
auto A = dotOp.getOperand(0).getType().cast<RankedTensorType>();
auto B = dotOp.getOperand(1).getType().cast<RankedTensorType>();
// for FMA, should retain the blocked layout.
if (A.getElementType().isF32() && B.getElementType().isF32() &&
!dotOp.allowTF32())
return failure();
// get MMA encoding for the given number of warps
auto retShape = oldRetType.getShape();
auto mod = op->getParentOfType<mlir::ModuleOp>();
@@ -629,4 +637,4 @@ public:
std::unique_ptr<Pass> mlir::createTritonGPUCombineOpsPass() {
return std::make_unique<TritonGPUCombineOpsPass>();
}
}