#ifndef TRITON_CONVERSION_TRITONGPU_TO_LLVM_BASE_H #define TRITON_CONVERSION_TRITONGPU_TO_LLVM_BASE_H // TODO: refactor so that it doesn't fail if Allocation.h // is included after utility.h (due to conflict in `store` macro // and #include "triton/Analysis/Allocation.h" // #include "Utility.h" #include "mlir/IR/TypeUtilities.h" #include "triton/Analysis/AxisInfo.h" using namespace mlir; using namespace mlir::triton; using ::mlir::LLVM::SharedMemoryObject; using ::mlir::triton::gpu::BlockedEncodingAttr; using ::mlir::triton::gpu::MmaEncodingAttr; using ::mlir::triton::gpu::SliceEncodingAttr; // FuncOpConversion/FuncOpConversionBase is borrowed from // https://github.com/llvm/llvm-project/blob/fae656b2dd80246c3c6f01e9c77c49560368752c/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp#L276 // since it is not exposed on header files in mlir v14 // TODO(Superjomn): remove the code when MLIR v15.0 is included. // All the rights are reserved by the LLVM community. struct FuncOpConversionBase : public ConvertOpToLLVMPattern { private: /// Only retain those attributes that are not constructed by /// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out argument /// attributes. static void filterFuncAttributes(ArrayRef attrs, bool filterArgAttrs, SmallVectorImpl &result) { for (const auto &attr : attrs) { if (attr.getName() == SymbolTable::getSymbolAttrName() || attr.getName() == FunctionOpInterface::getTypeAttrName() || attr.getName() == "std.varargs" || (filterArgAttrs && attr.getName() == FunctionOpInterface::getArgDictAttrName())) continue; result.push_back(attr); } } /// Helper function for wrapping all attributes into a single DictionaryAttr static auto wrapAsStructAttrs(OpBuilder &b, ArrayAttr attrs) { return DictionaryAttr::get(b.getContext(), b.getNamedAttr("llvm.struct_attrs", attrs)); } protected: using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; // Convert input FuncOp to LLVMFuncOp by using the LLVMTypeConverter provided // to this legalization pattern. LLVM::LLVMFuncOp convertFuncOpToLLVMFuncOp(FuncOp funcOp, ConversionPatternRewriter &rewriter) const { // Convert the original function arguments. They are converted using the // LLVMTypeConverter provided to this legalization pattern. auto varargsAttr = funcOp->getAttrOfType("func.varargs"); TypeConverter::SignatureConversion result(funcOp.getNumArguments()); auto llvmType = getTypeConverter()->convertFunctionSignature( funcOp.getType(), varargsAttr && varargsAttr.getValue(), result); if (!llvmType) return nullptr; // Propagate argument/result attributes to all converted arguments/result // obtained after converting a given original argument/result. SmallVector attributes; filterFuncAttributes(funcOp->getAttrs(), /*filterArgAttrs=*/true, attributes); if (ArrayAttr resAttrDicts = funcOp.getAllResultAttrs()) { assert(!resAttrDicts.empty() && "expected array to be non-empty"); auto newResAttrDicts = (funcOp.getNumResults() == 1) ? resAttrDicts : rewriter.getArrayAttr( {wrapAsStructAttrs(rewriter, resAttrDicts)}); attributes.push_back(rewriter.getNamedAttr( FunctionOpInterface::getResultDictAttrName(), newResAttrDicts)); } if (ArrayAttr argAttrDicts = funcOp.getAllArgAttrs()) { SmallVector newArgAttrs( llvmType.cast().getNumParams()); for (unsigned i = 0, e = funcOp.getNumArguments(); i < e; ++i) { auto mapping = result.getInputMapping(i); assert(mapping && "unexpected deletion of function argument"); for (size_t j = 0; j < mapping->size; ++j) newArgAttrs[mapping->inputNo + j] = argAttrDicts[i]; } attributes.push_back( rewriter.getNamedAttr(FunctionOpInterface::getArgDictAttrName(), rewriter.getArrayAttr(newArgAttrs))); } for (const auto &pair : llvm::enumerate(attributes)) { if (pair.value().getName() == "llvm.linkage") { attributes.erase(attributes.begin() + pair.index()); break; } } // Create an LLVM function, use external linkage by default until MLIR // functions have linkage. LLVM::Linkage linkage = LLVM::Linkage::External; if (funcOp->hasAttr("llvm.linkage")) { auto attr = funcOp->getAttr("llvm.linkage").dyn_cast(); if (!attr) { funcOp->emitError() << "Contains llvm.linkage attribute not of type LLVM::LinkageAttr"; return nullptr; } linkage = attr.getLinkage(); } auto newFuncOp = rewriter.create( funcOp.getLoc(), funcOp.getName(), llvmType, linkage, /*dsoLocal*/ false, attributes); rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(), newFuncOp.end()); if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), *typeConverter, &result))) return nullptr; return newFuncOp; } }; using IndexCacheKeyT = std::pair>; struct CacheKeyDenseMapInfo { static IndexCacheKeyT getEmptyKey() { auto *pointer = llvm::DenseMapInfo::getEmptyKey(); return std::make_pair( mlir::Attribute(static_cast(pointer)), SmallVector{}); } static IndexCacheKeyT getTombstoneKey() { auto *pointer = llvm::DenseMapInfo::getTombstoneKey(); return std::make_pair( mlir::Attribute(static_cast(pointer)), SmallVector{std::numeric_limits::max()}); } static unsigned getHashValue(IndexCacheKeyT key) { return llvm::hash_combine( mlir::hash_value(key.first), llvm::hash_combine_range(key.second.begin(), key.second.end())); } static bool isEqual(IndexCacheKeyT LHS, IndexCacheKeyT RHS) { return LHS == RHS; } }; class ConvertTritonGPUOpToLLVMPatternBase { public: // Two levels of value cache in emitting indices calculation: // Key: pair struct IndexCacheInfo { DenseMap, CacheKeyDenseMapInfo> *baseIndexCache; DenseMap>, CacheKeyDenseMapInfo> *indexCache; OpBuilder::InsertPoint *indexInsertPoint; }; explicit ConvertTritonGPUOpToLLVMPatternBase(LLVMTypeConverter &typeConverter) : converter(&typeConverter) {} explicit ConvertTritonGPUOpToLLVMPatternBase(LLVMTypeConverter &typeConverter, const Allocation *allocation, Value smem) : converter(&typeConverter), allocation(allocation), smem(smem) {} explicit ConvertTritonGPUOpToLLVMPatternBase(LLVMTypeConverter &typeConverter, const Allocation *allocation, Value smem, IndexCacheInfo indexCacheInfo) : converter(&typeConverter), indexCacheInfo(indexCacheInfo), allocation(allocation), smem(smem) {} LLVMTypeConverter *getTypeConverter() const { return converter; } static Value getStructFromSharedMemoryObject(Location loc, const SharedMemoryObject &smemObj, ConversionPatternRewriter &rewriter) { auto elems = smemObj.getElems(); auto types = smemObj.getTypes(); auto structTy = LLVM::LLVMStructType::getLiteral(rewriter.getContext(), types); return getStructFromElements(loc, elems, rewriter, structTy); } Value getThreadId(ConversionPatternRewriter &rewriter, Location loc) const { auto llvmIndexTy = this->getTypeConverter()->getIndexType(); auto cast = rewriter.create( loc, TypeRange{llvmIndexTy}, ValueRange{rewriter.create<::mlir::gpu::ThreadIdOp>( loc, rewriter.getIndexType(), ::mlir::gpu::Dimension::x)}); Value threadId = cast.getResult(0); return threadId; } // ----------------------------------------------------------------------- // Shared memory utilities // ----------------------------------------------------------------------- template Value getSharedMemoryBase(Location loc, ConversionPatternRewriter &rewriter, T value) const { auto ptrTy = LLVM::LLVMPointerType::get( this->getTypeConverter()->convertType(rewriter.getI8Type()), 3); auto bufferId = allocation->getBufferId(value); assert(bufferId != Allocation::InvalidBufferId && "BufferId not found"); size_t offset = allocation->getOffset(bufferId); Value offVal = idx_val(offset); Value base = gep(ptrTy, smem, offVal); return base; } // ----------------------------------------------------------------------- // Utilities // ----------------------------------------------------------------------- // Convert an \param index to a multi-dim coordinate given \param shape and // \param order. SmallVector delinearize(ConversionPatternRewriter &rewriter, Location loc, Value linear, ArrayRef shape, ArrayRef order) const { unsigned rank = shape.size(); assert(rank == order.size()); auto reordered = reorder(shape, order); auto reorderedMultiDim = delinearize(rewriter, loc, linear, reordered); SmallVector multiDim(rank); for (unsigned i = 0; i < rank; ++i) { multiDim[order[i]] = reorderedMultiDim[i]; } return multiDim; } SmallVector delinearize(ConversionPatternRewriter &rewriter, Location loc, Value linear, ArrayRef shape) const { unsigned rank = shape.size(); assert(rank > 0); SmallVector multiDim(rank); if (rank == 1) { multiDim[0] = linear; } else { Value remained = linear; for (auto &&en : llvm::enumerate(shape.drop_back())) { Value dimSize = idx_val(en.value()); multiDim[en.index()] = urem(remained, dimSize); remained = udiv(remained, dimSize); } multiDim[rank - 1] = remained; } return multiDim; } Value linearize(ConversionPatternRewriter &rewriter, Location loc, ArrayRef multiDim, ArrayRef shape, ArrayRef order) const { return linearize(rewriter, loc, reorder(multiDim, order), reorder(shape, order)); } Value linearize(ConversionPatternRewriter &rewriter, Location loc, ArrayRef multiDim, ArrayRef shape) const { auto rank = multiDim.size(); Value linear = idx_val(0); if (rank > 0) { linear = multiDim.back(); for (auto [dim, dimShape] : llvm::reverse(llvm::zip(multiDim.drop_back(), shape.drop_back()))) { Value dimSize = idx_val(dimShape); linear = add(mul(linear, dimSize), dim); } } return linear; } Value dot(ConversionPatternRewriter &rewriter, Location loc, ArrayRef offsets, ArrayRef strides) const { assert(offsets.size() == strides.size()); Value ret = idx_val(0); for (auto [offset, stride] : llvm::zip(offsets, strides)) { ret = add(ret, mul(offset, stride)); } return ret; } struct SmallVectorKeyInfo { static unsigned getHashValue(const SmallVector &key) { return llvm::hash_combine_range(key.begin(), key.end()); } static bool isEqual(const SmallVector &lhs, const SmallVector &rhs) { return lhs == rhs; } static SmallVector getEmptyKey() { return SmallVector(); } static SmallVector getTombstoneKey() { return {std::numeric_limits::max()}; } }; // ----------------------------------------------------------------------- // Get offsets / indices for any layout // ----------------------------------------------------------------------- SmallVector emitBaseIndexForLayout(Location loc, ConversionPatternRewriter &rewriter, const Attribute &layout, ArrayRef shape) const { IndexCacheKeyT key = std::make_pair(layout, llvm::to_vector(shape)); auto cache = indexCacheInfo.baseIndexCache; assert(cache && "baseIndexCache is nullptr"); auto insertPt = indexCacheInfo.indexInsertPoint; if (cache->count(key) > 0) { return cache->lookup(key); } else { ConversionPatternRewriter::InsertionGuard guard(rewriter); restoreInsertionPointIfSet(insertPt, rewriter); SmallVector result; if (auto blockedLayout = layout.dyn_cast()) { result = emitBaseIndexForBlockedLayout(loc, rewriter, blockedLayout, shape); } else if (auto mmaLayout = layout.dyn_cast()) { if (mmaLayout.isVolta()) result = emitBaseIndexForMmaLayoutV1(loc, rewriter, mmaLayout, shape); if (mmaLayout.isAmpere()) result = emitBaseIndexForMmaLayoutV2(loc, rewriter, mmaLayout, shape); } else { llvm_unreachable("unsupported emitBaseIndexForLayout"); } cache->insert(std::make_pair(key, result)); *insertPt = rewriter.saveInsertionPoint(); return result; } } SmallVector> emitOffsetForLayout(const Attribute &layout, ArrayRef shape) const { if (auto blockedLayout = layout.dyn_cast()) return emitOffsetForBlockedLayout(blockedLayout, shape); if (auto mmaLayout = layout.dyn_cast()) { if (mmaLayout.isVolta()) return emitOffsetForMmaLayoutV1(mmaLayout, shape); if (mmaLayout.isAmpere()) return emitOffsetForMmaLayoutV2(mmaLayout, shape); } llvm_unreachable("unsupported emitOffsetForLayout"); } // ----------------------------------------------------------------------- // Emit indices // ----------------------------------------------------------------------- SmallVector> emitIndices(Location loc, ConversionPatternRewriter &b, const Attribute &layout, ArrayRef shape) const { IndexCacheKeyT key(layout, llvm::to_vector(shape)); auto cache = indexCacheInfo.indexCache; assert(cache && "indexCache is nullptr"); auto insertPt = indexCacheInfo.indexInsertPoint; if (cache->count(key) > 0) { return cache->lookup(key); } else { ConversionPatternRewriter::InsertionGuard guard(b); restoreInsertionPointIfSet(insertPt, b); SmallVector> result; if (auto blocked = layout.dyn_cast()) { result = emitIndicesForDistributedLayout(loc, b, blocked, shape); } else if (auto mma = layout.dyn_cast()) { result = emitIndicesForDistributedLayout(loc, b, mma, shape); } else if (auto slice = layout.dyn_cast()) { result = emitIndicesForSliceLayout(loc, b, slice, shape); } else { llvm_unreachable( "emitIndices for layouts other than blocked & slice not " "implemented yet"); } cache->insert(std::make_pair(key, result)); *insertPt = b.saveInsertionPoint(); return result; } } private: void restoreInsertionPointIfSet(OpBuilder::InsertPoint *insertPt, ConversionPatternRewriter &rewriter) const { if (insertPt->isSet()) { rewriter.restoreInsertionPoint(*insertPt); } else { auto func = rewriter.getInsertionPoint()->getParentOfType(); rewriter.setInsertionPointToStart(&func.getBody().front()); } } // ----------------------------------------------------------------------- // Blocked layout indices // ----------------------------------------------------------------------- // Get an index-base for each dimension for a \param blocked_layout. SmallVector emitBaseIndexForBlockedLayout(Location loc, ConversionPatternRewriter &rewriter, const BlockedEncodingAttr &blocked_layout, ArrayRef shape) const { Value threadId = getThreadId(rewriter, loc); Value warpSize = idx_val(32); Value laneId = urem(threadId, warpSize); Value warpId = udiv(threadId, warpSize); auto sizePerThread = blocked_layout.getSizePerThread(); auto threadsPerWarp = blocked_layout.getThreadsPerWarp(); auto warpsPerCTA = blocked_layout.getWarpsPerCTA(); auto order = blocked_layout.getOrder(); unsigned rank = shape.size(); // delinearize threadId to get the base index SmallVector multiDimWarpId = delinearize(rewriter, loc, warpId, warpsPerCTA, order); SmallVector multiDimThreadId = delinearize(rewriter, loc, laneId, threadsPerWarp, order); SmallVector multiDimBase(rank); for (unsigned k = 0; k < rank; ++k) { // Wrap around multiDimWarpId/multiDimThreadId incase // shape[k] > shapePerCTA[k] auto maxWarps = ceil(shape[k], sizePerThread[k] * threadsPerWarp[k]); auto maxThreads = ceil(shape[k], sizePerThread[k]); multiDimWarpId[k] = urem(multiDimWarpId[k], idx_val(maxWarps)); multiDimThreadId[k] = urem(multiDimThreadId[k], idx_val(maxThreads)); // multiDimBase[k] = (multiDimThreadId[k] + // multiDimWarpId[k] * threadsPerWarp[k]) * // sizePerThread[k]; Value threadsPerWarpK = idx_val(threadsPerWarp[k]); Value sizePerThreadK = idx_val(sizePerThread[k]); multiDimBase[k] = mul(sizePerThreadK, add(multiDimThreadId[k], mul(multiDimWarpId[k], threadsPerWarpK))); } return multiDimBase; } SmallVector> emitOffsetForBlockedLayout(const BlockedEncodingAttr &blockedLayout, ArrayRef shape) const { auto sizePerThread = blockedLayout.getSizePerThread(); auto threadsPerWarp = blockedLayout.getThreadsPerWarp(); auto warpsPerCTA = blockedLayout.getWarpsPerCTA(); auto order = blockedLayout.getOrder(); unsigned rank = shape.size(); SmallVector shapePerCTA = getShapePerCTA(blockedLayout); SmallVector tilesPerDim(rank); for (unsigned k = 0; k < rank; ++k) tilesPerDim[k] = ceil(shape[k], shapePerCTA[k]); SmallVector> offset(rank); for (unsigned k = 0; k < rank; ++k) { // 1 block in minimum if shape[k] is less than shapePerCTA[k] for (unsigned blockOffset = 0; blockOffset < tilesPerDim[k]; ++blockOffset) for (unsigned warpOffset = 0; warpOffset < warpsPerCTA[k]; ++warpOffset) for (unsigned threadOffset = 0; threadOffset < threadsPerWarp[k]; ++threadOffset) for (unsigned elemOffset = 0; elemOffset < sizePerThread[k]; ++elemOffset) offset[k].push_back(blockOffset * sizePerThread[k] * threadsPerWarp[k] * warpsPerCTA[k] + warpOffset * sizePerThread[k] * threadsPerWarp[k] + threadOffset * sizePerThread[k] + elemOffset); } unsigned elemsPerThread = blockedLayout.getElemsPerThread(shape); unsigned totalSizePerThread = product(sizePerThread); SmallVector> reorderedOffset(elemsPerThread); for (unsigned n = 0; n < elemsPerThread; ++n) { unsigned linearNanoTileId = n / totalSizePerThread; unsigned linearNanoTileElemId = n % totalSizePerThread; SmallVector multiDimNanoTileId = getMultiDimIndex(linearNanoTileId, tilesPerDim, order); SmallVector multiDimNanoTileElemId = getMultiDimIndex( linearNanoTileElemId, sizePerThread, order); for (unsigned k = 0; k < rank; ++k) { unsigned reorderedMultiDimId = multiDimNanoTileId[k] * (sizePerThread[k] * threadsPerWarp[k] * warpsPerCTA[k]) + multiDimNanoTileElemId[k]; reorderedOffset[n].push_back(offset[k][reorderedMultiDimId]); } } return reorderedOffset; } // ----------------------------------------------------------------------- // Mma layout indices // ----------------------------------------------------------------------- SmallVector emitBaseIndexForMmaLayoutV1(Location loc, ConversionPatternRewriter &rewriter, const MmaEncodingAttr &mmaLayout, ArrayRef shape) const { llvm_unreachable("emitIndicesForMmaLayoutV1 not implemented"); } SmallVector> emitOffsetForMmaLayoutV1(const MmaEncodingAttr &mmaLayout, ArrayRef shape) const { SmallVector> ret; for (unsigned i = 0; i < shape[0]; i += getShapePerCTA(mmaLayout)[0]) { for (unsigned j = 0; j < shape[1]; j += getShapePerCTA(mmaLayout)[1]) { ret.push_back({i, j}); ret.push_back({i, j + 1}); ret.push_back({i + 2, j}); ret.push_back({i + 2, j + 1}); ret.push_back({i, j + 8}); ret.push_back({i, j + 9}); ret.push_back({i + 2, j + 8}); ret.push_back({i + 2, j + 9}); } } return ret; } SmallVector emitBaseIndexForMmaLayoutV2(Location loc, ConversionPatternRewriter &rewriter, const MmaEncodingAttr &mmaLayout, ArrayRef shape) const { auto _warpsPerCTA = mmaLayout.getWarpsPerCTA(); assert(_warpsPerCTA.size() == 2); SmallVector warpsPerCTA = {idx_val(_warpsPerCTA[0]), idx_val(_warpsPerCTA[1])}; Value threadId = getThreadId(rewriter, loc); Value warpSize = idx_val(32); Value laneId = urem(threadId, warpSize); Value warpId = udiv(threadId, warpSize); Value warpId0 = urem(warpId, warpsPerCTA[0]); Value warpId1 = urem(udiv(warpId, warpsPerCTA[0]), warpsPerCTA[1]); Value offWarp0 = mul(warpId0, idx_val(16)); Value offWarp1 = mul(warpId1, idx_val(8)); SmallVector multiDimBase(2); multiDimBase[0] = add(udiv(laneId, idx_val(4)), offWarp0); multiDimBase[1] = add(mul(idx_val(2), urem(laneId, idx_val(4))), offWarp1); return multiDimBase; } SmallVector> emitOffsetForMmaLayoutV2(const MmaEncodingAttr &mmaLayout, ArrayRef shape) const { SmallVector> ret; for (unsigned i = 0; i < shape[0]; i += getShapePerCTA(mmaLayout)[0]) { for (unsigned j = 0; j < shape[1]; j += getShapePerCTA(mmaLayout)[1]) { ret.push_back({i, j}); ret.push_back({i, j + 1}); ret.push_back({i + 8, j}); ret.push_back({i + 8, j + 1}); } } return ret; } // Emit indices calculation within each ConversionPattern, and returns a // [elemsPerThread X rank] index matrix. // TODO: [phil] redundant indices computation do not appear to hurt // performance much, but they could still significantly slow down // computations. SmallVector> emitIndicesForDistributedLayout( Location loc, ConversionPatternRewriter &rewriter, const Attribute &layout, ArrayRef shape) const { // step 1, delinearize threadId to get the base index auto multiDimBase = emitBaseIndexForLayout(loc, rewriter, layout, shape); // step 2, get offset of each element auto offset = emitOffsetForLayout(layout, shape); // step 3, add offset to base, and reorder the sequence of indices to // guarantee that elems in the same sizePerThread are adjacent in order unsigned rank = shape.size(); unsigned elemsPerThread = offset.size(); SmallVector> multiDimIdx(elemsPerThread, SmallVector(rank)); for (unsigned n = 0; n < elemsPerThread; ++n) for (unsigned k = 0; k < rank; ++k) multiDimIdx[n][k] = add(multiDimBase[k], idx_val(offset[n][k])); return multiDimIdx; } SmallVector> emitIndicesForSliceLayout(Location loc, ConversionPatternRewriter &rewriter, const SliceEncodingAttr &sliceLayout, ArrayRef shape) const { auto parent = sliceLayout.getParent(); unsigned dim = sliceLayout.getDim(); size_t rank = shape.size(); auto parentIndices = emitIndices(loc, rewriter, parent, sliceLayout.paddedShape(shape)); unsigned numIndices = parentIndices.size(); SmallVector> resultIndices; for (unsigned i = 0; i < numIndices; ++i) { SmallVector indices = parentIndices[i]; indices.erase(indices.begin() + dim); resultIndices.push_back(indices); } return resultIndices; } protected: LLVMTypeConverter *converter; const Allocation *allocation; Value smem; IndexCacheInfo indexCacheInfo; }; template class ConvertTritonGPUOpToLLVMPattern : public ConvertOpToLLVMPattern, public ConvertTritonGPUOpToLLVMPatternBase { public: using OpAdaptor = typename SourceOp::Adaptor; explicit ConvertTritonGPUOpToLLVMPattern(LLVMTypeConverter &typeConverter, PatternBenefit benefit = 1) : ConvertOpToLLVMPattern(typeConverter, benefit), ConvertTritonGPUOpToLLVMPatternBase(typeConverter) {} explicit ConvertTritonGPUOpToLLVMPattern(LLVMTypeConverter &typeConverter, const Allocation *allocation, Value smem, PatternBenefit benefit = 1) : ConvertOpToLLVMPattern(typeConverter, benefit), ConvertTritonGPUOpToLLVMPatternBase(typeConverter, allocation, smem) {} explicit ConvertTritonGPUOpToLLVMPattern(LLVMTypeConverter &typeConverter, const Allocation *allocation, Value smem, IndexCacheInfo indexCacheInfo, PatternBenefit benefit = 1) : ConvertOpToLLVMPattern(typeConverter, benefit), ConvertTritonGPUOpToLLVMPatternBase(typeConverter, allocation, smem, indexCacheInfo) {} protected: LLVMTypeConverter *getTypeConverter() const { return ((ConvertTritonGPUOpToLLVMPatternBase *)this)->getTypeConverter(); } }; #endif