From c280ebda1b75a3ae579ec05a753e85a6724dda49 Mon Sep 17 00:00:00 2001 From: Keren Zhou Date: Thu, 1 Dec 2022 11:54:18 -0800 Subject: [PATCH] [Triton-MLIR][BACKEND] Fix the membar pass to add missing barriers caused by scf.for (#933) 1. Add missing barriers and revert the previous temporary solution 2. Extract the `run` method from membar analysis because the membar analysis should have two phases, including construction, which doesn't modify any IR, and modification, which adds barrier IRs. Hope this could make the use of membar clear. --- include/triton/Analysis/Membar.h | 10 +-- lib/Analysis/Membar.cpp | 49 ++++++++--- .../TritonGPUToLLVM/TritonGPUToLLVM.cpp | 58 ++++++------- lib/Dialect/TritonGPU/Transforms/Combine.cpp | 86 ++++++++++--------- lib/Dialect/TritonGPU/Transforms/Prefetch.cpp | 4 +- python/src/triton.cc | 13 ++- test/Analysis/test-membar.mlir | 48 ++++++++++- test/lib/Analysis/TestMembar.cpp | 4 +- 8 files changed, 170 insertions(+), 102 deletions(-) diff --git a/include/triton/Analysis/Membar.h b/include/triton/Analysis/Membar.h index 242b54ecc..ceb192753 100644 --- a/include/triton/Analysis/Membar.h +++ b/include/triton/Analysis/Membar.h @@ -29,7 +29,11 @@ public: /// The following circumstances are not considered yet: /// - Double buffers /// - N buffers - MembarAnalysis(Allocation *allocation) : allocation(allocation) { run(); } + MembarAnalysis(Allocation *allocation) : allocation(allocation) {} + + /// Runs the membar analysis to the given operation, inserts a barrier if + /// necessary. + void run(); private: struct RegionInfo { @@ -82,10 +86,6 @@ private: } }; - /// Runs the membar analysis to the given operation, inserts a barrier if - /// necessary. - void run(); - /// Applies the barrier analysis based on the SCF dialect, in which each /// region has a single basic block only. /// Example: diff --git a/lib/Analysis/Membar.cpp b/lib/Analysis/Membar.cpp index 715265c0a..68aebdbd1 100644 --- a/lib/Analysis/Membar.cpp +++ b/lib/Analysis/Membar.cpp @@ -24,21 +24,43 @@ void MembarAnalysis::dfsOperation(Operation *operation, // scf.if only: two regions // scf.for: one region RegionInfo curRegionInfo; - for (auto ®ion : operation->getRegions()) { - // Copy the parent info as the current info. - RegionInfo regionInfo = *parentRegionInfo; - for (auto &block : region.getBlocks()) { - assert(region.getBlocks().size() == 1 && - "Multiple blocks in a region is not supported"); - for (auto &op : block.getOperations()) { - // Traverse the nested operation. - dfsOperation(&op, ®ionInfo, builder); + auto traverseRegions = [&]() -> auto{ + for (auto ®ion : operation->getRegions()) { + // Copy the parent info as the current info. + RegionInfo regionInfo = *parentRegionInfo; + for (auto &block : region.getBlocks()) { + assert(region.getBlocks().size() == 1 && + "Multiple blocks in a region is not supported"); + for (auto &op : block.getOperations()) { + // Traverse the nested operation. + dfsOperation(&op, ®ionInfo, builder); + } } + curRegionInfo.join(regionInfo); } - curRegionInfo.join(regionInfo); + // Set the parent region info as the union of the nested region info. + *parentRegionInfo = curRegionInfo; + }; + + traverseRegions(); + if (isa(operation)) { + // scf.for can have two possible inputs: the init value and the + // previous iteration's result. Although we've applied alias analysis, + // there could be unsynced memory accesses on reused memories. + // For example, consider the following code: + // %1 = convert_layout %0: blocked -> shared + // ... + // gpu.barrier + // ... + // %5 = convert_layout %4 : shared -> dot + // %6 = tt.dot %2, %5 + // scf.yield + // + // Though %5 could be released before scf.yield, it may shared the same + // memory with %1. So we actually have to insert a barrier before %1 to + // make sure the memory is synced. + traverseRegions(); } - // Set the parent region info as the union of the nested region info. - *parentRegionInfo = curRegionInfo; } } @@ -49,8 +71,7 @@ void MembarAnalysis::transfer(Operation *op, RegionInfo *regionInfo, // Do not insert barriers before control flow operations and // alloc/extract/insert // alloc is an allocation op without memory write. - // In contrast, arith.constant is an allocation op with memory write. - // FIXME(Keren): extract is always alias for now + // FIXME(Keren): extract_slice is always alias for now return; } diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index 349ccef01..68fc4bac3 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -644,7 +644,6 @@ public: return multiDimIdx; } - struct SmallVectorKeyInfo { static unsigned getHashValue(const SmallVector &key) { return llvm::hash_combine_range(key.begin(), key.end()); @@ -672,7 +671,7 @@ public: emitIndices(loc, rewriter, parent, sliceLayout.paddedShape(shape)); unsigned numIndices = parentIndices.size(); SmallVector> resultIndices; - for (unsigned i = 0; i < numIndices; ++i){ + for (unsigned i = 0; i < numIndices; ++i) { SmallVector indices = parentIndices[i]; indices.erase(indices.begin() + dim); resultIndices.push_back(indices); @@ -1203,14 +1202,14 @@ struct BroadcastOpConversion auto resultOffsets = emitOffsetForLayout(resultLayout, resultShape); SmallVector srcVals = getElementsFromStruct(loc, src, rewriter); DenseMap, Value, SmallVectorKeyInfo> srcValues; - for(size_t i = 0; i < srcOffsets.size(); i++){ + for (size_t i = 0; i < srcOffsets.size(); i++) { srcValues[srcOffsets[i]] = srcVals[i]; } SmallVector resultVals; - for(size_t i = 0; i < resultOffsets.size(); i++) { + for (size_t i = 0; i < resultOffsets.size(); i++) { auto offset = resultOffsets[i]; - for(size_t j = 0; j < srcShape.size(); j++) - if(srcShape[j]==1) + for (size_t j = 0; j < srcShape.size(); j++) + if (srcShape[j] == 1) offset[j] = 0; resultVals.push_back(srcValues.lookup(offset)); } @@ -1940,8 +1939,8 @@ struct MakeRangeOpConversion unsigned elems = idxs.size(); SmallVector retVals(elems); // TODO: slice layout has more elements than expected. - // Unexpected behavior for make range, but genereally ok when followed by expand dims + broadcast. - // very weird behavior otherwise potentially. + // Unexpected behavior for make range, but genereally ok when followed by + // expand dims + broadcast. very weird behavior otherwise potentially. for (const auto multiDim : llvm::enumerate(idxs)) { assert(multiDim.value().size() == 1); retVals[multiDim.index()] = add(multiDim.value()[0], start); @@ -2647,13 +2646,13 @@ public: } // dot_op = #mma // when #mma = MmaEncoding - if(srcLayout.isa() && + if (srcLayout.isa() && dstLayout.isa()) { auto srcMmaLayout = srcLayout.cast(); auto dstDotLayout = dstLayout.cast(); - if(srcMmaLayout.getWarpsPerCTA()[1] == 1 && - dstDotLayout.getOpIdx() == 0 && - dstDotLayout.getParent() == srcMmaLayout) { + if (srcMmaLayout.getWarpsPerCTA()[1] == 1 && + dstDotLayout.getOpIdx() == 0 && + dstDotLayout.getParent() == srcMmaLayout) { // get source values Location loc = op->getLoc(); auto vals = getElementsFromStruct(loc, adaptor.src(), rewriter); @@ -2662,35 +2661,37 @@ public: this->getTypeConverter()->convertType(srcTy.getElementType()); // for the destination type, we need to pack values together // so they can be consumed by tensor core operations - unsigned vecSize = std::max(32 / elemTy.getIntOrFloatBitWidth(), 1); + unsigned vecSize = + std::max(32 / elemTy.getIntOrFloatBitWidth(), 1); Type vecTy = vec_ty(elemTy, vecSize); - SmallVector types(elems/vecSize, vecTy); + SmallVector types(elems / vecSize, vecTy); SmallVector vecVals; - for(unsigned i = 0; i < elems; i += vecSize) { + for (unsigned i = 0; i < elems; i += vecSize) { Value packed = rewriter.create(loc, vecTy); - for(unsigned j = 0; j < vecSize; j++) - packed = insert_element(vecTy, packed, vals[i+j], i32_val(j)); + for (unsigned j = 0; j < vecSize; j++) + packed = insert_element(vecTy, packed, vals[i + j], i32_val(j)); vecVals.push_back(packed); } - + // This needs to be ordered the same way that // ldmatrix.x4 would order it // TODO: this needs to be refactor so we don't // implicitly depends on how emitOffsetsForMMAV2 // is implemented SmallVector reorderedVals; - for(unsigned i = 0; i < vecVals.size(); i += 4) { + for (unsigned i = 0; i < vecVals.size(); i += 4) { reorderedVals.push_back(vecVals[i]); - reorderedVals.push_back(vecVals[i+2]); - reorderedVals.push_back(vecVals[i+1]); - reorderedVals.push_back(vecVals[i+3]); + reorderedVals.push_back(vecVals[i + 2]); + reorderedVals.push_back(vecVals[i + 1]); + reorderedVals.push_back(vecVals[i + 3]); } // return composeValuesToDotOperandLayoutStruct(ha, numRepM, numRepK); - - Type structTy = LLVM::LLVMStructType::getLiteral(this->getContext(), types); - Value view = getStructFromElements(loc, reorderedVals, rewriter, structTy); + Type structTy = + LLVM::LLVMStructType::getLiteral(this->getContext(), types); + Value view = + getStructFromElements(loc, reorderedVals, rewriter, structTy); rewriter.replaceOp(op, view); return success(); } @@ -3138,10 +3139,6 @@ LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared( ConversionPatternRewriter &rewriter) const { auto loc = op.getLoc(); - // TODO[Keren]: A temporary workaround for an issue from membar pass. - // https://triton-lang.slack.com/archives/C042VBSQWNS/p1669796615860699?thread_ts=1669779203.526739&cid=C042VBSQWNS - barrier(); - Value src = op.src(); Value dst = op.result(); auto srcTy = src.getType().cast(); @@ -4698,7 +4695,8 @@ public: decomposeInsertSliceAsyncOp(mod); Allocation allocation(mod); - MembarAnalysis membar(&allocation); + MembarAnalysis membarPass(&allocation); + membarPass.run(); RewritePatternSet scf_patterns(context); mlir::populateLoopToStdConversionPatterns(scf_patterns); diff --git a/lib/Dialect/TritonGPU/Transforms/Combine.cpp b/lib/Dialect/TritonGPU/Transforms/Combine.cpp index 884dd68f0..6ec2ee127 100644 --- a/lib/Dialect/TritonGPU/Transforms/Combine.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Combine.cpp @@ -50,22 +50,25 @@ public: auto dstType = convert.getType().cast(); if (srcType.getEncoding().isa() && dstType.getEncoding().isa()) { - auto dstDotOperand = dstType.getEncoding().cast(); + auto dstDotOperand = + dstType.getEncoding().cast(); auto dstParent = dstDotOperand.getParent(); - if(dstDotOperand.getOpIdx()==1 || - !dstParent.isa()) + if (dstDotOperand.getOpIdx() == 1 || + !dstParent.isa()) return mlir::failure(); auto dstParentMma = dstParent.cast(); - if(dstParentMma.getVersion() == 1 || - dstParentMma.getWarpsPerCTA()[1] > 1) + if (dstParentMma.getVersion() == 1 || + dstParentMma.getWarpsPerCTA()[1] > 1) return mlir::failure(); - SetVector bwdSlices; + SetVector bwdSlices; mlir::getBackwardSlice(convert.getResult(), &bwdSlices); - if(llvm::find_if(bwdSlices, [](Operation *op) { return isa(op); }) == bwdSlices.end()) + if (llvm::find_if(bwdSlices, [](Operation *op) { + return isa(op); + }) == bwdSlices.end()) return mlir::failure(); - - auto tmpType = - RankedTensorType::get(dstType.getShape(), dstType.getElementType(), dstParentMma); + + auto tmpType = RankedTensorType::get( + dstType.getShape(), dstType.getElementType(), dstParentMma); auto tmp = rewriter.create( convert.getLoc(), tmpType, convert.getOperand()); auto newConvert = rewriter.create( @@ -601,10 +604,9 @@ mmaVersionToShapePerWarp(int version, const ArrayRef &shape, } } - SmallVector warpsPerTileV1(triton::DotOp dotOp, - const ArrayRef shape, - int numWarps) { + const ArrayRef shape, + int numWarps) { SmallVector ret = {1, 1}; SmallVector shapePerWarp = mmaVersionToShapePerWarp(1, shape, numWarps); @@ -624,35 +626,37 @@ SmallVector warpsPerTileV1(triton::DotOp dotOp, } SmallVector warpsPerTileV2(triton::DotOp dotOp, - const ArrayRef shape, - int numWarps) { - SetVector slices; - mlir::getForwardSlice(dotOp.getResult(), &slices); - if(llvm::find_if(slices, [](Operation *op) { return isa(op); }) != slices.end()) - return {(unsigned)numWarps, 1}; - - SmallVector ret = {1, 1}; - SmallVector shapePerWarp = {16, 8}; - bool changed = false; - // TODO (@daadaada): double-check. - // original logic in - // https://github.com/openai/triton/blob/master/lib/codegen/analysis/layout.cc#L252 - // seems buggy for shape = [32, 16] ? - do { - changed = false; - if (ret[0] * ret[1] >= numWarps) - break; - if (shape[0] / shapePerWarp[0] / ret[0] >= - shape[1] / (shapePerWarp[1] * 2) / ret[1]) { - if (ret[0] < shape[0] / shapePerWarp[0]) { - ret[0] *= 2; - } else - ret[1] *= 2; - } else { + const ArrayRef shape, + int numWarps) { + SetVector slices; + mlir::getForwardSlice(dotOp.getResult(), &slices); + if (llvm::find_if(slices, [](Operation *op) { + return isa(op); + }) != slices.end()) + return {(unsigned)numWarps, 1}; + + SmallVector ret = {1, 1}; + SmallVector shapePerWarp = {16, 8}; + bool changed = false; + // TODO (@daadaada): double-check. + // original logic in + // https://github.com/openai/triton/blob/master/lib/codegen/analysis/layout.cc#L252 + // seems buggy for shape = [32, 16] ? + do { + changed = false; + if (ret[0] * ret[1] >= numWarps) + break; + if (shape[0] / shapePerWarp[0] / ret[0] >= + shape[1] / (shapePerWarp[1] * 2) / ret[1]) { + if (ret[0] < shape[0] / shapePerWarp[0]) { + ret[0] *= 2; + } else ret[1] *= 2; - } - } while (true); - return ret; + } else { + ret[1] *= 2; + } + } while (true); + return ret; } } // namespace diff --git a/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp b/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp index 4ebff3331..8287350de 100644 --- a/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp @@ -130,10 +130,10 @@ LogicalResult Prefetcher::initialize() { if (dotsInFor.empty()) return failure(); - + // TODO: segfault (original for still has uses) // when used in flash attention that has 2 dots in the loop - if(dotsInFor.size() > 1) + if (dotsInFor.size() > 1) return failure(); // returns source of cvt diff --git a/python/src/triton.cc b/python/src/triton.cc index 95fa120bc..f7ef61547 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -1251,13 +1251,12 @@ void init_triton_ir(py::module &&m) { llvm::StringRef(prefix)), values); }) - // Undef - .def("create_undef", - [](mlir::OpBuilder &self, mlir::Type &type) -> mlir::Value { - auto loc = self.getUnknownLoc(); - return self.create<::mlir::LLVM::UndefOp>(loc, type); - }) - ; + // Undef + .def("create_undef", + [](mlir::OpBuilder &self, mlir::Type &type) -> mlir::Value { + auto loc = self.getUnknownLoc(); + return self.create<::mlir::LLVM::UndefOp>(loc, type); + }); py::class_(m, "pass_manager") .def(py::init()) diff --git a/test/Analysis/test-membar.mlir b/test/Analysis/test-membar.mlir index c2afeb386..42ff6c9c3 100644 --- a/test/Analysis/test-membar.mlir +++ b/test/Analysis/test-membar.mlir @@ -261,9 +261,9 @@ func @for_alias(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : %cst0 = tt.cat %a_shared_init, %b_shared_init {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #A_SHARED> %c_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) { - %cst1 = tt.cat %a_shared_init, %b_shared_init {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #A_SHARED> + %cst1 = tt.cat %a_shared_init, %b_shared_init {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #AL> // CHECK-NEXT: Membar 6 - %cst2 = tt.cat %a_shared, %b_shared {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #A_SHARED> + %cst2 = tt.cat %a_shared, %b_shared {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #AL> scf.yield %c_shared, %a_shared, %b_shared : tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED> } // CHECK-NEXT: Membar 9 @@ -271,4 +271,48 @@ func @for_alias(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : return } +// Although cst2 is not an argument of scf.yield, its memory is reused by cst1. +// So we need a barrier both before and after cst1 +// CHECK-LABEL: for_reuse +func @for_reuse(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { + %a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> + %b_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> + // CHECK-NEXT: Membar 2 + %cst0 = tt.cat %a_shared_init, %b_shared_init {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #A_SHARED> + %c_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> + %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) { + // CHECK-NEXT: Membar 5 + %cst1 = tt.cat %a_shared_init, %b_shared_init {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #A_SHARED> + // CHECK-NEXT: Membar 7 + %cst2 = tt.cat %a_shared, %b_shared {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #A_SHARED> + scf.yield %c_shared, %a_shared, %b_shared : tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED> + } + // CHECK-NEXT: Membar 10 + %cst3 = tt.cat %cst0, %cst0 {axis = 0} : (tensor<256x32xf16, #A_SHARED>, tensor<256x32xf16, #A_SHARED>) -> tensor<512x32xf16, #A_SHARED> + return +} + + +// CHECK-LABEL: for_reuse_nested +func @for_reuse_nested(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { + %a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> + %b_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> + // CHECK-NEXT: Membar 2 + %cst0 = tt.cat %a_shared_init, %b_shared_init {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #A_SHARED> + %c_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> + %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) { + // CHECK-NEXT: Membar 5 + %cst1 = tt.cat %a_shared_init, %b_shared_init {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #A_SHARED> + %a_shared_next, %b_shared_next, %c_shared_next = scf.for %ivv = %lb to %ub step %step iter_args(%a_shared_nested = %a_shared_init, %b_shared_nested = %b_shared_init, %c_shared_nested = %c_shared_init) -> (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) { + // CHECK-NEXT: Membar 7 + %cst2 = tt.cat %a_shared_nested, %b_shared_nested {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #A_SHARED> + scf.yield %c_shared_nested, %a_shared_nested, %b_shared_nested : tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED> + } + scf.yield %c_shared, %a_shared, %b_shared : tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED> + } + // CHECK-NEXT: Membar 11 + %cst3 = tt.cat %cst0, %cst0 {axis = 0} : (tensor<256x32xf16, #A_SHARED>, tensor<256x32xf16, #A_SHARED>) -> tensor<512x32xf16, #A_SHARED> + return +} + } diff --git a/test/lib/Analysis/TestMembar.cpp b/test/lib/Analysis/TestMembar.cpp index 6a4c7ed0a..03a56cc10 100644 --- a/test/lib/Analysis/TestMembar.cpp +++ b/test/lib/Analysis/TestMembar.cpp @@ -26,7 +26,9 @@ struct TestMembarPass auto op_name = SymbolTable::getSymbolName(operation).getValue().str(); os << op_name << "\n"; Allocation allocation(operation); - MembarAnalysis analysis(&allocation); + MembarAnalysis membarPass(&allocation); + membarPass.run(); + size_t operationId = 0; operation->walk([&](Operation *op) { if (isa(op)) {