testing things...

This commit is contained in:
Phil Tillet
2022-12-09 19:31:34 -08:00
parent fa6dbbff60
commit 58d2867fe6
6 changed files with 105 additions and 30 deletions

View File

@@ -713,9 +713,9 @@ public:
}
};
class OptimizeBlockedToDotOperand : public mlir::RewritePattern {
class OptimizeConvertToDotOperand : public mlir::RewritePattern {
public:
OptimizeBlockedToDotOperand(mlir::MLIRContext *context)
OptimizeConvertToDotOperand(mlir::MLIRContext *context)
: RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), 1,
context) {}
@@ -725,18 +725,27 @@ public:
auto cvt = cast<triton::gpu::ConvertLayoutOp>(op);
auto srcType = cvt.getOperand().getType().cast<RankedTensorType>();
auto dstType = cvt.getResult().getType().cast<RankedTensorType>();
auto srcBlockedLayout =
srcType.getEncoding().dyn_cast<triton::gpu::BlockedEncodingAttr>();
// order
ArrayRef<unsigned> order;
if(auto srcBlockedLayout =
srcType.getEncoding().dyn_cast<triton::gpu::BlockedEncodingAttr>())
order = srcBlockedLayout.getOrder();
else if(auto srcSharedLayout =
srcType.getEncoding().dyn_cast<triton::gpu::SharedEncodingAttr>())
order = srcSharedLayout.getOrder();
else
return failure();
// dot operand output
auto dstDotOperandLayout =
dstType.getEncoding().dyn_cast<triton::gpu::DotOperandEncodingAttr>();
if (!srcBlockedLayout || !dstDotOperandLayout)
if (!dstDotOperandLayout)
return failure();
unsigned opIdx = dstDotOperandLayout.getOpIdx();
if(!dstDotOperandLayout.getIsMMAv1Row())
return failure();
bool isMMAv1Row = dstDotOperandLayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
if((srcBlockedLayout.getOrder()[0] == 1 && isMMAv1Row) ||
(srcBlockedLayout.getOrder()[0] == 0 && !isMMAv1Row))
if((order[0] == 1 && isMMAv1Row) ||
(order[0] == 0 && !isMMAv1Row))
return failure();
auto newIsRow = BoolAttr::get(op->getContext(), !isMMAv1Row);
auto newDstEncoding = triton::gpu::DotOperandEncodingAttr::get(
@@ -862,7 +871,7 @@ public:
mlir::RewritePatternSet patterns(context);
patterns.add<OptimizeBlockedToShared>(context);
patterns.add<OptimizeBlockedToDotOperand>(context);
// patterns.add<OptimizeConvertToDotOperand>(context);
patterns.add<SimplifyConversion>(context);
patterns.add<DecomposeDotOperand>(context);
patterns.add<RematerializeBackward>(context);
@@ -873,6 +882,7 @@ public:
if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) {
signalPassFailure();
}
}
};