This commit is contained in:
Phil Tillet
2022-12-28 13:42:43 -08:00
parent 7aba2a60d6
commit 54ae3e8d6e
7 changed files with 47 additions and 82 deletions

View File

@@ -19,7 +19,6 @@
#include <memory>
#define int_attr(num) rewriter.getI64IntegerAttr(num)
using namespace mlir;
namespace {
@@ -1155,60 +1154,6 @@ public:
}
};
class LoadConvertToInsertSlice : public mlir::RewritePattern{
public:
explicit LoadConvertToInsertSlice(mlir::MLIRContext *context)
: mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), 2, context) {}
mlir::LogicalResult
matchAndRewrite(mlir::Operation *op,
mlir::PatternRewriter &rewriter) const override {
auto cvt = cast<triton::gpu::ConvertLayoutOp>(op);
auto origRetType = cvt.getResult().getType().cast<RankedTensorType>();
auto shape = origRetType.getShape();
auto eltType = origRetType.getElementType();
auto dotOpEncoding = origRetType.getEncoding().dyn_cast<triton::gpu::DotOperandEncodingAttr>();
if(!dotOpEncoding){
return failure();
}
auto loadOp = dyn_cast<triton::LoadOp>(*cvt.getOperand().getDefiningOp());
if(!loadOp){
return failure();
}
auto blockedEncoding = loadOp.getType().cast<RankedTensorType>().getEncoding().dyn_cast<triton::gpu::BlockedEncodingAttr>();
if(!blockedEncoding)
return failure();
auto sharedEncoding = triton::gpu::SharedEncodingAttr::get(getContext(), dotOpEncoding, shape,
blockedEncoding.getOrder(), eltType);
auto srcTy = RankedTensorType::get({1, shape[0], shape[1]},
eltType,
sharedEncoding);
auto loadTensor = rewriter.create<triton::gpu::AllocTensorOp>(op->getLoc(), srcTy);
auto newOp = rewriter.create<triton::gpu::InsertSliceAsyncOp>(
op->getLoc(), loadTensor.getType(),
loadOp.ptr(),
loadTensor, rewriter.create<arith::ConstantIntOp>(op->getLoc(), 0, 32),
loadOp.mask(),
loadOp.other(), loadOp.cache(),
loadOp.evict(), loadOp.isVolatile(), /*axis*/ 0);
rewriter.create<triton::gpu::AsyncWaitOp>(op->getLoc(), 0);
auto tmpType = RankedTensorType::get({shape[0], shape[1]}, eltType, sharedEncoding);
auto tmp = rewriter.create<tensor::ExtractSliceOp>(op->getLoc(), tmpType, newOp,
SmallVector<OpFoldResult>{int_attr(0), int_attr(0), int_attr(0)},
SmallVector<OpFoldResult>{int_attr(1),
int_attr(shape[0]),
int_attr(shape[1])},
SmallVector<OpFoldResult>{int_attr(1), int_attr(1), int_attr(1)});
rewriter.replaceOpWithNewOp<triton::gpu::ConvertLayoutOp>(op, origRetType, tmp);
return success();
}
};
class FixupLoop : public mlir::RewritePattern {
public:
@@ -1280,7 +1225,6 @@ public:
patterns.add<MoveConvertOutOfLoop>(context);
patterns.add<MoveConvertOutOfIf>(context);
patterns.add<BlockedToMMA>(context, computeCapability);
patterns.add<LoadConvertToInsertSlice>(context);
if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) {
signalPassFailure();