From 6c5f646f4e5ee62ff34139df56d7dedbb65cdbb2 Mon Sep 17 00:00:00 2001 From: Keren Zhou Date: Sat, 19 Nov 2022 19:57:16 -0800 Subject: [PATCH] [WIP][Triton-MLIR] Prefetch pass fixup (#873) A (potential) problem by directly adopting `tensor.extract_slice`. Long story short, `tensor.extract_slice` is not aware of swizzling. Consider the following shared memory tensor and its first three slices, where each slice includes two tile (the loading unit of LDGSTS) of elements. Currently, the tiles haven't been swizzled yet, so slicing seems to work. image However, now consider the following figure, which is the layout after applying swizzling on the first figure. image Note that on phase 2, all tiles have been swizzled out of their originally slices. This implies that if we use the tile index after slicing, we can no longer locate the correct tiles. For example, T3 was in slice 1 but got swapped to slice 0 after swizzling. Here's a more detailed explanation. In the current `triton-mlir` branch, we only compute the relative offset of each tile. So T3's index in Slice 1 is *1*, and it will be swizzled using *1* and *phase id*. Whereas the correct index of T3 should be *3*, which is the relative offset to the beginning of the shared memory tensor being swizzled, and T3 should be swizzled using *3* and *phase id*. This PR proposes a hacky solution for this problem. We restore the "correct" offset of each tile by **assuming that slicing on a specific dim only happens at most once on the output of insert_slice_async**. I admit it's risky and fragile. The other possible solution is adopting cutlass' swizzling logic that limits the indices being swizzled in a "bounding box" that matches the mma instruction executes. For example, in the following tensor layout, each 4x4 submatrix is a minimum swizzling unit, and the entire tensor represents the tensor layout of operand A in `mma.16816`. image Co-authored-by: Phil Tillet --- .../TritonGPUToLLVM/TritonGPUToLLVM.cpp | 146 ++++++++++++------ lib/Dialect/Triton/IR/Ops.cpp | 6 +- lib/Dialect/TritonGPU/Transforms/Coalesce.cpp | 6 +- lib/Dialect/TritonGPU/Transforms/Combine.cpp | 22 ++- lib/Dialect/TritonGPU/Transforms/Prefetch.cpp | 21 +-- python/triton/compiler.py | 9 +- test/Conversion/tritongpu_to_llvm.mlir | 6 + 7 files changed, 146 insertions(+), 70 deletions(-) diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index b1d37a0a5..00129325d 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -99,6 +99,7 @@ void llPrintf(StringRef msg, ValueRange args, #define udiv(...) rewriter.create(loc, __VA_ARGS__) #define urem(...) rewriter.create(loc, __VA_ARGS__) #define add(...) rewriter.create(loc, __VA_ARGS__) +#define sub(...) rewriter.create(loc, __VA_ARGS__) #define fadd(...) rewriter.create(loc, __VA_ARGS__) #define mul(...) rewriter.create(loc, __VA_ARGS__) #define smax(...) rewriter.create(loc, __VA_ARGS__) @@ -441,25 +442,48 @@ struct SharedMemoryObject { // if we want to support more optimizations. SmallVector strides; // i32 int. The strides of the shared memory object. + SmallVector offsets; // i32 int. The offsets of the shared memory + // objects from the originally allocated object. - SharedMemoryObject(Value base, ArrayRef strides) - : base(base), strides(strides.begin(), strides.end()) {} + SharedMemoryObject(Value base, ArrayRef strides, + ArrayRef offsets) + : base(base), strides(strides.begin(), strides.end()), + offsets(offsets.begin(), offsets.end()) {} + SharedMemoryObject(Value base, ArrayRef shape, + ArrayRef order, Location loc, + ConversionPatternRewriter &rewriter) + : base(base) { + auto rank = shape.size(); + auto stride = 1; + strides.resize(rank); + for (auto idx : order) { + strides[idx] = i32_val(stride); + offsets.emplace_back(i32_val(0)); + stride *= shape[idx]; + } + } + + // XXX(Keren): a special allocator for 3d tensors. It's a workaround for + // now since we don't have a correct way to encoding 3d tensors in the + // pipeline pass. SharedMemoryObject(Value base, ArrayRef shape, Location loc, ConversionPatternRewriter &rewriter) : base(base) { auto stride = 1; for (auto dim : llvm::reverse(shape)) { - this->strides.emplace_back(i32_val(stride)); + strides.emplace_back(i32_val(stride)); + offsets.emplace_back(i32_val(0)); stride *= dim; } - this->strides = llvm::to_vector<4>(llvm::reverse(this->strides)); + strides = llvm::to_vector<4>(llvm::reverse(strides)); } SmallVector getElems() const { SmallVector elems; elems.push_back(base); elems.append(strides.begin(), strides.end()); + elems.append(offsets.begin(), offsets.end()); return elems; } @@ -467,8 +491,22 @@ struct SharedMemoryObject { SmallVector types; types.push_back(base.getType()); types.append(strides.size(), IntegerType::get(base.getContext(), 32)); + types.append(offsets.size(), IntegerType::get(base.getContext(), 32)); return types; } + + Value getCSwizzleOffset(int order) const { + assert(order >= 0 && order < strides.size()); + return offsets[order]; + } + + Value getBaseBeforeSwizzle(int order, Location loc, + ConversionPatternRewriter &rewriter) const { + Value cSwizzleOffset = getCSwizzleOffset(order); + Value offset = sub(i32_val(0), cSwizzleOffset); + Type type = base.getType(); + return gep(type, base, offset); + } }; struct ConvertTritonGPUOpToLLVMPatternBase { @@ -493,8 +531,11 @@ struct ConvertTritonGPUOpToLLVMPatternBase { getSharedMemoryObjectFromStruct(Location loc, Value llvmStruct, ConversionPatternRewriter &rewriter) { auto elems = getElementsFromStruct(loc, llvmStruct, rewriter); - return SharedMemoryObject(/*base=*/elems[0], - /*strides=*/{elems.begin() + 1, elems.end()}); + auto rank = (elems.size() - 1) / 2; + return SharedMemoryObject( + /*base=*/elems[0], + /*strides=*/{elems.begin() + 1, elems.begin() + 1 + rank}, + /*offsets=*/{elems.begin() + 1 + rank, elems.end()}); } static Value @@ -2238,31 +2279,34 @@ struct ExtractSliceOpConversion // Triton support either static and dynamic offsets auto smemObj = getSharedMemoryObjectFromStruct(loc, adaptor.source(), rewriter); + SmallVector opOffsetVals; SmallVector offsetVals; auto mixedOffsets = op.getMixedOffsets(); for (auto i = 0; i < mixedOffsets.size(); ++i) { if (op.isDynamicOffset(i)) - offsetVals.emplace_back(adaptor.offsets()[i]); + opOffsetVals.emplace_back(adaptor.offsets()[i]); else - offsetVals.emplace_back(i32_val(op.getStaticOffset(i))); + opOffsetVals.emplace_back(i32_val(op.getStaticOffset(i))); + offsetVals.emplace_back(add(smemObj.offsets[i], opOffsetVals[i])); } // Compute the offset based on the original strides of the shared memory // object - auto offset = dot(rewriter, loc, offsetVals, smemObj.strides); + auto offset = dot(rewriter, loc, opOffsetVals, smemObj.strides); // newShape = rank_reduce(shape) // Triton only supports static tensor sizes SmallVector strideVals; - auto staticSizes = op.static_sizes(); for (auto i = 0; i < op.static_sizes().size(); ++i) { - if (op.getStaticSize(i) != 1) { + if (op.getStaticSize(i) == 1) { + offsetVals.erase(offsetVals.begin() + i); + } else { strideVals.emplace_back(smemObj.strides[i]); } } auto llvmElemTy = getTypeConverter()->convertType(srcTy.getElementType()); auto elemPtrTy = ptr_ty(llvmElemTy, 3); auto resTy = op.getType().dyn_cast(); - smemObj = - SharedMemoryObject(gep(elemPtrTy, smemObj.base, offset), strideVals); + smemObj = SharedMemoryObject(gep(elemPtrTy, smemObj.base, offset), + strideVals, offsetVals); auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter); rewriter.replaceOp(op, retVal); return success(); @@ -3128,7 +3172,7 @@ LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared( Value smemBase = getSharedMemoryBase(loc, rewriter, dst); auto elemPtrTy = ptr_ty(getTypeConverter()->convertType(elemTy), 3); smemBase = bitcast(smemBase, elemPtrTy); - auto smemObj = SharedMemoryObject(smemBase, dstShape, loc, rewriter); + auto smemObj = SharedMemoryObject(smemBase, dstShape, outOrd, loc, rewriter); auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter); unsigned numWordsEachRep = product(wordsInEachRep); SmallVector wordVecs(numWordsEachRep); @@ -3228,17 +3272,16 @@ public: if (canUseLdmatrix) { // Each CTA, the warps is arranged as [1xwpt] if not transposed, // otherwise [wptx1], and each warp will perform a mma. - numPtr = + numPtrs = tileShape[order[0]] / (needTrans ? wpt : 1) / instrShape[order[0]]; } else { - numPtr = tileShape[order[0]] / wpt / matShape[order[0]]; + numPtrs = tileShape[order[0]] / wpt / matShape[order[0]]; } - - numPtr = std::max(numPtr, 2); + numPtrs = std::max(numPtrs, 2); // Special rule for i8/u8, 4 ptrs for each matrix if (!canUseLdmatrix && elemBytes == 1) - numPtr *= 4; + numPtrs *= 4; int loadStrideInMat[2]; loadStrideInMat[kOrder] = @@ -3257,24 +3300,26 @@ public: // lane = thread % 32 // warpOff = (thread/32) % wpt(0) - llvm::SmallVector computeOffsets(Value warpOff, Value lane) { + llvm::SmallVector computeOffsets(Value warpOff, Value lane, + Value cSwizzleOffset) { if (canUseLdmatrix) - return computeLdmatrixMatOffs(warpOff, lane); + return computeLdmatrixMatOffs(warpOff, lane, cSwizzleOffset); else if (elemBytes == 4 && needTrans) - return computeB32MatOffs(warpOff, lane); + return computeB32MatOffs(warpOff, lane, cSwizzleOffset); else if (elemBytes == 1 && needTrans) - return computeB8MatOffs(warpOff, lane); + return computeB8MatOffs(warpOff, lane, cSwizzleOffset); else llvm::report_fatal_error("Invalid smem load config"); return {}; } - int getNumPtr() const { return numPtr; } + int getNumPtrs() const { return numPtrs; } // Compute the offset to the matrix this thread(indexed by warpOff and lane) // mapped to. - SmallVector computeLdmatrixMatOffs(Value warpId, Value lane) { + SmallVector computeLdmatrixMatOffs(Value warpId, Value lane, + Value cSwizzleOffset) { // 4x4 matrices Value c = urem(lane, i32_val(8)); Value s = udiv(lane, i32_val(8)); // sub-warp-id @@ -3312,14 +3357,16 @@ public: // Physical offset (before swizzling) Value cMatOff = matOff[order[0]]; Value sMatOff = matOff[order[1]]; + Value cSwizzleMatOff = udiv(cSwizzleOffset, i32_val(cMatShape)); + cMatOff = add(cMatOff, cSwizzleMatOff); // row offset inside a matrix, each matrix has 8 rows. Value sOffInMat = c; - SmallVector offs(numPtr); + SmallVector offs(numPtrs); Value phase = urem(udiv(sOffInMat, i32_val(perPhase)), i32_val(maxPhase)); Value sOff = add(sOffInMat, mul(sMatOff, i32_val(sMatShape))); - for (int i = 0; i < numPtr; ++i) { + for (int i = 0; i < numPtrs; ++i) { Value cMatOffI = add(cMatOff, i32_val(i * pLoadStrideInMat)); cMatOffI = xor_(cMatOffI, phase); offs[i] = add(mul(cMatOffI, i32_val(cMatShape)), mul(sOff, sStride)); @@ -3329,14 +3376,15 @@ public: } // Compute 32-bit matrix offsets. - SmallVector computeB32MatOffs(Value warpOff, Value lane) { + SmallVector computeB32MatOffs(Value warpOff, Value lane, + Value cSwizzleOffset) { assert(needTrans && "Only used in transpose mode."); // Load tf32 matrices with lds32 Value cOffInMat = udiv(lane, i32_val(4)); Value sOffInMat = urem(lane, i32_val(4)); Value phase = urem(udiv(sOffInMat, i32_val(perPhase)), i32_val(maxPhase)); - SmallVector offs(numPtr); + SmallVector offs(numPtrs); for (int mat = 0; mat < 4; ++mat) { // Load 4 mats each time int kMatArrInt = kOrder == 1 ? mat / 2 : mat % 2; @@ -3348,10 +3396,13 @@ public: Value cMatOff = add(mul(warpOff, i32_val(warpOffStride)), mul(nkMatArr, i32_val(matArrStride))); + Value cSwizzleMatOff = udiv(cSwizzleOffset, i32_val(cMatShape)); + cMatOff = add(cMatOff, cSwizzleMatOff); + Value sMatOff = kMatArr; Value sOff = add(sOffInMat, mul(sMatOff, i32_val(sMatShape))); // FIXME: (kOrder == 1?) is really dirty hack - for (int i = 0; i < numPtr / 2; ++i) { + for (int i = 0; i < numPtrs / 2; ++i) { Value cMatOffI = add(cMatOff, i32_val(i * pLoadStrideInMat * (kOrder == 1 ? 1 : 2))); cMatOffI = xor_(cMatOffI, phase); @@ -3365,13 +3416,14 @@ public: } // compute 8-bit matrix offset. - SmallVector computeB8MatOffs(Value warpOff, Value lane) { + SmallVector computeB8MatOffs(Value warpOff, Value lane, + Value cSwizzleOffset) { assert(needTrans && "Only used in transpose mode."); Value cOffInMat = udiv(lane, i32_val(4)); Value sOffInMat = mul(urem(lane, i32_val(4)), i32_val(4)); // each thread load 4 cols - SmallVector offs(numPtr); + SmallVector offs(numPtrs); for (int mat = 0; mat < 4; ++mat) { int kMatArrInt = kOrder == 1 ? mat / 2 : mat % 2; int nkMatArrInt = kOrder == 1 ? mat % 2 : mat / 2; @@ -3384,7 +3436,7 @@ public: mul(nkMatArr, i32_val(matArrStride))); Value sMatOff = kMatArr; - for (int loadx4Off = 0; loadx4Off < numPtr / 8; ++loadx4Off) { + for (int loadx4Off = 0; loadx4Off < numPtrs / 8; ++loadx4Off) { for (int elemOff = 0; elemOff < 4; ++elemOff) { int ptrOff = loadx4Off * 8 + nkMatArrInt * 4 + elemOff; Value cMatOffI = add(cMatOff, i32_val(loadx4Off * pLoadStrideInMat * @@ -3587,7 +3639,7 @@ private: bool needTrans; bool canUseLdmatrix; - int numPtr; + int numPtrs; int pLoadStrideInMat; int sMatStride; @@ -4392,14 +4444,17 @@ private: wpt, sharedLayout.getOrder(), kOrder, smemObj.strides, tensorTy.getShape() /*tileShape*/, instrShape, matShape, perPhase, maxPhase, elemBytes, rewriter, typeConverter, loc); - SmallVector offs = loader.computeOffsets(warpId, lane); - const int numPtrs = loader.getNumPtr(); + Value cSwizzleOffset = smemObj.getCSwizzleOffset(order[0]); + SmallVector offs = + loader.computeOffsets(warpId, lane, cSwizzleOffset); + const int numPtrs = loader.getNumPtrs(); SmallVector ptrs(numPtrs); + Value smemBase = smemObj.getBaseBeforeSwizzle(order[0], loc, rewriter); Type smemPtrTy = helper.getShemPtrTy(); for (int i = 0; i < numPtrs; ++i) { - ptrs[i] = bitcast(gep(smemPtrTy, smemObj.base, ValueRange({offs[i]})), - smemPtrTy); + ptrs[i] = + bitcast(gep(smemPtrTy, smemBase, ValueRange({offs[i]})), smemPtrTy); } auto [ha0, ha1, ha2, ha3] = loader.loadX4( @@ -4432,7 +4487,7 @@ private: // ... // (2,0), (2,1), (3,0), (3,1), # i=1, j=0 // (2,2), (2,3), (3,2), (3,3), # i=1, j=1 - // (2,4), (2,5), (2,4), (2,5), # i=1, j=2 + // (2,4), (2,5), (3,4), (3,5), # i=1, j=2 // ... // ] // i \in [0, n0) and j \in [0, n1) @@ -4811,15 +4866,13 @@ DotOpConversion::convertMMA884(triton::DotOp op, DotOpAdaptor adaptor, Value DotOpMmaV1ConversionHelper::loadA( Value tensor, const SharedMemoryObject &smemObj, Value thread, Location loc, ConversionPatternRewriter &rewriter) const { - // smem - Value smem = smemObj.base; - auto strides = smemObj.strides; auto *ctx = rewriter.getContext(); auto tensorTy = tensor.getType().cast(); auto shape = tensorTy.getShape(); auto sharedLayout = tensorTy.getEncoding().cast(); auto order = sharedLayout.getOrder(); + Value cSwizzleOffset = smemObj.getCSwizzleOffset(order[0]); bool isARow = order[0] != 0; bool isAVec4 = !isARow && shape[order[0]] <= 16; // fp16*4 = 16bytes @@ -4834,6 +4887,7 @@ Value DotOpMmaV1ConversionHelper::loadA( int vecA = sharedLayout.getVec(); + auto strides = smemObj.strides; Value strideAM = isARow ? strides[0] : i32_val(1); Value strideAK = isARow ? i32_val(1) : strides[1]; Value strideA0 = isARow ? strideAK : strideAM; @@ -4856,8 +4910,8 @@ Value DotOpMmaV1ConversionHelper::loadA( Value offA0 = isARow ? offsetAK : offsetAM; Value offA1 = isARow ? offsetAM : offsetAK; Value phaseA = urem(udiv(offA1, i32_val(perPhaseA)), i32_val(maxPhaseA)); + offA0 = add(offA0, cSwizzleOffset); SmallVector offA(numPtrA); - for (int i = 0; i < numPtrA; i++) { Value offA0I = add(offA0, i32_val(i * (isARow ? 4 : strideRepM))); offA0I = udiv(offA0I, i32_val(vecA)); @@ -4875,6 +4929,7 @@ Value DotOpMmaV1ConversionHelper::loadA( SmallVector ptrA(numPtrA); std::map, std::pair> has; + auto smem = smemObj.getBaseBeforeSwizzle(order[0], loc, rewriter); for (int i = 0; i < numPtrA; i++) ptrA[i] = gep(ptr_ty(f16_ty), smem, offA[i]); @@ -4971,6 +5026,8 @@ Value DotOpMmaV1ConversionHelper::loadB( Value offB0 = isBRow ? offsetBN : offsetBK; Value offB1 = isBRow ? offsetBK : offsetBN; Value phaseB = urem(udiv(offB1, i32_val(perPhaseB)), i32_val(maxPhaseB)); + Value cSwizzleOffset = smemObj.getCSwizzleOffset(order[0]); + offB0 = add(offB0, cSwizzleOffset); SmallVector offB(numPtrB); for (int i = 0; i < numPtrB; ++i) { Value offB0I = add(offB0, i32_val(i * (isBRow ? strideRepN : 4))); @@ -5480,7 +5537,8 @@ public: types.push_back(ptrType); // shape dims auto rank = type.getRank(); - for (auto i = 0; i < rank; i++) { + // offsets + strides + for (auto i = 0; i < rank * 2; i++) { types.push_back(IntegerType::get(ctx, 32)); } return LLVM::LLVMStructType::getLiteral(ctx, types); diff --git a/lib/Dialect/Triton/IR/Ops.cpp b/lib/Dialect/Triton/IR/Ops.cpp index 5c85e6b86..503bc84f3 100644 --- a/lib/Dialect/Triton/IR/Ops.cpp +++ b/lib/Dialect/Triton/IR/Ops.cpp @@ -126,7 +126,7 @@ namespace triton { //-- FpToFpOp -- bool FpToFpOp::areCastCompatible(::mlir::TypeRange inputs, - ::mlir::TypeRange outputs) { + ::mlir::TypeRange outputs) { if (inputs.size() != 1 || outputs.size() != 1) return false; auto srcEltType = inputs.front(); @@ -143,8 +143,8 @@ bool FpToFpOp::areCastCompatible(::mlir::TypeRange inputs, std::swap(srcEltType, dstEltType); if (!srcEltType.dyn_cast()) return false; - return dstEltType.isF16() || dstEltType.isBF16() || - dstEltType.isF32() || dstEltType.isF64(); + return dstEltType.isF16() || dstEltType.isBF16() || dstEltType.isF32() || + dstEltType.isF64(); } //-- StoreOp -- diff --git a/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp b/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp index d1b62f75c..d2cb9c964 100644 --- a/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp @@ -33,9 +33,9 @@ struct CoalescePass : public TritonGPUCoalesceBase { SmallVector sizePerThread(rank, 1); PointerType ptrType = origType.getElementType().cast(); auto pointeeType = ptrType.getPointeeType(); - unsigned numBits = - pointeeType.isa() ? - 8 : pointeeType.getIntOrFloatBitWidth(); + unsigned numBits = pointeeType.isa() + ? 8 + : pointeeType.getIntOrFloatBitWidth(); unsigned maxMultiple = info.getDivisibility(order[0]); unsigned maxContig = info.getContiguity(order[0]); unsigned alignment = std::min(maxMultiple, maxContig); diff --git a/lib/Dialect/TritonGPU/Transforms/Combine.cpp b/lib/Dialect/TritonGPU/Transforms/Combine.cpp index 2a5628716..6133f9381 100644 --- a/lib/Dialect/TritonGPU/Transforms/Combine.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Combine.cpp @@ -78,8 +78,6 @@ public: if (!llvm::isa(op)) return mlir::failure(); auto convert = llvm::cast(op); - auto srcType = convert.getOperand().getType().cast(); - auto dstType = convert.getType().cast(); // we don't handle conversions to DotOperandEncodingAttr // this is a heuristics to accommodate fused attention // if (dstType.getEncoding().isa()) @@ -96,6 +94,9 @@ public: // cvt(alloc_tensor(x), type2) -> alloc_tensor(x, type2) auto alloc_tensor = dyn_cast(arg); if (alloc_tensor) { + if (!isSharedEncoding(op->getResult(0))) { + return mlir::failure(); + } rewriter.replaceOpWithNewOp( op, op->getResult(0).getType()); return mlir::success(); @@ -103,6 +104,9 @@ public: // cvt(insert_slice(x), type2) -> insert_slice(cvt(x, type2)) auto insert_slice = dyn_cast(arg); if (insert_slice) { + if (!isSharedEncoding(op->getResult(0))) { + return mlir::failure(); + } auto newType = op->getResult(0).getType().cast(); // Ensure that the new insert_slice op is placed in the same place as the // old insert_slice op. Otherwise, the new insert_slice op may be placed @@ -121,6 +125,9 @@ public: // cvt(extract_slice(x), type2) -> extract_slice(cvt(x, type2)) auto extract_slice = dyn_cast(arg); if (extract_slice) { + if (!isSharedEncoding(op->getResult(0))) { + return mlir::failure(); + } auto origType = extract_slice.source().getType().cast(); auto newType = RankedTensorType::get( origType.getShape(), origType.getElementType(), @@ -144,16 +151,15 @@ public: return mlir::success(); } - // cvt(type2, x) + // cvt(cvt(x, type1), type2) -> cvt(x, type2) if (llvm::isa(arg)) { - auto argType = arg->getOperand(0).getType().cast(); if (arg->getOperand(0).getDefiningOp() && - !argType.getEncoding().isa() && - srcType.getEncoding().isa() && - !dstType.getEncoding().isa()) { - + !isSharedEncoding(arg->getOperand(0)) && + isSharedEncoding(convert.getOperand()) && + !isSharedEncoding(convert.getResult())) { return mlir::failure(); } + auto srcType = convert.getOperand().getType().cast(); auto srcShared = srcType.getEncoding().dyn_cast(); if (srcShared && srcShared.getVec() > 1) diff --git a/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp b/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp index 92be1bf7e..c7b451822 100644 --- a/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp @@ -27,6 +27,7 @@ //===----------------------------------------------------------------------===// #include "mlir/IR/BlockAndValueMapping.h" +#include "triton/Analysis/Utility.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/Transforms/Passes.h" @@ -60,7 +61,7 @@ class Prefetcher { LogicalResult isForOpOperand(Value v); - Value generatePrefetch(Value v, unsigned opIdx, bool isPrefetch, + Value generatePrefetch(Value v, unsigned opIdx, bool isPrologue, Attribute dotEncoding, OpBuilder &builder, llvm::Optional offsetK = llvm::None, llvm::Optional shapeK = llvm::None); @@ -79,7 +80,7 @@ public: scf::ForOp createNewForOp(); }; -Value Prefetcher::generatePrefetch(Value v, unsigned opIdx, bool isPrefetch, +Value Prefetcher::generatePrefetch(Value v, unsigned opIdx, bool isPrologue, Attribute dotEncoding, OpBuilder &builder, llvm::Optional offsetK, llvm::Optional shapeK) { @@ -94,8 +95,8 @@ Value Prefetcher::generatePrefetch(Value v, unsigned opIdx, bool isPrefetch, // k => (prefetchWidth, k - prefetchWidth) int64_t kIdx = opIdx == 0 ? 1 : 0; - offset[kIdx] = isPrefetch ? 0 : prefetchWidth; - shape[kIdx] = isPrefetch ? prefetchWidth : (shape[kIdx] - prefetchWidth); + offset[kIdx] = isPrologue ? 0 : prefetchWidth; + shape[kIdx] = isPrologue ? prefetchWidth : (shape[kIdx] - prefetchWidth); if (shapeK) shape[kIdx] = *shapeK; @@ -132,9 +133,9 @@ LogicalResult Prefetcher::initialize() { // returns source of cvt auto getPrefetchSrc = [](Value v) -> Value { - // TODO: Check if the layout of src is SharedEncodingAttr if (auto cvt = v.getDefiningOp()) - return cvt.src(); + if (isSharedEncoding(cvt.getOperand())) + return cvt.src(); return Value(); }; @@ -152,6 +153,10 @@ LogicalResult Prefetcher::initialize() { }; for (triton::DotOp dot : dotsInFor) { + auto kSize = dot.a().getType().cast().getShape()[1]; + // Skip prefetching if kSize is less than prefetchWidth + if (kSize < prefetchWidth) + continue; Value aSmem = getPrefetchSrc(dot.a()); Value bSmem = getPrefetchSrc(dot.b()); if (aSmem && bSmem) { @@ -217,7 +222,7 @@ scf::ForOp Prefetcher::createNewForOp() { mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]); for (Operation &op : forOp.getBody()->without_terminator()) { - Operation *newOp = nullptr; + Operation *newOp = builder.clone(op, mapping); auto dot = dyn_cast(&op); if (dots.contains(dot)) { Attribute dotEncoding = @@ -252,8 +257,6 @@ scf::ForOp Prefetcher::createNewForOp() { kOff += kShape; kRem -= kShape; } - } else { - newOp = builder.clone(op, mapping); } // update mapping of results for (unsigned dstIdx : llvm::seq(unsigned(0), op.getNumResults())) diff --git a/python/triton/compiler.py b/python/triton/compiler.py index 8da30e4b9..b98675068 100644 --- a/python/triton/compiler.py +++ b/python/triton/compiler.py @@ -857,6 +857,7 @@ def build_triton_ir(fn, signature, specialization, constants): ret.context = context return ret, generator + def optimize_triton_ir(mod): pm = _triton.ir.pass_manager(mod.context) pm.enable_debug() @@ -868,10 +869,12 @@ def optimize_triton_ir(mod): pm.run(mod) return mod + def ast_to_ttir(fn, signature, specialization, constants): mod, _ = build_triton_ir(fn, signature, specialization, constants) return optimize_triton_ir(mod) + def ttir_to_ttgir(mod, num_warps, num_stages): pm = _triton.ir.pass_manager(mod.context) pm.add_convert_triton_to_tritongpu_pass(num_warps) @@ -880,6 +883,9 @@ def ttir_to_ttgir(mod, num_warps, num_stages): # can get shared memory swizzled correctly. pm.add_triton_gpu_combine_pass() pm.add_tritongpu_pipeline_pass(num_stages) + # Prefetch must be done after pipeline pass because pipeline pass + # extracts slices from the original tensor. + pm.add_tritongpu_prefetch_pass() pm.add_canonicalizer_pass() pm.add_cse_pass() pm.add_coalesce_pass() @@ -922,7 +928,6 @@ def llir_to_ptx(mod: Any, compute_capability: int = None, ptx_version: int = Non return _triton.translate_llvmir_to_ptx(mod, compute_capability, ptx_version) - def ptx_to_cubin(ptx: str, device: int): ''' Compile TritonGPU module to cubin. @@ -992,8 +997,6 @@ def path_to_ptxas(): instance_descriptor = namedtuple("instance_descriptor", ["divisible_by_16", "equal_to_1"], defaults=[set(), set()]) - - # ------------------------------------------------------------------------------ # compiler # ------------------------------------------------------------------------------ diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index 62286b21d..da5b908cc 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -351,8 +351,14 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { // CHECK-NEXT: llvm.extractvalue // CHECK-NEXT: llvm.extractvalue // CHECK-NEXT: llvm.extractvalue + // CHECK-NEXT: llvm.extractvalue + // CHECK-NEXT: llvm.extractvalue + // CHECK-NEXT: llvm.extractvalue + // CHECK-NEXT: llvm.add // CHECK-NEXT: llvm.mlir.constant(0 : i32) : i32 + // CHECK-NEXT: llvm.add // CHECK-NEXT: llvm.mlir.constant(0 : i32) : i32 + // CHECK-NEXT: llvm.add // CHECK-NEXT: llvm.mlir.constant(0 : i32) : i32 // CHECK-NEXT: llvm.mul // CHECK-NEXT: llvm.add