From 3236642e8fa456bfbace89e666921f586be76561 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Sun, 31 Jul 2022 20:59:31 -0700 Subject: [PATCH] [OPTIMIZER] Added memory coalescing pass (#31) --- .../Dialect/TritonGPU/Transforms/Passes.h | 2 + .../Dialect/TritonGPU/Transforms/Passes.td | 18 +++ lib/Analysis/AxisInfo.cpp | 3 +- .../TritonGPU/Transforms/CMakeLists.txt | 1 + lib/Dialect/TritonGPU/Transforms/Coalesce.cpp | 108 ++++++++++++++++++ python/examples/copy_strided.py | 6 +- test/TritonGPU/coalesce.mlir | 46 ++++++++ 7 files changed, 180 insertions(+), 4 deletions(-) create mode 100644 lib/Dialect/TritonGPU/Transforms/Coalesce.cpp create mode 100644 test/TritonGPU/coalesce.mlir diff --git a/include/triton/Dialect/TritonGPU/Transforms/Passes.h b/include/triton/Dialect/TritonGPU/Transforms/Passes.h index 3c79ab320..9ed457431 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/Passes.h +++ b/include/triton/Dialect/TritonGPU/Transforms/Passes.h @@ -6,6 +6,8 @@ namespace mlir { std::unique_ptr createTritonGPUPipelinePass(int numStages = 2); +std::unique_ptr createTritonGPUCoalescePass(); + std::unique_ptr createTritonGPUCombineOpsPass(); std::unique_ptr createTritonGPUVerifier(); diff --git a/include/triton/Dialect/TritonGPU/Transforms/Passes.td b/include/triton/Dialect/TritonGPU/Transforms/Passes.td index 2ad90bc48..d28c32a98 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/Passes.td +++ b/include/triton/Dialect/TritonGPU/Transforms/Passes.td @@ -23,6 +23,24 @@ def TritonGPUPipeline : Pass<"tritongpu-pipeline", "mlir::ModuleOp"> { ]; } +def TritonGPUCoalesce: Pass<"tritongpu-coalesce", "mlir::ModuleOp"> { + let summary = "coalesce"; + + let description = [{ + TODO + }]; + + let constructor = "mlir::createTritonGPUCoalescePass()"; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect"]; + + let options = [ + Option<"numWarps", "num-warps", + "int32_t", /*default*/"4", + "number of warps"> + ]; +} + def TritonGPUCombineOps : Pass<"tritongpu-combine", "mlir::ModuleOp"> { let summary = "combine triton gpu ops"; diff --git a/lib/Analysis/AxisInfo.cpp b/lib/Analysis/AxisInfo.cpp index a11f1a970..5fa10769c 100644 --- a/lib/Analysis/AxisInfo.cpp +++ b/lib/Analysis/AxisInfo.cpp @@ -94,7 +94,8 @@ ChangeResult AxisInfoAnalysis::visitOperation( AxisInfo curr; // This preserves the input axes (e.g., cast): if (llvm::isa(op)) + triton::PtrToIntOp, triton::IntToPtrOp, + triton::gpu::ConvertLayoutOp>(op)) curr = operands[0]->getValue(); // Constant ranges if (triton::MakeRangeOp make_range = diff --git a/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt b/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt index d110608be..9fa32b806 100644 --- a/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt +++ b/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt @@ -3,6 +3,7 @@ mlir_tablegen(TritonGPUCombine.inc -gen-rewriters) add_public_tablegen_target(TritonGPUCombineIncGen) add_mlir_dialect_library(TritonGPUTransforms + Coalesce.cpp Combine.cpp Pipeline.cpp Verifier.cpp diff --git a/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp b/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp new file mode 100644 index 000000000..f24f6b2b3 --- /dev/null +++ b/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp @@ -0,0 +1,108 @@ +#include "triton/Analysis/AxisInfo.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include + +using namespace mlir; +using namespace mlir::triton; + +#define GEN_PASS_CLASSES +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +struct CoalescePass : public TritonGPUCoalesceBase { + + Attribute getCoalescedEncoding(AxisInfoAnalysis &axisInfo, Value ptr) { + auto origType = ptr.getType().cast(); + // Get the shape of the tensor. + size_t rank = origType.getRank(); + AxisInfo info = axisInfo.lookupLatticeElement(ptr)->getValue(); + // Layout order in decreasing order of contiguity + SmallVector order(rank); + std::iota(order.begin(), order.end(), 0); + auto contiguity = info.getContiguity(); + std::sort(order.begin(), order.end(), [&](unsigned x, unsigned y) { + return contiguity[x] > contiguity[y]; + }); + // Thread tile size depends on memory alignment + SmallVector sizePerThread(rank, 1); + PointerType ptrType = origType.getElementType().cast(); + unsigned numBits = ptrType.getPointeeType().getIntOrFloatBitWidth(); + unsigned maxMultiple = info.getDivisibility(order[0]); + unsigned maxContig = info.getContiguity(order[0]); + unsigned alignment = std::min(maxMultiple, maxContig); + unsigned perThread = std::min(alignment, 128 / numBits); + sizePerThread[order[0]] = perThread; + // create encoding + Attribute encoding = triton::gpu::TritonGPUBlockedEncodingAttr::get( + &getContext(), origType.getShape(), sizePerThread, order, + this->numWarps); + return encoding; + } + + std::function getTypeConverter(AxisInfoAnalysis &axisInfo, + Value ptr) { + Attribute encoding = getCoalescedEncoding(axisInfo, ptr); + return [encoding](Type _type) { + RankedTensorType type = _type.cast(); + return RankedTensorType::get(type.getShape(), type.getElementType(), + encoding); + }; + } + + template + void coalesceOp(AxisInfoAnalysis &axisInfo, Operation *op, Value ptr, + OpBuilder builder) { + RankedTensorType ty = ptr.getType().template dyn_cast(); + if (!ty) + return; + AxisInfo info = axisInfo.lookupLatticeElement(ptr)->getValue(); + auto convertType = getTypeConverter(axisInfo, ptr); + // convert operands + SmallVector newArgs; + for (auto v : op->getOperands()) + newArgs.push_back(builder.create( + op->getLoc(), convertType(v.getType()), v)); + // convert output types + SmallVector newTypes; + for (auto t : op->getResultTypes()) + newTypes.push_back(convertType(t)); + // construct new op with the new encoding + Operation *newOp = + builder.create(op->getLoc(), newTypes, newArgs, op->getAttrs()); + // cast the results back to the original layout + for (size_t i = 0; i < op->getNumResults(); i++) { + auto newResult = builder.create( + op->getLoc(), op->getResult(i).getType(), newOp->getResult(i)); + op->getResult(i).replaceAllUsesWith(newResult); + } + op->erase(); + } + + void runOnOperation() override { + Operation *op = getOperation(); + // Run axis info analysis + AxisInfoAnalysis axisInfo(&getContext()); + axisInfo.run(op); + OpBuilder builder(op); + + // For each memory op that has a layout L1: + // 1. Create a coalesced memory layout L2 of the pointer operands + // 2. Convert all operands from layout L1 to layout L2 + // 3. Create a new memory op that consumes these operands and + // produces a tensor with layout L2 + // 4. Convert the output of this new memory op back to L1 + // 5. Replace all the uses of the original memory op by the new one + op->walk([&](Operation *curr) { + OpBuilder::InsertionGuard g(builder); + builder.setInsertionPoint(curr); + if (auto load = dyn_cast(curr)) + coalesceOp(axisInfo, curr, load.ptr(), builder); + if (auto store = dyn_cast(curr)) + coalesceOp(axisInfo, curr, store.ptr(), builder); + }); + } +}; + +std::unique_ptr mlir::createTritonGPUCoalescePass() { + return std::make_unique(); +} diff --git a/python/examples/copy_strided.py b/python/examples/copy_strided.py index 7f95e8f24..922c5ba5c 100644 --- a/python/examples/copy_strided.py +++ b/python/examples/copy_strided.py @@ -5,8 +5,8 @@ import triton.language as tl # triton kernel @triton.jit -def kernel(X, stride_xm, stride_xn, - Z, stride_zm, stride_zn, +def kernel(X, stride_xm, + Z, stride_zn, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): off_m = tl.arange(0, BLOCK_M) off_n = tl.arange(0, BLOCK_N) @@ -15,5 +15,5 @@ def kernel(X, stride_xm, stride_xn, tl.store(Zs, tl.load(Xs)) -ret = triton.compile(kernel, "*fp32,i32,i32,*fp32,i32,i32", constants={"BLOCK_M": 64, "BLOCK_N": 64}, output="ttgir") +ret = triton.compile(kernel, "*fp32,i32,*fp32,i32", constants={"BLOCK_M": 64, "BLOCK_N": 64}, output="ttgir") print(ret) diff --git a/test/TritonGPU/coalesce.mlir b/test/TritonGPU/coalesce.mlir new file mode 100644 index 000000000..1bd8e73d8 --- /dev/null +++ b/test/TritonGPU/coalesce.mlir @@ -0,0 +1,46 @@ +// RUN: triton-opt %s -split-input-file -tritongpu-coalesce -canonicalize -tritongpu-verifier | FileCheck %s + +#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> +#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1]}> + + +// CHECK: [[row_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}> +// CHECK: [[col_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [0, 1]}> +// CHECK: [[load_ptr:%.*]] = triton_gpu.convert_layout {{.*}} -> tensor<64x64x!tt.ptr, [[row_layout]]> +// CHECK: [[load_mask:%.*]] = triton_gpu.convert_layout {{.*}} -> tensor<64x64xi1, [[row_layout]]> +// CHECK: [[load_other:%.*]] = triton_gpu.convert_layout {{.*}} -> tensor<64x64xf32, [[row_layout]]> +// CHECK: [[load_val:%.*]] = tt.load [[load_ptr]], [[load_mask]], [[load_other]] {{.*}} : tensor<64x64xf32, [[row_layout]]> +// CHECK: [[store_ptr:%.*]] = triton_gpu.convert_layout {{.*}} -> tensor<64x64x!tt.ptr, [[col_layout]]> +// CHECK: [[store_val:%.*]] = triton_gpu.convert_layout {{.*}} -> tensor<64x64xf32, [[col_layout]]> +// CHECK: [[store_mask:%.*]] = triton_gpu.convert_layout {{.*}} -> tensor<64x64xi1, [[col_layout]]> +// CHECK: tt.store [[store_ptr]], [[store_val]], [[store_mask]] +func @transpose(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, + %arg1: i32 {tt.divisibility = 16 : i32}, + %arg2: !tt.ptr {tt.divisibility = 16 : i32}, + %arg3: i32 {tt.divisibility = 16 : i32}) { + %cst = arith.constant dense : tensor<64x64xi1, #blocked1> + %cst_0 = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #blocked1> + %0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #blocked0> + %1 = tt.view %0 : (tensor<64xi32, #blocked0>) -> tensor<64x1xi32, #blocked1> + %2 = tt.splat %arg1 : (i32) -> tensor<64x1xi32, #blocked1> + %3 = arith.muli %1, %2 : tensor<64x1xi32, #blocked1> + %4 = tt.splat %arg0 : (!tt.ptr) -> tensor<64x1x!tt.ptr, #blocked1> + %5 = tt.getelementptr %4, %3 : tensor<64x1x!tt.ptr, #blocked1> + %6 = tt.view %0 : (tensor<64xi32, #blocked0>) -> tensor<1x64xi32, #blocked2> + %7 = tt.broadcast %5 : (tensor<64x1x!tt.ptr, #blocked1>) -> tensor<64x64x!tt.ptr, #blocked1> + %8 = tt.broadcast %6 : (tensor<1x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked2> + %9 = triton_gpu.convert_layout %8 : (tensor<64x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked1> + %10 = tt.getelementptr %7, %9 : tensor<64x64x!tt.ptr, #blocked1> + %11 = tt.splat %arg2 : (!tt.ptr) -> tensor<64x1x!tt.ptr, #blocked1> + %12 = tt.getelementptr %11, %1 : tensor<64x1x!tt.ptr, #blocked1> + %13 = tt.splat %arg3 : (i32) -> tensor<1x64xi32, #blocked2> + %14 = arith.muli %6, %13 : tensor<1x64xi32, #blocked2> + %15 = tt.broadcast %12 : (tensor<64x1x!tt.ptr, #blocked1>) -> tensor<64x64x!tt.ptr, #blocked1> + %16 = tt.broadcast %14 : (tensor<1x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked2> + %17 = triton_gpu.convert_layout %16 : (tensor<64x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked1> + %18 = tt.getelementptr %15, %17 : tensor<64x64x!tt.ptr, #blocked1> + %19 = tt.load %10, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64xf32, #blocked1> + tt.store %18, %19, %cst, : tensor<64x64xf32, #blocked1> + return +} \ No newline at end of file