From 0c5319eed9e1c85f4d1e6484affb125f02e15313 Mon Sep 17 00:00:00 2001 From: Yan Da Date: Thu, 5 May 2022 20:56:55 +0800 Subject: [PATCH] More progress on SCF type conversion --- .../TritonToTritonGPU/TritonToTritonGPU.cpp | 47 +++++++++++++++---- .../Transforms/TritonGPUConversion.cpp | 4 +- 2 files changed, 40 insertions(+), 11 deletions(-) diff --git a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp index 46ed700ba..7e4a49b16 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp @@ -4,8 +4,8 @@ #include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +// #include "mlir/IR/BlockAndValueMapping.h" #include "../PassDetail.h" -#include using namespace mlir; using namespace mlir::triton; @@ -222,16 +222,38 @@ void populateTritonPatterns( // struct SCFForPattern : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; - + // Ref: ConvertForOpTypes LogicalResult matchAndRewrite(scf::ForOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto newFor = rewriter.replaceOpWithNewOp( - op, adaptor.getLowerBound(), adaptor.getUpperBound(), adaptor.getStep(), - adaptor.getInitArgs() - ); - // TODO: we need to copy (?) the body of ForOp - llvm_unreachable("Not implemented"); - // newFor.getRegion().takeBody(adaptor.getRegion()); + SmallVector newResultTypes; + for (Type type : op.getResultTypes()) { + Type newType = typeConverter->convertType(type); + if (!newType) + return rewriter.notifyMatchFailure(op, "not a 1:1 type conversion"); + newResultTypes.push_back(newType); + } + + auto newOp = cast(rewriter.cloneWithoutRegions(*op.getOperation())); + rewriter.inlineRegionBefore(op.getLoopBody(), newOp.getLoopBody(), + newOp.getLoopBody().end()); + + // Now, update all the types. + + // Convert the type of the entry block of the ForOp's body. + if (failed(rewriter.convertRegionTypes(&newOp.getLoopBody(), + *getTypeConverter()))) { + return rewriter.notifyMatchFailure(op, "could not convert body types"); + } + // Change the clone to use the updated operands. We could have cloned with + // a BlockAndValueMapping, but this seems a bit more direct. + newOp->setOperands(adaptor.getOperands()); + // Update the result types to the new converted types. + for (auto t : llvm::zip(newOp.getResults(), newResultTypes)) + std::get<0>(t).setType(std::get<1>(t)); + + rewriter.replaceOp(op, newOp.getResults()); + return success(); + return success(); } }; @@ -241,8 +263,11 @@ struct SCFYieldPattern : public OpConversionPattern { LogicalResult matchAndRewrite(scf::YieldOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { + // rewriter.setInsertionPointToEnd(rewriter.getInsertionBlock()); + // rewriter.create(op.getLoc(), adaptor.getOperands()); + // op.erase(); rewriter.replaceOpWithNewOp( - op, adaptor.getResults() + op, adaptor.getOperands() ); return success(); } @@ -274,6 +299,8 @@ public: // add rules populateArithmeticPatternsAndLegality(typeConverter, patterns, target); populateTritonPatterns(typeConverter, patterns); + // TODO: can we use + // mlir::scf::populateSCFStructurealTypeConversionsAndLegality(...) here? populateSCFPatterns(typeConverter, patterns); if(failed(applyPartialConversion(mod, target, diff --git a/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp b/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp index 4067156dc..1d1e35a61 100644 --- a/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp +++ b/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp @@ -50,7 +50,9 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context, // materailizations addArgumentMaterialization([&](OpBuilder &builder, RankedTensorType tensorType, ValueRange inputs, Location loc) { - // llvm::errs() << "Trying to materialize target... " << inputs[0] << "\n"; + llvm::errs() << "Trying to materialize target... " << inputs[0] << "\n" + << "in: \n"; + inputs[0].dyn_cast().getOwner()->getParentOp()->getParentOp()->print(llvm::errs()); llvm_unreachable("Not implemented"); return llvm::None; });