diff --git a/lib/Dialect/TritonGPU/Transforms/OptimizeLoadConvert.cpp b/lib/Dialect/TritonGPU/Transforms/OptimizeLoadConvert.cpp new file mode 100644 index 000000000..f49fe7f0a --- /dev/null +++ b/lib/Dialect/TritonGPU/Transforms/OptimizeLoadConvert.cpp @@ -0,0 +1,104 @@ +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/SCF/SCF.h" +#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Verifier.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" +#include "mlir/Transforms/RegionUtils.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h" +#define GEN_PASS_CLASSES +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +using namespace mlir; + + +class LoadConvertToInsertSlice : public mlir::RewritePattern{ + +public: + explicit LoadConvertToInsertSlice(mlir::MLIRContext *context) + : mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), 2, context) {} + + mlir::LogicalResult + matchAndRewrite(mlir::Operation *op, + mlir::PatternRewriter &rewriter) const override { + auto cvt = cast(op); + auto origRetType = cvt.getResult().getType().cast(); + auto shape = origRetType.getShape(); + auto eltType = origRetType.getElementType(); + auto dotOpEncoding = origRetType.getEncoding().dyn_cast(); + if(!dotOpEncoding){ + return failure(); + } + auto cvtArg = cvt.getOperand().getDefiningOp(); + if(!cvtArg) + return failure(); + auto loadOp = dyn_cast(*cvtArg); + if(!loadOp){ + return failure(); + } + auto blockedEncoding = loadOp.getType().cast().getEncoding().dyn_cast(); + if(!blockedEncoding) + return failure(); + auto sharedEncoding = triton::gpu::SharedEncodingAttr::get(getContext(), dotOpEncoding, shape, + blockedEncoding.getOrder(), eltType); + auto srcTy = RankedTensorType::get({1, shape[0], shape[1]}, + eltType, + sharedEncoding); + auto loadTensor = rewriter.create(op->getLoc(), srcTy); + + auto newOp = rewriter.create( + op->getLoc(), loadTensor.getType(), + loadOp.ptr(), + loadTensor, rewriter.create(op->getLoc(), 0, 32), + loadOp.mask(), + loadOp.other(), loadOp.cache(), + loadOp.evict(), loadOp.isVolatile(), /*axis*/ 0); + + rewriter.create(op->getLoc(), 0); + auto tmpType = RankedTensorType::get({shape[0], shape[1]}, eltType, sharedEncoding); + auto _0 = rewriter.getI64IntegerAttr(0); + auto _1 = rewriter.getI64IntegerAttr(1); + auto tmp = rewriter.create(op->getLoc(), tmpType, newOp, + SmallVector{_0, _0, _0}, + SmallVector{_1, + rewriter.getI64IntegerAttr(shape[0]), + rewriter.getI64IntegerAttr(shape[1])}, + SmallVector{_1, _1, _1}); + rewriter.replaceOpWithNewOp(op, origRetType, tmp); + return success(); + + } + +}; + +class TritonGPUOptimizeLoadConvertPass + : public TritonGPUOptimizeLoadConvertBase { +public: + TritonGPUOptimizeLoadConvertPass() = default; + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp m = getOperation(); + + mlir::RewritePatternSet patterns(context); + patterns.add(context); + if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) { + signalPassFailure(); + } + } +}; + +std::unique_ptr +mlir::createTritonGPUOptimizeLoadConvertPass() { + return std::make_unique(); +}