From a9464f4993f62d4e5cf68275de664e0b9fb5805d Mon Sep 17 00:00:00 2001 From: Yan Chunwei Date: Wed, 7 Sep 2022 03:28:09 +0800 Subject: [PATCH] [Backend] Vectorize Load/Store Ops (#86) This PR does the following things: - Code refactoring on Load and Store op codegen, rewrite with same logic and share much code - Support the vectorized load/store --- lib/Analysis/Alias.cpp | 4 +- lib/Analysis/AxisInfo.cpp | 2 +- .../TritonGPUToLLVM/TritonGPUToLLVM.cpp | 465 +++++++++--------- lib/Dialect/TritonGPU/Transforms/Coalesce.cpp | 1 - python/src/triton.cc | 4 + python/tests/test_compiler.py | 7 +- python/tests/test_vecadd_no_scf.py | 59 +-- python/triton/compiler.py | 3 +- python/triton/runtime/jit.py | 129 ++++- test/Conversion/tritongpu_to_llvm.mlir | 54 +- 10 files changed, 433 insertions(+), 295 deletions(-) diff --git a/lib/Analysis/Alias.cpp b/lib/Analysis/Alias.cpp index 3cc7968d1..ab3ca544e 100644 --- a/lib/Analysis/Alias.cpp +++ b/lib/Analysis/Alias.cpp @@ -27,9 +27,9 @@ ChangeResult SharedMemoryAliasAnalysis::visitOperation( if (isSharedEncoding(result)) { aliasInfo.insert(result); pessimistic = false; + } else { + llvm::errs() << "op: " << op->getName() << "\n"; } - } else { - llvm::errs() << "op: " << op->getName() << "\n"; } // XXX(Keren): triton ops don't support aliasing yet. // else if (auto viewOp = dyn_cast(op) || diff --git a/lib/Analysis/AxisInfo.cpp b/lib/Analysis/AxisInfo.cpp index 6bdcc47fa..a6b280cc1 100644 --- a/lib/Analysis/AxisInfo.cpp +++ b/lib/Analysis/AxisInfo.cpp @@ -214,4 +214,4 @@ ChangeResult AxisInfoAnalysis::visitOperation( return result; } -} // namespace mlir \ No newline at end of file +} // namespace mlir diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index aba9e71d6..77d53d0a1 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -1,5 +1,6 @@ #include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h" #include "../PassDetail.h" +#include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h" #include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" #include "mlir/Conversion/LLVMCommon/LoweringOptions.h" @@ -307,16 +308,7 @@ static T getLinearIndex(ArrayRef multidim_index, ArrayRef shape) { return linear_index; } -template -class ConvertTritonGPUOpToLLVMPattern - : public ConvertOpToLLVMPattern { -public: - using OpAdaptor = typename SourceOp::Adaptor; - - explicit ConvertTritonGPUOpToLLVMPattern(LLVMTypeConverter &typeConverter, - PatternBenefit benefit = 1) - : ConvertOpToLLVMPattern(typeConverter, benefit) {} - +struct ConvertTritonGPUOpToLLVMPatternBase { SmallVector getElementsFromStruct(Location loc, Value llvmStruct, unsigned elems, ConversionPatternRewriter &rewriter) const { @@ -329,18 +321,18 @@ public: } return results; } +}; - Value getStructFromElements(Location loc, ValueRange resultVals, - ConversionPatternRewriter &rewriter, - Type structType) const { - Value llvmStruct = rewriter.create(loc, structType); - for (auto v : llvm::enumerate(resultVals)) { - llvmStruct = rewriter.create( - loc, structType, llvmStruct, v.value(), - rewriter.getI64ArrayAttr(v.index())); - } - return llvmStruct; - } +template +class ConvertTritonGPUOpToLLVMPattern + : public ConvertOpToLLVMPattern, + public ConvertTritonGPUOpToLLVMPatternBase { +public: + using OpAdaptor = typename SourceOp::Adaptor; + + explicit ConvertTritonGPUOpToLLVMPattern(LLVMTypeConverter &typeConverter, + PatternBenefit benefit = 1) + : ConvertOpToLLVMPattern(typeConverter, benefit) {} SmallVector delinearize(ConversionPatternRewriter &rewriter, Location loc, Value linear, @@ -432,24 +424,18 @@ public: for (unsigned blockOffset = 0; blockOffset < shape[k] / (sizePerThread[k] * threadsPerWarp[k] * warpsPerCTA[k]); - ++blockOffset) { - for (unsigned warpOffset = 0; warpOffset < warpsPerCTA[k]; - ++warpOffset) { + ++blockOffset) + for (unsigned warpOffset = 0; warpOffset < warpsPerCTA[k]; ++warpOffset) for (unsigned threadOffset = 0; threadOffset < threadsPerWarp[k]; - ++threadOffset) { + ++threadOffset) for (unsigned elemOffset = 0; elemOffset < sizePerThread[k]; - ++elemOffset) { + ++elemOffset) offset[k].push_back(blockOffset * sizePerThread[k] * threadsPerWarp[k] * warpsPerCTA[k] + warpOffset * sizePerThread[k] * threadsPerWarp[k] + threadOffset * sizePerThread[k] + elemOffset); - } - } - } - } } - // step 3, add offset to base, and reorder the sequence of indices, // to guarantee that elems in a same sizePerThread are adjacent in // order @@ -535,9 +521,9 @@ struct SplatOpConversion } }; -// This pattern helps to convert arith::ConstantOp(with SplatElementsAttr), the -// logic is the same as triton::SplatOp, so the underlying implementation is -// reused. +// This pattern helps to convert arith::ConstantOp(with SplatElementsAttr), +// the logic is the same as triton::SplatOp, so the underlying implementation +// is reused. struct ArithConstantSplatOpConversion : public ConvertTritonGPUOpToLLVMPattern { using ConvertTritonGPUOpToLLVMPattern< @@ -576,20 +562,104 @@ struct ArithConstantSplatOpConversion } }; +// Contains some helper functions for both Load and Store conversions. +struct LoadStoreConversionBase : public ConvertTritonGPUOpToLLVMPatternBase { + LoadStoreConversionBase(AxisInfoAnalysis &axisAnalysisPass) + : AxisAnalysisPass(axisAnalysisPass) {} + + // Get corresponding LLVM element values of \param value. + SmallVector getLLVMElems(Value value, Value llValue, + const BlockedEncodingAttr &layout, + TypeConverter *typeConverter, + ConversionPatternRewriter &rewriter, + Location loc) const { + if (!value) + return {}; + + auto ty = value.getType().cast(); + auto shape = ty.getShape(); + // Here, we assume that all inputs should have a blockedLayout + + unsigned valueElems = getElemsPerThread(layout, shape); + + auto llvmElemTy = typeConverter->convertType(ty.getElementType()); + auto llvmElemPtrPtrTy = + LLVM::LLVMPointerType::get(LLVM::LLVMPointerType::get(llvmElemTy)); + + auto valueVals = getElementsFromStruct(loc, llValue, valueElems, rewriter); + return valueVals; + } + + // Get the blocked layout. + std::tuple getLayout(Value val) const { + auto ty = val.getType().cast(); + // Here, we assume that all inputs should have a blockedLayout + auto layout = ty.getEncoding().dyn_cast(); + auto shape = ty.getShape(); + unsigned valueElems = getElemsPerThread(layout, shape); + return std::make_tuple(layout, valueElems); + } + + unsigned getAlignment(Value val, const BlockedEncodingAttr &layout) const { + auto axisInfo = getAxisInfo(val); + + auto order = layout.getOrder(); + + unsigned maxMultiple = axisInfo->getDivisibility(order[0]); + unsigned maxContig = axisInfo->getContiguity(order[0]); + unsigned alignment = std::min(maxMultiple, maxContig); + return alignment; + } + + unsigned getVectorizeSize(Value ptr, + const BlockedEncodingAttr &layout) const { + auto axisInfo = getAxisInfo(ptr); + auto contig = axisInfo->getContiguity(); + // Here order should be ordered by contiguous first, so the first element + // should have the largest contiguous. + auto order = layout.getOrder(); + unsigned align = getAlignment(ptr, layout); + + auto getTensorShape = [](Value val) -> ArrayRef { + auto ty = val.getType().cast(); + auto shape = ty.getShape(); + return shape; + }; + + // unsigned contigPerThread = layout.getSizePerThread()[order[0]]; + unsigned contigPerThread = getElemsPerThread(layout, getTensorShape(ptr)); + + unsigned vec = std::min(align, contigPerThread); + + return vec; + } + + llvm::Optional getAxisInfo(Value val) const { + if (auto it = AxisAnalysisPass.lookupLatticeElement(val)) { + return it->getValue(); + } + + return llvm::Optional{}; + } + +protected: + AxisInfoAnalysis &AxisAnalysisPass; +}; + struct StoreOpConversion - : public ConvertTritonGPUOpToLLVMPattern { + : public ConvertTritonGPUOpToLLVMPattern, + public LoadStoreConversionBase { using ConvertTritonGPUOpToLLVMPattern< triton::StoreOp>::ConvertTritonGPUOpToLLVMPattern; StoreOpConversion(LLVMTypeConverter &converter, AxisInfoAnalysis &axisAnalysisPass, PatternBenefit benefit) : ConvertTritonGPUOpToLLVMPattern(converter, benefit), - AxisAnalysisPass(axisAnalysisPass) {} + LoadStoreConversionBase(axisAnalysisPass) {} LogicalResult matchAndRewrite(triton::StoreOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - Value ptr = op.ptr(); Value mask = op.mask(); Value value = op.value(); @@ -598,125 +668,52 @@ struct StoreOpConversion Value llMask = adaptor.mask(); Value llValue = adaptor.value(); + auto loc = op->getLoc(); + MLIRContext *ctx = rewriter.getContext(); + auto valueTy = value.getType().dyn_cast(); if (!valueTy) return failure(); Type valueElemTy = getTypeConverter()->convertType(valueTy.getElementType()); - MLIRContext *ctx = rewriter.getContext(); - auto loc = op->getLoc(); + auto [layout, numElems] = getLayout(ptr); - auto getLLVMElems = - [&](Value value, Value llValue, - const BlockedEncodingAttr &layout) -> SmallVector { - auto ty = value.getType().cast(); - auto shape = ty.getShape(); - // Here, we assume that all inputs should have a blockedLayout + auto ptrElems = + getLLVMElems(ptr, llPtr, layout, getTypeConverter(), rewriter, loc); + auto valueElems = + getLLVMElems(value, llValue, layout, getTypeConverter(), rewriter, loc); + assert(ptrElems.size() == valueElems.size()); - unsigned valueElems = getElemsPerThread(layout, shape); - - auto llvmElemTy = getTypeConverter()->convertType(ty.getElementType()); - auto llvmElemPtrPtrTy = - LLVM::LLVMPointerType::get(LLVM::LLVMPointerType::get(llvmElemTy)); - - auto valueVals = - getElementsFromStruct(loc, llValue, valueElems, rewriter); - return valueVals; - }; - - auto getLayout = - [&](Value val) -> std::tuple { - auto ty = val.getType().cast(); - auto shape = ty.getShape(); - // Here, we assume that all inputs should have a blockedLayout - auto layout = ty.getEncoding().dyn_cast(); - - unsigned valueElems = getElemsPerThread(layout, shape); - - return std::make_tuple(layout, valueElems); - }; - - auto [ptrLayout, ptrNumElems] = getLayout(ptr); - auto [valueLayout, valueNumElems] = getLayout(value); - - auto ptrElems = getLLVMElems(ptr, llPtr, ptrLayout); - auto valueElems = getLLVMElems(value, llValue, valueLayout); SmallVector maskElems; if (llMask) { - auto [maskLayout, maskNumElems] = getLayout(mask); - maskElems = getLLVMElems(mask, llMask, maskLayout); + maskElems = + getLLVMElems(mask, llMask, layout, getTypeConverter(), rewriter, loc); assert(valueElems.size() == maskElems.size()); } - auto getAlign = [this](Value val, - const BlockedEncodingAttr &layout) -> unsigned { - auto axisInfo = getAxisInfo(val); - assert(axisInfo.hasValue()); - - auto order = layout.getOrder(); - - unsigned maxMultiple = axisInfo->getDivisibility(order[0]); - unsigned maxContig = axisInfo->getContiguity(order[0]); - unsigned alignment = std::min(maxMultiple, maxContig); - return alignment; - }; - - // get align - auto getVec = [this, - &getAlign](Value val, - const BlockedEncodingAttr &layout) -> unsigned { - auto axisInfo = getAxisInfo(val); - auto contig = axisInfo->getContiguity(); - // Here order should be ordered by contiguous first, so the first element - // should have the largest contiguous. - auto order = layout.getOrder(); - unsigned align = getAlign(val, layout); - - assert(!order.empty()); - // Is this right? - unsigned contigPerThread = layout.getSizePerThread()[order[0]]; - unsigned vec = std::min(align, contigPerThread); - - // TODO(Superjomn) Consider the is_mma_first_row in the legacy code - bool isMMAFirstRow = false; - - if (isMMAFirstRow) - vec = std::min(2, align); - - return vec; - }; - // Determine the vectorization size - size_t vec = getVec(ptr, ptrLayout); + size_t vec = getVectorizeSize(ptr, layout); - const size_t dtsize = value.getType() - .cast() - .getElementType() - .getIntOrFloatBitWidth() / - 8; + const size_t dtsize = + std::max(1, valueElemTy.getIntOrFloatBitWidth() / 8); const size_t valueElemNbits = dtsize * 8; - const int numVecs = ptrNumElems / vec; - for (size_t vecIdx = 0; vecIdx < ptrNumElems; vecIdx += vec) { - + const int numVecs = numElems / vec; + for (size_t vecStart = 0; vecStart < numElems; vecStart += vec) { // TODO: optimization when ptr is GEP with constant offset size_t in_off = 0; - // pack sub-words (< 32/64bits) into words - // each load has width min(nbits*vec, 32/64) - // and there are (nbits * vec)/width of them const int maxWordWidth = std::max(32, valueElemNbits); const int totalWidth = valueElemNbits * vec; const int width = std::min(totalWidth, maxWordWidth); const int nWords = std::max(1, totalWidth / width); const int wordNElems = width / valueElemNbits; const int vecNElems = totalWidth / valueElemNbits; + assert(wordNElems * nWords * numVecs == numElems); - assert(wordNElems * nWords * numVecs == valueElems.size()); - - // TODO(Superjomn) Add cache policy to store. - // TODO(Superjomn) deal with cache policy. + // TODO(Superjomn) Add cache policy fields to StoreOp. + // TODO(Superjomn) Deal with cache policy here. const bool hasL2EvictPolicy = false; PTXBuilder ptxBuilder; @@ -733,8 +730,9 @@ struct StoreOpConversion Value llWord = rewriter.create(loc, wordTy); // Insert each value element to the composition for (int elemIdx = 0; elemIdx < wordNElems; elemIdx++) { - Value elem = - valueElems[vecIdx * vecNElems + wordIdx * wordNElems + elemIdx]; + const size_t elemOffset = vecStart + wordIdx * wordNElems + elemIdx; + assert(elemOffset < valueElems.size()); + Value elem = valueElems[elemOffset]; if (elem.getType().isInteger(1)) elem = rewriter.create(loc, type::i8Ty(ctx), elem); elem = rewriter.create(loc, valueElemTy, elem); @@ -751,13 +749,17 @@ struct StoreOpConversion asmArgList->listAppend(ptxBuilder.newOperand(llWord, constraint)); } + // TODO(Superjomn) Need to check masks before vectorize the load for all + // the values share one predicate? Here assume all the mask values are + // the same. Value maskVal = - llMask ? maskElems[vecIdx] + llMask ? maskElems[vecStart] : createLLVMIntegerConstant(rewriter, loc, getTypeConverter(), rewriter.getIntegerType(1), 1); ptxStoreInstr.predicate(maskVal, "b").global().b(width).v(nWords); - auto *asmAddr = ptxBuilder.newAddrOperand(ptrElems[vecIdx], "l", in_off); + auto *asmAddr = + ptxBuilder.newAddrOperand(ptrElems[vecStart], "l", in_off); ptxStoreInstr(asmAddr, asmArgList); Type boolTy = getTypeConverter()->convertType(rewriter.getIntegerType(1)); @@ -782,17 +784,6 @@ struct StoreOpConversion rewriter.eraseOp(op); return success(); } - - llvm::Optional getAxisInfo(Value val) const { - if (auto it = AxisAnalysisPass.lookupLatticeElement(val)) { - return it->getValue(); - } - - return llvm::Optional{}; - } - -private: - AxisInfoAnalysis &AxisAnalysisPass; }; struct BroadcastOpConversion @@ -909,8 +900,10 @@ struct ViewOpConversion struct MakeRangeOpConversion : public ConvertTritonGPUOpToLLVMPattern { - using ConvertTritonGPUOpToLLVMPattern< - triton::MakeRangeOp>::ConvertTritonGPUOpToLLVMPattern; + + MakeRangeOpConversion(LLVMTypeConverter &converter, PatternBenefit benefit) + : ConvertTritonGPUOpToLLVMPattern(converter, + benefit) {} LogicalResult matchAndRewrite(triton::MakeRangeOp op, OpAdaptor adaptor, @@ -918,13 +911,12 @@ struct MakeRangeOpConversion Location loc = op->getLoc(); auto rankedTy = op.result().getType().dyn_cast(); auto shape = rankedTy.getShape(); - auto blocked_layout = - rankedTy.getEncoding().dyn_cast(); + auto layout = rankedTy.getEncoding().cast(); + auto elemTy = rankedTy.getElementType(); assert(elemTy.isInteger(32)); Value start = createIndexAttrConstant(rewriter, loc, elemTy, op.start()); - auto idxs = - emitIndicesForBlockedLayout(loc, rewriter, blocked_layout, shape); + auto idxs = emitIndicesForBlockedLayout(loc, rewriter, layout, shape); unsigned elems = idxs.size(); SmallVector retVals(elems); for (auto multiDim : llvm::enumerate(idxs)) { @@ -941,92 +933,116 @@ struct MakeRangeOpConversion }; struct LoadOpConversion - : public ConvertTritonGPUOpToLLVMPattern { + : public ConvertTritonGPUOpToLLVMPattern, + public LoadStoreConversionBase { using ConvertTritonGPUOpToLLVMPattern< triton::LoadOp>::ConvertTritonGPUOpToLLVMPattern; + LoadOpConversion(LLVMTypeConverter &converter, + AxisInfoAnalysis &axisAnalysisPass, PatternBenefit benefit) + : ConvertTritonGPUOpToLLVMPattern(converter, benefit), + LoadStoreConversionBase(axisAnalysisPass) {} + LogicalResult matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - Location loc = op->getLoc(); - Value ptr = adaptor.ptr(); - Value mask = adaptor.mask(); - Value other = adaptor.other(); - auto resultTy = op.result().getType().cast(); - auto blockedLayout = resultTy.getEncoding().dyn_cast(); - auto shape = resultTy.getShape(); - // TODO: Handle AxisInfo - // vecWidth = std::min(nts, aln) - // TODO: special processing for mma_first_row in legacy codes - assert(blockedLayout && "LoadOp only accepts blocked_layout"); - unsigned vecWidth = - blockedLayout.getSizePerThread()[blockedLayout.getOrder()[0]]; + Value ptr = op.ptr(); + Value mask = op.mask(); + Value other = op.other(); - auto elemTy = resultTy.getElementType(); - unsigned numElems = getElemsPerThread(blockedLayout, shape); - auto ptrVals = getElementsFromStruct(loc, ptr, numElems, rewriter); - SmallVector maskVals; - if (mask) { - maskVals = getElementsFromStruct(loc, mask, numElems, rewriter); + Value llPtr = adaptor.ptr(); + Value llMask = adaptor.mask(); + Value llOther = adaptor.other(); + + auto loc = op->getLoc(); + MLIRContext *ctx = rewriter.getContext(); + + auto valueTy = op.getResult().getType().dyn_cast(); + if (!valueTy) + return failure(); + Type valueElemTy = + getTypeConverter()->convertType(valueTy.getElementType()); + + auto [layout, numElems] = getLayout(ptr); + + auto ptrElems = + getLLVMElems(ptr, llPtr, layout, getTypeConverter(), rewriter, loc); + assert(ptrElems.size() == numElems); + + SmallVector maskElems; + if (llMask) { + maskElems = + getLLVMElems(mask, llMask, layout, getTypeConverter(), rewriter, loc); + assert(ptrElems.size() == maskElems.size()); } - SmallVector otherVals; - if (other) { - otherVals = getElementsFromStruct(loc, other, numElems, rewriter); - } - unsigned nbits = elemTy.isa() - ? elemTy.cast().getWidth() - : elemTy.cast().getWidth(); - // unsigned dtsize = nbits / 8; - int max_word_width = std::max(32, nbits); - int tot_width = nbits * vecWidth; - int width = std::min(tot_width, max_word_width); - int n_words = std::max(1, tot_width / width); - // TODO: currently disable until supported in `store` - bool has_l2_evict_policy = false; + + // Determine the vectorization size + size_t vec = getVectorizeSize(ptr, layout); + + const size_t dtsize = + std::max(1, valueElemTy.getIntOrFloatBitWidth() / 8); + const size_t valueElemNbits = dtsize * 8; + + const int numVecs = numElems / vec; // TODO: (goostavz) handle when other is const but not splat, which // should be rarely seen bool otherIsSplatConstInt = false; DenseElementsAttr constAttr; int64_t splatVal = 0; - if (elemTy.isa() && + if (valueElemTy.isa() && matchPattern(op.other(), m_Constant(&constAttr)) && constAttr.isSplat()) { otherIsSplatConstInt = true; splatVal = constAttr.getSplatValue().getSExtValue(); } + auto otherElems = + getLLVMElems(other, llOther, layout, getTypeConverter(), rewriter, loc); + SmallVector loadedVals; - for (size_t i = 0; i < numElems; i += vecWidth) { - Value ptr = ptrVals[i]; - // TODO: Handle the optimization if ptr is from GEP and the idx is - // constant. This should be a canonicalization pattern in LLVM Dialect - unsigned in_off = 0; + for (size_t vecStart = 0; vecStart < numElems; vecStart += vec) { + // TODO: optimization when ptr is GEP with constant offset + size_t in_off = 0; + + const int maxWordWidth = std::max(32, valueElemNbits); + const int totalWidth = valueElemNbits * vec; + const int width = std::min(totalWidth, maxWordWidth); + const int nWords = std::max(1, totalWidth / width); + const int wordNElems = width / valueElemNbits; + const int vecNElems = totalWidth / valueElemNbits; + assert(wordNElems * nWords * numVecs == numElems); + + // TODO(Superjomn) Add cache policy fields to StoreOp. + // TODO(Superjomn) Deal with cache policy here. + const bool hasL2EvictPolicy = false; + + PTXBuilder ptxBuilder; + auto &ld = *ptxBuilder.create("ld"); + + // TODO(Superjomn) Need to check masks before vectorize the load for all + // the values share one predicate? Here assume all the mask values are + // the same. Value pred = - mask ? maskVals[i] + mask ? maskElems[vecStart] : createLLVMIntegerConstant(rewriter, loc, getTypeConverter(), rewriter.getIntegerType(1), 1); - // --- - // create inline asm string - // --- - - const std::string readConstrait = + const std::string readConstraint = (width == 64) ? "l" : ((width == 32) ? "r" : "c"); - const std::string writeConstrait = + const std::string writeConstraint = (width == 64) ? "=l" : ((width == 32) ? "=r" : "=c"); - PTXBuilder ptxBuilder; - PtxIOInstr &ld = *ptxBuilder.create("ld"); - // prepare asm operands auto *dstsOpr = ptxBuilder.newListOperand(); - for (int i = 0; i < n_words; i++) { - auto *opr = ptxBuilder.newOperand(writeConstrait); // =r operations + for (int wordIdx = 0; wordIdx < nWords; wordIdx++) { + auto *opr = ptxBuilder.newOperand(writeConstraint); // =r operations dstsOpr->listAppend(opr); } - auto *addrOpr = ptxBuilder.newAddrOperand(ptr, "l", in_off); + + auto *addrOpr = + ptxBuilder.newAddrOperand(ptrElems[vecStart], "l", in_off); // Define the instruction opcode ld.predicate(pred, "b") @@ -1037,11 +1053,12 @@ struct LoadOpConversion .o("L1::evict_first", op.evict() == triton::EvictionPolicy::EVICT_FIRST) .o("L1::evict_last", op.evict() == triton::EvictionPolicy::EVICT_LAST) - .o("L1::cache_hint", has_l2_evict_policy) - .v(n_words) + .o("L1::cache_hint", hasL2EvictPolicy) + .v(nWords) .b(width); PTXBuilder::Operand *evictOpr{}; + // Here lack a mlir::Value to bind to this operation, so disabled. // if (has_l2_evict_policy) // evictOpr = ptxBuilder.newOperand(l2Evict, "l"); @@ -1053,16 +1070,16 @@ struct LoadOpConversion SmallVector others; if (other) { - for (size_t ii = 0; ii < n_words; ii++) { + for (size_t ii = 0; ii < nWords; ii++) { PTXInstr &mov = *ptxBuilder.create<>("mov"); mov.predicateNot(pred, "b").o("u", width); - size_t size = width / nbits; + size_t size = width / valueElemNbits; - auto vecTy = LLVM::getFixedVectorType(elemTy, size); + auto vecTy = LLVM::getFixedVectorType(valueElemTy, size); Value v = rewriter.create(loc, vecTy); for (size_t s = 0; s < size; s++) { - Value falseVal = otherVals[i + ii * size + s]; + Value falseVal = otherElems[vecStart + ii * size + s]; Value sVal = createIndexAttrConstant( rewriter, loc, this->getTypeConverter()->getIndexType(), s); v = rewriter.create(loc, vecTy, v, falseVal, @@ -1075,7 +1092,7 @@ struct LoadOpConversion if (otherIsSplatConstInt) { opr = ptxBuilder.newConstantOperand(splatVal); } else { - opr = ptxBuilder.newOperand(v, readConstrait); + opr = ptxBuilder.newOperand(v, readConstraint); others.push_back(v); } @@ -1086,7 +1103,7 @@ struct LoadOpConversion // --- // create inline ASM signature // --- - SmallVector retTys(n_words, IntegerType::get(getContext(), width)); + SmallVector retTys(nWords, IntegerType::get(getContext(), width)); Type retTy = retTys.size() > 1 ? LLVM::LLVMStructType::getLiteral(getContext(), retTys) : retTys[0]; @@ -1097,7 +1114,8 @@ struct LoadOpConversion auto inlineAsmOp = rewriter.create( loc, retTy, /*operands=*/ptxBuilder.getAllMLIRArgs(), /*asm_string=*/ptxBuilder.dump(), - /*constraints=*/ptxBuilder.getConstrains(), /*has_side_effects=*/true, + /*constraints=*/ptxBuilder.getConstrains(), + /*has_side_effects=*/true, /*is_align_stack=*/false, /*asm_dialect=*/asmDialectAttr, /*operand_attrs=*/ArrayAttr()); Value ret = inlineAsmOp.getResult(0); @@ -1106,8 +1124,8 @@ struct LoadOpConversion // extract and store return values // --- SmallVector rets; - for (unsigned int ii = 0; ii < n_words; ii++) { - Value curr = nullptr; + for (unsigned int ii = 0; ii < nWords; ii++) { + Value curr; if (retTy.isa()) { curr = rewriter.create( loc, IntegerType::get(getContext(), width), ret, @@ -1116,19 +1134,21 @@ struct LoadOpConversion curr = ret; } curr = rewriter.create( - loc, LLVM::getFixedVectorType(elemTy, width / nbits), curr); + loc, LLVM::getFixedVectorType(valueElemTy, width / valueElemNbits), + curr); rets.push_back(curr); } - int tmp = (width / nbits); - for (size_t ii = 0; ii < vecWidth; ii++) { + int tmp = (width / valueElemNbits); + for (size_t ii = 0; ii < vec; ii++) { Value vecIdx = createIndexAttrConstant( rewriter, loc, this->getTypeConverter()->getIndexType(), ii % tmp); Value loaded = rewriter.create( - loc, elemTy, rets[ii / tmp], vecIdx); + loc, valueElemTy, rets[ii / tmp], vecIdx); loadedVals.push_back(loaded); } } // end vec - Type llvmResultStructTy = getTypeConverter()->convertType(resultTy); + + Type llvmResultStructTy = getTypeConverter()->convertType(valueTy); Value resultStruct = getStructFromElements(loc, loadedVals, rewriter, llvmResultStructTy); rewriter.replaceOp(op, {resultStruct}); @@ -1272,11 +1292,16 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter, benefit); patterns.add>(typeConverter, benefit); + patterns.add>(typeConverter, + benefit); + patterns.add>(typeConverter, + benefit); + patterns.add(typeConverter, benefit); patterns.add(typeConverter, numWarps, benefit); patterns.add(typeConverter, benefit); patterns.add(typeConverter, benefit); - patterns.add(typeConverter, benefit); + patterns.add(typeConverter, analysis, benefit); patterns.add(typeConverter, benefit); patterns.add(typeConverter, benefit); patterns.add(typeConverter, benefit); diff --git a/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp b/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp index 63a0a6d1b..704a89734 100644 --- a/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp @@ -10,7 +10,6 @@ using namespace mlir::triton; #include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" struct CoalescePass : public TritonGPUCoalesceBase { - Attribute getCoalescedEncoding(AxisInfoAnalysis &axisInfo, Value ptr, int numWarps) { auto origType = ptr.getType().cast(); diff --git a/python/src/triton.cc b/python/src/triton.cc index 41f99f84a..9e3485df8 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -1649,6 +1649,10 @@ void init_triton_ir(py::module &&m) { .def( "add_sccp_pass", [](mlir::PassManager &self) { self.addPass(mlir::createSCCPPass()); }) + .def("add_coalesce_pass", + [](mlir::PassManager &self) { + self.addPass(mlir::createTritonGPUCoalescePass()); + }) .def("add_symbol_dce_pass", [](mlir::PassManager &self) { self.addPass(mlir::createSymbolDCEPass()); diff --git a/python/tests/test_compiler.py b/python/tests/test_compiler.py index 452d33f63..ff4ca1bcf 100644 --- a/python/tests/test_compiler.py +++ b/python/tests/test_compiler.py @@ -29,7 +29,6 @@ def test_empty_kernel_cubin_compile(): def test_empty_kernel_launch(): device = torch.cuda.current_device() binary = runtime.build_kernel(empty_kernel, "*fp32,i32,i32", - device=device, constants={"BLOCK": 256}, num_warps=4, num_stages=3) @@ -38,11 +37,9 @@ def test_empty_kernel_launch(): ) A = torch.zeros([1024], device="cuda") - runtime.launch_kernel(fn=empty_kernel, - binary=binary, + runtime.launch_kernel(kernel=binary, grid=grid, - num_warps=4, - num_stages=3, + device=device, X=A, stride_xm=256, BLOCK=tl.constexpr(256)) diff --git a/python/tests/test_vecadd_no_scf.py b/python/tests/test_vecadd_no_scf.py index ee4994d2b..780d76926 100644 --- a/python/tests/test_vecadd_no_scf.py +++ b/python/tests/test_vecadd_no_scf.py @@ -5,17 +5,12 @@ import triton import triton.language as tl import triton.runtime as runtime -NUM_WARPS = 4 -BLOCK_SIZE = 256 -# triton kernel - - -def test_vecadd_no_scf(): +def vecadd_no_scf_tester(num_warps, block_size): @triton.jit - def kernel(x_ptr, stride_xn, - y_ptr, stride_yn, - z_ptr, stride_zn, + def kernel(x_ptr, + y_ptr, + z_ptr, BLOCK_SIZE_N: tl.constexpr): pid = tl.program_id(axis=0) offset = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) @@ -27,37 +22,35 @@ def test_vecadd_no_scf(): z_ptrs = z_ptr + offset tl.store(z_ptrs, z) - # TODO: add this to CI, to make sure the the compilation flow is at lease OK - # before we have GPU machines for CI. - # ptx, shem_size, kernel_name = triton.compile(kernel, - # "*fp32,i32,*fp32,i32,*fp32,i32", - # constants={"BLOCK_SIZE_N": 256}, - # num_warps=NUM_WARPS, - # device=0, output="ptx") - torch.zeros([10], device=torch.device('cuda')) device = torch.cuda.current_device() - binary = runtime.build_kernel(kernel, "*fp32,i32,*fp32,i32,*fp32,i32", - device=device, - constants={"BLOCK_SIZE_N": BLOCK_SIZE}, - num_warps=NUM_WARPS, + binary = runtime.build_kernel(kernel, "*fp32,*fp32,*fp32,i32", + constants={"BLOCK_SIZE_N": block_size}, + num_warps=num_warps, num_stages=3) - grid = lambda META: (1, ) - x = torch.randn((256,), device='cuda', dtype=torch.float32) - y = torch.randn((256,), device='cuda', dtype=torch.float32) - z = torch.empty((256,), device=x.device, dtype=x.dtype) - runtime.launch_kernel(fn=kernel, - binary=binary, + x = torch.randn((block_size,), device='cuda', dtype=torch.float32) + y = torch.randn((block_size,), device='cuda', dtype=torch.float32) + z = torch.empty((block_size,), device=x.device, dtype=x.dtype) + + assert x.shape.numel() % block_size == 0, "Only test load without mask here" + grid = lambda EA: (x.shape.numel() // block_size,) + + runtime.launch_kernel(kernel=binary, grid=grid, - num_warps=NUM_WARPS, - num_stages=3, + device=device, x_ptr=x, - stride_xn=x.stride(0), y_ptr=y, - stride_yn=y.stride(0), z_ptr=z, - stride_zn=z.stride(0), - BLOCK_SIZE_N=tl.constexpr(BLOCK_SIZE)) + BLOCK_SIZE_N=tl.constexpr(block_size)) golden_z = x + y assert_allclose(z, golden_z, rtol=1e-7, atol=1e-7) + + +def test_vecadd_no_scf(): + vecadd_no_scf_tester(num_warps=2, block_size=256) + vecadd_no_scf_tester(num_warps=1, block_size=256) + + +if __name__ == '__main__': + test_vecadd_no_scf() diff --git a/python/triton/compiler.py b/python/triton/compiler.py index 203c114a6..f9f9a53c5 100644 --- a/python/triton/compiler.py +++ b/python/triton/compiler.py @@ -798,7 +798,8 @@ def optimize_tritongpu_ir(mod, num_stages): pm.add_tritongpu_pipeline_pass(num_stages) pm.add_canonicalizer_pass() pm.add_cse_pass() - # pm.add_triton_gpu_combine_pass() + pm.add_coalesce_pass() + pm.add_triton_gpu_combine_pass() pm.add_triton_gpu_verifier_pass() pm.run(mod) return mod diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py index f99cf9a77..7fe270c97 100644 --- a/python/triton/runtime/jit.py +++ b/python/triton/runtime/jit.py @@ -8,7 +8,7 @@ import os import subprocess import tempfile import textwrap -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional import torch @@ -256,41 +256,126 @@ class JITFunction: return f"JITFunction({self.module}:{self.fn.__name__})" +def pow2_divisor(N): + if N % 16 == 0: + return 16 + if N % 8 == 0: + return 8 + if N % 4 == 0: + return 4 + if N % 2 == 0: + return 2 + return 1 + + +class _KernelCache: + def __init__(self, + fn: JITFunction, + fn_type: str, + constants: Dict[str, Any], + num_warps: int = 4, + num_stages: int = 3): + # hold the arguments for building a kernel + self.fn = fn + self.fn_type = fn_type + self.constants = constants + self.num_warps = num_warps + self.num_stages = num_stages + + # kernel compilation cache + self._binary_cache: Optional[LoadedBinary] = None + + @property + def binary_cache(self): + return self._binary_cache + + def set_binary_cache(self, binary: LoadedBinary): + assert binary + assert not self._binary_cache, "cannot set binary cache duplicately" + self._binary_cache = binary + + def build_kernel(fn: JITFunction, fn_type: str, - device: int, constants: Dict[str, Any], num_warps: int = 4, num_stages: int = 3, - ) -> LoadedBinary: - cubin, shem_size, kernel_name = compile(fn, fn_type, device=device, constants=constants, num_warps=num_warps, num_stages=num_stages, output="cubin") - assert cubin - assert kernel_name - - backend = _triton.runtime.backend.CUDA - - max_shared_memory = _triton.runtime.max_shared_memory(backend, device) - assert shem_size <= max_shared_memory, "shared memory out of resource, max size is %d, but want %s" % (max_shared_memory, shem_size) - - asm = dict(cubin=cubin) - binary = Binary(backend, kernel_name, asm, shem_size, num_warps) - loaded_binary = LoadedBinary(device, binary) - return loaded_binary + ) -> _KernelCache: + return _KernelCache(fn, fn_type, constants, num_warps, num_stages) -def launch_kernel(fn: JITFunction, binary: LoadedBinary, grid, num_warps, num_stages, *wargs, **kwargs): - kwargs = {fn.arg_names.index(name): value for name, value in kwargs.items()} +torch_dtype_to_bytes = { + torch.int8: 1, + torch.uint8: 1, + + torch.int16: 2, + torch.short: 2, + + torch.int: 4, + torch.int32: 4, + + torch.long: 8, + torch.int64: 8, + + torch.float32: 4, + torch.float: 4, + + torch.float16: 2, + torch.half: 2, + torch.bfloat16: 2, + # free to extend +} + + +def launch_kernel(kernel: _KernelCache, grid, device, *wargs, **kwargs): + def is_tensor(arg): + return hasattr(arg, 'data_ptr') # a torch.tensor + + # prepare function args for compile + kwargs = {kernel.fn.arg_names.index(name): value for name, value in kwargs.items()} wargs = list(wargs) for i, pos in enumerate(sorted(kwargs)): wargs.insert(pos + i, kwargs[pos]) - assert len(wargs) == len(fn.arg_names), "Function argument list not match, need %d but get %d args" % (len(fn.arg_names), len(wargs)) + assert len(wargs) == len(kernel.fn.arg_names), "Function argument list not match, need %d but get %d args" % (len(kernel.fn.arg_names), len(wargs)) + + if not kernel.binary_cache: + # build the kernel cache + backend = _triton.runtime.backend.CUDA + + attributes = dict() + for i, arg in enumerate(wargs): + if i in kernel.fn.do_not_specialize: + continue + if isinstance(arg, int): + attributes[i] = pow2_divisor(arg) + elif is_tensor(arg): + assert arg.dtype in torch_dtype_to_bytes + addr = arg.data_ptr() + range_size = _triton.runtime.get_pointer_range_size(addr) + divisibility = min(pow2_divisor(addr), pow2_divisor(range_size)) // torch_dtype_to_bytes[arg.dtype] + attributes[i] = divisibility + + attributes_ = dict() + for i, value in attributes.items(): + attributes_[kernel.fn.arg_names[i]] = value + + cubin, shem_size, kernel_name = compile(kernel.fn, kernel.fn_type, device=device, constants=kernel.constants, attributes=attributes_, num_warps=kernel.num_warps, num_stages=kernel.num_stages, output="cubin") + assert cubin + assert kernel_name + + max_shared_memory = _triton.runtime.max_shared_memory(backend, device) + assert shem_size <= max_shared_memory, "shared memory out of resource, max size is %d, but want %s" % (max_shared_memory, shem_size) + + asm = dict(cubin=cubin) + binary = Binary(backend, kernel_name, asm, shem_size, kernel.num_warps) + loaded_binary = LoadedBinary(device, binary) + kernel.set_binary_cache(loaded_binary) - device = torch.cuda.current_device() torch.cuda.set_device(device) stream = get_cuda_stream(device) - _triton.runtime.launch_binary(binary, wargs, fn.do_not_specialize, fn.arg_names, - stream, num_warps, num_stages, grid) + _triton.runtime.launch_binary(kernel.binary_cache, wargs, kernel.fn.do_not_specialize, kernel.fn.arg_names, + stream, kernel.num_warps, kernel.num_stages, grid) # ----------------------------------------------------------------------------- diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index b164db9b0..c811bf516 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -28,14 +28,14 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [8], warpsPerCTA = [4], order = [0]}> +#blocked0 = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> module attributes {"triton_gpu.num-warps" = 4 : i32} { // CHECK-LABEL: vectorized_load func @vectorized_load(%a_ptr_init : tensor<256x!tt.ptr, #blocked0>, %cst : tensor<256xi1, #blocked0>, %cst_0 : tensor<256xf32, #blocked0>) { // CHECK: llvm.inline_asm - // CHECK-SAME: ld.global.v4.b32 + // CHECK-SAME: ld.global.b32 // CHECK: llvm.inline_asm - // CHECK-SAME: ld.global.v4.b32 + // CHECK-SAME: ld.global.b32 %1 = tt.load %a_ptr_init, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked0> return } @@ -43,14 +43,14 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [8], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.num-warps" = 4 : i32} { +#blocked0 = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> +module attributes {"triton_gpu.num-warps" = 1 : i32} { // CHECK-LABEL: vectorized_load_f16 - func @vectorized_load_f16(%a_ptr_init : tensor<256x!tt.ptr, #blocked0>, %cst : tensor<256xi1, #blocked0>, %cst_0 : tensor<256xf16, #blocked0>) { + func @vectorized_load_f16(%a_ptr_init: tensor<256x!tt.ptr, #blocked0>, %cst : tensor<256xi1, #blocked0>, %cst_0 : tensor<256xf16, #blocked0>) { // CHECK: llvm.inline_asm - // CHECK-SAME: ld.global.v2.b32 + // CHECK-SAME: ld.global.b16 // CHECK: llvm.inline_asm - // CHECK-SAME: ld.global.v2.b32 + // CHECK-SAME: ld.global.b16 %1 = tt.load %a_ptr_init, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf16, #blocked0> return } @@ -59,7 +59,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { // ----- // TODO: Pending on the support of isSplat constant -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}> module attributes {"triton_gpu.num-warps" = 4 : i32} { // CHECK-LABEL: masked_load_const_other func @masked_load_const_other(%a_ptr_init : tensor<256x!tt.ptr, #blocked0>, %cst : tensor<256xi1, #blocked0>) { @@ -69,6 +69,40 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { } } +// ----- + +#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0]}> +module attributes {"triton_gpu.num-warps" = 2 : i32} { + // CHECK-LABEL: kernel__Pfp32_Pfp32_Pfp32_i32__3c256 + func @kernel__Pfp32_Pfp32_Pfp32_i32__3c256(%arg0: !tt.ptr {tt.divisibility = 4 : i32}, %arg1: !tt.ptr {tt.divisibility = 4 : i32}, %arg2: !tt.ptr {tt.divisibility = 4 : i32}, %arg3: i32) { + %c256_i32 = arith.constant 256 : i32 + %0 = tt.get_program_id {axis = 0 : i32} : i32 + %1 = arith.muli %0, %c256_i32 : i32 + %2 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked0> + %3 = tt.splat %1 : (i32) -> tensor<256xi32, #blocked0> + %4 = arith.addi %3, %2 : tensor<256xi32, #blocked0> + %5 = tt.splat %arg0 : (!tt.ptr) -> tensor<256x!tt.ptr, #blocked0> + %6 = tt.getelementptr %5, %4 : tensor<256x!tt.ptr, #blocked0> + %7 = tt.splat %arg1 : (!tt.ptr) -> tensor<256x!tt.ptr, #blocked0> + %8 = tt.getelementptr %7, %4 : tensor<256x!tt.ptr, #blocked0> + + // CHECK: ld.global.v4.b32 + %9 = tt.load %6 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked0> + // CHECK: ld.global.v4.b32 + %10 = tt.load %8 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked0> + %11 = arith.addf %9, %10 : tensor<256xf32, #blocked0> + %12 = tt.splat %arg2 : (!tt.ptr) -> tensor<256x!tt.ptr, #blocked0> + %13 = tt.getelementptr %12, %4 : tensor<256x!tt.ptr, #blocked0> + + // Store 4 elements to global + // CHECK: st.global.b32.v4 + tt.store %13, %11 : tensor<256xf32, #blocked0> + return + } +} + + + // TODO: Add a testcase to verify the optimization when ptr of the LoadOp // is from a GEP with const idx @@ -99,7 +133,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#blocked0 = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> module attributes {"triton_gpu.num-warps" = 4 : i32} { // CHECK-LABEL: basic_make_range func @basic_make_range() {