#include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Dialect/SCF/SCF.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/Verifier.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/Passes.h" #include "mlir/Transforms/RegionUtils.h" #include "triton/Analysis/Utility.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/Transforms/Passes.h" #include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h" #define GEN_PASS_CLASSES #include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" using namespace mlir; class TritonGPUDecomposeConversionsToDotOperandPass : public TritonGPUDecomposeConversionsToDotOperandBase< TritonGPUDecomposeConversionsToDotOperandPass> { public: TritonGPUDecomposeConversionsToDotOperandPass() = default; void runOnOperation() override { MLIRContext *context = &getContext(); ModuleOp mod = getOperation(); mod.walk([&](triton::gpu::ConvertLayoutOp cvtOp) -> void { OpBuilder builder(cvtOp); auto srcType = cvtOp.getOperand().getType().cast(); auto dstType = cvtOp.getType().cast(); auto srcEncoding = srcType.getEncoding(); if(srcEncoding.isa()) return; auto dstDotOp = dstType.getEncoding().dyn_cast(); if (dstDotOp) { auto tmpType = RankedTensorType::get( dstType.getShape(), dstType.getElementType(), triton::gpu::SharedEncodingAttr::get( mod.getContext(), dstDotOp, srcType.getShape(), triton::gpu::getOrder(srcEncoding), srcType.getElementType())); auto tmp = builder.create( cvtOp.getLoc(), tmpType, cvtOp.getOperand()); auto newConvert = builder.create( cvtOp.getLoc(), dstType, tmp); cvtOp.replaceAllUsesWith(newConvert.getResult()); cvtOp.erase(); } }); } }; std::unique_ptr mlir::createTritonGPUDecomposeConversionsToDotOperandPass() { return std::make_unique(); }