[Triton-MLIR][Backend] Port FMADot conversion for DotOp (#844)
Co-authored-by: ben-zhang-609 <benzh609@gmail.com>
This commit is contained in:
@@ -576,6 +576,14 @@ public:
|
||||
auto oldRetType = dotOp.getResult().getType().cast<RankedTensorType>();
|
||||
if (oldRetType.getEncoding().isa<triton::gpu::MmaEncodingAttr>())
|
||||
return failure();
|
||||
|
||||
auto A = dotOp.getOperand(0).getType().cast<RankedTensorType>();
|
||||
auto B = dotOp.getOperand(1).getType().cast<RankedTensorType>();
|
||||
// for FMA, should retain the blocked layout.
|
||||
if (A.getElementType().isF32() && B.getElementType().isF32() &&
|
||||
!dotOp.allowTF32())
|
||||
return failure();
|
||||
|
||||
// get MMA encoding for the given number of warps
|
||||
auto retShape = oldRetType.getShape();
|
||||
auto mod = op->getParentOfType<mlir::ModuleOp>();
|
||||
@@ -629,4 +637,4 @@ public:
|
||||
|
||||
std::unique_ptr<Pass> mlir::createTritonGPUCombineOpsPass() {
|
||||
return std::make_unique<TritonGPUCombineOpsPass>();
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user