From 661be523c06d3ea587923361ba976a7fdf1517e3 Mon Sep 17 00:00:00 2001 From: Qingyi Liu Date: Tue, 29 Nov 2022 11:50:31 +0800 Subject: [PATCH] [Triton-MLIR][BACKEND] Minor fixes of shared memory in ReduceOpConversion (#924) --- lib/Analysis/Utility.cpp | 5 -- .../TritonGPUToLLVM/TritonGPUToLLVM.cpp | 76 ++++--------------- lib/Conversion/TritonGPUToLLVM/Utility.h | 42 ++++++++++ 3 files changed, 56 insertions(+), 67 deletions(-) diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index 8458b5ee5..f23b111ec 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -58,11 +58,6 @@ SmallVector> ReduceOpHelper::getScratchConfigsFast() { unsigned numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod); smemShapes[1].push_back(numWarps * 32); - /// FIXME(Qingyi): This requirement is actually not necessary, because it is - /// always smaller than smemShapes[0] shared memory block2 - smemShapes[2] = convertType(getSrcShape()); - smemShapes[2].erase(smemShapes[2].begin() + axis); - return smemShapes; } diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index 921895dc9..cdc230d49 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -43,6 +43,8 @@ using ::mlir::LLVM::getStridesFromShapeAndOrder; using ::mlir::LLVM::getStructFromElements; using ::mlir::LLVM::MMA16816ConversionHelper; using ::mlir::LLVM::SharedMemoryObject; +using ::mlir::LLVM::shflSync; +using ::mlir::LLVM::storeShared; using ::mlir::triton::gpu::BlockedEncodingAttr; using ::mlir::triton::gpu::DotOperandEncodingAttr; using ::mlir::triton::gpu::getElemsPerThread; @@ -307,20 +309,6 @@ T getLinearIndex(ArrayRef multiDimIndex, ArrayRef shape, reorder(shape, order)); } -Value storeShared(ConversionPatternRewriter &rewriter, Location loc, Value ptr, - Value val, Value pred) { - MLIRContext *ctx = rewriter.getContext(); - unsigned bits = val.getType().getIntOrFloatBitWidth(); - const char *c = bits == 64 ? "l" : (bits == 16 ? "h" : "r"); - - PTXBuilder builder; - auto *ptrOpr = builder.newAddrOperand(ptr, "r"); - auto *valOpr = builder.newOperand(val, c); - auto &st = builder.create<>("st")->shared().b(bits); - st(ptrOpr, valOpr).predicate(pred, "b"); - return builder.launch(rewriter, loc, void_ty(ctx)); -} - struct ConvertTritonGPUOpToLLVMPatternBase { static Value getStructFromSharedMemoryObject(Location loc, @@ -1342,9 +1330,6 @@ private: RedOp redOp, Value &acc, Value &accIndex, Value cur, Value curIndex, bool isFirst) const; - Value shflSync(ConversionPatternRewriter &rewriter, Location loc, Value val, - int i) const; - // Use shared memory for reduction within warps and across warps LogicalResult matchAndRewriteBasic(triton::ReduceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const; @@ -1472,34 +1457,6 @@ void ReduceOpConversion::accumulateWithIndex( } } -Value ReduceOpConversion::shflSync(ConversionPatternRewriter &rewriter, - Location loc, Value val, int i) const { - unsigned bits = val.getType().getIntOrFloatBitWidth(); - - if (bits == 64) { - Type vecTy = vec_ty(f32_ty, 2); - Value vec = bitcast(val, vecTy); - Value val0 = extract_element(f32_ty, vec, i32_val(0)); - Value val1 = extract_element(f32_ty, vec, i32_val(1)); - val0 = shflSync(rewriter, loc, val0, i); - val1 = shflSync(rewriter, loc, val1, i); - vec = undef(vecTy); - vec = insert_element(vecTy, vec, val0, i32_val(0)); - vec = insert_element(vecTy, vec, val1, i32_val(1)); - return bitcast(vec, val.getType()); - } - - PTXBuilder builder; - auto &shfl = builder.create("shfl.sync")->o("bfly").o("b32"); - auto *dOpr = builder.newOperand("=r"); - auto *aOpr = builder.newOperand(val, "r"); - auto *bOpr = builder.newConstantOperand(i); - auto *cOpr = builder.newConstantOperand("0x1f"); - auto *maskOpr = builder.newConstantOperand("0xffffffff"); - shfl(dOpr, aOpr, bOpr, cOpr, maskOpr); - return builder.launch(rewriter, loc, val.getType(), false); -} - LogicalResult ReduceOpConversion::matchAndRewriteBasic( triton::ReduceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { @@ -1665,7 +1622,6 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast( auto smemShapes = helper.getScratchConfigsFast(); unsigned elems = product(smemShapes[0]); unsigned maxElems = std::max(elems, product(smemShapes[1])); - maxElems = std::max(maxElems, product(smemShapes[2])); Value indexSmemBase = gep(elemPtrTy, smemBase, i32_val(maxElems)); indexSmemBase = bitcast(indexSmemBase, indexPtrTy); @@ -1725,11 +1681,11 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast( // reduce within warps for (unsigned N = sizeIntraWarps / 2; N > 0; N >>= 1) { - Value shfl = shflSync(rewriter, loc, acc, N); + Value shfl = shflSync(loc, rewriter, acc, N); if (!withIndex) { accumulate(rewriter, loc, op.redOp(), acc, shfl, false); } else { - Value shflIndex = shflSync(rewriter, loc, accIndex, N); + Value shflIndex = shflSync(loc, rewriter, accIndex, N); accumulateWithIndex(rewriter, loc, op.redOp(), acc, accIndex, shfl, shflIndex, false); } @@ -1750,8 +1706,7 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast( barrier(); // the second round of shuffle reduction - // now the problem size: sizeInterWarps, s1, s2, .. , sn => - // 1, s1, s2, .. , sn + // now the problem size: sizeInterWarps, s1, s2, .. , sn // where sizeInterWarps is 2^m // // each thread needs to process: @@ -1762,6 +1717,7 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast( Value readOffset = threadId; for (unsigned round = 0; round < elemsPerThread; ++round) { Value readPtr = gep(elemPtrTy, smemBase, readOffset); + // FIXME(Qingyi): need predicate icmp_slt(threadId, i32_val(sizeInerWarps)) Value acc = load(readPtr); Value accIndex; if (withIndex) { @@ -1770,17 +1726,18 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast( } for (unsigned N = sizeInterWarps / 2; N > 0; N >>= 1) { - Value shfl = shflSync(rewriter, loc, acc, N); + Value shfl = shflSync(loc, rewriter, acc, N); if (!withIndex) { accumulate(rewriter, loc, op.redOp(), acc, shfl, false); } else { - Value shflIndex = shflSync(rewriter, loc, accIndex, N); + Value shflIndex = shflSync(loc, rewriter, accIndex, N); accumulateWithIndex(rewriter, loc, op.redOp(), acc, accIndex, shfl, shflIndex, false); } } - Value writeOffset = udiv(readOffset, i32_val(sizeInterWarps)); + // only the first thread in each sizeInterWarps is writing + Value writeOffset = readOffset; Value writePtr = gep(elemPtrTy, smemBase, writeOffset); Value threadIsNeeded = icmp_slt(threadId, i32_val(elems)); Value laneIdModSizeInterWarps = urem(laneId, i32_val(sizeInterWarps)); @@ -1807,22 +1764,17 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast( if (auto resultTy = op.getType().dyn_cast()) { // nd-tensor where n >= 1 auto resultLayout = resultTy.getEncoding().cast(); - SmallVector resultOrd; - for (auto ord : order) { - if (ord != 0) - resultOrd.push_back(ord - 1); - } - + auto resultShape = resultTy.getShape(); unsigned resultElems = getElemsPerThread(resultTy); - auto resultIndices = - emitIndices(loc, rewriter, resultLayout, resultTy.getShape()); + auto resultIndices = emitIndices(loc, rewriter, resultLayout, resultShape); assert(resultIndices.size() == resultElems); SmallVector resultVals(resultElems); for (size_t i = 0; i < resultElems; ++i) { SmallVector readIdx = resultIndices[i]; + readIdx.insert(readIdx.begin() + axis, i32_val(0)); Value readOffset = - linearize(rewriter, loc, readIdx, smemShapes[2], resultOrd); + linearize(rewriter, loc, readIdx, smemShapes[0], order); Value readPtr = gep(elemPtrTy, smemBase, readOffset); Value indexReadPtr = gep(indexPtrTy, indexSmemBase, readOffset); resultVals[i] = withIndex ? load(indexReadPtr) : load(readPtr); diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.h b/lib/Conversion/TritonGPUToLLVM/Utility.h index 888a572dc..323be4827 100644 --- a/lib/Conversion/TritonGPUToLLVM/Utility.h +++ b/lib/Conversion/TritonGPUToLLVM/Utility.h @@ -269,6 +269,48 @@ getSharedMemoryObjectFromStruct(Location loc, Value llvmStruct, /*offsets=*/{elems.begin() + 1 + rank, elems.end()}}; } +Value storeShared(ConversionPatternRewriter &rewriter, Location loc, Value ptr, + Value val, Value pred) { + MLIRContext *ctx = rewriter.getContext(); + unsigned bits = val.getType().getIntOrFloatBitWidth(); + const char *c = bits == 64 ? "l" : (bits == 16 ? "h" : "r"); + + PTXBuilder builder; + auto *ptrOpr = builder.newAddrOperand(ptr, "r"); + auto *valOpr = builder.newOperand(val, c); + auto &st = builder.create<>("st")->shared().b(bits); + st(ptrOpr, valOpr).predicate(pred, "b"); + return builder.launch(rewriter, loc, void_ty(ctx)); +} + +Value shflSync(Location loc, ConversionPatternRewriter &rewriter, Value val, + int i) { + unsigned bits = val.getType().getIntOrFloatBitWidth(); + + if (bits == 64) { + Type vecTy = vec_ty(f32_ty, 2); + Value vec = bitcast(val, vecTy); + Value val0 = extract_element(f32_ty, vec, i32_val(0)); + Value val1 = extract_element(f32_ty, vec, i32_val(1)); + val0 = shflSync(loc, rewriter, val0, i); + val1 = shflSync(loc, rewriter, val1, i); + vec = undef(vecTy); + vec = insert_element(vecTy, vec, val0, i32_val(0)); + vec = insert_element(vecTy, vec, val1, i32_val(1)); + return bitcast(vec, val.getType()); + } + + PTXBuilder builder; + auto &shfl = builder.create("shfl.sync")->o("bfly").o("b32"); + auto *dOpr = builder.newOperand("=r"); + auto *aOpr = builder.newOperand(val, "r"); + auto *bOpr = builder.newConstantOperand(i); + auto *cOpr = builder.newConstantOperand("0x1f"); + auto *maskOpr = builder.newConstantOperand("0xffffffff"); + shfl(dOpr, aOpr, bOpr, cOpr, maskOpr); + return builder.launch(rewriter, loc, val.getType(), false); +} + } // namespace LLVM } // namespace mlir