From 12d60cb4a306e8397ee00717486eb0f36c6eddcb Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Tue, 1 Nov 2022 13:19:58 -0700 Subject: [PATCH] [BACKEND] Added support for 1D conversion blocked -> slice (#831) --- lib/Analysis/Allocation.cpp | 24 ++--- .../TritonGPUToLLVM/TritonGPUToLLVM.cpp | 3 +- lib/Dialect/TritonGPU/IR/Dialect.cpp | 26 ++--- python/tests/test_core.py | 102 +++++++++--------- test/Conversion/tritongpu_to_llvm.mlir | 26 +++++ 5 files changed, 103 insertions(+), 78 deletions(-) diff --git a/lib/Analysis/Allocation.cpp b/lib/Analysis/Allocation.cpp index cafe3b777..c13310160 100644 --- a/lib/Analysis/Allocation.cpp +++ b/lib/Analysis/Allocation.cpp @@ -11,7 +11,9 @@ #include using ::mlir::triton::gpu::BlockedEncodingAttr; +using ::mlir::triton::gpu::getOrder; using ::mlir::triton::gpu::getShapePerCTA; +using ::mlir::triton::gpu::getSizePerThread; using ::mlir::triton::gpu::MmaEncodingAttr; using ::mlir::triton::gpu::SharedEncodingAttr; using ::mlir::triton::gpu::SliceEncodingAttr; @@ -34,28 +36,16 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec, "Unexpect layout in getScratchConfigForCvtLayout()"); unsigned rank = dstTy.getRank(); SmallVector paddedRepShape(rank); - if (auto srcSliceLayout = srcLayout.dyn_cast()) - srcLayout = srcSliceLayout.getParent(); - if (auto dstSliceLayout = dstLayout.dyn_cast()) - dstLayout = dstSliceLayout.getParent(); auto srcBlockedLayout = srcLayout.dyn_cast(); auto srcMmaLayout = srcLayout.dyn_cast(); auto dstBlockedLayout = dstLayout.dyn_cast(); auto dstMmaLayout = dstLayout.dyn_cast(); - assert((srcBlockedLayout || srcMmaLayout) && - "Unexpected srcLayout in getScratchConfigForCvtLayout"); - assert((dstBlockedLayout || dstMmaLayout) && - "Unexpected dstLayout in getScratchConfigForCvtLayout"); assert(!(srcMmaLayout && dstMmaLayout) && "Unexpected mma -> mma layout conversion"); - auto inOrd = - srcMmaLayout ? dstBlockedLayout.getOrder() : srcBlockedLayout.getOrder(); - auto outOrd = - dstMmaLayout ? srcBlockedLayout.getOrder() : dstBlockedLayout.getOrder(); - unsigned srcContigPerThread = - srcBlockedLayout ? srcBlockedLayout.getSizePerThread()[inOrd[0]] : 2; - unsigned dstContigPerThread = - dstBlockedLayout ? dstBlockedLayout.getSizePerThread()[outOrd[0]] : 2; + auto inOrd = srcMmaLayout ? getOrder(dstLayout) : getOrder(srcLayout); + auto outOrd = dstMmaLayout ? getOrder(srcLayout) : getOrder(dstLayout); + unsigned srcContigPerThread = getSizePerThread(srcLayout)[inOrd[0]]; + unsigned dstContigPerThread = getSizePerThread(dstLayout)[outOrd[0]]; // TODO: Fix the legacy issue that ourOrd[0] == 0 always means // that we cannot do vectorization. inVec = outOrd[0] == 0 ? 1 : inOrd[0] == 0 ? 1 : srcContigPerThread; @@ -70,6 +60,8 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec, std::max(std::min(srcTy.getShape()[d], srcShapePerCTA[d]), std::min(dstTy.getShape()[d], dstShapePerCTA[d])); } + if (rank == 1) + return paddedRepShape; unsigned paddedDim = 1; if (auto dstBlockedLayout = dstLayout.dyn_cast()) { paddedDim = dstBlockedLayout.getOrder()[0]; diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index 85f20a5eb..a8da162a3 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -1197,7 +1197,7 @@ struct BroadcastOpConversion broadcastDims.push_back(d); srcLogicalShape[d] = 1; srcLogicalShape[d + rank] = - std::max(unsigned(1), srcLayout.getSizePerThread()[d]); + std::max(1, srcLayout.getSizePerThread()[d]); } else { srcLogicalShape[d] = numCtas; srcLogicalShape[d + rank] = resultLayout.getSizePerThread()[d]; @@ -2231,6 +2231,7 @@ LogicalResult ConvertLayoutOpConversion::lowerDistributedToDistributed( } barrier(); if (dstLayout.isa() || + dstLayout.isa() || dstLayout.isa()) { processReplica(loc, rewriter, /*stNotRd*/ false, dstTy, outNumCTAsEachRep, multiDimRepId, outVec, paddedRepShape, outOrd, outVals, diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index c44cbf5fc..96b0925bc 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -68,18 +68,7 @@ SmallVector getSizePerThread(Attribute layout) { return SmallVector(blockedLayout.getSizePerThread().begin(), blockedLayout.getSizePerThread().end()); } else if (auto sliceLayout = layout.dyn_cast()) { - unsigned dim = sliceLayout.getDim(); - auto parent = sliceLayout.getParent(); - if (auto blockedParent = parent.dyn_cast()) { - SmallVector sizePerThread( - blockedParent.getSizePerThread().begin(), - blockedParent.getSizePerThread().end()); - sizePerThread.erase(sizePerThread.begin() + dim); - return sizePerThread; - } else { - assert(0 && "SliceEncodingAttr with parent other than " - "BlockedEncodingAttr not implemented"); - } + return getSizePerThread(sliceLayout.getParent()); } else if (auto mmaLayout = layout.dyn_cast()) { assert(mmaLayout.getVersion() == 2 && "mmaLayout version = 1 is not implemented yet"); @@ -144,6 +133,19 @@ SmallVector getOrder(const Attribute &layout) { blockedLayout.getOrder().end()); } else if (auto mmaLayout = layout.dyn_cast()) { return SmallVector{1, 0}; + } else if (auto sliceLayout = layout.dyn_cast()) { + SmallVector parentOrder = getOrder(sliceLayout.getParent()); + unsigned dim = sliceLayout.getDim(); + SmallVector order; + for (unsigned d : parentOrder) { + if (d == dim) + continue; + else if (d > dim) + order.push_back(d - 1); + else + order.push_back(d); + } + return order; } else if (auto sharedLayout = layout.dyn_cast()) { return SmallVector(sharedLayout.getOrder().begin(), sharedLayout.getOrder().end()); diff --git a/python/tests/test_core.py b/python/tests/test_core.py index afcc67d9d..5fd25b118 100644 --- a/python/tests/test_core.py +++ b/python/tests/test_core.py @@ -493,61 +493,65 @@ def make_ptr_str(name, shape): # TODO: handle `%4 = triton_gpu.convert_layout %3 : (tensor<32xi32, #blocked0>) -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>`` -# @pytest.mark.parametrize("expr, dtype_str", [ -# (f'x[{s}]', d) -# for s in ['None, :', ':, None', 'None, :, :', ':, :, None'] -# for d in ['int32', 'uint32', 'uint16'] -# ]) -# def test_index1d(expr, dtype_str, device='cuda'): -# rank_x = expr.count(':') -# rank_y = expr.count(',') + 1 -# shape_x = [32 for _ in range(rank_x)] -# shape_z = [32 for _ in range(rank_y)] -# shape_z_rank_mismatch = [32 for _ in range(rank_y + 1)] -# shape_z_dim_mismatch = [64 for _ in range(rank_y)] +@pytest.mark.parametrize("expr, dtype_str", [ + (f'x[{s}]', d) + for s in ['None, :', ':, None', + # TODO: 3D + # 'None, :, :', + # ':, :, None' + ] + for d in ['int32', 'uint32', 'uint16'] +]) +def test_index1d(expr, dtype_str, device='cuda'): + rank_x = expr.count(':') + rank_y = expr.count(',') + 1 + shape_x = [32 for _ in range(rank_x)] + shape_z = [32 for _ in range(rank_y)] + shape_z_rank_mismatch = [32 for _ in range(rank_y + 1)] + shape_z_dim_mismatch = [64 for _ in range(rank_y)] -# # Triton kernel -# @triton.jit -# def kernel(Z, X, SIZE: tl.constexpr): -# m = tl.arange(0, SIZE) -# n = tl.arange(0, SIZE) -# x = tl.load(X_PTR_EXPR) -# z = GENERATE_TEST_HERE -# tl.store(Z_PTR_EXPR, z) + # Triton kernel + @triton.jit + def kernel(Z, X, SIZE: tl.constexpr): + m = tl.arange(0, SIZE) + n = tl.arange(0, SIZE) + x = tl.load(X_PTR_EXPR) + z = GENERATE_TEST_HERE + tl.store(Z_PTR_EXPR, z) -# def generate_kernel(shape_x, shape_z): -# to_replace = { -# 'X_PTR_EXPR': make_ptr_str('X', shape_x), -# 'Z_PTR_EXPR': make_ptr_str('Z', shape_z), -# 'GENERATE_TEST_HERE': expr, -# } -# return patch_kernel(kernel, to_replace) + def generate_kernel(shape_x, shape_z): + to_replace = { + 'X_PTR_EXPR': make_ptr_str('X', shape_x), + 'Z_PTR_EXPR': make_ptr_str('Z', shape_z), + 'GENERATE_TEST_HERE': expr, + } + return patch_kernel(kernel, to_replace) -# kernel_match = generate_kernel(shape_x, shape_z) -# kernel_dim_mismatch = generate_kernel(shape_x, shape_z_dim_mismatch) -# kernel_rank_mismatch = generate_kernel(shape_x, shape_z_rank_mismatch) + kernel_match = generate_kernel(shape_x, shape_z) + kernel_dim_mismatch = generate_kernel(shape_x, shape_z_dim_mismatch) + kernel_rank_mismatch = generate_kernel(shape_x, shape_z_rank_mismatch) -# # torch result -# x = numpy_random(shape_x, dtype_str=dtype_str) -# y = np.zeros(shape_z, dtype=getattr(np, dtype_str)) -# z_ref = eval(expr) + y -# # triton result -# z_tri = to_triton(np.empty_like(z_ref), device=device) -# x_tri = to_triton(x) -# kernel_match[(1, )](z_tri, x_tri, num_warps=1, SIZE=shape_x[0]) -# # compare -# assert (z_ref == to_numpy(z_tri)).all() + # torch result + x = numpy_random(shape_x, dtype_str=dtype_str) + y = np.zeros(shape_z, dtype=getattr(np, dtype_str)) + z_ref = eval(expr) + y + # triton result + z_tri = to_triton(np.empty_like(z_ref), device=device) + x_tri = to_triton(x) + kernel_match[(1, )](z_tri, x_tri, num_warps=1, SIZE=shape_x[0]) + # compare + assert (z_ref == to_numpy(z_tri)).all() -# def catch_compilation_error(kernel): -# try: -# kernel[(1, )](z_tri, x_tri, num_warps=1, SIZE=shape_x[0]) -# except triton.CompilationError as e: -# np.testing.assert_(True) -# except BaseException: -# np.testing.assert_(False) + def catch_compilation_error(kernel): + try: + kernel[(1, )](z_tri, x_tri, num_warps=1, SIZE=shape_x[0]) + except triton.CompilationError as e: + np.testing.assert_(True) + except BaseException: + np.testing.assert_(False) -# catch_compilation_error(kernel_dim_mismatch) -# catch_compilation_error(kernel_rank_mismatch) + catch_compilation_error(kernel_dim_mismatch) + catch_compilation_error(kernel_rank_mismatch) # # --------------- diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index 3bfed0569..807bd1396 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -715,6 +715,32 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} { } } +// ----- +#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [1, 1], order = [1, 0]}> +module attributes {"triton_gpu.num-warps" = 1 : i32} { + // CHECK-LABEL: convert_blocked1d_to_slice0 + func @convert_blocked1d_to_slice0(%src:tensor<32xi32, #blocked0>) { + // CHECK-COUNT-4: llvm.load {{.*}} : !llvm.ptr, 3> + %cvt = triton_gpu.convert_layout %src : (tensor<32xi32, #blocked0>) -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + return + } +} + + +// ----- +#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [1, 1], order = [1, 0]}> +module attributes {"triton_gpu.num-warps" = 1 : i32} { + // CHECK-LABEL: convert_blocked1d_to_slice1 + func @convert_blocked1d_to_slice1(%src:tensor<32xi32, #blocked0>) { + // CHECK-COUNT-32: llvm.load {{.*}} : !llvm.ptr, 3> + %cvt = triton_gpu.convert_layout %src : (tensor<32xi32, #blocked0>) -> tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + return + } +} + + // ----- #blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}> #shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>