diff --git a/include/triton/Analysis/Membar.h b/include/triton/Analysis/Membar.h index 242b54ecc..ceb192753 100644 --- a/include/triton/Analysis/Membar.h +++ b/include/triton/Analysis/Membar.h @@ -29,7 +29,11 @@ public: /// The following circumstances are not considered yet: /// - Double buffers /// - N buffers - MembarAnalysis(Allocation *allocation) : allocation(allocation) { run(); } + MembarAnalysis(Allocation *allocation) : allocation(allocation) {} + + /// Runs the membar analysis to the given operation, inserts a barrier if + /// necessary. + void run(); private: struct RegionInfo { @@ -82,10 +86,6 @@ private: } }; - /// Runs the membar analysis to the given operation, inserts a barrier if - /// necessary. - void run(); - /// Applies the barrier analysis based on the SCF dialect, in which each /// region has a single basic block only. /// Example: diff --git a/lib/Analysis/Membar.cpp b/lib/Analysis/Membar.cpp index 715265c0a..68aebdbd1 100644 --- a/lib/Analysis/Membar.cpp +++ b/lib/Analysis/Membar.cpp @@ -24,21 +24,43 @@ void MembarAnalysis::dfsOperation(Operation *operation, // scf.if only: two regions // scf.for: one region RegionInfo curRegionInfo; - for (auto ®ion : operation->getRegions()) { - // Copy the parent info as the current info. - RegionInfo regionInfo = *parentRegionInfo; - for (auto &block : region.getBlocks()) { - assert(region.getBlocks().size() == 1 && - "Multiple blocks in a region is not supported"); - for (auto &op : block.getOperations()) { - // Traverse the nested operation. - dfsOperation(&op, ®ionInfo, builder); + auto traverseRegions = [&]() -> auto{ + for (auto ®ion : operation->getRegions()) { + // Copy the parent info as the current info. + RegionInfo regionInfo = *parentRegionInfo; + for (auto &block : region.getBlocks()) { + assert(region.getBlocks().size() == 1 && + "Multiple blocks in a region is not supported"); + for (auto &op : block.getOperations()) { + // Traverse the nested operation. + dfsOperation(&op, ®ionInfo, builder); + } } + curRegionInfo.join(regionInfo); } - curRegionInfo.join(regionInfo); + // Set the parent region info as the union of the nested region info. + *parentRegionInfo = curRegionInfo; + }; + + traverseRegions(); + if (isa(operation)) { + // scf.for can have two possible inputs: the init value and the + // previous iteration's result. Although we've applied alias analysis, + // there could be unsynced memory accesses on reused memories. + // For example, consider the following code: + // %1 = convert_layout %0: blocked -> shared + // ... + // gpu.barrier + // ... + // %5 = convert_layout %4 : shared -> dot + // %6 = tt.dot %2, %5 + // scf.yield + // + // Though %5 could be released before scf.yield, it may shared the same + // memory with %1. So we actually have to insert a barrier before %1 to + // make sure the memory is synced. + traverseRegions(); } - // Set the parent region info as the union of the nested region info. - *parentRegionInfo = curRegionInfo; } } @@ -49,8 +71,7 @@ void MembarAnalysis::transfer(Operation *op, RegionInfo *regionInfo, // Do not insert barriers before control flow operations and // alloc/extract/insert // alloc is an allocation op without memory write. - // In contrast, arith.constant is an allocation op with memory write. - // FIXME(Keren): extract is always alias for now + // FIXME(Keren): extract_slice is always alias for now return; } diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index 349ccef01..68fc4bac3 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -644,7 +644,6 @@ public: return multiDimIdx; } - struct SmallVectorKeyInfo { static unsigned getHashValue(const SmallVector &key) { return llvm::hash_combine_range(key.begin(), key.end()); @@ -672,7 +671,7 @@ public: emitIndices(loc, rewriter, parent, sliceLayout.paddedShape(shape)); unsigned numIndices = parentIndices.size(); SmallVector> resultIndices; - for (unsigned i = 0; i < numIndices; ++i){ + for (unsigned i = 0; i < numIndices; ++i) { SmallVector indices = parentIndices[i]; indices.erase(indices.begin() + dim); resultIndices.push_back(indices); @@ -1203,14 +1202,14 @@ struct BroadcastOpConversion auto resultOffsets = emitOffsetForLayout(resultLayout, resultShape); SmallVector srcVals = getElementsFromStruct(loc, src, rewriter); DenseMap, Value, SmallVectorKeyInfo> srcValues; - for(size_t i = 0; i < srcOffsets.size(); i++){ + for (size_t i = 0; i < srcOffsets.size(); i++) { srcValues[srcOffsets[i]] = srcVals[i]; } SmallVector resultVals; - for(size_t i = 0; i < resultOffsets.size(); i++) { + for (size_t i = 0; i < resultOffsets.size(); i++) { auto offset = resultOffsets[i]; - for(size_t j = 0; j < srcShape.size(); j++) - if(srcShape[j]==1) + for (size_t j = 0; j < srcShape.size(); j++) + if (srcShape[j] == 1) offset[j] = 0; resultVals.push_back(srcValues.lookup(offset)); } @@ -1940,8 +1939,8 @@ struct MakeRangeOpConversion unsigned elems = idxs.size(); SmallVector retVals(elems); // TODO: slice layout has more elements than expected. - // Unexpected behavior for make range, but genereally ok when followed by expand dims + broadcast. - // very weird behavior otherwise potentially. + // Unexpected behavior for make range, but genereally ok when followed by + // expand dims + broadcast. very weird behavior otherwise potentially. for (const auto multiDim : llvm::enumerate(idxs)) { assert(multiDim.value().size() == 1); retVals[multiDim.index()] = add(multiDim.value()[0], start); @@ -2647,13 +2646,13 @@ public: } // dot_op = #mma // when #mma = MmaEncoding - if(srcLayout.isa() && + if (srcLayout.isa() && dstLayout.isa()) { auto srcMmaLayout = srcLayout.cast(); auto dstDotLayout = dstLayout.cast(); - if(srcMmaLayout.getWarpsPerCTA()[1] == 1 && - dstDotLayout.getOpIdx() == 0 && - dstDotLayout.getParent() == srcMmaLayout) { + if (srcMmaLayout.getWarpsPerCTA()[1] == 1 && + dstDotLayout.getOpIdx() == 0 && + dstDotLayout.getParent() == srcMmaLayout) { // get source values Location loc = op->getLoc(); auto vals = getElementsFromStruct(loc, adaptor.src(), rewriter); @@ -2662,35 +2661,37 @@ public: this->getTypeConverter()->convertType(srcTy.getElementType()); // for the destination type, we need to pack values together // so they can be consumed by tensor core operations - unsigned vecSize = std::max(32 / elemTy.getIntOrFloatBitWidth(), 1); + unsigned vecSize = + std::max(32 / elemTy.getIntOrFloatBitWidth(), 1); Type vecTy = vec_ty(elemTy, vecSize); - SmallVector types(elems/vecSize, vecTy); + SmallVector types(elems / vecSize, vecTy); SmallVector vecVals; - for(unsigned i = 0; i < elems; i += vecSize) { + for (unsigned i = 0; i < elems; i += vecSize) { Value packed = rewriter.create(loc, vecTy); - for(unsigned j = 0; j < vecSize; j++) - packed = insert_element(vecTy, packed, vals[i+j], i32_val(j)); + for (unsigned j = 0; j < vecSize; j++) + packed = insert_element(vecTy, packed, vals[i + j], i32_val(j)); vecVals.push_back(packed); } - + // This needs to be ordered the same way that // ldmatrix.x4 would order it // TODO: this needs to be refactor so we don't // implicitly depends on how emitOffsetsForMMAV2 // is implemented SmallVector reorderedVals; - for(unsigned i = 0; i < vecVals.size(); i += 4) { + for (unsigned i = 0; i < vecVals.size(); i += 4) { reorderedVals.push_back(vecVals[i]); - reorderedVals.push_back(vecVals[i+2]); - reorderedVals.push_back(vecVals[i+1]); - reorderedVals.push_back(vecVals[i+3]); + reorderedVals.push_back(vecVals[i + 2]); + reorderedVals.push_back(vecVals[i + 1]); + reorderedVals.push_back(vecVals[i + 3]); } // return composeValuesToDotOperandLayoutStruct(ha, numRepM, numRepK); - - Type structTy = LLVM::LLVMStructType::getLiteral(this->getContext(), types); - Value view = getStructFromElements(loc, reorderedVals, rewriter, structTy); + Type structTy = + LLVM::LLVMStructType::getLiteral(this->getContext(), types); + Value view = + getStructFromElements(loc, reorderedVals, rewriter, structTy); rewriter.replaceOp(op, view); return success(); } @@ -3138,10 +3139,6 @@ LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared( ConversionPatternRewriter &rewriter) const { auto loc = op.getLoc(); - // TODO[Keren]: A temporary workaround for an issue from membar pass. - // https://triton-lang.slack.com/archives/C042VBSQWNS/p1669796615860699?thread_ts=1669779203.526739&cid=C042VBSQWNS - barrier(); - Value src = op.src(); Value dst = op.result(); auto srcTy = src.getType().cast(); @@ -4698,7 +4695,8 @@ public: decomposeInsertSliceAsyncOp(mod); Allocation allocation(mod); - MembarAnalysis membar(&allocation); + MembarAnalysis membarPass(&allocation); + membarPass.run(); RewritePatternSet scf_patterns(context); mlir::populateLoopToStdConversionPatterns(scf_patterns); diff --git a/lib/Dialect/TritonGPU/Transforms/Combine.cpp b/lib/Dialect/TritonGPU/Transforms/Combine.cpp index 884dd68f0..6ec2ee127 100644 --- a/lib/Dialect/TritonGPU/Transforms/Combine.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Combine.cpp @@ -50,22 +50,25 @@ public: auto dstType = convert.getType().cast(); if (srcType.getEncoding().isa() && dstType.getEncoding().isa()) { - auto dstDotOperand = dstType.getEncoding().cast(); + auto dstDotOperand = + dstType.getEncoding().cast(); auto dstParent = dstDotOperand.getParent(); - if(dstDotOperand.getOpIdx()==1 || - !dstParent.isa()) + if (dstDotOperand.getOpIdx() == 1 || + !dstParent.isa()) return mlir::failure(); auto dstParentMma = dstParent.cast(); - if(dstParentMma.getVersion() == 1 || - dstParentMma.getWarpsPerCTA()[1] > 1) + if (dstParentMma.getVersion() == 1 || + dstParentMma.getWarpsPerCTA()[1] > 1) return mlir::failure(); - SetVector bwdSlices; + SetVector bwdSlices; mlir::getBackwardSlice(convert.getResult(), &bwdSlices); - if(llvm::find_if(bwdSlices, [](Operation *op) { return isa(op); }) == bwdSlices.end()) + if (llvm::find_if(bwdSlices, [](Operation *op) { + return isa(op); + }) == bwdSlices.end()) return mlir::failure(); - - auto tmpType = - RankedTensorType::get(dstType.getShape(), dstType.getElementType(), dstParentMma); + + auto tmpType = RankedTensorType::get( + dstType.getShape(), dstType.getElementType(), dstParentMma); auto tmp = rewriter.create( convert.getLoc(), tmpType, convert.getOperand()); auto newConvert = rewriter.create( @@ -601,10 +604,9 @@ mmaVersionToShapePerWarp(int version, const ArrayRef &shape, } } - SmallVector warpsPerTileV1(triton::DotOp dotOp, - const ArrayRef shape, - int numWarps) { + const ArrayRef shape, + int numWarps) { SmallVector ret = {1, 1}; SmallVector shapePerWarp = mmaVersionToShapePerWarp(1, shape, numWarps); @@ -624,35 +626,37 @@ SmallVector warpsPerTileV1(triton::DotOp dotOp, } SmallVector warpsPerTileV2(triton::DotOp dotOp, - const ArrayRef shape, - int numWarps) { - SetVector slices; - mlir::getForwardSlice(dotOp.getResult(), &slices); - if(llvm::find_if(slices, [](Operation *op) { return isa(op); }) != slices.end()) - return {(unsigned)numWarps, 1}; - - SmallVector ret = {1, 1}; - SmallVector shapePerWarp = {16, 8}; - bool changed = false; - // TODO (@daadaada): double-check. - // original logic in - // https://github.com/openai/triton/blob/master/lib/codegen/analysis/layout.cc#L252 - // seems buggy for shape = [32, 16] ? - do { - changed = false; - if (ret[0] * ret[1] >= numWarps) - break; - if (shape[0] / shapePerWarp[0] / ret[0] >= - shape[1] / (shapePerWarp[1] * 2) / ret[1]) { - if (ret[0] < shape[0] / shapePerWarp[0]) { - ret[0] *= 2; - } else - ret[1] *= 2; - } else { + const ArrayRef shape, + int numWarps) { + SetVector slices; + mlir::getForwardSlice(dotOp.getResult(), &slices); + if (llvm::find_if(slices, [](Operation *op) { + return isa(op); + }) != slices.end()) + return {(unsigned)numWarps, 1}; + + SmallVector ret = {1, 1}; + SmallVector shapePerWarp = {16, 8}; + bool changed = false; + // TODO (@daadaada): double-check. + // original logic in + // https://github.com/openai/triton/blob/master/lib/codegen/analysis/layout.cc#L252 + // seems buggy for shape = [32, 16] ? + do { + changed = false; + if (ret[0] * ret[1] >= numWarps) + break; + if (shape[0] / shapePerWarp[0] / ret[0] >= + shape[1] / (shapePerWarp[1] * 2) / ret[1]) { + if (ret[0] < shape[0] / shapePerWarp[0]) { + ret[0] *= 2; + } else ret[1] *= 2; - } - } while (true); - return ret; + } else { + ret[1] *= 2; + } + } while (true); + return ret; } } // namespace diff --git a/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp b/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp index 4ebff3331..8287350de 100644 --- a/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp @@ -130,10 +130,10 @@ LogicalResult Prefetcher::initialize() { if (dotsInFor.empty()) return failure(); - + // TODO: segfault (original for still has uses) // when used in flash attention that has 2 dots in the loop - if(dotsInFor.size() > 1) + if (dotsInFor.size() > 1) return failure(); // returns source of cvt diff --git a/python/src/triton.cc b/python/src/triton.cc index 95fa120bc..f7ef61547 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -1251,13 +1251,12 @@ void init_triton_ir(py::module &&m) { llvm::StringRef(prefix)), values); }) - // Undef - .def("create_undef", - [](mlir::OpBuilder &self, mlir::Type &type) -> mlir::Value { - auto loc = self.getUnknownLoc(); - return self.create<::mlir::LLVM::UndefOp>(loc, type); - }) - ; + // Undef + .def("create_undef", + [](mlir::OpBuilder &self, mlir::Type &type) -> mlir::Value { + auto loc = self.getUnknownLoc(); + return self.create<::mlir::LLVM::UndefOp>(loc, type); + }); py::class_(m, "pass_manager") .def(py::init()) diff --git a/test/Analysis/test-membar.mlir b/test/Analysis/test-membar.mlir index c2afeb386..42ff6c9c3 100644 --- a/test/Analysis/test-membar.mlir +++ b/test/Analysis/test-membar.mlir @@ -261,9 +261,9 @@ func @for_alias(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : %cst0 = tt.cat %a_shared_init, %b_shared_init {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #A_SHARED> %c_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) { - %cst1 = tt.cat %a_shared_init, %b_shared_init {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #A_SHARED> + %cst1 = tt.cat %a_shared_init, %b_shared_init {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #AL> // CHECK-NEXT: Membar 6 - %cst2 = tt.cat %a_shared, %b_shared {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #A_SHARED> + %cst2 = tt.cat %a_shared, %b_shared {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #AL> scf.yield %c_shared, %a_shared, %b_shared : tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED> } // CHECK-NEXT: Membar 9 @@ -271,4 +271,48 @@ func @for_alias(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : return } +// Although cst2 is not an argument of scf.yield, its memory is reused by cst1. +// So we need a barrier both before and after cst1 +// CHECK-LABEL: for_reuse +func @for_reuse(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { + %a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> + %b_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> + // CHECK-NEXT: Membar 2 + %cst0 = tt.cat %a_shared_init, %b_shared_init {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #A_SHARED> + %c_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> + %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) { + // CHECK-NEXT: Membar 5 + %cst1 = tt.cat %a_shared_init, %b_shared_init {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #A_SHARED> + // CHECK-NEXT: Membar 7 + %cst2 = tt.cat %a_shared, %b_shared {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #A_SHARED> + scf.yield %c_shared, %a_shared, %b_shared : tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED> + } + // CHECK-NEXT: Membar 10 + %cst3 = tt.cat %cst0, %cst0 {axis = 0} : (tensor<256x32xf16, #A_SHARED>, tensor<256x32xf16, #A_SHARED>) -> tensor<512x32xf16, #A_SHARED> + return +} + + +// CHECK-LABEL: for_reuse_nested +func @for_reuse_nested(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { + %a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> + %b_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> + // CHECK-NEXT: Membar 2 + %cst0 = tt.cat %a_shared_init, %b_shared_init {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #A_SHARED> + %c_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> + %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) { + // CHECK-NEXT: Membar 5 + %cst1 = tt.cat %a_shared_init, %b_shared_init {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #A_SHARED> + %a_shared_next, %b_shared_next, %c_shared_next = scf.for %ivv = %lb to %ub step %step iter_args(%a_shared_nested = %a_shared_init, %b_shared_nested = %b_shared_init, %c_shared_nested = %c_shared_init) -> (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) { + // CHECK-NEXT: Membar 7 + %cst2 = tt.cat %a_shared_nested, %b_shared_nested {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #A_SHARED> + scf.yield %c_shared_nested, %a_shared_nested, %b_shared_nested : tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED> + } + scf.yield %c_shared, %a_shared, %b_shared : tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED> + } + // CHECK-NEXT: Membar 11 + %cst3 = tt.cat %cst0, %cst0 {axis = 0} : (tensor<256x32xf16, #A_SHARED>, tensor<256x32xf16, #A_SHARED>) -> tensor<512x32xf16, #A_SHARED> + return +} + } diff --git a/test/lib/Analysis/TestMembar.cpp b/test/lib/Analysis/TestMembar.cpp index 6a4c7ed0a..03a56cc10 100644 --- a/test/lib/Analysis/TestMembar.cpp +++ b/test/lib/Analysis/TestMembar.cpp @@ -26,7 +26,9 @@ struct TestMembarPass auto op_name = SymbolTable::getSymbolName(operation).getValue().str(); os << op_name << "\n"; Allocation allocation(operation); - MembarAnalysis analysis(&allocation); + MembarAnalysis membarPass(&allocation); + membarPass.run(); + size_t operationId = 0; operation->walk([&](Operation *op) { if (isa(op)) {