#include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h" #include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "../PassDetail.h" #include using namespace mlir; using namespace mlir::triton; namespace { template class ArithGenericPattern : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(Op op, typename Op::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type retType = this->getTypeConverter()->convertType(op.getType()); Op res = rewriter.replaceOpWithNewOp( op, retType, adaptor.getOperands() ); return success(); } }; template class ArithCmpPattern : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(SrcOp op, typename SrcOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type retType = this->getTypeConverter()->convertType(op.getType()); DstOp res = rewriter.replaceOpWithNewOp( op, retType, adaptor.getPredicate(), adaptor.getLhs(), adaptor.getRhs() ); return success(); } }; class ArithConstantPattern : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type retType = getTypeConverter()->convertType(op.getType()); auto value = adaptor.getValue().dyn_cast(); assert(value); rewriter.replaceOpWithNewOp( op, retType, value.reshape(retType) // This is a hack. We just want to add encoding ); return success(); } }; class ConvertArithmeticOp: public ConversionPattern { public: ConvertArithmeticOp(TritonGPUTypeConverter &typeConverter, MLIRContext *context) : ConversionPattern(typeConverter, MatchAnyOpTypeTag(), /*benefit=*/1, context) {} LogicalResult matchAndRewrite(Operation* op, ArrayRef operands, ConversionPatternRewriter& rewriter) const override { Dialect* dialect = op->getDialect(); if(dialect->getTypeID() != mlir::TypeID::get()) return failure(); return success(); } }; void populateArithmeticPatternsAndLegality( TritonGPUTypeConverter& typeConverter, RewritePatternSet &patterns, TritonGPUConversionTarget &target){ // -------------- // Add legality and rewrite pattern rules for operations // from the Arithmetic dialect. The basic premise is that // arithmetic operations require both inputs to have the same // non-null encoding // -------------- MLIRContext *context = patterns.getContext(); // // Legality rule // target.addDynamicallyLegalDialect( // // TODO: check above rule here // [](Operation *op){ // return true; // } // ); // Rewrite rule // patterns.add(typeConverter, context); patterns.add, ArithGenericPattern, ArithGenericPattern, ArithGenericPattern, ArithGenericPattern, ArithGenericPattern, ArithGenericPattern, ArithGenericPattern, ArithGenericPattern, ArithGenericPattern, ArithGenericPattern, ArithGenericPattern, ArithGenericPattern, ArithGenericPattern, ArithGenericPattern, ArithGenericPattern, // NegFOp // Floating point ArithGenericPattern, ArithGenericPattern, // MaxMin ArithGenericPattern, ArithGenericPattern, ArithGenericPattern, ArithGenericPattern, ArithGenericPattern, ArithGenericPattern, // Floating point ArithGenericPattern, ArithGenericPattern, ArithGenericPattern, // Cmp ArithCmpPattern, ArithCmpPattern, // Cast Ops ArithGenericPattern, ArithGenericPattern >(typeConverter, context); } // // Triton patterns // // TODO: Do we need to put them in anonymous namespace? struct TritonMakeRangePattern : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(triton::MakeRangeOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type retType = getTypeConverter()->convertType(op.getType()); rewriter.replaceOpWithNewOp( op, retType, adaptor.start(), adaptor.end() ); return success(); } }; struct TritonDotPattern : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(triton::DotOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type retType = getTypeConverter()->convertType(op.getType()); // a & b must be of smem layout auto aType = adaptor.a().getType().cast(); auto bType = adaptor.b().getType().cast(); Attribute aEncoding = aType.getEncoding(); Attribute bEncoding = bType.getEncoding(); if (!aEncoding || !bEncoding) return failure(); Value a = adaptor.a(); Value b = adaptor.b(); if (!aEncoding.isa()) { Attribute encoding = triton::gpu::TritonGPUSharedEncodingAttr::get(getContext(), 1, 1, 1); auto dstType = RankedTensorType::get(aType.getShape(), aType.getElementType(), encoding); a = rewriter.create(a.getLoc(), dstType, a); } if (!bEncoding.isa()) { Attribute encoding = triton::gpu::TritonGPUSharedEncodingAttr::get(getContext(), 1, 1, 1); auto dstType = RankedTensorType::get(bType.getShape(), bType.getElementType(), encoding); b = rewriter.create(b.getLoc(), dstType, b); } auto newDot = rewriter.replaceOpWithNewOp( op, retType, a, b, adaptor.c(), adaptor.allowTF32() ); // auto newDot = rewriter.create(op.getLoc(), retType, // a, b, adaptor.c(), adaptor.allowTF32()); // rewriter.replaceOp(op, {newDot}); return success(); } }; struct TritonLoadPattern : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type retType = getTypeConverter()->convertType(op.getType()); rewriter.replaceOpWithNewOp( op, retType, adaptor.ptr(), adaptor.mask(), adaptor.other(), adaptor.cache(), adaptor.evict(), adaptor.isVolatile() ); return success(); } }; struct TritonStorePattern : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(triton::StoreOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto newOp = rewriter.replaceOpWithNewOp( op, adaptor.ptr(), adaptor.value(), adaptor.mask() ); return success(); } }; template struct TritonGenericPattern : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(Op op, typename Op::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type retType = this->getTypeConverter()->convertType(op.getType()); rewriter.replaceOpWithNewOp( op, retType, adaptor.getOperands() ); return success(); } }; void populateTritonPatterns( TritonGPUTypeConverter& typeConverter, RewritePatternSet &patterns ) { MLIRContext *context = patterns.getContext(); patterns.add, TritonGenericPattern, TritonGenericPattern, TritonMakeRangePattern, TritonDotPattern, TritonLoadPattern, TritonStorePattern >(typeConverter, context); } // // SCF patterns // // This is borrowed from ConvertForOpTypes in // SCF/Transforms/StructuralTypeConversions.cpp struct SCFForPattern : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; // Ref: ConvertForOpTypes LogicalResult matchAndRewrite(scf::ForOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto newOp = cast(rewriter.cloneWithoutRegions(*op.getOperation())); rewriter.inlineRegionBefore(op.getLoopBody(), newOp.getLoopBody(), newOp.getLoopBody().end()); // Now, update all the types. // Convert the types of block arguments within the given region. This // replaces each block with a new block containing the updated signature. The // entry block may have a special conversion if `entryConversion` is // provided. On success, the new entry block to the region is returned for // convenience. Otherwise, failure is returned. if (failed(rewriter.convertRegionTypes(&newOp.getLoopBody(), *getTypeConverter()))) { return rewriter.notifyMatchFailure(op, "could not convert body types"); } // Change the clone to use the updated operands. We could have cloned with // a BlockAndValueMapping, but this seems a bit more direct. newOp->setOperands(adaptor.getOperands()); // Update the result types to the new converted types. SmallVector newResultTypes; for (Type type : op.getResultTypes()) { Type newType = typeConverter->convertType(type); if (!newType) return rewriter.notifyMatchFailure(op, "not a 1:1 type conversion"); newResultTypes.push_back(newType); } for (auto t : llvm::zip(newOp.getResults(), newResultTypes)) std::get<0>(t).setType(std::get<1>(t)); rewriter.replaceOp(op, newOp.getResults()); return success(); } }; struct SCFYieldPattern : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(scf::YieldOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // rewriter.setInsertionPointToEnd(rewriter.getInsertionBlock()); // rewriter.create(op.getLoc(), adaptor.getOperands()); // op.erase(); rewriter.replaceOpWithNewOp( op, adaptor.getOperands() ); return success(); } }; void populateSCFPatterns( TritonGPUTypeConverter &typeConverter, RewritePatternSet &patterns ) { MLIRContext *context = patterns.getContext(); patterns.add(typeConverter, context); } class ConvertTritonToTritonGPU : public ConvertTritonToTritonGPUBase { public: void runOnOperation() override { MLIRContext *context = &getContext(); ModuleOp mod = getOperation(); // int numThreads = mod.getAttr(); // type converter TritonGPUTypeConverter typeConverter(context, /*numThreads*/32); TritonGPUConversionTarget target(*context, typeConverter); // rewrite patterns RewritePatternSet patterns(context); // add rules populateArithmeticPatternsAndLegality(typeConverter, patterns, target); populateTritonPatterns(typeConverter, patterns); // TODO: can we use // mlir::scf::populateSCFStructurealTypeConversionsAndLegality(...) here? populateSCFPatterns(typeConverter, patterns); if(failed(applyPartialConversion(mod, target, std::move(patterns)))) return signalPassFailure(); } }; } std::unique_ptr> mlir::triton::createConvertTritonToTritonGPUPass() { return std::make_unique<::ConvertTritonToTritonGPU>(); }