From 53cf93ce6a2146a07a31d36d11d2c88c0dbbca06 Mon Sep 17 00:00:00 2001 From: Yan Da Date: Sat, 18 Jun 2022 14:57:41 +0800 Subject: [PATCH] Revert "Remove TypeConverter from TritonToTritonGPU conversion" This reverts commit 64d0b87ef06fc689629220c8fb0522cddcb271d0. --- .../Transforms/TritonGPUConversion.h | 3 +- .../TritonToTritonGPU/TritonToTritonGPU.cpp | 29 ++++++++++++------- .../Transforms/TritonGPUConversion.cpp | 21 +++++--------- 3 files changed, 27 insertions(+), 26 deletions(-) diff --git a/include/triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h b/include/triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h index fddcf2905..2f34d71f7 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h +++ b/include/triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h @@ -20,8 +20,9 @@ private: }; class TritonGPUConversionTarget : public ConversionTarget { + TritonGPUTypeConverter &typeConverter; public: - explicit TritonGPUConversionTarget(MLIRContext &ctx); + explicit TritonGPUConversionTarget(MLIRContext &ctx, TritonGPUTypeConverter &typeConverter); }; } // namespace mlir diff --git a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp index 81ef9e790..591af75b5 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp @@ -73,7 +73,8 @@ public: }; void populateArithmeticPatternsAndLegality( - RewritePatternSet &patterns, TritonGPUConversionTarget &target){ + TritonGPUTypeConverter& typeConverter, RewritePatternSet &patterns, + TritonGPUConversionTarget &target){ // -------------- // Add legality and rewrite pattern rules for operations // from the Arithmetic dialect. The basic premise is that @@ -127,7 +128,7 @@ void populateArithmeticPatternsAndLegality( // Cast Ops ArithGenericPattern, ArithGenericPattern - >(context); + >(typeConverter, context); } // @@ -234,7 +235,9 @@ struct TritonReducePattern : public OpConversionPattern { } }; -void populateTritonPatterns(RewritePatternSet &patterns) { +void populateTritonPatterns( + TritonGPUTypeConverter& typeConverter, RewritePatternSet &patterns +) { MLIRContext *context = patterns.getContext(); patterns.add, TritonGenericPattern, @@ -244,7 +247,7 @@ void populateTritonPatterns(RewritePatternSet &patterns) { TritonDotPattern, TritonLoadPattern, TritonStorePattern - >(context); + >(typeConverter, context); } // @@ -307,9 +310,12 @@ struct SCFYieldPattern : public OpConversionPattern { } }; -void populateSCFPatterns(RewritePatternSet &patterns) { +void populateSCFPatterns( + TritonGPUTypeConverter &typeConverter, RewritePatternSet &patterns +) { MLIRContext *context = patterns.getContext(); - patterns.add(context); + patterns.add(typeConverter, context); } @@ -324,16 +330,17 @@ public: MLIRContext *context = &getContext(); ModuleOp mod = getOperation(); int numThreads = numWarps * 32; - - TritonGPUConversionTarget target(*context); + // type converter + TritonGPUTypeConverter typeConverter(context, numThreads); + TritonGPUConversionTarget target(*context, typeConverter); // rewrite patterns RewritePatternSet patterns(context); // add rules - populateArithmeticPatternsAndLegality(patterns, target); - populateTritonPatterns(patterns); + populateArithmeticPatternsAndLegality(typeConverter, patterns, target); + populateTritonPatterns(typeConverter, patterns); // TODO: can we use // mlir::scf::populateSCFStructurealTypeConversionsAndLegality(...) here? - populateSCFPatterns(patterns); + populateSCFPatterns(typeConverter, patterns); if(failed(applyPartialConversion(mod, target, std::move(patterns)))) diff --git a/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp b/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp index dddbe4de9..970ecbedc 100644 --- a/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp +++ b/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp @@ -90,8 +90,9 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context, // // TritonGPUConversion // -TritonGPUConversionTarget::TritonGPUConversionTarget(MLIRContext &context) - : ConversionTarget(context) { +TritonGPUConversionTarget::TritonGPUConversionTarget( + MLIRContext &context, TritonGPUTypeConverter &typeConverter) + : ConversionTarget(context), typeConverter(typeConverter) { // TODO: we should also verify ops of TritonGPUDialect addLegalDialect(); @@ -103,18 +104,7 @@ TritonGPUConversionTarget::TritonGPUConversionTarget(MLIRContext &context) triton::TritonDialect, StandardOpsDialect, scf::SCFDialect>([&](Operation *op) { - auto isLegal = [](Value v) -> bool { - Type type = v.getType(); - if (auto tensorType = type.dyn_cast()) { - if (tensorType.getEncoding()) - return true; - return false; - } - return true; - }; - - if (llvm::all_of(op->getOperands(), isLegal) && - llvm::all_of(op->getResults(), isLegal)) + if (typeConverter.isLegal(op)) return true; return false; }); @@ -127,6 +117,9 @@ TritonGPUConversionTarget::TritonGPUConversionTarget(MLIRContext &context) if (aEncoding && aEncoding.isa() && bEncoding && bEncoding.isa()) return true; + // // TODO: we should delete this + // if (this->typeConverter.isLegal(dotOp)) + // return true; return false; });