[FRONTEND][BACKEND] Added trans
instruction; made flash attention bwd pass work (#943)
This commit is contained in:
@@ -178,6 +178,10 @@ public:
|
||||
!isSharedEncoding(convert.getResult())) {
|
||||
return mlir::failure();
|
||||
}
|
||||
if (isSharedEncoding(convert.getOperand()) &&
|
||||
isSharedEncoding(convert.getResult())) {
|
||||
return mlir::failure();
|
||||
}
|
||||
auto srcType = convert.getOperand().getType().cast<RankedTensorType>();
|
||||
auto srcShared =
|
||||
srcType.getEncoding().dyn_cast<triton::gpu::SharedEncodingAttr>();
|
||||
@@ -661,6 +665,54 @@ SmallVector<unsigned, 2> warpsPerTileV2(triton::DotOp dotOp,
|
||||
|
||||
} // namespace
|
||||
|
||||
class OptimizeBlockedToShared : public mlir::RewritePattern {
|
||||
public:
|
||||
OptimizeBlockedToShared(mlir::MLIRContext *context)
|
||||
: RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), 1,
|
||||
context) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::Operation *op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
auto cvt = cast<triton::gpu::ConvertLayoutOp>(op);
|
||||
auto srcType = cvt.getOperand().getType().cast<RankedTensorType>();
|
||||
auto dstType = cvt.getResult().getType().cast<RankedTensorType>();
|
||||
auto srcBlockedLayout =
|
||||
srcType.getEncoding().dyn_cast<triton::gpu::BlockedEncodingAttr>();
|
||||
auto dstSharedLayout =
|
||||
dstType.getEncoding().dyn_cast<triton::gpu::SharedEncodingAttr>();
|
||||
if (!srcBlockedLayout || !dstSharedLayout)
|
||||
return failure();
|
||||
if (srcBlockedLayout.getOrder() == dstSharedLayout.getOrder())
|
||||
return failure();
|
||||
// For now only works if single use is transpose
|
||||
// TODO: rematerialize #shared uses
|
||||
auto users = op->getUsers();
|
||||
if (std::distance(users.begin(), users.end()) != 1 ||
|
||||
!isa<triton::TransOp>(*users.begin()))
|
||||
return failure();
|
||||
|
||||
auto tmpShared = triton::gpu::SharedEncodingAttr::get(
|
||||
op->getContext(), dstSharedLayout.getVec(),
|
||||
dstSharedLayout.getPerPhase(), dstSharedLayout.getMaxPhase(),
|
||||
srcBlockedLayout.getOrder());
|
||||
auto tmpType = RankedTensorType::get(srcType.getShape(),
|
||||
srcType.getElementType(), tmpShared);
|
||||
auto tmpCvt = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||
op->getLoc(), tmpType, cvt.getOperand());
|
||||
|
||||
auto newDstType = RankedTensorType::get(
|
||||
users.begin()->getResultTypes()[0].cast<RankedTensorType>().getShape(),
|
||||
srcType.getElementType(), dstSharedLayout);
|
||||
|
||||
auto newTrans = rewriter.create<triton::TransOp>(op->getLoc(), newDstType,
|
||||
tmpCvt.getResult());
|
||||
|
||||
rewriter.replaceOp(*users.begin(), newTrans.getResult());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
class BlockedToMMA : public mlir::RewritePattern {
|
||||
int computeCapability;
|
||||
|
||||
@@ -755,6 +807,7 @@ public:
|
||||
|
||||
mlir::RewritePatternSet patterns(context);
|
||||
|
||||
patterns.add<OptimizeBlockedToShared>(context);
|
||||
patterns.add<SimplifyConversion>(context);
|
||||
patterns.add<DecomposeDotOperand>(context);
|
||||
patterns.add<RematerializeBackward>(context);
|
||||
|
Reference in New Issue
Block a user