#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/GPU/GPUDialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h" #include "llvm/ADT/APSInt.h" #include using namespace mlir; using namespace mlir::triton; #define GEN_PASS_CLASSES #include "triton/Conversion/Passes.h.inc" namespace { template class GenericOpPattern : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(Op op, typename Op::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type retType = this->getTypeConverter()->convertType(op.getType()); rewriter.replaceOpWithNewOp(op, retType, adaptor.getOperands()); return success(); } }; template class ArithCmpPattern : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(SrcOp op, typename SrcOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type retType = this->getTypeConverter()->convertType(op.getType()); rewriter.replaceOpWithNewOp(op, retType, adaptor.getPredicate(), adaptor.getLhs(), adaptor.getRhs()); return success(); } }; class ArithConstantPattern : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type retType = getTypeConverter()->convertType(op.getType()); auto value = adaptor.getValue().dyn_cast(); assert(value); rewriter.replaceOpWithNewOp( op, retType, value.reshape(retType) // This is a hack. We just want to add encoding ); return success(); } }; class ConvertArithmeticOp : public ConversionPattern { public: ConvertArithmeticOp(TritonGPUTypeConverter &typeConverter, MLIRContext *context) : ConversionPattern(typeConverter, MatchAnyOpTypeTag(), /*benefit=*/1, context) {} LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { Dialect *dialect = op->getDialect(); if (dialect->getTypeID() != mlir::TypeID::get()) return failure(); return success(); } }; void populateArithmeticPatternsAndLegality( TritonGPUTypeConverter &typeConverter, RewritePatternSet &patterns, TritonGPUConversionTarget &target) { // -------------- // Add legality and rewrite pattern rules for operations // from the Arithmetic dialect. The basic premise is that // arithmetic operations require both inputs to have the same // non-null encoding // -------------- MLIRContext *context = patterns.getContext(); // TODO: there's probably a better way to avoid adding all ops one-by-one patterns.add< ArithConstantPattern, GenericOpPattern, GenericOpPattern, GenericOpPattern, GenericOpPattern, GenericOpPattern, GenericOpPattern, GenericOpPattern, GenericOpPattern, GenericOpPattern, GenericOpPattern, GenericOpPattern, GenericOpPattern, GenericOpPattern, GenericOpPattern, GenericOpPattern, GenericOpPattern, // NegFOp // Floating point GenericOpPattern, GenericOpPattern, // MaxMin GenericOpPattern, GenericOpPattern, GenericOpPattern, GenericOpPattern, GenericOpPattern, GenericOpPattern, // Floating point GenericOpPattern, GenericOpPattern, GenericOpPattern, // Cmp ArithCmpPattern, ArithCmpPattern, // Cast Ops GenericOpPattern, GenericOpPattern, GenericOpPattern, GenericOpPattern, GenericOpPattern, GenericOpPattern, GenericOpPattern, GenericOpPattern, GenericOpPattern>(typeConverter, context); } // this shouldn't exist if mlir's SelectOp checked encodings properly class StdSelectPattern : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(SelectOp op, typename SelectOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type retType = this->getTypeConverter()->convertType(op.getType()); rewriter.replaceOpWithNewOp( op, retType, adaptor.getCondition(), adaptor.getTrueValue(), adaptor.getFalseValue()); return success(); } }; void populateStdPatternsAndLegality(TritonGPUTypeConverter &typeConverter, RewritePatternSet &patterns, TritonGPUConversionTarget &target) { MLIRContext *context = patterns.getContext(); // Rewrite rule patterns.add(typeConverter, context); target.addLegalOp(); // this is ok because all functions are inlined // by the frontend } void populateMathPatternsAndLegality(TritonGPUTypeConverter &typeConverter, RewritePatternSet &patterns, TritonGPUConversionTarget &target) { MLIRContext *context = patterns.getContext(); // Rewrite rule patterns.add, GenericOpPattern, GenericOpPattern, GenericOpPattern, GenericOpPattern>(typeConverter, context); } // // Triton patterns // // TODO: Do we need to put them in anonymous namespace? struct TritonMakeRangePattern : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(triton::MakeRangeOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type retType = getTypeConverter()->convertType(op.getType()); rewriter.replaceOpWithNewOp( op, retType, adaptor.start(), adaptor.end()); return success(); } }; struct TritonExpandDimsPattern : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(triton::ExpandDimsOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Type retType = op.getType()); RankedTensorType argType = adaptor.src().getType().cast(); Attribute _argEncoding = argType.getEncoding(); if (!_argEncoding) return failure(); auto argEncoding = _argEncoding.cast(); // return shape auto retShape = argType.getShape().vec(); retShape.insert(retShape.begin() + op.axis(), 1); // return encoding auto retSizePerThread = argEncoding.getSizePerThread().vec(); retSizePerThread.insert(retSizePerThread.begin() + op.axis(), 1); auto retThreadsPerWarp = argEncoding.getThreadsPerWarp().vec(); retThreadsPerWarp.insert(retThreadsPerWarp.begin() + op.axis(), 1); auto retWarpsPerCTA = argEncoding.getWarpsPerCTA().vec(); retWarpsPerCTA.insert(retWarpsPerCTA.begin() + op.axis(), 1); SmallVector retOrder(retShape.size()); std::iota(retOrder.begin(), retOrder.end(), 0); triton::gpu::BlockedEncodingAttr retEncoding = triton::gpu::BlockedEncodingAttr::get(getContext(), retSizePerThread, retThreadsPerWarp, retWarpsPerCTA, retOrder); // convert operand to slice of return type Attribute newArgEncoding = triton::gpu::SliceEncodingAttr::get( getContext(), op.axis(), retEncoding); RankedTensorType newArgType = RankedTensorType::get( argType.getShape(), argType.getElementType(), newArgEncoding); // construct new op auto newSrc = rewriter.create( op.getLoc(), newArgType, adaptor.src()); rewriter.replaceOpWithNewOp(op, newSrc, adaptor.axis()); return success(); } }; struct TritonDotPattern : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(triton::DotOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { RankedTensorType origType = op.getType().cast(); auto origShape = origType.getShape(); auto typeConverter = getTypeConverter(); int numWarps = typeConverter->getNumWarps(); SmallVector retSizePerThread = {1, 1}; if (origShape[0] * origShape[1] / (numWarps * 32) >= 4) retSizePerThread = {2, 2}; if (origShape[0] * origShape[1] / (numWarps * 32) >= 16) retSizePerThread = {4, 4}; SmallVector retOrder = {1, 0}; Attribute dEncoding = triton::gpu::BlockedEncodingAttr::get( getContext(), origShape, retSizePerThread, retOrder, numWarps); RankedTensorType retType = RankedTensorType::get(origShape, origType.getElementType(), dEncoding); // a & b must be of smem layout auto aType = adaptor.a().getType().cast(); auto bType = adaptor.b().getType().cast(); Attribute aEncoding = aType.getEncoding(); Attribute bEncoding = bType.getEncoding(); if (!aEncoding || !bEncoding) return failure(); Value a = adaptor.a(); Value b = adaptor.b(); Value c = adaptor.c(); if (!aEncoding.isa()) { Attribute encoding = triton::gpu::DotOperandEncodingAttr::get(getContext(), 0, dEncoding); auto dstType = RankedTensorType::get(aType.getShape(), aType.getElementType(), encoding); a = rewriter.create(a.getLoc(), dstType, a); } if (!bEncoding.isa()) { Attribute encoding = triton::gpu::DotOperandEncodingAttr::get(getContext(), 1, dEncoding); auto dstType = RankedTensorType::get(bType.getShape(), bType.getElementType(), encoding); b = rewriter.create(b.getLoc(), dstType, b); } c = rewriter.create(c.getLoc(), retType, c); rewriter.replaceOpWithNewOp(op, retType, a, b, c, adaptor.allowTF32()); return success(); } }; struct TritonCatPattern : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(triton::CatOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // For now, this behaves like generic, but this will evolve when // we add support for `can_reorder=False` Type retType = this->getTypeConverter()->convertType(op.getType()); rewriter.replaceOpWithNewOp(op, retType, adaptor.getOperands()); return success(); } }; struct TritonTransPattern : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(triton::TransOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Value src = adaptor.src(); auto srcType = src.getType().cast(); Attribute srcEncoding = srcType.getEncoding(); if (!srcEncoding) return failure(); if (!srcEncoding.isa()) { // TODO: end-to-end correctness is broken if // the input is blocked and the output is shared // with different order. Maybe a backend issue in BlockedToShared? SmallVector order = {1, 0}; if (auto srcBlockedEncoding = srcEncoding.dyn_cast()) llvm::copy(srcBlockedEncoding.getOrder(), order.begin()); srcEncoding = triton::gpu::SharedEncodingAttr::get(getContext(), 1, 1, 1, order); srcType = RankedTensorType::get(srcType.getShape(), srcType.getElementType(), srcEncoding); src = rewriter.create(src.getLoc(), srcType, src); } auto srcSharedEncoding = srcEncoding.cast(); SmallVector retOrder(srcSharedEncoding.getOrder().begin(), srcSharedEncoding.getOrder().end()); SmallVector retShapes(srcType.getShape().begin(), srcType.getShape().end()); std::reverse(retOrder.begin(), retOrder.end()); std::reverse(retShapes.begin(), retShapes.end()); auto retEncoding = triton::gpu::SharedEncodingAttr::get(getContext(), 1, 1, 1, retOrder); auto retType = RankedTensorType::get(retShapes, srcType.getElementType(), retEncoding); rewriter.replaceOpWithNewOp(op, retType, src); return success(); } }; struct TritonLoadPattern : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp( op, typeConverter->convertType(op.getType()), adaptor.ptr(), adaptor.mask(), adaptor.other(), adaptor.cache(), adaptor.evict(), adaptor.isVolatile()); return success(); } }; struct TritonStorePattern : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(triton::StoreOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp( op, adaptor.ptr(), adaptor.value(), adaptor.mask()); return success(); } }; struct TritonAtomicCASPattern : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(triton::AtomicCASOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp( op, typeConverter->convertType(op.getType()), adaptor.ptr(), adaptor.cmp(), adaptor.val()); return success(); } }; struct TritonAtomicRMWPattern : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(triton::AtomicRMWOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp( op, typeConverter->convertType(op.getType()), adaptor.atomic_rmw_op(), adaptor.ptr(), adaptor.val(), adaptor.mask()); return success(); } }; struct TritonExtElemwisePattern : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(triton::ExtElemwiseOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp( op, typeConverter->convertType(op.getType()), adaptor.args(), adaptor.libname(), adaptor.libpath(), adaptor.symbol()); return success(); } }; template struct TritonGenericPattern : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(Op op, typename Op::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type retType = this->getTypeConverter()->convertType(op.getType()); rewriter.replaceOpWithNewOp(op, retType, adaptor.getOperands()); return success(); } }; struct TritonBroadcastPattern : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; // This creates a tensor with the new shape but the argument's layout LogicalResult matchAndRewrite(BroadcastOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto srcType = adaptor.src().getType().cast(); auto srcEncoding = srcType.getEncoding(); if (!srcEncoding) return failure(); auto opType = op.getType().cast(); Type retType = RankedTensorType::get(opType.getShape(), opType.getElementType(), srcEncoding); // Type retType = this->getTypeConverter()->convertType(op.getType()); rewriter.replaceOpWithNewOp(op, retType, adaptor.getOperands()); return success(); } }; struct TritonReducePattern : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(triton::ReduceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp( op, adaptor.redOp(), adaptor.operand(), adaptor.axis()); return success(); } }; struct TritonPrintfPattern : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(PrintfOp op, typename PrintfOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp(op, op.prefixAttr(), adaptor.getOperands()); return success(); } }; void populateTritonPatterns(TritonGPUTypeConverter &typeConverter, RewritePatternSet &patterns) { MLIRContext *context = patterns.getContext(); patterns.add< // TODO: view should have custom pattern that views the layout TritonGenericPattern, TritonGenericPattern, TritonGenericPattern, TritonGenericPattern, TritonGenericPattern, TritonGenericPattern, TritonBroadcastPattern, TritonGenericPattern, TritonCatPattern, TritonReducePattern, TritonTransPattern, TritonExpandDimsPattern, TritonMakeRangePattern, TritonDotPattern, TritonLoadPattern, TritonStorePattern, TritonExtElemwisePattern, TritonPrintfPattern, TritonAtomicRMWPattern>(typeConverter, context); } // // SCF patterns // // This is borrowed from ConvertForOpTypes in // SCF/Transforms/StructuralTypeConversions.cpp struct SCFForPattern : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; // Ref: ConvertForOpTypes LogicalResult matchAndRewrite(scf::ForOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto newOp = cast(rewriter.cloneWithoutRegions(*op.getOperation())); rewriter.inlineRegionBefore(op.getLoopBody(), newOp.getLoopBody(), newOp.getLoopBody().end()); // Now, update all the types. // Convert the types of block arguments within the given region. This // replaces each block with a new block containing the updated signature. // The entry block may have a special conversion if `entryConversion` is // provided. On success, the new entry block to the region is returned for // convenience. Otherwise, failure is returned. if (failed(rewriter.convertRegionTypes(&newOp.getLoopBody(), *getTypeConverter()))) { return rewriter.notifyMatchFailure(op, "could not convert body types"); } // Change the clone to use the updated operands. We could have cloned with // a BlockAndValueMapping, but this seems a bit more direct. newOp->setOperands(adaptor.getOperands()); // Update the result types to the new converted types. SmallVector newResultTypes; for (Type type : op.getResultTypes()) { Type newType = typeConverter->convertType(type); if (!newType) return rewriter.notifyMatchFailure(op, "not a 1:1 type conversion"); newResultTypes.push_back(newType); } for (auto t : llvm::zip(newOp.getResults(), newResultTypes)) std::get<0>(t).setType(std::get<1>(t)); rewriter.replaceOp(op, newOp.getResults()); return success(); } }; struct SCFYieldPattern : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(scf::YieldOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // rewriter.setInsertionPointToEnd(rewriter.getInsertionBlock()); // rewriter.create(op.getLoc(), adaptor.getOperands()); // op.erase(); rewriter.replaceOpWithNewOp(op, adaptor.getOperands()); return success(); } }; // This is borrowed from ConvertFIfOpTypes in // SCF/Transforms/StructuralTypeConversions.cpp class SCFIfPattern : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(scf::IfOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // TODO: Generalize this to any type conversion, not just 1:1. // // We need to implement something more sophisticated here that tracks which // types convert to which other types and does the appropriate // materialization logic. // For example, it's possible that one result type converts to 0 types and // another to 2 types, so newResultTypes would at least be the right size to // not crash in the llvm::zip call below, but then we would set the the // wrong type on the SSA values! These edge cases are also why we cannot // safely use the TypeConverter::convertTypes helper here. SmallVector newResultTypes; for (auto type : op.getResultTypes()) { Type newType = typeConverter->convertType(type); if (!newType) return rewriter.notifyMatchFailure(op, "not a 1:1 type conversion"); newResultTypes.push_back(newType); } // See comments in the ForOp pattern for why we clone without regions and // then inline. scf::IfOp newOp = cast(rewriter.cloneWithoutRegions(*op.getOperation())); rewriter.inlineRegionBefore(op.getThenRegion(), newOp.getThenRegion(), newOp.getThenRegion().end()); rewriter.inlineRegionBefore(op.getElseRegion(), newOp.getElseRegion(), newOp.getElseRegion().end()); // Update the operands and types. newOp->setOperands(adaptor.getOperands()); for (auto t : llvm::zip(newOp.getResults(), newResultTypes)) std::get<0>(t).setType(std::get<1>(t)); rewriter.replaceOp(op, newOp.getResults()); return success(); } }; void populateSCFPatterns(TritonGPUTypeConverter &typeConverter, RewritePatternSet &patterns) { MLIRContext *context = patterns.getContext(); patterns.add(typeConverter, context); } class ConvertTritonToTritonGPU : public ConvertTritonToTritonGPUBase { public: ConvertTritonToTritonGPU() = default; // constructor with some parameters set explicitly. ConvertTritonToTritonGPU(int numWarps) { this->numWarps = numWarps; } void runOnOperation() override { MLIRContext *context = &getContext(); ModuleOp mod = getOperation(); // type converter TritonGPUTypeConverter typeConverter(context, numWarps); TritonGPUConversionTarget target(*context, typeConverter); // rewrite patterns RewritePatternSet patterns(context); // add rules populateStdPatternsAndLegality(typeConverter, patterns, target); populateArithmeticPatternsAndLegality(typeConverter, patterns, target); populateMathPatternsAndLegality(typeConverter, patterns, target); populateTritonPatterns(typeConverter, patterns); // TODO: can we use // mlir::scf::populateSCFStructurealTypeConversionsAndLegality(...) here? populateSCFPatterns(typeConverter, patterns); if (failed(applyPartialConversion(mod, target, std::move(patterns)))) return signalPassFailure(); auto inti = llvm::APSInt(32, false); auto i32_ty = IntegerType::get(mod->getContext(), 32); mod->setAttr( AttrNumWarpsName, IntegerAttr::get(i32_ty, llvm::APInt(32, numWarps.getValue()))); // update layouts // broadcast src => multicast, dst => broadcasted // if (failed(target.refineLayouts(mod, numWarps))) // return signalPassFailure(); } }; } // namespace std::unique_ptr> mlir::triton::createConvertTritonToTritonGPUPass(int numWarps) { return std::make_unique<::ConvertTritonToTritonGPU>(numWarps); } std::unique_ptr> mlir::triton::createConvertTritonToTritonGPUPass() { return std::make_unique<::ConvertTritonToTritonGPU>(); }