From 64d0b87ef06fc689629220c8fb0522cddcb271d0 Mon Sep 17 00:00:00 2001 From: Yan Da Date: Sat, 18 Jun 2022 14:34:59 +0800 Subject: [PATCH] Remove TypeConverter from TritonToTritonGPU conversion --- .../Transforms/TritonGPUConversion.h | 3 +- .../TritonToTritonGPU/TritonToTritonGPU.cpp | 29 +++++++------------ .../Transforms/TritonGPUConversion.cpp | 21 +++++++++----- 3 files changed, 26 insertions(+), 27 deletions(-) diff --git a/include/triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h b/include/triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h index 2f34d71f7..fddcf2905 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h +++ b/include/triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h @@ -20,9 +20,8 @@ private: }; class TritonGPUConversionTarget : public ConversionTarget { - TritonGPUTypeConverter &typeConverter; public: - explicit TritonGPUConversionTarget(MLIRContext &ctx, TritonGPUTypeConverter &typeConverter); + explicit TritonGPUConversionTarget(MLIRContext &ctx); }; } // namespace mlir diff --git a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp index 591af75b5..81ef9e790 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp @@ -73,8 +73,7 @@ public: }; void populateArithmeticPatternsAndLegality( - TritonGPUTypeConverter& typeConverter, RewritePatternSet &patterns, - TritonGPUConversionTarget &target){ + RewritePatternSet &patterns, TritonGPUConversionTarget &target){ // -------------- // Add legality and rewrite pattern rules for operations // from the Arithmetic dialect. The basic premise is that @@ -128,7 +127,7 @@ void populateArithmeticPatternsAndLegality( // Cast Ops ArithGenericPattern, ArithGenericPattern - >(typeConverter, context); + >(context); } // @@ -235,9 +234,7 @@ struct TritonReducePattern : public OpConversionPattern { } }; -void populateTritonPatterns( - TritonGPUTypeConverter& typeConverter, RewritePatternSet &patterns -) { +void populateTritonPatterns(RewritePatternSet &patterns) { MLIRContext *context = patterns.getContext(); patterns.add, TritonGenericPattern, @@ -247,7 +244,7 @@ void populateTritonPatterns( TritonDotPattern, TritonLoadPattern, TritonStorePattern - >(typeConverter, context); + >(context); } // @@ -310,12 +307,9 @@ struct SCFYieldPattern : public OpConversionPattern { } }; -void populateSCFPatterns( - TritonGPUTypeConverter &typeConverter, RewritePatternSet &patterns -) { +void populateSCFPatterns(RewritePatternSet &patterns) { MLIRContext *context = patterns.getContext(); - patterns.add(typeConverter, context); + patterns.add(context); } @@ -330,17 +324,16 @@ public: MLIRContext *context = &getContext(); ModuleOp mod = getOperation(); int numThreads = numWarps * 32; - // type converter - TritonGPUTypeConverter typeConverter(context, numThreads); - TritonGPUConversionTarget target(*context, typeConverter); + + TritonGPUConversionTarget target(*context); // rewrite patterns RewritePatternSet patterns(context); // add rules - populateArithmeticPatternsAndLegality(typeConverter, patterns, target); - populateTritonPatterns(typeConverter, patterns); + populateArithmeticPatternsAndLegality(patterns, target); + populateTritonPatterns(patterns); // TODO: can we use // mlir::scf::populateSCFStructurealTypeConversionsAndLegality(...) here? - populateSCFPatterns(typeConverter, patterns); + populateSCFPatterns(patterns); if(failed(applyPartialConversion(mod, target, std::move(patterns)))) diff --git a/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp b/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp index 970ecbedc..dddbe4de9 100644 --- a/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp +++ b/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp @@ -90,9 +90,8 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context, // // TritonGPUConversion // -TritonGPUConversionTarget::TritonGPUConversionTarget( - MLIRContext &context, TritonGPUTypeConverter &typeConverter) - : ConversionTarget(context), typeConverter(typeConverter) { +TritonGPUConversionTarget::TritonGPUConversionTarget(MLIRContext &context) + : ConversionTarget(context) { // TODO: we should also verify ops of TritonGPUDialect addLegalDialect(); @@ -104,7 +103,18 @@ TritonGPUConversionTarget::TritonGPUConversionTarget( triton::TritonDialect, StandardOpsDialect, scf::SCFDialect>([&](Operation *op) { - if (typeConverter.isLegal(op)) + auto isLegal = [](Value v) -> bool { + Type type = v.getType(); + if (auto tensorType = type.dyn_cast()) { + if (tensorType.getEncoding()) + return true; + return false; + } + return true; + }; + + if (llvm::all_of(op->getOperands(), isLegal) && + llvm::all_of(op->getResults(), isLegal)) return true; return false; }); @@ -117,9 +127,6 @@ TritonGPUConversionTarget::TritonGPUConversionTarget( if (aEncoding && aEncoding.isa() && bEncoding && bEncoding.isa()) return true; - // // TODO: we should delete this - // if (this->typeConverter.isLegal(dotOp)) - // return true; return false; });