From d9017f85936c8ad98c0f9ccec6d11c3041a80be7 Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Sat, 30 Apr 2022 20:42:25 -0700 Subject: [PATCH] add basic template for legalizing arithmetic op --- .../TritonToTritonGPU/TritonToTritonGPU.cpp | 52 ++++++++++++++++++- 1 file changed, 51 insertions(+), 1 deletion(-) diff --git a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp index a0e93f48c..8b8ceb289 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp @@ -1,5 +1,6 @@ #include "mlir/Transforms/DialectConversion.h" #include "triton/Dialect/Triton/IR/Dialect.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h" #include "../PassDetail.h" @@ -8,6 +9,44 @@ using namespace mlir::triton; namespace { +class ConvertArithmeticOp: public ConversionPattern { +public: + ConvertArithmeticOp(TypeConverter &typeConverter, MLIRContext *context) + : ConversionPattern(typeConverter, MatchAnyOpTypeTag(), /*benefit=*/1, + context) {} + + LogicalResult matchAndRewrite(Operation* op, ArrayRef operands, + ConversionPatternRewriter& rewriter) const override { + Dialect* dialect = op->getDialect(); + if(dialect->getTypeID() != mlir::TypeID::get()) + return failure(); + // Arithmetic op to legalize here. Create layout conversion if necessary + return success(); + } +}; + +void populateArithmeticPatternsAndLegality( + TypeConverter& typeConverter, RewritePatternSet &patterns, + ConversionTarget &target){ + // -------------- + // Add legality and rewrite pattern rules for operations + // from the Arithmetic dialect. The basic premise is that + // arithmetic operations require both inputs to have the same + // non-null encoding + // -------------- + MLIRContext *context = patterns.getContext(); + // Legality rule + target.addDynamicallyLegalDialect( + // TODO: check above rule here + [](Operation *op){ + return false; + } + ); + // Rewrite rule + patterns.add(typeConverter, context); +} + + class ConvertTritonToTritonGPU: public ConvertTritonToTritonGPUBase { @@ -24,7 +63,18 @@ public: void runOnOperation() override { MLIRContext *context = &getContext(); ConversionTarget target(*context); - std::cout << "Converting" << std::endl; + // type converter + TypeConverter typeConverter; + // rewrite patterns + RewritePatternSet patterns(context); + // add rules + populateArithmeticPatternsAndLegality(typeConverter, patterns, target); + + + if(failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) + return signalPassFailure(); + } };