From 7d90a07d0bdafb219bcdf3675f3698694250db6e Mon Sep 17 00:00:00 2001 From: Keren Zhou Date: Wed, 30 Nov 2022 10:07:34 -0800 Subject: [PATCH] [Triton-MLIR][BACKEND] Refactor decompose insert_slice_async (#929) 1. Improve pipline's comment 2. Decompose insert_slice_async when load vector size is not supported 3. Add a test that could fail our gemm code Copy my comments here: There's a knob that may cause performance regression when decomposition has been performed. We should remove this knob once we have thorough analysis on async wait. Currently, we decompose `insert_slice_async` into `load` and `insert_slice` without knowing which `async_wait` is responsible for the `insert_slice_async`. To guarantee correctness, we blindly set the `async_wait` to wait for all async ops if any `insert_slice_async` has been decomposed. There are two options to improve this: 1. We can perform a dataflow analysis to find the `async_wait` that is responsible for the `insert_slice_async` in the backend. 4. We can modify the pipeline to perform the decomposition before the `async_wait` is inserted. However, it is also risky because we don't know the correct vectorized shape yet in the pipeline pass. Making the pipeline pass aware of the vectorization could introduce additional dependencies on the AxisInfoAnalysis and the Coalesce analysis. --- include/triton/Analysis/AxisInfo.h | 6 + include/triton/Dialect/TritonGPU/IR/Dialect.h | 6 +- .../Dialect/TritonGPU/IR/TritonGPUOps.td | 4 + lib/Analysis/AxisInfo.cpp | 42 ++++++ .../TritonGPUToLLVM/TritonGPUToLLVM.cpp | 121 ++++++++++-------- .../TritonToTritonGPU/TritonToTritonGPU.cpp | 4 +- lib/Dialect/Triton/IR/Ops.cpp | 3 +- lib/Dialect/TritonGPU/IR/Dialect.cpp | 15 ++- lib/Dialect/TritonGPU/Transforms/Pipeline.cpp | 76 ++++++----- python/src/triton.cc | 5 +- python/tests/test_gemm.py | 5 +- test/Conversion/tritongpu_to_llvm.mlir | 39 ++++++ 12 files changed, 219 insertions(+), 107 deletions(-) diff --git a/include/triton/Analysis/AxisInfo.h b/include/triton/Analysis/AxisInfo.h index 6026e2648..f9cb2e66b 100644 --- a/include/triton/Analysis/AxisInfo.h +++ b/include/triton/Analysis/AxisInfo.h @@ -131,6 +131,12 @@ public: ChangeResult visitOperation(Operation *op, ArrayRef *> operands) override; + + unsigned getPtrVectorSize(Value ptr); + + unsigned getPtrAlignment(Value ptr); + + unsigned getMaskAlignment(Value mask); }; } // namespace mlir diff --git a/include/triton/Dialect/TritonGPU/IR/Dialect.h b/include/triton/Dialect/TritonGPU/IR/Dialect.h index c4f169935..8c24a5777 100644 --- a/include/triton/Dialect/TritonGPU/IR/Dialect.h +++ b/include/triton/Dialect/TritonGPU/IR/Dialect.h @@ -25,11 +25,11 @@ namespace gpu { unsigned getElemsPerThread(Type type); -SmallVector getThreadsPerWarp(Attribute layout); +SmallVector getThreadsPerWarp(const Attribute &layout); -SmallVector getWarpsPerCTA(Attribute layout); +SmallVector getWarpsPerCTA(const Attribute &layout); -SmallVector getSizePerThread(Attribute layout); +SmallVector getSizePerThread(const Attribute &layout); SmallVector getContigPerThread(Attribute layout); diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td index e5b1da097..1fe76624d 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td @@ -151,6 +151,10 @@ def TTG_InsertSliceAsyncOp : TTG_Op<"insert_slice_async", // attr-dict `:` type($src) `->` type($dst) //}]; + let extraClassDeclaration = [{ + static DenseSet getEligibleLoadByteWidth(int computeCapability); + }]; + // The custom parser could be replaced with oilist in LLVM-16 let parser = [{ return parseInsertSliceAsyncOp(parser, result); }]; diff --git a/lib/Analysis/AxisInfo.cpp b/lib/Analysis/AxisInfo.cpp index 58e3efa3c..42394c3a3 100644 --- a/lib/Analysis/AxisInfo.cpp +++ b/lib/Analysis/AxisInfo.cpp @@ -276,4 +276,46 @@ ChangeResult AxisInfoAnalysis::visitOperation( return result; } +unsigned AxisInfoAnalysis::getPtrVectorSize(Value ptr) { + auto tensorTy = ptr.getType().dyn_cast(); + if (!tensorTy) + return 1; + auto layout = tensorTy.getEncoding(); + auto shape = tensorTy.getShape(); + + // Here order should be ordered by contiguous first, so the first element + // should have the largest contiguous. + auto order = triton::gpu::getOrder(layout); + unsigned align = getPtrAlignment(ptr); + + unsigned contigPerThread = triton::gpu::getSizePerThread(layout)[order[0]]; + unsigned vec = std::min(align, contigPerThread); + vec = std::min(shape[order[0]], vec); + + return vec; +} + +unsigned AxisInfoAnalysis::getPtrAlignment(Value ptr) { + auto tensorTy = ptr.getType().dyn_cast(); + if (!tensorTy) + return 1; + auto axisInfo = lookupLatticeElement(ptr)->getValue(); + auto layout = tensorTy.getEncoding(); + auto order = triton::gpu::getOrder(layout); + unsigned maxMultiple = axisInfo.getDivisibility(order[0]); + unsigned maxContig = axisInfo.getContiguity(order[0]); + unsigned alignment = std::min(maxMultiple, maxContig); + return alignment; +} + +unsigned AxisInfoAnalysis::getMaskAlignment(Value mask) { + auto tensorTy = mask.getType().dyn_cast(); + if (!tensorTy) + return 1; + auto maskOrder = triton::gpu::getOrder(tensorTy.getEncoding()); + auto maskAxis = lookupLatticeElement(mask)->getValue(); + auto alignment = std::max(maskAxis.getConstancy(maskOrder[0]), 1); + return alignment; +} + } // namespace mlir diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index acf33a905..349ccef01 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -822,7 +822,7 @@ struct ArithConstantSplatOpConversion // Contains some helper functions for both Load and Store conversions. struct LoadStoreConversionBase : public ConvertTritonGPUOpToLLVMPatternBase { explicit LoadStoreConversionBase(AxisInfoAnalysis &axisAnalysisPass) - : AxisAnalysisPass(axisAnalysisPass) {} + : axisAnalysisPass(axisAnalysisPass) {} // Get corresponding LLVM element values of \param value. static SmallVector getLLVMElems(Value value, Value llValue, @@ -838,51 +838,15 @@ struct LoadStoreConversionBase : public ConvertTritonGPUOpToLLVMPatternBase { } unsigned getVectorSize(Value ptr) const { - auto tensorTy = ptr.getType().dyn_cast(); - if (!tensorTy) - return 1; - auto layout = tensorTy.getEncoding(); - auto shape = tensorTy.getShape(); - - auto axisInfo = getAxisInfo(ptr); - // Here order should be ordered by contiguous first, so the first element - // should have the largest contiguous. - auto order = getOrder(layout); - unsigned align = getAlignment(ptr, layout); - - unsigned contigPerThread = getSizePerThread(layout)[order[0]]; - unsigned vec = std::min(align, contigPerThread); - vec = std::min(shape[order[0]], vec); - - return vec; - } - - unsigned getAlignment(Value val, const Attribute &layout) const { - auto axisInfo = getAxisInfo(val); - auto order = getOrder(layout); - unsigned maxMultiple = axisInfo->getDivisibility(order[0]); - unsigned maxContig = axisInfo->getContiguity(order[0]); - unsigned alignment = std::min(maxMultiple, maxContig); - return alignment; + return axisAnalysisPass.getPtrVectorSize(ptr); } unsigned getMaskAlignment(Value mask) const { - auto tensorTy = mask.getType().cast(); - auto maskOrder = getOrder(tensorTy.getEncoding()); - auto maskAxis = getAxisInfo(mask); - return std::max(maskAxis->getConstancy(maskOrder[0]), 1); - } - - llvm::Optional getAxisInfo(Value val) const { - if (auto it = AxisAnalysisPass.lookupLatticeElement(val)) { - return it->getValue(); - } - - return llvm::Optional{}; + return axisAnalysisPass.getMaskAlignment(mask); } protected: - AxisInfoAnalysis &AxisAnalysisPass; + AxisInfoAnalysis &axisAnalysisPass; }; struct LoadOpConversion @@ -4604,30 +4568,68 @@ private: }); } - void decomposeInsertSliceAsyncOp(ModuleOp mod, - TritonGPUToLLVMTypeConverter &converter) { - // cp.async is supported in Ampere and later - if (computeCapability >= 80) - return; - + void decomposeInsertSliceAsyncOp(ModuleOp mod) { + AxisInfoAnalysis axisInfoAnalysis(mod.getContext()); + axisInfoAnalysis.run(mod); + // TODO(Keren): This is a hacky knob that may cause performance regression + // when decomposition has been performed. We should remove this knob once we + // have thorough analysis on async wait. Currently, we decompose + // `insert_slice_async` into `load` and `insert_slice` without knowing which + // `async_wait` is responsible for the `insert_slice_async`. To guarantee + // correctness, we blindly set the `async_wait` to wait for all async ops. + // + // There are two options to improve this: + // 1. We can perform a dataflow analysis to find the `async_wait` that is + // responsible for the `insert_slice_async` in the backend. + // 2. We can modify the pipeline to perform the decomposition before the + // `async_wait` is inserted. However, it is also risky because we don't know + // the correct vectorized shape yet in the pipeline pass. Making the + // pipeline pass aware of the vectorization could introduce additional + // dependencies on the AxisInfoAnalysis and the Coalesce analysis. + bool decomposed = false; // insert_slice_async %src, %dst, %idx, %mask, %other // => // %tmp = load %src, %mask, %other // %res = insert_slice %tmp into %dst[%idx] mod.walk([&](triton::gpu::InsertSliceAsyncOp insertSliceAsyncOp) -> void { OpBuilder builder(insertSliceAsyncOp); - // load - auto srcTy = insertSliceAsyncOp.src().getType().cast(); - auto dstTy = insertSliceAsyncOp.getType().cast(); + + // Get the vectorized load size + auto src = insertSliceAsyncOp.src(); + auto dst = insertSliceAsyncOp.dst(); + auto srcTy = src.getType().cast(); + auto dstTy = dst.getType().cast(); auto srcBlocked = srcTy.getEncoding().dyn_cast(); - auto elemTy = converter.convertType(dstTy.getElementType()); - auto tmpTy = RankedTensorType::get(srcTy.getShape(), elemTy, srcBlocked); + auto resSharedLayout = + dstTy.getEncoding().dyn_cast(); + auto resElemTy = dstTy.getElementType(); + unsigned inVec = axisInfoAnalysis.getPtrVectorSize(src); + unsigned outVec = resSharedLayout.getVec(); + unsigned minVec = std::min(outVec, inVec); + auto maxBitWidth = + std::max(128, resElemTy.getIntOrFloatBitWidth()); + auto vecBitWidth = resElemTy.getIntOrFloatBitWidth() * minVec; + auto bitWidth = std::min(maxBitWidth, vecBitWidth); + auto byteWidth = bitWidth / 8; + + // If the load byte width is not eligible or the current compute + // capability does not support async copy, then we do decompose + if (triton::gpu::InsertSliceAsyncOp::getEligibleLoadByteWidth( + computeCapability) + .contains(byteWidth) && + computeCapability >= 80) + return; + + // load + auto tmpTy = + RankedTensorType::get(srcTy.getShape(), resElemTy, srcBlocked); auto loadOp = builder.create( insertSliceAsyncOp.getLoc(), tmpTy, insertSliceAsyncOp.src(), insertSliceAsyncOp.mask(), insertSliceAsyncOp.other(), insertSliceAsyncOp.cache(), insertSliceAsyncOp.evict(), insertSliceAsyncOp.isVolatile()); + // insert_slice auto axis = insertSliceAsyncOp.axis(); auto intAttr = [&](int64_t v) { return builder.getI64IntegerAttr(v); }; @@ -4645,10 +4647,20 @@ private: // Replace insertSliceAsyncOp.replaceAllUsesWith(insertSliceOp.getResult()); insertSliceAsyncOp.erase(); + decomposed = true; }); + // async wait is supported in Ampere and later mod.walk([&](triton::gpu::AsyncWaitOp asyncWaitOp) -> void { - asyncWaitOp.erase(); + if (computeCapability < 80) { + asyncWaitOp.erase(); + } else if (decomposed) { + OpBuilder builder(asyncWaitOp); + // Wait for all previous async ops + auto newAsyncWaitOp = builder.create( + asyncWaitOp.getLoc(), builder.getI64IntegerAttr(0)); + asyncWaitOp.erase(); + } }); } @@ -4671,7 +4683,7 @@ public: // step 1: Decompose unoptimized layout conversions to use shared memory // step 2: Decompose insert_slice_async to use load + insert_slice for - // pre-Ampere architectures + // pre-Ampere architectures or unsupported vectorized load sizes // step 3: Allocate shared memories and insert barriers // step 4: Convert SCF to CFG // step 5: Convert FuncOp to LLVMFuncOp via partial conversion @@ -4681,10 +4693,9 @@ public: // separation between 1/4 is that, step 3 is out of the scope of Dialect // Conversion, thus we need to make sure the smem is not revised during the // conversion of step 4. - decomposeBlockedToDotOperand(mod); - decomposeInsertSliceAsyncOp(mod, typeConverter); + decomposeInsertSliceAsyncOp(mod); Allocation allocation(mod); MembarAnalysis membar(&allocation); diff --git a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp index 637acab64..b9fd481a6 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp @@ -286,8 +286,8 @@ struct TritonAtomicCASPattern matchAndRewrite(triton::AtomicCASOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp( - op, typeConverter->convertType(op.getType()), - adaptor.ptr(), adaptor.cmp(), adaptor.val()); + op, typeConverter->convertType(op.getType()), adaptor.ptr(), + adaptor.cmp(), adaptor.val()); return success(); } }; diff --git a/lib/Dialect/Triton/IR/Ops.cpp b/lib/Dialect/Triton/IR/Ops.cpp index 88bdbdc38..cba9e8b6b 100644 --- a/lib/Dialect/Triton/IR/Ops.cpp +++ b/lib/Dialect/Triton/IR/Ops.cpp @@ -241,7 +241,8 @@ mlir::LogicalResult mlir::triton::ReduceOp::inferReturnTypes( auto argTy = arg.getType().cast(); auto argEltTy = argTy.getElementType(); auto i32Ty = IntegerType::get(argEltTy.getContext(), 32); - auto redOp = attributes.get("redOp").cast().getValue(); + auto redOp = + attributes.get("redOp").cast().getValue(); bool withIndex = mlir::triton::ReduceOp::withIndex(redOp); auto retEltTy = withIndex ? i32Ty : argEltTy; auto retShape = argTy.getShape().vec(); diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index a178470d9..3592c52d4 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -71,7 +71,7 @@ unsigned getElemsPerThread(Type type) { return getElemsPerThread(tensorType.getEncoding(), tensorType.getShape()); } -SmallVector getThreadsPerWarp(Attribute layout) { +SmallVector getThreadsPerWarp(const Attribute &layout) { if (auto blockedLayout = layout.dyn_cast()) { return SmallVector(blockedLayout.getThreadsPerWarp().begin(), blockedLayout.getThreadsPerWarp().end()); @@ -86,7 +86,7 @@ SmallVector getThreadsPerWarp(Attribute layout) { return {}; } -SmallVector getWarpsPerCTA(Attribute layout) { +SmallVector getWarpsPerCTA(const Attribute &layout) { if (auto blockedLayout = layout.dyn_cast()) { return SmallVector(blockedLayout.getWarpsPerCTA().begin(), blockedLayout.getWarpsPerCTA().end()); @@ -99,7 +99,7 @@ SmallVector getWarpsPerCTA(Attribute layout) { return {}; } -SmallVector getSizePerThread(Attribute layout) { +SmallVector getSizePerThread(const Attribute &layout) { if (auto blockedLayout = layout.dyn_cast()) { return SmallVector(blockedLayout.getSizePerThread().begin(), blockedLayout.getSizePerThread().end()); @@ -659,6 +659,15 @@ void printInsertSliceAsyncOp(OpAsmPrinter &printer, printer.printStrippedAttrOrType(insertSliceAsyncOp.result().getType()); } +DenseSet +InsertSliceAsyncOp::getEligibleLoadByteWidth(int computeCapability) { + DenseSet validLoadBytes; + if (computeCapability >= 80) { + validLoadBytes = {4, 8, 16}; + } + return validLoadBytes; +} + //===----------------------------------------------------------------------===// // ASM Interface (i.e.: alias) //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp index 4ef695276..b2292d901 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp @@ -25,18 +25,20 @@ static Type getI1SameShape(Value v) { tensorType.getEncoding()); } +#define int_attr(num) builder.getI64IntegerAttr(num) + namespace { class LoopPipeliner { - /// cache forOp we are working on + /// Cache forOp we are working on scf::ForOp forOp; - /// cache YieldOp for this forOp + /// Cache YieldOp for this forOp scf::YieldOp yieldOp; - /// loads to be pipelined + /// Loads to be pipelined SetVector loads; - /// the value that each load will be mapped to (after layout conversion) + /// The value that each load will be mapped to (after layout conversion) DenseMap loadsMapping; /// load => buffer DenseMap loadsBuffer; @@ -51,7 +53,7 @@ class LoopPipeliner { /// Value loopIterIdx; - /// comments on numStages: + /// Comments on numStages: /// [0, numStages-1) are in the prologue /// numStages-1 is appended after the loop body int numStages; @@ -61,6 +63,7 @@ class LoopPipeliner { /// Block arguments that loads depend on DenseSet depArgs; + /// Operations (inside the loop body) that loads depend on DenseSet depOps; @@ -71,7 +74,7 @@ class LoopPipeliner { Value lookupOrDefault(Value origin, int stage); - /// returns a empty buffer of size + /// Returns a empty buffer of size ttg::AllocTensorOp allocateEmptyBuffer(Operation *op, OpBuilder &builder); public: @@ -84,7 +87,7 @@ public: /// Collect loads to pipeline. Return success if we can pipeline this loop LogicalResult initialize(); - /// emit pipelined loads (before loop body) + /// Emit pipelined loads (before loop body) void emitPrologue(); /// emit pipelined loads (after loop body) @@ -134,7 +137,7 @@ void LoopPipeliner::collectDeps(Value v, int stages, DenseSet &deps) { ttg::AllocTensorOp LoopPipeliner::allocateEmptyBuffer(Operation *op, OpBuilder &builder) { - // allocate a buffer for each pipelined tensor + // Allocate a buffer for each pipelined tensor // shape: e.g. (numStages==4), <32x64xbf16> -> <4x32x64xbf16> Value convertLayout = loadsMapping[op->getResult(0)]; if (auto tensorType = convertLayout.getType().dyn_cast()) { @@ -215,9 +218,9 @@ LogicalResult LoopPipeliner::initialize() { loads.insert(loadOp); } - // we have some loads to pipeline + // We have some loads to pipeline if (!loads.empty()) { - // update depArgs & depOps + // Update depArgs & depOps for (Value loadOp : loads) { for (Value dep : loadDeps[loadOp]) { // TODO: we should record the stage that the value is depended on @@ -244,23 +247,20 @@ void LoopPipeliner::emitPrologue() { setValueMapping(arg, operand.get(), 0); } - // helper to construct int attribute - auto intAttr = [&](int64_t val) { return builder.getI64IntegerAttr(val); }; - // prologue from [0, numStage-1) Value iv = forOp.getLowerBound(); pipelineIterIdx = builder.create(iv.getLoc(), 0, 32); for (int stage = 0; stage < numStages - 1; ++stage) { - // special handling for induction variable as the increment is implicit + // Special handling for induction variable as the increment is implicit if (stage != 0) iv = builder.create(iv.getLoc(), iv, forOp.getStep()); setValueMapping(forOp.getInductionVar(), iv, stage); - // special handling for loop condition as there is no condition in ForOp + // Special handling for loop condition as there is no condition in ForOp Value loopCond = builder.create( iv.getLoc(), arith::CmpIPredicate::slt, iv, forOp.getUpperBound()); - // rematerialize peeled values + // Rematerialize peeled values SmallVector orderedDeps; for (Operation &op : forOp.getLoopBody().front()) { if (depOps.contains(&op)) @@ -314,7 +314,7 @@ void LoopPipeliner::emitPrologue() { } } - // update mapping of results + // Update mapping of results for (unsigned dstIdx : llvm::seq(unsigned(0), op->getNumResults())) { Value originalResult = op->getResult(dstIdx); // copy_async will update the value of its only use @@ -350,13 +350,14 @@ void LoopPipeliner::emitPrologue() { loadsBufferType[loadOp].getEncoding()); Value extractSlice = builder.create( loadOp.getLoc(), sliceType, loadStageBuffer[loadOp][numStages - 1], - SmallVector{intAttr(0), intAttr(0), intAttr(0)}, - SmallVector{intAttr(1), intAttr(sliceType.getShape()[0]), - intAttr(sliceType.getShape()[1])}, - SmallVector{intAttr(1), intAttr(1), intAttr(1)}); + SmallVector{int_attr(0), int_attr(0), int_attr(0)}, + SmallVector{int_attr(1), + int_attr(sliceType.getShape()[0]), + int_attr(sliceType.getShape()[1])}, + SmallVector{int_attr(1), int_attr(1), int_attr(1)}); loadsExtract[loadOp] = extractSlice; } - // bump up loopIterIdx, this is used for getting the correct slice for the + // Bump up loopIterIdx, this is used for getting the correct slice for the // *next* iteration loopIterIdx = builder.create( loopIterIdx.getLoc(), loopIterIdx, @@ -365,9 +366,6 @@ void LoopPipeliner::emitPrologue() { void LoopPipeliner::emitEpilogue() { // If there's any outstanding async copies, we need to wait for them. - // TODO(Keren): We may want to completely avoid the async copies in the last - // few iterations by setting is_masked attribute to true. We don't want to use - // the mask operand because it's a tensor but not a scalar. OpBuilder builder(forOp); OpBuilder::InsertionGuard g(builder); builder.setInsertionPointAfter(forOp); @@ -376,9 +374,8 @@ void LoopPipeliner::emitEpilogue() { scf::ForOp LoopPipeliner::createNewForOp() { OpBuilder builder(forOp); - auto intAttr = [&](int64_t v) { return builder.getI64IntegerAttr(v); }; - // order of new args: + // Order of new args: // (original args), // (insertSliceAsync buffer at stage numStages - 1) for each load // (extracted tensor) for each load @@ -465,7 +462,7 @@ scf::ForOp LoopPipeliner::createNewForOp() { newForOp.getRegionIterArgs()[argIdx + depArgsBeginIdx]); ++argIdx; } - // special handling for iv & loop condition + // Special handling for iv & loop condition Value nextIV = builder.create( newForOp.getInductionVar().getLoc(), newForOp.getRegionIterArgs()[nextIVIdx], newForOp.getStep()); @@ -473,7 +470,7 @@ scf::ForOp LoopPipeliner::createNewForOp() { builder.create(nextIV.getLoc(), arith::CmpIPredicate::slt, nextIV, newForOp.getUpperBound()); - // slice index + // Slice index SmallVector nextBuffers; SmallVector extractSlices; @@ -490,7 +487,7 @@ scf::ForOp LoopPipeliner::createNewForOp() { for (Operation *op : orderedDeps) { Operation *nextOp = nullptr; - // update loading mask + // Update loading mask if (loads.contains(op->getResult(0))) { auto loadOp = llvm::cast(op); Value mask = loadOp.mask(); @@ -500,7 +497,7 @@ scf::ForOp LoopPipeliner::createNewForOp() { mask.getLoc(), mask.getType(), nextLoopCond); newMask = builder.create( mask.getLoc(), splatCond, nextMapping.lookupOrDefault(mask)); - // if mask is defined outside the loop, don't update the map more than + // If mask is defined outside the loop, don't update the map more than // once if (!(forOp.isDefinedOutsideOfLoop(mask) && nextMapping.contains(mask))) nextMapping.map(mask, newMask); @@ -522,18 +519,19 @@ scf::ForOp LoopPipeliner::createNewForOp() { loadsBufferType[loadOp].getEncoding()); nextOp = builder.create( op->getLoc(), sliceType, insertAsyncOp, - SmallVector{extractSliceIndex, intAttr(0), intAttr(0)}, - SmallVector{intAttr(1), - intAttr(sliceType.getShape()[0]), - intAttr(sliceType.getShape()[1])}, - SmallVector{intAttr(1), intAttr(1), intAttr(1)}); + SmallVector{extractSliceIndex, int_attr(0), + int_attr(0)}, + SmallVector{int_attr(1), + int_attr(sliceType.getShape()[0]), + int_attr(sliceType.getShape()[1])}, + SmallVector{int_attr(1), int_attr(1), int_attr(1)}); extractSlices.push_back(nextOp->getResult(0)); } else nextOp = builder.clone(*op, nextMapping); - // update mapping of results + // Update mapping of results for (unsigned dstIdx : llvm::seq(unsigned(0), op->getNumResults())) { nextMapping.map(op->getResult(dstIdx), nextOp->getResult(dstIdx)); - // if this is a loop-carried value, update the mapping for yield + // If this is a loop-carried value, update the mapping for yield auto originYield = cast(forOp.getBody()->getTerminator()); for (OpOperand &operand : originYield->getOpOperands()) { if (operand.get() == op->getResult(dstIdx)) { @@ -583,7 +581,7 @@ scf::ForOp LoopPipeliner::createNewForOp() { it->getDefiningOp()->moveAfter(asyncWait); } - // bump iteration count + // Bump iteration count pipelineIterIdx = builder.create( nextIV.getLoc(), pipelineIterIdx, builder.create(nextIV.getLoc(), 1, 32)); diff --git a/python/src/triton.cc b/python/src/triton.cc index 141f16006..95fa120bc 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -341,7 +341,7 @@ void init_triton_ir(py::module &&m) { return funcs[0]; }); - m.def("make_attr", + m.def("make_attr", [](const std::vector &values, mlir::MLIRContext &context) { return mlir::DenseIntElementsAttr::get( mlir::RankedTensorType::get( @@ -1113,7 +1113,8 @@ void init_triton_ir(py::module &&m) { mlir::Value &val) -> mlir::Value { auto loc = self.getUnknownLoc(); mlir::Type dstType; - if (auto srcTensorType = ptr.getType().dyn_cast()) { + if (auto srcTensorType = + ptr.getType().dyn_cast()) { mlir::Type dstElemType = srcTensorType.getElementType() .cast() .getPointeeType(); diff --git a/python/tests/test_gemm.py b/python/tests/test_gemm.py index 0edd53b30..e37c8490c 100644 --- a/python/tests/test_gemm.py +++ b/python/tests/test_gemm.py @@ -172,8 +172,9 @@ def get_proper_err(a, b, golden): [128, 64, 128, 4, 128, 64, 128, False, False], [16, 16, 16, 16, 16, 16, 16, False, False], # wpt overflow issue # K-Forloop - [32, 32, 64, 4, 32, 32, 32, False, False], # Single shared encoding - [16, 16, 128, 4, 16, 16, 16, False, False], # Single shared encoding and small k + #[16, 16, 64, 4, 8, 8, 8, False, False], # Wrap threads + [32, 32, 64, 4, 32, 32, 32, False, False], # Single shared encoding + [16, 16, 128, 4, 16, 16, 16, False, False], # Single shared encoding and small k [64, 32, 128, 4, 64, 32, 64, False, False], [128, 16, 128, 4, 128, 16, 32, False, False], [32, 16, 128, 4, 32, 16, 32, False, False], diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index e7f9a4b90..8ec45a385 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -387,6 +387,45 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { // ----- +#block0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [4], warpsPerCTA = [4], order = [0]}> +#block1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [8], warpsPerCTA = [4], order = [0]}> +#block2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 1], warpsPerCTA = [4, 1], order = [1, 0]}> +#block3 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 8], warpsPerCTA = [1, 4], order = [1, 0]}> +#slice2d1 = #triton_gpu.slice<{dim = 1, parent=#block2}> +#slice3d0 = #triton_gpu.slice<{dim = 0, parent=#block3}> +#AL = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#A = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [1, 0]}> +module attributes {"triton_gpu.num-warps" = 4 : i32} { + // CHECK-LABEL: basic_insert_slice_async_fallback + func @basic_insert_slice_async_fallback(%arg0: !tt.ptr {tt.divisibility = 1 : i32}) { + %off0_ = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #slice2d1> + %off1_ = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<64xi32, #slice3d0> + %off0 = tt.expand_dims %off0_ {axis = 1 : i32} : (tensor<16xi32, #slice2d1>) -> tensor<16x1xi32, #block2> + %off1 = tt.expand_dims %off1_ {axis = 0 : i32} : (tensor<64xi32, #slice3d0>) -> tensor<1x64xi32, #block3> + %broadcast_off0_scalar = tt.broadcast %off0 : (tensor<16x1xi32, #block2>) -> tensor<16x64xi32, #block2> + %cst_scalar = arith.constant 64 : i32 + %cst = tt.splat %cst_scalar : (i32) -> tensor<16x64xi32, #block2> + %broadcast_off0_ = arith.muli %broadcast_off0_scalar, %cst : tensor<16x64xi32, #block2> + %broadcast_off1_ = tt.broadcast %off1 : (tensor<1x64xi32, #block3>) -> tensor<16x64xi32, #block3> + %broadcast_off0 = triton_gpu.convert_layout %broadcast_off0_ : (tensor<16x64xi32, #block2>) -> tensor<16x64xi32, #AL> + %broadcast_off1 = triton_gpu.convert_layout %broadcast_off1_ : (tensor<16x64xi32, #block3>) -> tensor<16x64xi32, #AL> + %off = arith.addi %broadcast_off0, %broadcast_off1 : tensor<16x64xi32, #AL> + %a_init = tt.splat %arg0 : (!tt.ptr) -> tensor<16x64x!tt.ptr, #AL> + %a_ptr = tt.addptr %a_init, %off : tensor<16x64x!tt.ptr, #AL> + %tensor = triton_gpu.alloc_tensor : tensor<2x16x64xf16, #A> + %index = arith.constant 1 : i32 + + // CHECK: llvm.load + // CHECK-SAME: !llvm.ptr, 3> + // CHECK: llvm.load + // CHECK-SAME: !llvm.ptr, 3> + %a = triton_gpu.insert_slice_async %a_ptr, %tensor, %index {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x64x!tt.ptr, #AL> -> tensor<2x16x64xf16, #A> + return + } +} + +// ----- + #block0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [4], warpsPerCTA = [4], order = [0]}> #block1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [8], warpsPerCTA = [4], order = [0]}> #block2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 1], warpsPerCTA = [4, 1], order = [1, 0]}>