#include "ViewOpToLLVM.h" #include "DotOpHelpers.h" using namespace mlir; using namespace mlir::triton; using ::mlir::LLVM::DotOpFMAConversionHelper; using ::mlir::LLVM::DotOpMmaV1ConversionHelper; using ::mlir::LLVM::DotOpMmaV2ConversionHelper; using ::mlir::LLVM::getElementsFromStruct; using ::mlir::LLVM::getSharedMemoryObjectFromStruct; using ::mlir::LLVM::getStructFromElements; using ::mlir::LLVM::MMA16816ConversionHelper; using ::mlir::triton::gpu::getElemsPerThread; struct SplatOpConversion : public ConvertTritonGPUOpToLLVMPattern { using ConvertTritonGPUOpToLLVMPattern< triton::SplatOp>::ConvertTritonGPUOpToLLVMPattern; // Convert SplatOp or arith::ConstantOp with SplatElementsAttr to a // LLVM::StructType value. // // @elemType: the element type in operand. // @resType: the return type of the Splat-like op. // @constVal: a LLVM::ConstantOp or other scalar value. static Value convertSplatLikeOp(Type elemType, Type resType, Value constVal, TypeConverter *typeConverter, ConversionPatternRewriter &rewriter, Location loc) { auto tensorTy = resType.cast(); if (tensorTy.getEncoding().isa() || tensorTy.getEncoding().isa()) { auto srcType = typeConverter->convertType(elemType); auto llSrc = bitcast(constVal, srcType); size_t elemsPerThread = getElemsPerThread(tensorTy); llvm::SmallVector elems(elemsPerThread, llSrc); llvm::SmallVector elemTypes(elems.size(), srcType); auto structTy = LLVM::LLVMStructType::getLiteral(rewriter.getContext(), elemTypes); return getStructFromElements(loc, elems, rewriter, structTy); } else if (auto dotLayout = tensorTy.getEncoding() .dyn_cast()) { return convertSplatLikeOpWithDotOperandLayout( dotLayout, resType, elemType, constVal, typeConverter, rewriter, loc); } else if (auto mmaLayout = tensorTy.getEncoding().dyn_cast()) { return convertSplatLikeOpWithMmaLayout( mmaLayout, resType, elemType, constVal, typeConverter, rewriter, loc); } else assert(false && "Unsupported layout found in ConvertSplatLikeOp"); return {}; } static Value convertSplatLikeOpWithDotOperandLayout( const triton::gpu::DotOperandEncodingAttr &layout, Type resType, Type elemType, Value constVal, TypeConverter *typeConverter, ConversionPatternRewriter &rewriter, Location loc) { auto tensorTy = resType.cast(); auto shape = tensorTy.getShape(); auto parent = layout.getParent(); int numElems{}; if (auto mmaLayout = parent.dyn_cast()) { if (mmaLayout.isAmpere()) { numElems = layout.getOpIdx() == 0 ? MMA16816ConversionHelper::getANumElemsPerThread( tensorTy, mmaLayout.getWarpsPerCTA()[0]) : MMA16816ConversionHelper::getBNumElemsPerThread( tensorTy, mmaLayout.getWarpsPerCTA()[1]); } else if (mmaLayout.isVolta()) { DotOpMmaV1ConversionHelper helper(mmaLayout); numElems = layout.getOpIdx() == 0 ? helper.numElemsPerThreadA(shape, {0, 1}) : helper.numElemsPerThreadB(shape, {0, 1}); } } else if (auto blockedLayout = parent.dyn_cast()) { numElems = DotOpFMAConversionHelper::getNumElemsPerThread(shape, layout); } else { assert(false && "Unsupported layout found"); } auto structTy = LLVM::LLVMStructType::getLiteral( rewriter.getContext(), SmallVector(numElems, elemType)); return getStructFromElements(loc, SmallVector(numElems, constVal), rewriter, structTy); } static Value convertSplatLikeOpWithMmaLayout( const MmaEncodingAttr &layout, Type resType, Type elemType, Value constVal, TypeConverter *typeConverter, ConversionPatternRewriter &rewriter, Location loc) { auto tensorTy = resType.cast(); auto shape = tensorTy.getShape(); if (layout.isAmpere()) { auto [repM, repN] = DotOpMmaV2ConversionHelper::getRepMN(tensorTy); size_t fcSize = 4 * repM * repN; auto structTy = LLVM::LLVMStructType::getLiteral( rewriter.getContext(), SmallVector(fcSize, elemType)); return getStructFromElements(loc, SmallVector(fcSize, constVal), rewriter, structTy); } if (layout.isVolta()) { DotOpMmaV1ConversionHelper helper(layout); int repM = helper.getRepM(shape[0]); int repN = helper.getRepN(shape[1]); // According to mma layout of v1, each thread process 8 elements. int elems = 8 * repM * repN; auto structTy = LLVM::LLVMStructType::getLiteral( rewriter.getContext(), SmallVector(elems, elemType)); return getStructFromElements(loc, SmallVector(elems, constVal), rewriter, structTy); } assert(false && "Unsupported mma layout found"); return {}; } LogicalResult matchAndRewrite(triton::SplatOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { auto loc = op->getLoc(); auto src = adaptor.src(); auto llStruct = convertSplatLikeOp(src.getType(), op.getType(), src, getTypeConverter(), rewriter, loc); rewriter.replaceOp(op, {llStruct}); return success(); } }; // This pattern helps to convert arith::ConstantOp(with SplatElementsAttr), // the logic is the same as triton::SplatOp, so the underlying implementation // is reused. struct ArithConstantSplatOpConversion : public ConvertTritonGPUOpToLLVMPattern { using ConvertTritonGPUOpToLLVMPattern< arith::ConstantOp>::ConvertTritonGPUOpToLLVMPattern; LogicalResult matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto value = op.getValue(); if (!value.dyn_cast()) return failure(); auto loc = op->getLoc(); LLVM::ConstantOp arithConstantOp; auto values = op.getValue().dyn_cast(); auto elemType = values.getElementType(); Attribute val; if (elemType.isBF16() || type::isFloat(elemType)) { val = values.getValues()[0]; } else if (type::isInt(elemType)) { val = values.getValues()[0]; } else { llvm::errs() << "ArithConstantSplatOpConversion get unsupported type: " << value.getType() << "\n"; return failure(); } auto constOp = rewriter.create(loc, elemType, val); auto llStruct = SplatOpConversion::convertSplatLikeOp( elemType, op.getType(), constOp, getTypeConverter(), rewriter, loc); rewriter.replaceOp(op, llStruct); return success(); } }; struct CatOpConversion : public ConvertTritonGPUOpToLLVMPattern { using OpAdaptor = typename CatOp::Adaptor; explicit CatOpConversion(LLVMTypeConverter &typeConverter, PatternBenefit benefit = 1) : ConvertTritonGPUOpToLLVMPattern(typeConverter, benefit) {} LogicalResult matchAndRewrite(CatOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); auto resultTy = op.getType().template cast(); unsigned elems = getElemsPerThread(resultTy); Type elemTy = this->getTypeConverter()->convertType(resultTy.getElementType()); SmallVector types(elems, elemTy); // unpack input values auto lhsVals = getElementsFromStruct(loc, adaptor.lhs(), rewriter); auto rhsVals = getElementsFromStruct(loc, adaptor.rhs(), rewriter); // concatenate (and potentially reorder) values SmallVector retVals; for (Value v : lhsVals) retVals.push_back(v); for (Value v : rhsVals) retVals.push_back(v); // pack and replace Type structTy = LLVM::LLVMStructType::getLiteral(this->getContext(), types); Value ret = getStructFromElements(loc, retVals, rewriter, structTy); rewriter.replaceOp(op, ret); return success(); } }; template struct ViewLikeOpConversion : public ConvertTritonGPUOpToLLVMPattern { using OpAdaptor = typename SourceOp::Adaptor; explicit ViewLikeOpConversion(LLVMTypeConverter &typeConverter, PatternBenefit benefit = 1) : ConvertTritonGPUOpToLLVMPattern(typeConverter, benefit) {} LogicalResult matchAndRewrite(SourceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // We cannot directly run `rewriter.replaceOp(op, adaptor.src())` // due to MLIR's restrictions Location loc = op->getLoc(); auto resultTy = op.getType().template cast(); unsigned elems = getElemsPerThread(resultTy); Type elemTy = this->getTypeConverter()->convertType(resultTy.getElementType()); SmallVector types(elems, elemTy); Type structTy = LLVM::LLVMStructType::getLiteral(this->getContext(), types); auto vals = getElementsFromStruct(loc, adaptor.src(), rewriter); Value view = getStructFromElements(loc, vals, rewriter, structTy); rewriter.replaceOp(op, view); return success(); } }; struct TransOpConversion : public ConvertTritonGPUOpToLLVMPattern { using ConvertTritonGPUOpToLLVMPattern< triton::TransOp>::ConvertTritonGPUOpToLLVMPattern; LogicalResult matchAndRewrite(triton::TransOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); auto srcSmemObj = getSharedMemoryObjectFromStruct(loc, adaptor.src(), rewriter); SmallVector dstStrides = {srcSmemObj.strides[1], srcSmemObj.strides[0]}; SmallVector dstOffsets = {srcSmemObj.offsets[1], srcSmemObj.offsets[0]}; auto dstSmemObj = SharedMemoryObject(srcSmemObj.base, dstStrides, dstOffsets); auto retVal = getStructFromSharedMemoryObject(loc, dstSmemObj, rewriter); rewriter.replaceOp(op, retVal); return success(); } }; void populateViewOpToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, int numWarps, AxisInfoAnalysis &axisInfoAnalysis, const Allocation *allocation, Value smem, PatternBenefit benefit) { patterns.add>(typeConverter, benefit); patterns.add>(typeConverter, benefit); patterns.add(typeConverter, benefit); patterns.add(typeConverter, benefit); patterns.add(typeConverter, benefit); patterns.add(typeConverter, benefit); }