#include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h" #include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "../PassDetail.h" using namespace mlir; using namespace mlir::triton; namespace { class ConvertArithmeticOp: public ConversionPattern { public: ConvertArithmeticOp(TypeConverter &typeConverter, MLIRContext *context) : ConversionPattern(typeConverter, MatchAnyOpTypeTag(), /*benefit=*/1, context) {} LogicalResult matchAndRewrite(Operation* op, ArrayRef operands, ConversionPatternRewriter& rewriter) const override { Dialect* dialect = op->getDialect(); if(dialect->getTypeID() != mlir::TypeID::get()) return failure(); // Arithmetic op to legalize here. Create layout conversion if necessary return success(); } }; void populateArithmeticPatternsAndLegality( TypeConverter& typeConverter, RewritePatternSet &patterns, ConversionTarget &target){ // -------------- // Add legality and rewrite pattern rules for operations // from the Arithmetic dialect. The basic premise is that // arithmetic operations require both inputs to have the same // non-null encoding // -------------- MLIRContext *context = patterns.getContext(); // Legality rule target.addDynamicallyLegalDialect( // TODO: check above rule here [](Operation *op){ return true; } ); // Rewrite rule patterns.add(typeConverter, context); } class ConvertTritonToTritonGPU : public ConvertTritonToTritonGPUBase { public: void runOnOperation() override { MLIRContext *context = &getContext(); TritonGPUConversionTarget target(*context); ModuleOp mod = getOperation(); // int numThreads = mod.getAttr(); // type converter TritonGPUTypeConverter typeConverter(context, /*numThreads*/4*32); // rewrite patterns RewritePatternSet patterns(context); // add rules populateArithmeticPatternsAndLegality(typeConverter, patterns, target); if(failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) return signalPassFailure(); } }; } std::unique_ptr> mlir::triton::createConvertTritonToTritonGPUPass() { return std::make_unique<::ConvertTritonToTritonGPU>(); }