diff --git a/bin/triton-translate.cpp b/bin/triton-translate.cpp index bfb658a96..f39e32aec 100644 --- a/bin/triton-translate.cpp +++ b/bin/triton-translate.cpp @@ -107,7 +107,8 @@ LogicalResult tritonTranslateMain(int argc, char **argv, } llvm::LLVMContext llvmContext; - auto llvmir = translateTritonGPUToLLVMIR(&llvmContext, *module); + auto llvmir = + translateTritonGPUToLLVMIR(&llvmContext, *module, SMArch.getValue()); if (!llvmir) { llvm::errs() << "Translate to LLVM IR failed"; } diff --git a/include/triton/Analysis/Utility.h b/include/triton/Analysis/Utility.h index 0053776c4..22f6692da 100644 --- a/include/triton/Analysis/Utility.h +++ b/include/triton/Analysis/Utility.h @@ -12,6 +12,8 @@ bool isSharedEncoding(Value value); bool maybeSharedAllocationOp(Operation *op); +bool maybeAliasOp(Operation *op); + std::string getValueOperandName(Value value, AsmState &state); template Int product(llvm::ArrayRef arr) { diff --git a/include/triton/Conversion/Passes.td b/include/triton/Conversion/Passes.td index e78b4dc4f..70bb20b78 100644 --- a/include/triton/Conversion/Passes.td +++ b/include/triton/Conversion/Passes.td @@ -43,6 +43,12 @@ def ConvertTritonGPUToLLVM : Pass<"convert-triton-gpu-to-llvm", "mlir::ModuleOp" "mlir::triton::gpu::TritonGPUDialect", "mlir::NVVM::NVVMDialect", "mlir::StandardOpsDialect"]; + + let options = [ + Option<"computeCapability", "compute-capability", + "int32_t", /*default*/"80", + "device compute capability"> + ]; } #endif diff --git a/include/triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h b/include/triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h index 7c4143c11..86ddc066d 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h +++ b/include/triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h @@ -33,7 +33,8 @@ struct NVVMMetadataField { static constexpr char Kernel[] = "nvvm.kernel"; }; -std::unique_ptr> createConvertTritonGPUToLLVMPass(); +std::unique_ptr> +createConvertTritonGPUToLLVMPass(int computeCapability = 80); } // namespace triton diff --git a/include/triton/Target/LLVMIR/LLVMIRTranslation.h b/include/triton/Target/LLVMIR/LLVMIRTranslation.h index 1b8b399d7..a540a9478 100644 --- a/include/triton/Target/LLVMIR/LLVMIRTranslation.h +++ b/include/triton/Target/LLVMIR/LLVMIRTranslation.h @@ -25,7 +25,8 @@ void addExternalLibs(mlir::ModuleOp &module, // Translate TritonGPU dialect to LLVMIR, return null if failed. std::unique_ptr translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext, - mlir::ModuleOp module); + mlir::ModuleOp module, + int computeCapability); // Translate mlir LLVM dialect to LLVMIR, return null if failed. std::unique_ptr diff --git a/lib/Analysis/Alias.cpp b/lib/Analysis/Alias.cpp index b938b29a6..25ba3aeb0 100644 --- a/lib/Analysis/Alias.cpp +++ b/lib/Analysis/Alias.cpp @@ -26,13 +26,14 @@ ChangeResult SharedMemoryAliasAnalysis::visitOperation( // These ops may allocate a new shared memory buffer. auto result = op->getResult(0); // FIXME(Keren): extract and insert are always alias for now - if (auto extractSliceOp = dyn_cast(op)) { + if (isa(op)) { // extract_slice %src aliasInfo = AliasInfo(operands[0]->getValue()); pessimistic = false; - } else if (auto insertSliceOp = - dyn_cast(op)) { + } else if (isa(op) || + isa(op)) { // insert_slice_async %src, %dst, %index + // insert_slice %src into %dst[%offsets] aliasInfo = AliasInfo(operands[1]->getValue()); pessimistic = false; } else if (isSharedEncoding(result)) { diff --git a/lib/Analysis/Allocation.cpp b/lib/Analysis/Allocation.cpp index 82e91224f..a155df8de 100644 --- a/lib/Analysis/Allocation.cpp +++ b/lib/Analysis/Allocation.cpp @@ -28,7 +28,7 @@ namespace mlir { namespace triton { // Bitwidth of pointers -constexpr int kPtrBitWidth = 64; +constexpr int kPtrBitWidth = 64; static std::pair, SmallVector> getCvtOrder(const Attribute &srcLayout, const Attribute &dstLayout) { @@ -155,8 +155,7 @@ private: // For example: %a = scf.if -> yield // %a must be allocated elsewhere by other operations. // FIXME(Keren): extract and insert are always alias for now - if (!maybeSharedAllocationOp(op) || isa(op) || - isa(op)) { + if (!maybeSharedAllocationOp(op) || maybeAliasOp(op)) { return; } @@ -210,9 +209,9 @@ private: auto smemShape = getScratchConfigForCvtLayout(cvtLayout, inVec, outVec); unsigned elems = std::accumulate(smemShape.begin(), smemShape.end(), 1, std::multiplies{}); - auto bytes = srcTy.getElementType().isa()? - elems * kPtrBitWidth / 8 : - elems * srcTy.getElementTypeBitWidth() / 8; + auto bytes = srcTy.getElementType().isa() + ? elems * kPtrBitWidth / 8 + : elems * srcTy.getElementTypeBitWidth() / 8; allocation->addBuffer(op, bytes); } else if (auto atomicRMWOp = dyn_cast(op)) { auto value = op->getOperand(0); diff --git a/lib/Analysis/Membar.cpp b/lib/Analysis/Membar.cpp index 2822b7ace..715265c0a 100644 --- a/lib/Analysis/Membar.cpp +++ b/lib/Analysis/Membar.cpp @@ -1,4 +1,5 @@ #include "triton/Analysis/Membar.h" +#include "triton/Analysis/Alias.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "mlir/Dialect/GPU/GPUDialect.h" @@ -71,11 +72,17 @@ void MembarAnalysis::transfer(Operation *op, RegionInfo *regionInfo, RegionInfo curRegionInfo; for (Value value : op->getOperands()) { - // ConvertLayoutOp: shared memory -> registers - // Need to consider all alias buffers for (auto bufferId : allocation->getBufferIds(value)) { if (bufferId != Allocation::InvalidBufferId) { - curRegionInfo.syncReadBuffers.insert(bufferId); + if (isa(op) || + isa(op)) { + // FIXME(Keren): insert_slice and insert_slice_async are always alias + // for now + curRegionInfo.syncWriteBuffers.insert(bufferId); + } else { + // ConvertLayoutOp: shared memory -> registers + curRegionInfo.syncReadBuffers.insert(bufferId); + } } } } diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index 5fcb9654d..780b9bf9a 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -28,6 +28,12 @@ bool maybeSharedAllocationOp(Operation *op) { dialect->getTypeID() == mlir::TypeID::get()); } +bool maybeAliasOp(Operation *op) { + return isa(op) || + isa(op) || + isa(op); +} + std::string getValueOperandName(Value value, AsmState &state) { std::string opName; llvm::raw_string_ostream ss(opName); diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index 0c4557698..471ab5b72 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -205,6 +205,20 @@ auto wrapAsStructAttrs(OpBuilder &b, ArrayAttr attrs) { b.getContext(), b.getNamedAttr(LLVM::getStructAttrsAttrName(), attrs)); } +/// Helper function to get strides from a given shape and its order +auto getStridesFromShapeAndOrder(ArrayRef shape, + ArrayRef order, Location loc, + ConversionPatternRewriter &rewriter) { + auto rank = shape.size(); + SmallVector strides(rank); + auto stride = 1; + for (auto idx : order) { + strides[idx] = i32_val(stride); + stride *= shape[idx]; + } + return strides; +} + struct FuncOpConversionBase : public ConvertOpToLLVMPattern { protected: using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; @@ -452,13 +466,10 @@ struct SharedMemoryObject { ArrayRef order, Location loc, ConversionPatternRewriter &rewriter) : base(base) { - auto rank = shape.size(); - auto stride = 1; - strides.resize(rank); + strides = getStridesFromShapeAndOrder(shape, order, loc, rewriter); + for (auto idx : order) { - strides[idx] = i32_val(stride); offsets.emplace_back(i32_val(0)); - stride *= shape[idx]; } } @@ -2835,6 +2846,112 @@ public: return failure(); } + static void storeBlockedToShared(Value src, Value llSrc, + ArrayRef srcStrides, + ArrayRef srcIndices, Value dst, + Value smemBase, Type elemPtrTy, Location loc, + ConversionPatternRewriter &rewriter) { + auto srcTy = src.getType().cast(); + auto srcShape = srcTy.getShape(); + assert(srcShape.size() == 2 && "Unexpected rank of insertSlice"); + + auto elemTy = srcTy.getElementType(); + auto dstTy = dst.getType().cast(); + auto srcBlockedLayout = srcTy.getEncoding().cast(); + auto dstSharedLayout = dstTy.getEncoding().cast(); + auto inOrd = srcBlockedLayout.getOrder(); + auto outOrd = dstSharedLayout.getOrder(); + unsigned inVec = + inOrd == outOrd ? srcBlockedLayout.getSizePerThread()[inOrd[0]] : 1; + unsigned outVec = dstSharedLayout.getVec(); + unsigned minVec = std::min(outVec, inVec); + unsigned perPhase = dstSharedLayout.getPerPhase(); + unsigned maxPhase = dstSharedLayout.getMaxPhase(); + unsigned numElems = getElemsPerThread(srcTy); + auto inVals = getElementsFromStruct(loc, llSrc, rewriter); + auto srcAccumSizeInThreads = + product(srcBlockedLayout.getSizePerThread()); + auto wordTy = vec_ty(elemTy, minVec); + + // TODO: [goostavz] We should make a cache for the calculation of + // emitBaseIndexForBlockedLayout in case backend compiler not being able to + // optimize that + SmallVector srcShapePerCTA = getShapePerCTA(srcBlockedLayout); + SmallVector reps{ceil(srcShape[0], srcShapePerCTA[0]), + ceil(srcShape[1], srcShapePerCTA[1])}; + + // Visit each input value in the order they are placed in inVals + // + // Please note that the order was not awaring of blockLayout.getOrder(), + // thus the adjacent elems may not belong to a same word. This could be + // improved if we update the elements order by emitIndicesForBlockedLayout() + SmallVector wordsInEachRep(2); + wordsInEachRep[0] = inOrd[0] == 0 + ? srcBlockedLayout.getSizePerThread()[0] / minVec + : srcBlockedLayout.getSizePerThread()[0]; + wordsInEachRep[1] = inOrd[0] == 0 + ? srcBlockedLayout.getSizePerThread()[1] + : srcBlockedLayout.getSizePerThread()[1] / minVec; + Value outVecVal = i32_val(outVec); + Value minVecVal = i32_val(minVec); + auto numWordsEachRep = product(wordsInEachRep); + SmallVector wordVecs(numWordsEachRep); + for (unsigned i = 0; i < numElems; ++i) { + if (i % srcAccumSizeInThreads == 0) { + // start of a replication + for (unsigned w = 0; w < numWordsEachRep; ++w) { + wordVecs[w] = undef(wordTy); + } + } + unsigned linearIdxInNanoTile = i % srcAccumSizeInThreads; + auto multiDimIdxInNanoTile = getMultiDimIndex( + linearIdxInNanoTile, srcBlockedLayout.getSizePerThread(), inOrd); + unsigned pos = multiDimIdxInNanoTile[inOrd[0]] % minVec; + multiDimIdxInNanoTile[inOrd[0]] /= minVec; + auto wordVecIdx = getLinearIndex(multiDimIdxInNanoTile, + wordsInEachRep, inOrd); + wordVecs[wordVecIdx] = + insert_element(wordTy, wordVecs[wordVecIdx], inVals[i], i32_val(pos)); + + if (i % srcAccumSizeInThreads == srcAccumSizeInThreads - 1) { + // end of replication, store the vectors into shared memory + unsigned linearRepIdx = i / srcAccumSizeInThreads; + auto multiDimRepIdx = + getMultiDimIndex(linearRepIdx, reps, inOrd); + for (unsigned linearWordIdx = 0; linearWordIdx < numWordsEachRep; + ++linearWordIdx) { + // step 1: recover the multidim_index from the index of input_elements + auto multiDimWordIdx = + getMultiDimIndex(linearWordIdx, wordsInEachRep, inOrd); + SmallVector multiDimIdx(2); + auto wordOffset0 = multiDimRepIdx[0] * srcShapePerCTA[0] + + multiDimWordIdx[0] * (inOrd[0] == 0 ? minVec : 1); + auto wordOffset1 = multiDimRepIdx[1] * srcShapePerCTA[1] + + multiDimWordIdx[1] * (inOrd[0] == 1 ? minVec : 1); + multiDimIdx[0] = add(srcIndices[0], i32_val(wordOffset0)); + multiDimIdx[1] = add(srcIndices[1], i32_val(wordOffset1)); + + // step 2: do swizzling + Value remained = urem(multiDimIdx[outOrd[0]], outVecVal); + multiDimIdx[outOrd[0]] = udiv(multiDimIdx[outOrd[0]], outVecVal); + Value off_1 = mul(multiDimIdx[outOrd[1]], srcStrides[outOrd[1]]); + Value phaseId = udiv(multiDimIdx[outOrd[1]], i32_val(perPhase)); + phaseId = urem(phaseId, i32_val(maxPhase)); + Value off_0 = xor_(multiDimIdx[outOrd[0]], phaseId); + off_0 = mul(off_0, outVecVal); + remained = udiv(remained, minVecVal); + off_0 = add(off_0, mul(remained, minVecVal)); + Value offset = add(off_1, off_0); + + // step 3: store + Value smemAddr = gep(elemPtrTy, smemBase, offset); + smemAddr = bitcast(smemAddr, ptr_ty(wordTy, 3)); + store(wordVecs[linearWordIdx], smemAddr); + } + } + } + } + private: SmallVector getMultiDimOffset(Attribute layout, Location loc, ConversionPatternRewriter &rewriter, @@ -3129,110 +3246,91 @@ LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared( auto dstSharedLayout = dstTy.getEncoding().cast(); auto inOrd = srcBlockedLayout.getOrder(); auto outOrd = dstSharedLayout.getOrder(); - unsigned inVec = - inOrd == outOrd ? srcBlockedLayout.getSizePerThread()[inOrd[0]] : 1; - unsigned outVec = dstSharedLayout.getVec(); - unsigned minVec = std::min(outVec, inVec); - unsigned perPhase = dstSharedLayout.getPerPhase(); - unsigned maxPhase = dstSharedLayout.getMaxPhase(); - unsigned numElems = getElemsPerThread(srcTy); - auto inVals = getElementsFromStruct(loc, adaptor.src(), rewriter); - auto srcAccumSizeInThreads = - product(srcBlockedLayout.getSizePerThread()); - auto elemTy = srcTy.getElementType(); - auto wordTy = vec_ty(elemTy, minVec); - - // TODO: [goostavz] We should make a cache for the calculation of - // emitBaseIndexForBlockedLayout in case backend compiler not being able to - // optimize that - SmallVector multiDimOffsetFirstElem = - emitBaseIndexForBlockedLayout(loc, rewriter, srcBlockedLayout, srcShape); - SmallVector srcShapePerCTA = getShapePerCTA(srcBlockedLayout); - SmallVector reps{ceil(srcShape[0], srcShapePerCTA[0]), - ceil(srcShape[1], srcShapePerCTA[1])}; - - // Visit each input value in the order they are placed in inVals - // - // Please note that the order was not awaring of blockLayout.getOrder(), - // thus the adjacent elems may not belong to a same word. This could be - // improved if we update the elements order by emitIndicesForBlockedLayout() - SmallVector wordsInEachRep(2); - wordsInEachRep[0] = inOrd[0] == 0 - ? srcBlockedLayout.getSizePerThread()[0] / minVec - : srcBlockedLayout.getSizePerThread()[0]; - wordsInEachRep[1] = inOrd[0] == 0 - ? srcBlockedLayout.getSizePerThread()[1] - : srcBlockedLayout.getSizePerThread()[1] / minVec; - Value outVecVal = idx_val(outVec); - Value minVecVal = idx_val(minVec); Value smemBase = getSharedMemoryBase(loc, rewriter, dst); + auto elemTy = getTypeConverter()->convertType(srcTy.getElementType()); auto elemPtrTy = ptr_ty(getTypeConverter()->convertType(elemTy), 3); smemBase = bitcast(smemBase, elemPtrTy); + + auto srcStrides = getStridesFromShapeAndOrder(srcShape, inOrd, loc, rewriter); + auto srcIndices = + emitBaseIndexForBlockedLayout(loc, rewriter, srcBlockedLayout, srcShape); + storeBlockedToShared(src, adaptor.src(), srcStrides, srcIndices, dst, + smemBase, elemPtrTy, loc, rewriter); + auto smemObj = SharedMemoryObject(smemBase, dstShape, outOrd, loc, rewriter); auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter); - auto numWordsEachRep = product(wordsInEachRep); - SmallVector wordVecs(numWordsEachRep); - for (unsigned i = 0; i < numElems; ++i) { - if (i % srcAccumSizeInThreads == 0) { - // start of a replication - for (unsigned w = 0; w < numWordsEachRep; ++w) { - wordVecs[w] = undef(wordTy); - } - } - unsigned linearIdxInNanoTile = i % srcAccumSizeInThreads; - auto multiDimIdxInNanoTile = getMultiDimIndex( - linearIdxInNanoTile, srcBlockedLayout.getSizePerThread(), inOrd); - unsigned pos = multiDimIdxInNanoTile[inOrd[0]] % minVec; - multiDimIdxInNanoTile[inOrd[0]] /= minVec; - auto wordVecIdx = - getLinearIndex(multiDimIdxInNanoTile, wordsInEachRep, inOrd); - wordVecs[wordVecIdx] = - insert_element(wordTy, wordVecs[wordVecIdx], inVals[i], idx_val(pos)); - - if (i % srcAccumSizeInThreads == srcAccumSizeInThreads - 1) { - // end of replication, store the vectors into shared memory - unsigned linearRepIdx = i / srcAccumSizeInThreads; - auto multiDimRepIdx = - getMultiDimIndex(linearRepIdx, reps, inOrd); - for (unsigned linearWordIdx = 0; linearWordIdx < numWordsEachRep; - ++linearWordIdx) { - // step 1: recover the multidim_index from the index of input_elements - auto multiDimWordIdx = - getMultiDimIndex(linearWordIdx, wordsInEachRep, inOrd); - SmallVector multiDimIdx(2); - auto wordOffset0 = multiDimRepIdx[0] * srcShapePerCTA[0] + - multiDimWordIdx[0] * (inOrd[0] == 0 ? minVec : 1); - auto wordOffset1 = multiDimRepIdx[1] * srcShapePerCTA[1] + - multiDimWordIdx[1] * (inOrd[0] == 1 ? minVec : 1); - multiDimIdx[0] = add(multiDimOffsetFirstElem[0], idx_val(wordOffset0)); - multiDimIdx[1] = add(multiDimOffsetFirstElem[1], idx_val(wordOffset1)); - - // step 2: do swizzling - Value remained = urem(multiDimIdx[outOrd[0]], outVecVal); - multiDimIdx[outOrd[0]] = udiv(multiDimIdx[outOrd[0]], outVecVal); - Value off_1 = mul(multiDimIdx[outOrd[1]], idx_val(srcShape[outOrd[0]])); - Value phaseId = udiv(multiDimIdx[outOrd[1]], idx_val(perPhase)); - phaseId = urem(phaseId, idx_val(maxPhase)); - Value off_0 = xor_(multiDimIdx[outOrd[0]], phaseId); - off_0 = mul(off_0, outVecVal); - remained = udiv(remained, minVecVal); - off_0 = add(off_0, mul(remained, minVecVal)); - Value offset = add(off_1, off_0); - - // step 3: store - Value smemAddr = gep(elemPtrTy, smemBase, offset); - smemAddr = bitcast(smemAddr, ptr_ty(wordTy, 3)); - store(wordVecs[linearWordIdx], smemAddr); - } - } - } - // Barrier is not necessary. - // The membar pass knows that it writes to shared memory and will handle it - // properly. rewriter.replaceOp(op, retVal); return success(); } +struct InsertSliceOpConversion + : public ConvertTritonGPUOpToLLVMPattern { + using ConvertTritonGPUOpToLLVMPattern< + tensor::InsertSliceOp>::ConvertTritonGPUOpToLLVMPattern; + + LogicalResult + matchAndRewrite(tensor::InsertSliceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // %dst = insert_slice %src into %dst[%offsets] + Location loc = op->getLoc(); + Value dst = op.dest(); + Value src = op.source(); + Value res = op.result(); + assert(allocation->getBufferId(res) == Allocation::InvalidBufferId && + "Only support in-place insert_slice for now"); + + auto srcTy = src.getType().dyn_cast(); + auto srcLayout = srcTy.getEncoding().dyn_cast(); + auto srcShape = srcTy.getShape(); + assert(srcLayout && "Unexpected srcLayout in InsertSliceOpConversion"); + + auto dstTy = dst.getType().dyn_cast(); + auto dstLayout = dstTy.getEncoding().dyn_cast(); + auto llDst = adaptor.dest(); + assert(dstLayout && "Unexpected dstLayout in InsertSliceOpConversion"); + assert(op.hasUnitStride() && + "Only unit stride supported by InsertSliceOpConversion"); + + // newBase = base + offset + // Triton support either static and dynamic offsets + auto smemObj = getSharedMemoryObjectFromStruct(loc, llDst, rewriter); + SmallVector offsets; + SmallVector srcStrides; + auto mixedOffsets = op.getMixedOffsets(); + for (auto i = 0; i < mixedOffsets.size(); ++i) { + if (op.isDynamicOffset(i)) { + offsets.emplace_back(adaptor.offsets()[i]); + } else { + offsets.emplace_back(i32_val(op.getStaticOffset(i))); + } + // Like insert_slice_async, we only support slice from one dimension, + // which has a slice size of 1 + if (op.getStaticSize(i) != 1) { + srcStrides.emplace_back(smemObj.strides[i]); + } + } + + // Compute the offset based on the original strides of the shared memory + // object + auto offset = dot(rewriter, loc, offsets, smemObj.strides); + auto llvmElemTy = getTypeConverter()->convertType(dstTy.getElementType()); + auto elemPtrTy = ptr_ty(llvmElemTy, 3); + auto smemBase = gep(elemPtrTy, smemObj.base, offset); + + auto llSrc = adaptor.source(); + auto srcIndices = + emitBaseIndexForBlockedLayout(loc, rewriter, srcLayout, srcShape); + ConvertLayoutOpConversion::storeBlockedToShared(src, llSrc, srcStrides, + srcIndices, dst, smemBase, + elemPtrTy, loc, rewriter); + // Barrier is not necessary. + // The membar pass knows that it writes to shared memory and will handle it + // properly. + rewriter.replaceOp(op, llDst); + return success(); + } +}; + /// ====================== dot codegen begin ========================== // Data loader for mma.16816 instruction. @@ -5972,7 +6070,7 @@ struct AtomicRMWOpConversion auto valElements = getElementsFromStruct(loc, llVal, rewriter); auto ptrElements = getElementsFromStruct(loc, llPtr, rewriter); auto maskElements = getElementsFromStruct(loc, llMask, rewriter); - + auto valueTy = op.getResult().getType().dyn_cast(); Type valueElemTy = valueTy ? getTypeConverter()->convertType(valueTy.getElementType()) @@ -6166,11 +6264,14 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter, patterns.add(typeConverter, allocation, smem, benefit); patterns.add(typeConverter, allocation, smem, benefit); - patterns.add(typeConverter, allocation, smem, axisInfoAnalysis, benefit); + patterns.add(typeConverter, allocation, smem, + axisInfoAnalysis, benefit); patterns.add(typeConverter, allocation, smem, benefit); patterns.add(typeConverter, benefit); patterns.add(typeConverter, benefit); + patterns.add(typeConverter, allocation, smem, + benefit); patterns.add(typeConverter, allocation, smem, axisInfoAnalysis, benefit); patterns.add(typeConverter, axisInfoAnalysis, benefit); @@ -6216,8 +6317,57 @@ private: }); } + void decomposeInsertSliceAsyncOp(ModuleOp mod, + TritonGPUToLLVMTypeConverter &converter) { + // cp.async is supported in Ampere and later + if (computeCapability >= 80) + return; + + // insert_slice_async %src, %dst, %idx, %mask, %other + // => + // %tmp = load %src, %mask, %other + // %res = insert_slice %tmp into %dst[%idx] + mod.walk([&](triton::gpu::InsertSliceAsyncOp insertSliceAsyncOp) -> void { + OpBuilder builder(insertSliceAsyncOp); + // load + auto srcTy = insertSliceAsyncOp.src().getType().cast(); + auto dstTy = insertSliceAsyncOp.getType().cast(); + auto srcBlocked = + srcTy.getEncoding().dyn_cast(); + auto elemTy = converter.convertType(dstTy.getElementType()); + auto tmpTy = RankedTensorType::get(srcTy.getShape(), elemTy, srcBlocked); + auto loadOp = builder.create( + insertSliceAsyncOp.getLoc(), tmpTy, insertSliceAsyncOp.src(), + insertSliceAsyncOp.mask(), insertSliceAsyncOp.other(), + insertSliceAsyncOp.cache(), insertSliceAsyncOp.evict(), + insertSliceAsyncOp.isVolatile()); + // insert_slice + auto axis = insertSliceAsyncOp.axis(); + auto intAttr = [&](int64_t v) { return builder.getI64IntegerAttr(v); }; + auto offsets = SmallVector(dstTy.getRank(), intAttr(0)); + auto sizes = SmallVector(dstTy.getRank(), intAttr(1)); + auto strides = SmallVector(dstTy.getRank(), intAttr(1)); + offsets[axis] = insertSliceAsyncOp.index(); + for (size_t i = 0; i < dstTy.getRank(); i++) { + if (i != axis) + sizes[i] = intAttr(dstTy.getShape()[i]); + } + auto insertSliceOp = builder.create( + insertSliceAsyncOp.getLoc(), loadOp, insertSliceAsyncOp.dst(), + offsets, sizes, strides); + // Replace + insertSliceAsyncOp.replaceAllUsesWith(insertSliceOp.getResult()); + insertSliceAsyncOp.erase(); + }); + + mod.walk([&](triton::gpu::AsyncWaitOp asyncWaitOp) -> void { + asyncWaitOp.erase(); + }); + } + public: - ConvertTritonGPUToLLVM() = default; + explicit ConvertTritonGPUToLLVM(int computeCapability) + : computeCapability(computeCapability) {} void runOnOperation() override { MLIRContext *context = &getContext(); @@ -6233,18 +6383,22 @@ public: int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod); // step 1: Decompose unoptimized layout conversions to use shared memory - // step 2: Allocate shared memories and insert barriers - // step 3: Convert SCF to CFG - // step 4: Convert FuncOp to LLVMFuncOp via partial conversion - // step 5: Convert the rest of ops via partial conversion - // The reason for putting step 1 before step 2 is that the membar analysis - // currently only supports SCF but not CFG. - // The reason for a separation between 1/4 is that, step 3 is out of - // the scope of Dialect Conversion, thus we need to make sure the smem - // is not revised during the conversion of step 4. + // step 2: Decompose insert_slice_async to use load + insert_slice for + // pre-Ampere architectures + // step 3: Allocate shared memories and insert barriers + // step 4: Convert SCF to CFG + // step 5: Convert FuncOp to LLVMFuncOp via partial conversion + // step 6: Convert the rest of ops via partial + // conversion The reason for putting step 1 before step 2 is that the membar + // analysis currently only supports SCF but not CFG. The reason for a + // separation between 1/4 is that, step 3 is out of the scope of Dialect + // Conversion, thus we need to make sure the smem is not revised during the + // conversion of step 4. decomposeBlockedToDotOperand(mod); + decomposeInsertSliceAsyncOp(mod, typeConverter); + Allocation allocation(mod); MembarAnalysis membar(&allocation); @@ -6303,6 +6457,8 @@ protected: TritonGPUToLLVMTypeConverter &typeConverter); Value smem; + + int computeCapability{}; }; void ConvertTritonGPUToLLVM::initSharedMemory( @@ -6365,8 +6521,9 @@ TritonLLVMFunctionConversionTarget::TritonLLVMFunctionConversionTarget( namespace triton { -std::unique_ptr> createConvertTritonGPUToLLVMPass() { - return std::make_unique<::ConvertTritonGPUToLLVM>(); +std::unique_ptr> +createConvertTritonGPUToLLVMPass(int computeCapability) { + return std::make_unique<::ConvertTritonGPUToLLVM>(computeCapability); } } // namespace triton diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp index cf62ff578..4ef695276 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp @@ -202,8 +202,7 @@ LogicalResult LoopPipeliner::initialize() { bufferShape.insert(bufferShape.begin(), numStages); auto sharedEnc = ttg::SharedEncodingAttr::get( ty.getContext(), dotOpEnc, ty.getShape(), - triton::gpu::getOrder(ty.getEncoding()), - ty.getElementType()); + triton::gpu::getOrder(ty.getEncoding()), ty.getElementType()); loadsBufferType[loadOp] = RankedTensorType::get( bufferShape, ty.getElementType(), sharedEnc); } diff --git a/lib/Target/LLVMIR/LLVMIRTranslation.cpp b/lib/Target/LLVMIR/LLVMIRTranslation.cpp index d2994f923..4ab7876d6 100644 --- a/lib/Target/LLVMIR/LLVMIRTranslation.cpp +++ b/lib/Target/LLVMIR/LLVMIRTranslation.cpp @@ -119,7 +119,7 @@ translateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module) { std::unique_ptr translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext, - mlir::ModuleOp module) { + mlir::ModuleOp module, int computeCapability) { mlir::PassManager pm(module->getContext()); applyPassManagerCLOptions(pm); auto printingFlags = mlir::OpPrintingFlags(); diff --git a/python/src/triton.cc b/python/src/triton.cc index 6fd5003e6..cfe092821 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -1107,7 +1107,8 @@ void init_triton_ir(py::module &&m) { mlir::Value &mask) -> mlir::Value { auto loc = self.getUnknownLoc(); mlir::Type dstType; - if (auto srcTensorType = ptr.getType().dyn_cast()) { + if (auto srcTensorType = + ptr.getType().dyn_cast()) { mlir::Type dstElemType = srcTensorType.getElementType() .cast() .getPointeeType(); @@ -1315,8 +1316,8 @@ void init_triton_translation(py::module &m) { "translate_triton_gpu_to_llvmir", [](mlir::ModuleOp op, int computeCapability) { llvm::LLVMContext llvmContext; - auto llvmModule = - ::mlir::triton::translateTritonGPUToLLVMIR(&llvmContext, op); + auto llvmModule = ::mlir::triton::translateTritonGPUToLLVMIR( + &llvmContext, op, computeCapability); if (!llvmModule) llvm::report_fatal_error("Failed to translate TritonGPU to LLVM IR."); diff --git a/test/Analysis/test-alias.mlir b/test/Analysis/test-alias.mlir index ea3c1e7e6..e81480297 100644 --- a/test/Analysis/test-alias.mlir +++ b/test/Analysis/test-alias.mlir @@ -65,6 +65,20 @@ func @insert_slice_async(%A : !tt.ptr, %i1 : i1) { return } +// CHECK-LABEL: insert_slice +func @insert_slice(%A : !tt.ptr, %i1 : i1) { + %a_ptr = tt.broadcast %A : (!tt.ptr) -> tensor<16x16x!tt.ptr, #AL> + %mask = tt.splat %i1 : (i1) -> tensor<16x16xi1, #AL> + %other = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> + // CHECK: %cst_0 -> %cst_0 + %tensor = arith.constant dense<0.000000e+00> : tensor<1x16x16xf16, #A_SHARED> + %index = arith.constant 0 : index + %a = tt.load %a_ptr, %mask, %other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16xf16, #AL> + // CHECK: %3 -> %cst_0 + %b = tensor.insert_slice %a into %tensor[%index, 0, 0][1, 16, 16][1, 1, 1]: tensor<16x16xf16, #AL> into tensor<1x16x16xf16, #A_SHARED> + return +} + // CHECK-LABEL: extract_slice func @extract_slice(%A : !tt.ptr) { // CHECK: %cst -> %cst diff --git a/test/Analysis/test-membar.mlir b/test/Analysis/test-membar.mlir index 50c8c22c1..c2afeb386 100644 --- a/test/Analysis/test-membar.mlir +++ b/test/Analysis/test-membar.mlir @@ -119,8 +119,26 @@ func @insert_slice_async(%A : !tt.ptr, %i1 : i1) { %tensor = triton_gpu.alloc_tensor : tensor<1x16x16xf16, #A_SHARED> %index = arith.constant 0 : i32 %a = triton_gpu.insert_slice_async %a_ptr, %tensor, %index, %mask, %other {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16x!tt.ptr, #AL> -> tensor<1x16x16xf16, #A_SHARED> + // CHECK: Membar 6 %b = tt.cat %a, %a {axis = 0} : (tensor<1x16x16xf16, #A_SHARED>, tensor<1x16x16xf16, #A_SHARED>) -> tensor<2x16x16xf16, #A_SHARED> - // CHECK: Membar 7 + // CHECK: Membar 8 + %c = tt.cat %b, %b {axis = 0} : (tensor<2x16x16xf16, #A_SHARED>, tensor<2x16x16xf16, #A_SHARED>) -> tensor<4x16x16xf16, #A_SHARED> + return +} + +// CHECK-LABEL: insert_slice +func @insert_slice(%A : !tt.ptr, %i1 : i1) { + %a_ptr = tt.broadcast %A : (!tt.ptr) -> tensor<16x16x!tt.ptr, #AL> + %mask = tt.splat %i1 : (i1) -> tensor<16x16xi1, #AL> + %other = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> + %tensor = arith.constant dense<0.000000e+00> : tensor<1x16x16xf16, #A_SHARED> + %index = arith.constant 0 : index + %al = tt.load %a_ptr, %mask, %other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16xf16, #AL> + // CHECK: Membar 6 + %a = tensor.insert_slice %al into %tensor[%index, 0, 0][1, 16, 16][1, 1, 1]: tensor<16x16xf16, #AL> into tensor<1x16x16xf16, #A_SHARED> + // CHECK: Membar 8 + %b = tt.cat %a, %a {axis = 0} : (tensor<1x16x16xf16, #A_SHARED>, tensor<1x16x16xf16, #A_SHARED>) -> tensor<2x16x16xf16, #A_SHARED> + // CHECK: Membar 10 %c = tt.cat %b, %b {axis = 0} : (tensor<2x16x16xf16, #A_SHARED>, tensor<2x16x16xf16, #A_SHARED>) -> tensor<4x16x16xf16, #A_SHARED> return } diff --git a/unittest/Dialect/TritonGPU/SwizzleTest.cpp b/unittest/Dialect/TritonGPU/SwizzleTest.cpp index 58c43ade6..1e4be3ea1 100644 --- a/unittest/Dialect/TritonGPU/SwizzleTest.cpp +++ b/unittest/Dialect/TritonGPU/SwizzleTest.cpp @@ -34,7 +34,8 @@ TEST_P(SwizzleDotOperandTestFixture, DotOperands) { // create element type Type eltType = IntegerType::get(&ctx, params.typeWidth); - auto layout = SharedEncodingAttr::get(&ctx, encoding, params.shape, {1, 0}, eltType); + auto layout = + SharedEncodingAttr::get(&ctx, encoding, params.shape, {1, 0}, eltType); ASSERT_EQ(layout.getVec(), params.refSwizzle.vec); ASSERT_EQ(layout.getPerPhase(), params.refSwizzle.perPhase);