[Triton-MLIR][BACKEND] Fix the membar pass to add missing barriers caused by scf.for (#933)

1. Add missing barriers and revert the previous temporary solution
2. Extract the `run` method from membar analysis because the membar
analysis should have two phases, including construction, which doesn't
modify any IR, and modification, which adds barrier IRs. Hope this could
make the use of membar clear.
This commit is contained in:
Keren Zhou
2022-12-01 11:54:18 -08:00
committed by GitHub
parent 9def1bcebf
commit c280ebda1b
8 changed files with 170 additions and 102 deletions

View File

@@ -29,7 +29,11 @@ public:
/// The following circumstances are not considered yet:
/// - Double buffers
/// - N buffers
MembarAnalysis(Allocation *allocation) : allocation(allocation) { run(); }
MembarAnalysis(Allocation *allocation) : allocation(allocation) {}
/// Runs the membar analysis to the given operation, inserts a barrier if
/// necessary.
void run();
private:
struct RegionInfo {
@@ -82,10 +86,6 @@ private:
}
};
/// Runs the membar analysis to the given operation, inserts a barrier if
/// necessary.
void run();
/// Applies the barrier analysis based on the SCF dialect, in which each
/// region has a single basic block only.
/// Example:

View File

@@ -24,6 +24,7 @@ void MembarAnalysis::dfsOperation(Operation *operation,
// scf.if only: two regions
// scf.for: one region
RegionInfo curRegionInfo;
auto traverseRegions = [&]() -> auto{
for (auto &region : operation->getRegions()) {
// Copy the parent info as the current info.
RegionInfo regionInfo = *parentRegionInfo;
@@ -39,6 +40,27 @@ void MembarAnalysis::dfsOperation(Operation *operation,
}
// Set the parent region info as the union of the nested region info.
*parentRegionInfo = curRegionInfo;
};
traverseRegions();
if (isa<scf::ForOp>(operation)) {
// scf.for can have two possible inputs: the init value and the
// previous iteration's result. Although we've applied alias analysis,
// there could be unsynced memory accesses on reused memories.
// For example, consider the following code:
// %1 = convert_layout %0: blocked -> shared
// ...
// gpu.barrier
// ...
// %5 = convert_layout %4 : shared -> dot
// %6 = tt.dot %2, %5
// scf.yield
//
// Though %5 could be released before scf.yield, it may shared the same
// memory with %1. So we actually have to insert a barrier before %1 to
// make sure the memory is synced.
traverseRegions();
}
}
}
@@ -49,8 +71,7 @@ void MembarAnalysis::transfer(Operation *op, RegionInfo *regionInfo,
// Do not insert barriers before control flow operations and
// alloc/extract/insert
// alloc is an allocation op without memory write.
// In contrast, arith.constant is an allocation op with memory write.
// FIXME(Keren): extract is always alias for now
// FIXME(Keren): extract_slice is always alias for now
return;
}

View File

@@ -644,7 +644,6 @@ public:
return multiDimIdx;
}
struct SmallVectorKeyInfo {
static unsigned getHashValue(const SmallVector<unsigned> &key) {
return llvm::hash_combine_range(key.begin(), key.end());
@@ -1940,8 +1939,8 @@ struct MakeRangeOpConversion
unsigned elems = idxs.size();
SmallVector<Value> retVals(elems);
// TODO: slice layout has more elements than expected.
// Unexpected behavior for make range, but genereally ok when followed by expand dims + broadcast.
// very weird behavior otherwise potentially.
// Unexpected behavior for make range, but genereally ok when followed by
// expand dims + broadcast. very weird behavior otherwise potentially.
for (const auto multiDim : llvm::enumerate(idxs)) {
assert(multiDim.value().size() == 1);
retVals[multiDim.index()] = add(multiDim.value()[0], start);
@@ -2662,7 +2661,8 @@ public:
this->getTypeConverter()->convertType(srcTy.getElementType());
// for the destination type, we need to pack values together
// so they can be consumed by tensor core operations
unsigned vecSize = std::max<unsigned>(32 / elemTy.getIntOrFloatBitWidth(), 1);
unsigned vecSize =
std::max<unsigned>(32 / elemTy.getIntOrFloatBitWidth(), 1);
Type vecTy = vec_ty(elemTy, vecSize);
SmallVector<Type> types(elems / vecSize, vecTy);
SmallVector<Value> vecVals;
@@ -2688,9 +2688,10 @@ public:
// return composeValuesToDotOperandLayoutStruct(ha, numRepM, numRepK);
Type structTy = LLVM::LLVMStructType::getLiteral(this->getContext(), types);
Value view = getStructFromElements(loc, reorderedVals, rewriter, structTy);
Type structTy =
LLVM::LLVMStructType::getLiteral(this->getContext(), types);
Value view =
getStructFromElements(loc, reorderedVals, rewriter, structTy);
rewriter.replaceOp(op, view);
return success();
}
@@ -3138,10 +3139,6 @@ LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared(
ConversionPatternRewriter &rewriter) const {
auto loc = op.getLoc();
// TODO[Keren]: A temporary workaround for an issue from membar pass.
// https://triton-lang.slack.com/archives/C042VBSQWNS/p1669796615860699?thread_ts=1669779203.526739&cid=C042VBSQWNS
barrier();
Value src = op.src();
Value dst = op.result();
auto srcTy = src.getType().cast<RankedTensorType>();
@@ -4698,7 +4695,8 @@ public:
decomposeInsertSliceAsyncOp(mod);
Allocation allocation(mod);
MembarAnalysis membar(&allocation);
MembarAnalysis membarPass(&allocation);
membarPass.run();
RewritePatternSet scf_patterns(context);
mlir::populateLoopToStdConversionPatterns(scf_patterns);

View File

@@ -50,7 +50,8 @@ public:
auto dstType = convert.getType().cast<RankedTensorType>();
if (srcType.getEncoding().isa<triton::gpu::BlockedEncodingAttr>() &&
dstType.getEncoding().isa<triton::gpu::DotOperandEncodingAttr>()) {
auto dstDotOperand = dstType.getEncoding().cast<triton::gpu::DotOperandEncodingAttr>();
auto dstDotOperand =
dstType.getEncoding().cast<triton::gpu::DotOperandEncodingAttr>();
auto dstParent = dstDotOperand.getParent();
if (dstDotOperand.getOpIdx() == 1 ||
!dstParent.isa<triton::gpu::MmaEncodingAttr>())
@@ -61,11 +62,13 @@ public:
return mlir::failure();
SetVector<Operation *> bwdSlices;
mlir::getBackwardSlice(convert.getResult(), &bwdSlices);
if(llvm::find_if(bwdSlices, [](Operation *op) { return isa<triton::DotOp>(op); }) == bwdSlices.end())
if (llvm::find_if(bwdSlices, [](Operation *op) {
return isa<triton::DotOp>(op);
}) == bwdSlices.end())
return mlir::failure();
auto tmpType =
RankedTensorType::get(dstType.getShape(), dstType.getElementType(), dstParentMma);
auto tmpType = RankedTensorType::get(
dstType.getShape(), dstType.getElementType(), dstParentMma);
auto tmp = rewriter.create<triton::gpu::ConvertLayoutOp>(
convert.getLoc(), tmpType, convert.getOperand());
auto newConvert = rewriter.create<triton::gpu::ConvertLayoutOp>(
@@ -601,7 +604,6 @@ mmaVersionToShapePerWarp(int version, const ArrayRef<int64_t> &shape,
}
}
SmallVector<unsigned, 2> warpsPerTileV1(triton::DotOp dotOp,
const ArrayRef<int64_t> shape,
int numWarps) {
@@ -628,7 +630,9 @@ SmallVector<unsigned, 2> warpsPerTileV2(triton::DotOp dotOp,
int numWarps) {
SetVector<Operation *> slices;
mlir::getForwardSlice(dotOp.getResult(), &slices);
if(llvm::find_if(slices, [](Operation *op) { return isa<triton::DotOp>(op); }) != slices.end())
if (llvm::find_if(slices, [](Operation *op) {
return isa<triton::DotOp>(op);
}) != slices.end())
return {(unsigned)numWarps, 1};
SmallVector<unsigned, 2> ret = {1, 1};

View File

@@ -1256,8 +1256,7 @@ void init_triton_ir(py::module &&m) {
[](mlir::OpBuilder &self, mlir::Type &type) -> mlir::Value {
auto loc = self.getUnknownLoc();
return self.create<::mlir::LLVM::UndefOp>(loc, type);
})
;
});
py::class_<mlir::PassManager>(m, "pass_manager")
.def(py::init<mlir::MLIRContext *>())

View File

@@ -261,9 +261,9 @@ func @for_alias(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B :
%cst0 = tt.cat %a_shared_init, %b_shared_init {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #A_SHARED>
%c_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
%a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) {
%cst1 = tt.cat %a_shared_init, %b_shared_init {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #A_SHARED>
%cst1 = tt.cat %a_shared_init, %b_shared_init {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #AL>
// CHECK-NEXT: Membar 6
%cst2 = tt.cat %a_shared, %b_shared {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #A_SHARED>
%cst2 = tt.cat %a_shared, %b_shared {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #AL>
scf.yield %c_shared, %a_shared, %b_shared : tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>
}
// CHECK-NEXT: Membar 9
@@ -271,4 +271,48 @@ func @for_alias(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B :
return
}
// Although cst2 is not an argument of scf.yield, its memory is reused by cst1.
// So we need a barrier both before and after cst1
// CHECK-LABEL: for_reuse
func @for_reuse(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
%a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
%b_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
// CHECK-NEXT: Membar 2
%cst0 = tt.cat %a_shared_init, %b_shared_init {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #A_SHARED>
%c_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
%a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) {
// CHECK-NEXT: Membar 5
%cst1 = tt.cat %a_shared_init, %b_shared_init {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #A_SHARED>
// CHECK-NEXT: Membar 7
%cst2 = tt.cat %a_shared, %b_shared {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #A_SHARED>
scf.yield %c_shared, %a_shared, %b_shared : tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>
}
// CHECK-NEXT: Membar 10
%cst3 = tt.cat %cst0, %cst0 {axis = 0} : (tensor<256x32xf16, #A_SHARED>, tensor<256x32xf16, #A_SHARED>) -> tensor<512x32xf16, #A_SHARED>
return
}
// CHECK-LABEL: for_reuse_nested
func @for_reuse_nested(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
%a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
%b_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
// CHECK-NEXT: Membar 2
%cst0 = tt.cat %a_shared_init, %b_shared_init {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #A_SHARED>
%c_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
%a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) {
// CHECK-NEXT: Membar 5
%cst1 = tt.cat %a_shared_init, %b_shared_init {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #A_SHARED>
%a_shared_next, %b_shared_next, %c_shared_next = scf.for %ivv = %lb to %ub step %step iter_args(%a_shared_nested = %a_shared_init, %b_shared_nested = %b_shared_init, %c_shared_nested = %c_shared_init) -> (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) {
// CHECK-NEXT: Membar 7
%cst2 = tt.cat %a_shared_nested, %b_shared_nested {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #A_SHARED>
scf.yield %c_shared_nested, %a_shared_nested, %b_shared_nested : tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>
}
scf.yield %c_shared, %a_shared, %b_shared : tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>
}
// CHECK-NEXT: Membar 11
%cst3 = tt.cat %cst0, %cst0 {axis = 0} : (tensor<256x32xf16, #A_SHARED>, tensor<256x32xf16, #A_SHARED>) -> tensor<512x32xf16, #A_SHARED>
return
}
}

View File

@@ -26,7 +26,9 @@ struct TestMembarPass
auto op_name = SymbolTable::getSymbolName(operation).getValue().str();
os << op_name << "\n";
Allocation allocation(operation);
MembarAnalysis analysis(&allocation);
MembarAnalysis membarPass(&allocation);
membarPass.run();
size_t operationId = 0;
operation->walk<WalkOrder::PreOrder>([&](Operation *op) {
if (isa<gpu::BarrierOp>(op)) {