From 35c9ec1103610688d437ff721b5bef79d8d10986 Mon Sep 17 00:00:00 2001 From: Keren Zhou Date: Sat, 26 Nov 2022 12:30:38 -0800 Subject: [PATCH] [Triton-MLIR][Backend] Fix number of warps and threads per warp when matrices are small (#917) --- include/triton/Analysis/Utility.h | 23 +++++++++++++ .../Dialect/TritonGPU/IR/TritonGPUAttrDefs.td | 9 +++++- lib/Analysis/Allocation.cpp | 22 +++++-------- lib/Analysis/Utility.cpp | 32 +++++++++++++++++++ .../TritonGPUToLLVM/TritonGPUToLLVM.cpp | 11 +++---- python/tests/test_core.py | 20 ++++++++---- test/Conversion/triton_to_tritongpu.mlir | 28 +++++++++++++++- 7 files changed, 116 insertions(+), 29 deletions(-) diff --git a/include/triton/Analysis/Utility.h b/include/triton/Analysis/Utility.h index 22f6692da..77ebb0eaf 100644 --- a/include/triton/Analysis/Utility.h +++ b/include/triton/Analysis/Utility.h @@ -8,6 +8,29 @@ namespace mlir { +class ReduceOpHelper { +public: + explicit ReduceOpHelper(triton::ReduceOp op) : op(op) { + srcTy = op.operand().getType().cast(); + } + + ArrayRef getSrcShape() { return srcTy.getShape(); } + + Attribute getSrcLayout() { return srcTy.getEncoding(); } + + bool isFastReduction(); + + unsigned getInterWarpSize(); + + unsigned getIntraWarpSize(); + + unsigned getThreadsReductionAxis(); + +private: + triton::ReduceOp op; + RankedTensorType srcTy{}; +}; + bool isSharedEncoding(Value value); bool maybeSharedAllocationOp(Operation *op); diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td index d6b7e16b6..d52c8985c 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td @@ -228,9 +228,11 @@ for unsigned remainingLanes = 32; unsigned remainingThreads = numWarps*32; unsigned remainingWarps = numWarps; + unsigned prevLanes = 1; + unsigned prevWarps = 1; SmallVector threadsPerWarp(rank); SmallVector warpsPerCTA(rank); - for (int _dim = 0; _dim < rank; ++_dim) { + for (int _dim = 0; _dim < rank - 1; ++_dim) { int i = order[_dim]; unsigned threadsPerCTA = std::clamp(remainingThreads, 1, shape[i] / sizePerThread[i]); threadsPerWarp[i] = std::clamp(threadsPerCTA, 1, remainingLanes); @@ -238,7 +240,12 @@ for remainingWarps /= warpsPerCTA[i]; remainingLanes /= threadsPerWarp[i]; remainingThreads /= threadsPerCTA; + prevLanes *= threadsPerWarp[i]; + prevWarps *= warpsPerCTA[i]; } + // Expand the last dimension to fill the remaining lanes and warps + threadsPerWarp[order[rank-1]] = 32 / prevLanes; + warpsPerCTA[order[rank-1]] = numWarps / prevWarps; return $_get(context, sizePerThread, threadsPerWarp, warpsPerCTA, order); diff --git a/lib/Analysis/Allocation.cpp b/lib/Analysis/Allocation.cpp index de0a6b52c..a8576d060 100644 --- a/lib/Analysis/Allocation.cpp +++ b/lib/Analysis/Allocation.cpp @@ -89,24 +89,19 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec, } SmallVector getScratchConfigForReduce(triton::ReduceOp op) { - auto srcTy = op.operand().getType().cast(); - auto srcLayout = srcTy.getEncoding(); - auto srcShape = srcTy.getShape(); - auto axis = op.axis(); - - bool fastReduce = axis == getOrder(srcLayout)[0]; + ReduceOpHelper helper(op); SmallVector smemShape; + auto srcShape = helper.getSrcShape(); for (auto d : srcShape) smemShape.push_back(d); - if (fastReduce) { - unsigned sizeInterWarps = gpu::getWarpsPerCTA(srcLayout)[axis]; - smemShape[axis] = sizeInterWarps; + auto axis = op.axis(); + if (helper.isFastReduction()) { + smemShape[axis] = helper.getInterWarpSize(); } else { - unsigned threadsPerCTAAxis = gpu::getThreadsPerWarp(srcLayout)[axis] * - gpu::getWarpsPerCTA(srcLayout)[axis]; - smemShape[axis] = threadsPerCTAAxis; + smemShape[axis] = + std::min(smemShape[axis], helper.getThreadsReductionAxis()); } return smemShape; @@ -181,8 +176,7 @@ private: // TODO(Keren): Reduce with index is not supported yet. auto value = op->getOperand(0); if (auto tensorType = value.getType().dyn_cast()) { - auto srcLayout = tensorType.getEncoding(); - bool fastReduce = reduceOp.axis() == getOrder(srcLayout)[0]; + bool fastReduce = ReduceOpHelper(reduceOp).isFastReduction(); auto smemShape = getScratchConfigForReduce(reduceOp); unsigned elems = std::accumulate(smemShape.begin(), smemShape.end(), 1, std::multiplies{}); diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index 780b9bf9a..ab25a41bd 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -5,6 +5,38 @@ namespace mlir { +bool ReduceOpHelper::isFastReduction() { + auto srcLayout = srcTy.getEncoding(); + auto axis = op.axis(); + return axis == triton::gpu::getOrder(srcLayout)[0]; +} + +unsigned ReduceOpHelper::getInterWarpSize() { + auto srcLayout = srcTy.getEncoding(); + auto srcShape = srcTy.getShape(); + auto axis = op.axis(); + auto srcReduceDimSize = static_cast(srcShape[axis]); + unsigned sizeIntraWarps = getIntraWarpSize(); + return std::min(srcReduceDimSize / sizeIntraWarps, + triton::gpu::getWarpsPerCTA(srcLayout)[axis]); +} + +unsigned ReduceOpHelper::getIntraWarpSize() { + auto srcLayout = srcTy.getEncoding(); + auto srcShape = srcTy.getShape(); + auto axis = op.axis(); + auto srcReduceDimSize = static_cast(srcShape[axis]); + return std::min(srcReduceDimSize, + triton::gpu::getThreadsPerWarp(srcLayout)[axis]); +} + +unsigned ReduceOpHelper::getThreadsReductionAxis() { + auto srcLayout = srcTy.getEncoding(); + auto axis = op.axis(); + return triton::gpu::getThreadsPerWarp(srcLayout)[axis] * + triton::gpu::getWarpsPerCTA(srcLayout)[axis]; +} + bool isSharedEncoding(Value value) { auto type = value.getType(); if (auto tensorType = type.dyn_cast()) { diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index dfd31240b..908764579 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -1563,9 +1563,7 @@ private: LogicalResult ReduceOpConversion::matchAndRewrite(triton::ReduceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - auto srcTy = op.operand().getType().cast(); - auto srcLayout = srcTy.getEncoding().cast(); - if (op.axis() == srcLayout.getOrder()[0]) + if (ReduceOpHelper(op).isFastReduction()) return matchAndRewriteFast(op, adaptor, rewriter); return matchAndRewriteBasic(op, adaptor, rewriter); } @@ -1763,10 +1761,11 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast( Value smemBase = getSharedMemoryBase(loc, rewriter, op.getOperation()); smemBase = bitcast(smemBase, elemPtrTy); - auto order = getOrder(srcLayout); - unsigned sizeIntraWarps = threadsPerWarp[axis]; - unsigned sizeInterWarps = warpsPerCTA[axis]; + ReduceOpHelper helper(op); + unsigned sizeIntraWarps = helper.getIntraWarpSize(); + unsigned sizeInterWarps = helper.getInterWarpSize(); + auto order = getOrder(srcLayout); unsigned srcElems = getElemsPerThread(srcTy); auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcShape); auto srcValues = getElementsFromStruct(loc, adaptor.operand(), rewriter); diff --git a/python/tests/test_core.py b/python/tests/test_core.py index cbd6e046b..a32130b40 100644 --- a/python/tests/test_core.py +++ b/python/tests/test_core.py @@ -648,10 +648,11 @@ def test_atomic_rmw(op, dtype_x_str, mode, device='cuda'): else: np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01) -#TODO[dongdongl]:add more cases with size of tensor less than warp size -@pytest.mark.parametrize("axis", [0, 1]) -def test_tensor_atomic_rmw(axis, device="cuda"): - shape0, shape1 = 8, 8 + +@pytest.mark.parametrize("shape, axis", + [(shape, axis) for shape in [(2, 2), (2, 8), (8, 2), (8, 8), (32, 32)] for axis in [0, 1]]) +def test_tensor_atomic_rmw(shape, axis, device="cuda"): + shape0, shape1 = shape # triton kernel @triton.jit @@ -660,14 +661,19 @@ def test_tensor_atomic_rmw(axis, device="cuda"): off1 = tl.arange(0, SHAPE1) x = tl.load(X + off0[:, None] * SHAPE1 + off1[None, :]) z = tl.sum(x, axis=AXIS) - tl.atomic_add(Z + off0, z) + if AXIS == 1: + tl.atomic_add(Z + off0, z) + else: + tl.atomic_add(Z + off1, z) rs = RandomState(17) x = numpy_random((shape0, shape1), dtype_str="float32", rs=rs) + print(x) # reference result - z_ref = np.sum(x, axis=axis) + z_ref = np.sum(x, axis=axis, keepdims=False) # triton result x_tri = to_triton(x, device=device) - z_tri = to_triton(np.zeros((shape0,), dtype="float32"), device=device) + z_shape = (shape0, ) if axis == 1 else (shape1, ) + z_tri = to_triton(np.zeros(z_shape, dtype="float32"), device=device) kernel[(1,)](z_tri, x_tri, axis, shape0, shape1) np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=1e-4) diff --git a/test/Conversion/triton_to_tritongpu.mlir b/test/Conversion/triton_to_tritongpu.mlir index df5b85050..a160bc881 100644 --- a/test/Conversion/triton_to_tritongpu.mlir +++ b/test/Conversion/triton_to_tritongpu.mlir @@ -1,4 +1,4 @@ -// RUN: triton-opt %s -convert-triton-to-tritongpu=num-warps=2 | FileCheck %s +// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu=num-warps=2 | FileCheck %s func @ops() { // CHECK: module attributes {"triton_gpu.num-warps" = 2 : i32} {{.*}} @@ -9,6 +9,8 @@ func @ops() { return } +// ----- + func @load_ops(%ptr: !tt.ptr {tt.divisibility = 16 : i32}) { // Test if LoadOp is lowered properly (see #771) %ptrs = tt.splat %ptr : (!tt.ptr) -> tensor<128x!tt.ptr> @@ -25,3 +27,27 @@ func @load_ops(%ptr: !tt.ptr {tt.divisibility = 16 : i32}) { tt.store %ptrs, %c : tensor<128xf32> return } + +// ----- + +func @reduce_ops(%ptr: !tt.ptr {tt.divisibility = 16 : i32}) { + // Test if the total number of threadsPerWarp is 32 + // Test if the total number of warps is 2 + // CHECK: #blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 8], warpsPerCTA = [1, 2], order = [0, 1]}> + // CHECK: #blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 2], order = [0, 1]}> + // CHECK: #blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [1, 2], order = [0, 1]}> + // CHECK: module attributes {"triton_gpu.num-warps" = 2 : i32} {{.*}} + %c0 = arith.constant dense<1.00e+00> : tensor<4x4xf32> + %c1 = arith.constant dense<2.00e+00> : tensor<8x2xf32> + %c2 = arith.constant dense<3.00e+00> : tensor<16x16xf32> + // CHECK: tensor<4x4xf32, #blocked0> -> tensor<4xf32, #triton_gpu.slice<{dim = 0, parent = #blocked0}>> + %c0_ = tt.reduce %c0 {redOp = 1 : i32, axis = 0 : i32} : tensor<4x4xf32> -> tensor<4xf32> + // CHECK: tensor<8x2xf32, #blocked1> -> tensor<2xf32, #triton_gpu.slice<{dim = 0, parent = #blocked1}> + %c1_ = tt.reduce %c1 {redOp = 1 : i32, axis = 0 : i32} : tensor<8x2xf32> -> tensor<2xf32> + // CHECK: tensor<8x2xf32, #blocked1> -> tensor<8xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %c2_ = tt.reduce %c1 {redOp = 1 : i32, axis = 1 : i32} : tensor<8x2xf32> -> tensor<8xf32> + // CHECK: tensor<16x16xf32, #blocked2> -> tensor<16xf32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %c3_ = tt.reduce %c2 {redOp = 1 : i32, axis = 0 : i32} : tensor<16x16xf32> -> tensor<16xf32> + + return +}