[Triton-MLIR][BACKEND] Fix the membar pass to add missing barriers caused by scf.for (#933)

1. Add missing barriers and revert the previous temporary solution
2. Extract the `run` method from membar analysis because the membar
analysis should have two phases, including construction, which doesn't
modify any IR, and modification, which adds barrier IRs. Hope this could
make the use of membar clear.
This commit is contained in:
Keren Zhou
2022-12-01 11:54:18 -08:00
committed by GitHub
parent 9def1bcebf
commit c280ebda1b
8 changed files with 170 additions and 102 deletions

View File

@@ -50,22 +50,25 @@ public:
auto dstType = convert.getType().cast<RankedTensorType>();
if (srcType.getEncoding().isa<triton::gpu::BlockedEncodingAttr>() &&
dstType.getEncoding().isa<triton::gpu::DotOperandEncodingAttr>()) {
auto dstDotOperand = dstType.getEncoding().cast<triton::gpu::DotOperandEncodingAttr>();
auto dstDotOperand =
dstType.getEncoding().cast<triton::gpu::DotOperandEncodingAttr>();
auto dstParent = dstDotOperand.getParent();
if(dstDotOperand.getOpIdx()==1 ||
!dstParent.isa<triton::gpu::MmaEncodingAttr>())
if (dstDotOperand.getOpIdx() == 1 ||
!dstParent.isa<triton::gpu::MmaEncodingAttr>())
return mlir::failure();
auto dstParentMma = dstParent.cast<triton::gpu::MmaEncodingAttr>();
if(dstParentMma.getVersion() == 1 ||
dstParentMma.getWarpsPerCTA()[1] > 1)
if (dstParentMma.getVersion() == 1 ||
dstParentMma.getWarpsPerCTA()[1] > 1)
return mlir::failure();
SetVector<Operation*> bwdSlices;
SetVector<Operation *> bwdSlices;
mlir::getBackwardSlice(convert.getResult(), &bwdSlices);
if(llvm::find_if(bwdSlices, [](Operation *op) { return isa<triton::DotOp>(op); }) == bwdSlices.end())
if (llvm::find_if(bwdSlices, [](Operation *op) {
return isa<triton::DotOp>(op);
}) == bwdSlices.end())
return mlir::failure();
auto tmpType =
RankedTensorType::get(dstType.getShape(), dstType.getElementType(), dstParentMma);
auto tmpType = RankedTensorType::get(
dstType.getShape(), dstType.getElementType(), dstParentMma);
auto tmp = rewriter.create<triton::gpu::ConvertLayoutOp>(
convert.getLoc(), tmpType, convert.getOperand());
auto newConvert = rewriter.create<triton::gpu::ConvertLayoutOp>(
@@ -601,10 +604,9 @@ mmaVersionToShapePerWarp(int version, const ArrayRef<int64_t> &shape,
}
}
SmallVector<unsigned, 2> warpsPerTileV1(triton::DotOp dotOp,
const ArrayRef<int64_t> shape,
int numWarps) {
const ArrayRef<int64_t> shape,
int numWarps) {
SmallVector<unsigned, 2> ret = {1, 1};
SmallVector<int64_t, 2> shapePerWarp =
mmaVersionToShapePerWarp(1, shape, numWarps);
@@ -624,35 +626,37 @@ SmallVector<unsigned, 2> warpsPerTileV1(triton::DotOp dotOp,
}
SmallVector<unsigned, 2> warpsPerTileV2(triton::DotOp dotOp,
const ArrayRef<int64_t> shape,
int numWarps) {
SetVector<Operation*> slices;
mlir::getForwardSlice(dotOp.getResult(), &slices);
if(llvm::find_if(slices, [](Operation *op) { return isa<triton::DotOp>(op); }) != slices.end())
return {(unsigned)numWarps, 1};
SmallVector<unsigned, 2> ret = {1, 1};
SmallVector<int64_t, 2> shapePerWarp = {16, 8};
bool changed = false;
// TODO (@daadaada): double-check.
// original logic in
// https://github.com/openai/triton/blob/master/lib/codegen/analysis/layout.cc#L252
// seems buggy for shape = [32, 16] ?
do {
changed = false;
if (ret[0] * ret[1] >= numWarps)
break;
if (shape[0] / shapePerWarp[0] / ret[0] >=
shape[1] / (shapePerWarp[1] * 2) / ret[1]) {
if (ret[0] < shape[0] / shapePerWarp[0]) {
ret[0] *= 2;
} else
ret[1] *= 2;
} else {
const ArrayRef<int64_t> shape,
int numWarps) {
SetVector<Operation *> slices;
mlir::getForwardSlice(dotOp.getResult(), &slices);
if (llvm::find_if(slices, [](Operation *op) {
return isa<triton::DotOp>(op);
}) != slices.end())
return {(unsigned)numWarps, 1};
SmallVector<unsigned, 2> ret = {1, 1};
SmallVector<int64_t, 2> shapePerWarp = {16, 8};
bool changed = false;
// TODO (@daadaada): double-check.
// original logic in
// https://github.com/openai/triton/blob/master/lib/codegen/analysis/layout.cc#L252
// seems buggy for shape = [32, 16] ?
do {
changed = false;
if (ret[0] * ret[1] >= numWarps)
break;
if (shape[0] / shapePerWarp[0] / ret[0] >=
shape[1] / (shapePerWarp[1] * 2) / ret[1]) {
if (ret[0] < shape[0] / shapePerWarp[0]) {
ret[0] *= 2;
} else
ret[1] *= 2;
}
} while (true);
return ret;
} else {
ret[1] *= 2;
}
} while (true);
return ret;
}
} // namespace

View File

@@ -130,10 +130,10 @@ LogicalResult Prefetcher::initialize() {
if (dotsInFor.empty())
return failure();
// TODO: segfault (original for still has uses)
// when used in flash attention that has 2 dots in the loop
if(dotsInFor.size() > 1)
if (dotsInFor.size() > 1)
return failure();
// returns source of cvt