diff --git a/include/triton/Analysis/AxisInfo.h b/include/triton/Analysis/AxisInfo.h index 6026e2648..f9cb2e66b 100644 --- a/include/triton/Analysis/AxisInfo.h +++ b/include/triton/Analysis/AxisInfo.h @@ -131,6 +131,12 @@ public: ChangeResult visitOperation(Operation *op, ArrayRef *> operands) override; + + unsigned getPtrVectorSize(Value ptr); + + unsigned getPtrAlignment(Value ptr); + + unsigned getMaskAlignment(Value mask); }; } // namespace mlir diff --git a/include/triton/Dialect/TritonGPU/IR/Dialect.h b/include/triton/Dialect/TritonGPU/IR/Dialect.h index c4f169935..8c24a5777 100644 --- a/include/triton/Dialect/TritonGPU/IR/Dialect.h +++ b/include/triton/Dialect/TritonGPU/IR/Dialect.h @@ -25,11 +25,11 @@ namespace gpu { unsigned getElemsPerThread(Type type); -SmallVector getThreadsPerWarp(Attribute layout); +SmallVector getThreadsPerWarp(const Attribute &layout); -SmallVector getWarpsPerCTA(Attribute layout); +SmallVector getWarpsPerCTA(const Attribute &layout); -SmallVector getSizePerThread(Attribute layout); +SmallVector getSizePerThread(const Attribute &layout); SmallVector getContigPerThread(Attribute layout); diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td index e5b1da097..1fe76624d 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td @@ -151,6 +151,10 @@ def TTG_InsertSliceAsyncOp : TTG_Op<"insert_slice_async", // attr-dict `:` type($src) `->` type($dst) //}]; + let extraClassDeclaration = [{ + static DenseSet getEligibleLoadByteWidth(int computeCapability); + }]; + // The custom parser could be replaced with oilist in LLVM-16 let parser = [{ return parseInsertSliceAsyncOp(parser, result); }]; diff --git a/lib/Analysis/AxisInfo.cpp b/lib/Analysis/AxisInfo.cpp index 58e3efa3c..42394c3a3 100644 --- a/lib/Analysis/AxisInfo.cpp +++ b/lib/Analysis/AxisInfo.cpp @@ -276,4 +276,46 @@ ChangeResult AxisInfoAnalysis::visitOperation( return result; } +unsigned AxisInfoAnalysis::getPtrVectorSize(Value ptr) { + auto tensorTy = ptr.getType().dyn_cast(); + if (!tensorTy) + return 1; + auto layout = tensorTy.getEncoding(); + auto shape = tensorTy.getShape(); + + // Here order should be ordered by contiguous first, so the first element + // should have the largest contiguous. + auto order = triton::gpu::getOrder(layout); + unsigned align = getPtrAlignment(ptr); + + unsigned contigPerThread = triton::gpu::getSizePerThread(layout)[order[0]]; + unsigned vec = std::min(align, contigPerThread); + vec = std::min(shape[order[0]], vec); + + return vec; +} + +unsigned AxisInfoAnalysis::getPtrAlignment(Value ptr) { + auto tensorTy = ptr.getType().dyn_cast(); + if (!tensorTy) + return 1; + auto axisInfo = lookupLatticeElement(ptr)->getValue(); + auto layout = tensorTy.getEncoding(); + auto order = triton::gpu::getOrder(layout); + unsigned maxMultiple = axisInfo.getDivisibility(order[0]); + unsigned maxContig = axisInfo.getContiguity(order[0]); + unsigned alignment = std::min(maxMultiple, maxContig); + return alignment; +} + +unsigned AxisInfoAnalysis::getMaskAlignment(Value mask) { + auto tensorTy = mask.getType().dyn_cast(); + if (!tensorTy) + return 1; + auto maskOrder = triton::gpu::getOrder(tensorTy.getEncoding()); + auto maskAxis = lookupLatticeElement(mask)->getValue(); + auto alignment = std::max(maskAxis.getConstancy(maskOrder[0]), 1); + return alignment; +} + } // namespace mlir diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index acf33a905..349ccef01 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -822,7 +822,7 @@ struct ArithConstantSplatOpConversion // Contains some helper functions for both Load and Store conversions. struct LoadStoreConversionBase : public ConvertTritonGPUOpToLLVMPatternBase { explicit LoadStoreConversionBase(AxisInfoAnalysis &axisAnalysisPass) - : AxisAnalysisPass(axisAnalysisPass) {} + : axisAnalysisPass(axisAnalysisPass) {} // Get corresponding LLVM element values of \param value. static SmallVector getLLVMElems(Value value, Value llValue, @@ -838,51 +838,15 @@ struct LoadStoreConversionBase : public ConvertTritonGPUOpToLLVMPatternBase { } unsigned getVectorSize(Value ptr) const { - auto tensorTy = ptr.getType().dyn_cast(); - if (!tensorTy) - return 1; - auto layout = tensorTy.getEncoding(); - auto shape = tensorTy.getShape(); - - auto axisInfo = getAxisInfo(ptr); - // Here order should be ordered by contiguous first, so the first element - // should have the largest contiguous. - auto order = getOrder(layout); - unsigned align = getAlignment(ptr, layout); - - unsigned contigPerThread = getSizePerThread(layout)[order[0]]; - unsigned vec = std::min(align, contigPerThread); - vec = std::min(shape[order[0]], vec); - - return vec; - } - - unsigned getAlignment(Value val, const Attribute &layout) const { - auto axisInfo = getAxisInfo(val); - auto order = getOrder(layout); - unsigned maxMultiple = axisInfo->getDivisibility(order[0]); - unsigned maxContig = axisInfo->getContiguity(order[0]); - unsigned alignment = std::min(maxMultiple, maxContig); - return alignment; + return axisAnalysisPass.getPtrVectorSize(ptr); } unsigned getMaskAlignment(Value mask) const { - auto tensorTy = mask.getType().cast(); - auto maskOrder = getOrder(tensorTy.getEncoding()); - auto maskAxis = getAxisInfo(mask); - return std::max(maskAxis->getConstancy(maskOrder[0]), 1); - } - - llvm::Optional getAxisInfo(Value val) const { - if (auto it = AxisAnalysisPass.lookupLatticeElement(val)) { - return it->getValue(); - } - - return llvm::Optional{}; + return axisAnalysisPass.getMaskAlignment(mask); } protected: - AxisInfoAnalysis &AxisAnalysisPass; + AxisInfoAnalysis &axisAnalysisPass; }; struct LoadOpConversion @@ -4604,30 +4568,68 @@ private: }); } - void decomposeInsertSliceAsyncOp(ModuleOp mod, - TritonGPUToLLVMTypeConverter &converter) { - // cp.async is supported in Ampere and later - if (computeCapability >= 80) - return; - + void decomposeInsertSliceAsyncOp(ModuleOp mod) { + AxisInfoAnalysis axisInfoAnalysis(mod.getContext()); + axisInfoAnalysis.run(mod); + // TODO(Keren): This is a hacky knob that may cause performance regression + // when decomposition has been performed. We should remove this knob once we + // have thorough analysis on async wait. Currently, we decompose + // `insert_slice_async` into `load` and `insert_slice` without knowing which + // `async_wait` is responsible for the `insert_slice_async`. To guarantee + // correctness, we blindly set the `async_wait` to wait for all async ops. + // + // There are two options to improve this: + // 1. We can perform a dataflow analysis to find the `async_wait` that is + // responsible for the `insert_slice_async` in the backend. + // 2. We can modify the pipeline to perform the decomposition before the + // `async_wait` is inserted. However, it is also risky because we don't know + // the correct vectorized shape yet in the pipeline pass. Making the + // pipeline pass aware of the vectorization could introduce additional + // dependencies on the AxisInfoAnalysis and the Coalesce analysis. + bool decomposed = false; // insert_slice_async %src, %dst, %idx, %mask, %other // => // %tmp = load %src, %mask, %other // %res = insert_slice %tmp into %dst[%idx] mod.walk([&](triton::gpu::InsertSliceAsyncOp insertSliceAsyncOp) -> void { OpBuilder builder(insertSliceAsyncOp); - // load - auto srcTy = insertSliceAsyncOp.src().getType().cast(); - auto dstTy = insertSliceAsyncOp.getType().cast(); + + // Get the vectorized load size + auto src = insertSliceAsyncOp.src(); + auto dst = insertSliceAsyncOp.dst(); + auto srcTy = src.getType().cast(); + auto dstTy = dst.getType().cast(); auto srcBlocked = srcTy.getEncoding().dyn_cast(); - auto elemTy = converter.convertType(dstTy.getElementType()); - auto tmpTy = RankedTensorType::get(srcTy.getShape(), elemTy, srcBlocked); + auto resSharedLayout = + dstTy.getEncoding().dyn_cast(); + auto resElemTy = dstTy.getElementType(); + unsigned inVec = axisInfoAnalysis.getPtrVectorSize(src); + unsigned outVec = resSharedLayout.getVec(); + unsigned minVec = std::min(outVec, inVec); + auto maxBitWidth = + std::max(128, resElemTy.getIntOrFloatBitWidth()); + auto vecBitWidth = resElemTy.getIntOrFloatBitWidth() * minVec; + auto bitWidth = std::min(maxBitWidth, vecBitWidth); + auto byteWidth = bitWidth / 8; + + // If the load byte width is not eligible or the current compute + // capability does not support async copy, then we do decompose + if (triton::gpu::InsertSliceAsyncOp::getEligibleLoadByteWidth( + computeCapability) + .contains(byteWidth) && + computeCapability >= 80) + return; + + // load + auto tmpTy = + RankedTensorType::get(srcTy.getShape(), resElemTy, srcBlocked); auto loadOp = builder.create( insertSliceAsyncOp.getLoc(), tmpTy, insertSliceAsyncOp.src(), insertSliceAsyncOp.mask(), insertSliceAsyncOp.other(), insertSliceAsyncOp.cache(), insertSliceAsyncOp.evict(), insertSliceAsyncOp.isVolatile()); + // insert_slice auto axis = insertSliceAsyncOp.axis(); auto intAttr = [&](int64_t v) { return builder.getI64IntegerAttr(v); }; @@ -4645,10 +4647,20 @@ private: // Replace insertSliceAsyncOp.replaceAllUsesWith(insertSliceOp.getResult()); insertSliceAsyncOp.erase(); + decomposed = true; }); + // async wait is supported in Ampere and later mod.walk([&](triton::gpu::AsyncWaitOp asyncWaitOp) -> void { - asyncWaitOp.erase(); + if (computeCapability < 80) { + asyncWaitOp.erase(); + } else if (decomposed) { + OpBuilder builder(asyncWaitOp); + // Wait for all previous async ops + auto newAsyncWaitOp = builder.create( + asyncWaitOp.getLoc(), builder.getI64IntegerAttr(0)); + asyncWaitOp.erase(); + } }); } @@ -4671,7 +4683,7 @@ public: // step 1: Decompose unoptimized layout conversions to use shared memory // step 2: Decompose insert_slice_async to use load + insert_slice for - // pre-Ampere architectures + // pre-Ampere architectures or unsupported vectorized load sizes // step 3: Allocate shared memories and insert barriers // step 4: Convert SCF to CFG // step 5: Convert FuncOp to LLVMFuncOp via partial conversion @@ -4681,10 +4693,9 @@ public: // separation between 1/4 is that, step 3 is out of the scope of Dialect // Conversion, thus we need to make sure the smem is not revised during the // conversion of step 4. - decomposeBlockedToDotOperand(mod); - decomposeInsertSliceAsyncOp(mod, typeConverter); + decomposeInsertSliceAsyncOp(mod); Allocation allocation(mod); MembarAnalysis membar(&allocation); diff --git a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp index 637acab64..b9fd481a6 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp @@ -286,8 +286,8 @@ struct TritonAtomicCASPattern matchAndRewrite(triton::AtomicCASOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp( - op, typeConverter->convertType(op.getType()), - adaptor.ptr(), adaptor.cmp(), adaptor.val()); + op, typeConverter->convertType(op.getType()), adaptor.ptr(), + adaptor.cmp(), adaptor.val()); return success(); } }; diff --git a/lib/Dialect/Triton/IR/Ops.cpp b/lib/Dialect/Triton/IR/Ops.cpp index 88bdbdc38..cba9e8b6b 100644 --- a/lib/Dialect/Triton/IR/Ops.cpp +++ b/lib/Dialect/Triton/IR/Ops.cpp @@ -241,7 +241,8 @@ mlir::LogicalResult mlir::triton::ReduceOp::inferReturnTypes( auto argTy = arg.getType().cast(); auto argEltTy = argTy.getElementType(); auto i32Ty = IntegerType::get(argEltTy.getContext(), 32); - auto redOp = attributes.get("redOp").cast().getValue(); + auto redOp = + attributes.get("redOp").cast().getValue(); bool withIndex = mlir::triton::ReduceOp::withIndex(redOp); auto retEltTy = withIndex ? i32Ty : argEltTy; auto retShape = argTy.getShape().vec(); diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index a178470d9..3592c52d4 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -71,7 +71,7 @@ unsigned getElemsPerThread(Type type) { return getElemsPerThread(tensorType.getEncoding(), tensorType.getShape()); } -SmallVector getThreadsPerWarp(Attribute layout) { +SmallVector getThreadsPerWarp(const Attribute &layout) { if (auto blockedLayout = layout.dyn_cast()) { return SmallVector(blockedLayout.getThreadsPerWarp().begin(), blockedLayout.getThreadsPerWarp().end()); @@ -86,7 +86,7 @@ SmallVector getThreadsPerWarp(Attribute layout) { return {}; } -SmallVector getWarpsPerCTA(Attribute layout) { +SmallVector getWarpsPerCTA(const Attribute &layout) { if (auto blockedLayout = layout.dyn_cast()) { return SmallVector(blockedLayout.getWarpsPerCTA().begin(), blockedLayout.getWarpsPerCTA().end()); @@ -99,7 +99,7 @@ SmallVector getWarpsPerCTA(Attribute layout) { return {}; } -SmallVector getSizePerThread(Attribute layout) { +SmallVector getSizePerThread(const Attribute &layout) { if (auto blockedLayout = layout.dyn_cast()) { return SmallVector(blockedLayout.getSizePerThread().begin(), blockedLayout.getSizePerThread().end()); @@ -659,6 +659,15 @@ void printInsertSliceAsyncOp(OpAsmPrinter &printer, printer.printStrippedAttrOrType(insertSliceAsyncOp.result().getType()); } +DenseSet +InsertSliceAsyncOp::getEligibleLoadByteWidth(int computeCapability) { + DenseSet validLoadBytes; + if (computeCapability >= 80) { + validLoadBytes = {4, 8, 16}; + } + return validLoadBytes; +} + //===----------------------------------------------------------------------===// // ASM Interface (i.e.: alias) //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp index 4ef695276..b2292d901 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp @@ -25,18 +25,20 @@ static Type getI1SameShape(Value v) { tensorType.getEncoding()); } +#define int_attr(num) builder.getI64IntegerAttr(num) + namespace { class LoopPipeliner { - /// cache forOp we are working on + /// Cache forOp we are working on scf::ForOp forOp; - /// cache YieldOp for this forOp + /// Cache YieldOp for this forOp scf::YieldOp yieldOp; - /// loads to be pipelined + /// Loads to be pipelined SetVector loads; - /// the value that each load will be mapped to (after layout conversion) + /// The value that each load will be mapped to (after layout conversion) DenseMap loadsMapping; /// load => buffer DenseMap loadsBuffer; @@ -51,7 +53,7 @@ class LoopPipeliner { /// Value loopIterIdx; - /// comments on numStages: + /// Comments on numStages: /// [0, numStages-1) are in the prologue /// numStages-1 is appended after the loop body int numStages; @@ -61,6 +63,7 @@ class LoopPipeliner { /// Block arguments that loads depend on DenseSet depArgs; + /// Operations (inside the loop body) that loads depend on DenseSet depOps; @@ -71,7 +74,7 @@ class LoopPipeliner { Value lookupOrDefault(Value origin, int stage); - /// returns a empty buffer of size + /// Returns a empty buffer of size ttg::AllocTensorOp allocateEmptyBuffer(Operation *op, OpBuilder &builder); public: @@ -84,7 +87,7 @@ public: /// Collect loads to pipeline. Return success if we can pipeline this loop LogicalResult initialize(); - /// emit pipelined loads (before loop body) + /// Emit pipelined loads (before loop body) void emitPrologue(); /// emit pipelined loads (after loop body) @@ -134,7 +137,7 @@ void LoopPipeliner::collectDeps(Value v, int stages, DenseSet &deps) { ttg::AllocTensorOp LoopPipeliner::allocateEmptyBuffer(Operation *op, OpBuilder &builder) { - // allocate a buffer for each pipelined tensor + // Allocate a buffer for each pipelined tensor // shape: e.g. (numStages==4), <32x64xbf16> -> <4x32x64xbf16> Value convertLayout = loadsMapping[op->getResult(0)]; if (auto tensorType = convertLayout.getType().dyn_cast()) { @@ -215,9 +218,9 @@ LogicalResult LoopPipeliner::initialize() { loads.insert(loadOp); } - // we have some loads to pipeline + // We have some loads to pipeline if (!loads.empty()) { - // update depArgs & depOps + // Update depArgs & depOps for (Value loadOp : loads) { for (Value dep : loadDeps[loadOp]) { // TODO: we should record the stage that the value is depended on @@ -244,23 +247,20 @@ void LoopPipeliner::emitPrologue() { setValueMapping(arg, operand.get(), 0); } - // helper to construct int attribute - auto intAttr = [&](int64_t val) { return builder.getI64IntegerAttr(val); }; - // prologue from [0, numStage-1) Value iv = forOp.getLowerBound(); pipelineIterIdx = builder.create(iv.getLoc(), 0, 32); for (int stage = 0; stage < numStages - 1; ++stage) { - // special handling for induction variable as the increment is implicit + // Special handling for induction variable as the increment is implicit if (stage != 0) iv = builder.create(iv.getLoc(), iv, forOp.getStep()); setValueMapping(forOp.getInductionVar(), iv, stage); - // special handling for loop condition as there is no condition in ForOp + // Special handling for loop condition as there is no condition in ForOp Value loopCond = builder.create( iv.getLoc(), arith::CmpIPredicate::slt, iv, forOp.getUpperBound()); - // rematerialize peeled values + // Rematerialize peeled values SmallVector orderedDeps; for (Operation &op : forOp.getLoopBody().front()) { if (depOps.contains(&op)) @@ -314,7 +314,7 @@ void LoopPipeliner::emitPrologue() { } } - // update mapping of results + // Update mapping of results for (unsigned dstIdx : llvm::seq(unsigned(0), op->getNumResults())) { Value originalResult = op->getResult(dstIdx); // copy_async will update the value of its only use @@ -350,13 +350,14 @@ void LoopPipeliner::emitPrologue() { loadsBufferType[loadOp].getEncoding()); Value extractSlice = builder.create( loadOp.getLoc(), sliceType, loadStageBuffer[loadOp][numStages - 1], - SmallVector{intAttr(0), intAttr(0), intAttr(0)}, - SmallVector{intAttr(1), intAttr(sliceType.getShape()[0]), - intAttr(sliceType.getShape()[1])}, - SmallVector{intAttr(1), intAttr(1), intAttr(1)}); + SmallVector{int_attr(0), int_attr(0), int_attr(0)}, + SmallVector{int_attr(1), + int_attr(sliceType.getShape()[0]), + int_attr(sliceType.getShape()[1])}, + SmallVector{int_attr(1), int_attr(1), int_attr(1)}); loadsExtract[loadOp] = extractSlice; } - // bump up loopIterIdx, this is used for getting the correct slice for the + // Bump up loopIterIdx, this is used for getting the correct slice for the // *next* iteration loopIterIdx = builder.create( loopIterIdx.getLoc(), loopIterIdx, @@ -365,9 +366,6 @@ void LoopPipeliner::emitPrologue() { void LoopPipeliner::emitEpilogue() { // If there's any outstanding async copies, we need to wait for them. - // TODO(Keren): We may want to completely avoid the async copies in the last - // few iterations by setting is_masked attribute to true. We don't want to use - // the mask operand because it's a tensor but not a scalar. OpBuilder builder(forOp); OpBuilder::InsertionGuard g(builder); builder.setInsertionPointAfter(forOp); @@ -376,9 +374,8 @@ void LoopPipeliner::emitEpilogue() { scf::ForOp LoopPipeliner::createNewForOp() { OpBuilder builder(forOp); - auto intAttr = [&](int64_t v) { return builder.getI64IntegerAttr(v); }; - // order of new args: + // Order of new args: // (original args), // (insertSliceAsync buffer at stage numStages - 1) for each load // (extracted tensor) for each load @@ -465,7 +462,7 @@ scf::ForOp LoopPipeliner::createNewForOp() { newForOp.getRegionIterArgs()[argIdx + depArgsBeginIdx]); ++argIdx; } - // special handling for iv & loop condition + // Special handling for iv & loop condition Value nextIV = builder.create( newForOp.getInductionVar().getLoc(), newForOp.getRegionIterArgs()[nextIVIdx], newForOp.getStep()); @@ -473,7 +470,7 @@ scf::ForOp LoopPipeliner::createNewForOp() { builder.create(nextIV.getLoc(), arith::CmpIPredicate::slt, nextIV, newForOp.getUpperBound()); - // slice index + // Slice index SmallVector nextBuffers; SmallVector extractSlices; @@ -490,7 +487,7 @@ scf::ForOp LoopPipeliner::createNewForOp() { for (Operation *op : orderedDeps) { Operation *nextOp = nullptr; - // update loading mask + // Update loading mask if (loads.contains(op->getResult(0))) { auto loadOp = llvm::cast(op); Value mask = loadOp.mask(); @@ -500,7 +497,7 @@ scf::ForOp LoopPipeliner::createNewForOp() { mask.getLoc(), mask.getType(), nextLoopCond); newMask = builder.create( mask.getLoc(), splatCond, nextMapping.lookupOrDefault(mask)); - // if mask is defined outside the loop, don't update the map more than + // If mask is defined outside the loop, don't update the map more than // once if (!(forOp.isDefinedOutsideOfLoop(mask) && nextMapping.contains(mask))) nextMapping.map(mask, newMask); @@ -522,18 +519,19 @@ scf::ForOp LoopPipeliner::createNewForOp() { loadsBufferType[loadOp].getEncoding()); nextOp = builder.create( op->getLoc(), sliceType, insertAsyncOp, - SmallVector{extractSliceIndex, intAttr(0), intAttr(0)}, - SmallVector{intAttr(1), - intAttr(sliceType.getShape()[0]), - intAttr(sliceType.getShape()[1])}, - SmallVector{intAttr(1), intAttr(1), intAttr(1)}); + SmallVector{extractSliceIndex, int_attr(0), + int_attr(0)}, + SmallVector{int_attr(1), + int_attr(sliceType.getShape()[0]), + int_attr(sliceType.getShape()[1])}, + SmallVector{int_attr(1), int_attr(1), int_attr(1)}); extractSlices.push_back(nextOp->getResult(0)); } else nextOp = builder.clone(*op, nextMapping); - // update mapping of results + // Update mapping of results for (unsigned dstIdx : llvm::seq(unsigned(0), op->getNumResults())) { nextMapping.map(op->getResult(dstIdx), nextOp->getResult(dstIdx)); - // if this is a loop-carried value, update the mapping for yield + // If this is a loop-carried value, update the mapping for yield auto originYield = cast(forOp.getBody()->getTerminator()); for (OpOperand &operand : originYield->getOpOperands()) { if (operand.get() == op->getResult(dstIdx)) { @@ -583,7 +581,7 @@ scf::ForOp LoopPipeliner::createNewForOp() { it->getDefiningOp()->moveAfter(asyncWait); } - // bump iteration count + // Bump iteration count pipelineIterIdx = builder.create( nextIV.getLoc(), pipelineIterIdx, builder.create(nextIV.getLoc(), 1, 32)); diff --git a/python/src/triton.cc b/python/src/triton.cc index 141f16006..95fa120bc 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -341,7 +341,7 @@ void init_triton_ir(py::module &&m) { return funcs[0]; }); - m.def("make_attr", + m.def("make_attr", [](const std::vector &values, mlir::MLIRContext &context) { return mlir::DenseIntElementsAttr::get( mlir::RankedTensorType::get( @@ -1113,7 +1113,8 @@ void init_triton_ir(py::module &&m) { mlir::Value &val) -> mlir::Value { auto loc = self.getUnknownLoc(); mlir::Type dstType; - if (auto srcTensorType = ptr.getType().dyn_cast()) { + if (auto srcTensorType = + ptr.getType().dyn_cast()) { mlir::Type dstElemType = srcTensorType.getElementType() .cast() .getPointeeType(); diff --git a/python/tests/test_gemm.py b/python/tests/test_gemm.py index 0edd53b30..e37c8490c 100644 --- a/python/tests/test_gemm.py +++ b/python/tests/test_gemm.py @@ -172,8 +172,9 @@ def get_proper_err(a, b, golden): [128, 64, 128, 4, 128, 64, 128, False, False], [16, 16, 16, 16, 16, 16, 16, False, False], # wpt overflow issue # K-Forloop - [32, 32, 64, 4, 32, 32, 32, False, False], # Single shared encoding - [16, 16, 128, 4, 16, 16, 16, False, False], # Single shared encoding and small k + #[16, 16, 64, 4, 8, 8, 8, False, False], # Wrap threads + [32, 32, 64, 4, 32, 32, 32, False, False], # Single shared encoding + [16, 16, 128, 4, 16, 16, 16, False, False], # Single shared encoding and small k [64, 32, 128, 4, 64, 32, 64, False, False], [128, 16, 128, 4, 128, 16, 32, False, False], [32, 16, 128, 4, 32, 16, 32, False, False], diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index e7f9a4b90..8ec45a385 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -387,6 +387,45 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { // ----- +#block0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [4], warpsPerCTA = [4], order = [0]}> +#block1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [8], warpsPerCTA = [4], order = [0]}> +#block2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 1], warpsPerCTA = [4, 1], order = [1, 0]}> +#block3 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 8], warpsPerCTA = [1, 4], order = [1, 0]}> +#slice2d1 = #triton_gpu.slice<{dim = 1, parent=#block2}> +#slice3d0 = #triton_gpu.slice<{dim = 0, parent=#block3}> +#AL = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#A = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [1, 0]}> +module attributes {"triton_gpu.num-warps" = 4 : i32} { + // CHECK-LABEL: basic_insert_slice_async_fallback + func @basic_insert_slice_async_fallback(%arg0: !tt.ptr {tt.divisibility = 1 : i32}) { + %off0_ = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #slice2d1> + %off1_ = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<64xi32, #slice3d0> + %off0 = tt.expand_dims %off0_ {axis = 1 : i32} : (tensor<16xi32, #slice2d1>) -> tensor<16x1xi32, #block2> + %off1 = tt.expand_dims %off1_ {axis = 0 : i32} : (tensor<64xi32, #slice3d0>) -> tensor<1x64xi32, #block3> + %broadcast_off0_scalar = tt.broadcast %off0 : (tensor<16x1xi32, #block2>) -> tensor<16x64xi32, #block2> + %cst_scalar = arith.constant 64 : i32 + %cst = tt.splat %cst_scalar : (i32) -> tensor<16x64xi32, #block2> + %broadcast_off0_ = arith.muli %broadcast_off0_scalar, %cst : tensor<16x64xi32, #block2> + %broadcast_off1_ = tt.broadcast %off1 : (tensor<1x64xi32, #block3>) -> tensor<16x64xi32, #block3> + %broadcast_off0 = triton_gpu.convert_layout %broadcast_off0_ : (tensor<16x64xi32, #block2>) -> tensor<16x64xi32, #AL> + %broadcast_off1 = triton_gpu.convert_layout %broadcast_off1_ : (tensor<16x64xi32, #block3>) -> tensor<16x64xi32, #AL> + %off = arith.addi %broadcast_off0, %broadcast_off1 : tensor<16x64xi32, #AL> + %a_init = tt.splat %arg0 : (!tt.ptr) -> tensor<16x64x!tt.ptr, #AL> + %a_ptr = tt.addptr %a_init, %off : tensor<16x64x!tt.ptr, #AL> + %tensor = triton_gpu.alloc_tensor : tensor<2x16x64xf16, #A> + %index = arith.constant 1 : i32 + + // CHECK: llvm.load + // CHECK-SAME: !llvm.ptr, 3> + // CHECK: llvm.load + // CHECK-SAME: !llvm.ptr, 3> + %a = triton_gpu.insert_slice_async %a_ptr, %tensor, %index {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x64x!tt.ptr, #AL> -> tensor<2x16x64xf16, #A> + return + } +} + +// ----- + #block0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [4], warpsPerCTA = [4], order = [0]}> #block1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [8], warpsPerCTA = [4], order = [0]}> #block2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 1], warpsPerCTA = [4, 1], order = [1, 0]}>