[BACKEND] Make flash attention forward pass work (#928)

This also simplifies BroadcastOp codegen
This commit is contained in:
Philippe Tillet
2022-11-30 11:13:24 +01:00
committed by GitHub
parent 4e6a8209ed
commit 6461254fb5
7 changed files with 326 additions and 205 deletions

View File

@@ -50,10 +50,22 @@ public:
auto dstType = convert.getType().cast<RankedTensorType>();
if (srcType.getEncoding().isa<triton::gpu::BlockedEncodingAttr>() &&
dstType.getEncoding().isa<triton::gpu::DotOperandEncodingAttr>()) {
auto dstDotOperand = dstType.getEncoding().cast<triton::gpu::DotOperandEncodingAttr>();
auto dstParent = dstDotOperand.getParent();
if(dstDotOperand.getOpIdx()==1 ||
!dstParent.isa<triton::gpu::MmaEncodingAttr>())
return mlir::failure();
auto dstParentMma = dstParent.cast<triton::gpu::MmaEncodingAttr>();
if(dstParentMma.getVersion() == 1 ||
dstParentMma.getWarpsPerCTA()[1] > 1)
return mlir::failure();
SetVector<Operation*> bwdSlices;
mlir::getBackwardSlice(convert.getResult(), &bwdSlices);
if(llvm::find_if(bwdSlices, [](Operation *op) { return isa<triton::DotOp>(op); }) == bwdSlices.end())
return mlir::failure();
auto tmpType =
RankedTensorType::get(dstType.getShape(), dstType.getElementType(),
triton::gpu::SharedEncodingAttr::get(
op->getContext(), 1, 1, 1, {1, 0}));
RankedTensorType::get(dstType.getShape(), dstType.getElementType(), dstParentMma);
auto tmp = rewriter.create<triton::gpu::ConvertLayoutOp>(
convert.getLoc(), tmpType, convert.getOperand());
auto newConvert = rewriter.create<triton::gpu::ConvertLayoutOp>(
@@ -81,8 +93,11 @@ public:
auto convert = llvm::cast<triton::gpu::ConvertLayoutOp>(op);
// we don't handle conversions to DotOperandEncodingAttr
// this is a heuristics to accommodate fused attention
// if (dstType.getEncoding().isa<triton::gpu::DotOperandEncodingAttr>())
// return mlir::failure();
auto srcType = convert.getOperand().getType().cast<RankedTensorType>();
auto dstType = convert.getType().cast<RankedTensorType>();
if (dstType.getEncoding().isa<triton::gpu::DotOperandEncodingAttr>() &&
srcType.getEncoding().isa<triton::gpu::MmaEncodingAttr>())
return mlir::failure();
// convert to the same layout -- we can delete
if (op->getResultTypes() == op->getOperandTypes()) {
rewriter.replaceOp(op, op->getOperands());
@@ -586,12 +601,9 @@ mmaVersionToShapePerWarp(int version, const ArrayRef<int64_t> &shape,
}
}
template <int version>
SmallVector<unsigned, 2> warpsPerTile(const ArrayRef<int64_t> shape,
int numWarps);
template <>
SmallVector<unsigned, 2> warpsPerTile<1>(const ArrayRef<int64_t> shape,
SmallVector<unsigned, 2> warpsPerTileV1(triton::DotOp dotOp,
const ArrayRef<int64_t> shape,
int numWarps) {
SmallVector<unsigned, 2> ret = {1, 1};
SmallVector<int64_t, 2> shapePerWarp =
@@ -611,33 +623,40 @@ SmallVector<unsigned, 2> warpsPerTile<1>(const ArrayRef<int64_t> shape,
return ret;
}
template <>
SmallVector<unsigned, 2> warpsPerTile<2>(const ArrayRef<int64_t> shape,
SmallVector<unsigned, 2> warpsPerTileV2(triton::DotOp dotOp,
const ArrayRef<int64_t> shape,
int numWarps) {
SmallVector<unsigned, 2> ret = {1, 1};
SmallVector<int64_t, 2> shapePerWarp =
mmaVersionToShapePerWarp(2, shape, numWarps);
// TODO (@daadaada): double-check.
// original logic in
// https://github.com/openai/triton/blob/master/lib/codegen/analysis/layout.cc#L252
// seems buggy for shape = [32, 16] ?
do {
if (ret[0] * ret[1] >= numWarps)
break;
if (shape[0] / shapePerWarp[0] / ret[0] >=
shape[1] / (shapePerWarp[1] * 2) / ret[1]) {
if (ret[0] < shape[0] / shapePerWarp[0]) {
ret[0] *= 2;
} else
SetVector<Operation*> slices;
mlir::getForwardSlice(dotOp.getResult(), &slices);
if(llvm::find_if(slices, [](Operation *op) { return isa<triton::DotOp>(op); }) != slices.end())
return {(unsigned)numWarps, 1};
SmallVector<unsigned, 2> ret = {1, 1};
SmallVector<int64_t, 2> shapePerWarp = {16, 8};
bool changed = false;
// TODO (@daadaada): double-check.
// original logic in
// https://github.com/openai/triton/blob/master/lib/codegen/analysis/layout.cc#L252
// seems buggy for shape = [32, 16] ?
do {
changed = false;
if (ret[0] * ret[1] >= numWarps)
break;
if (shape[0] / shapePerWarp[0] / ret[0] >=
shape[1] / (shapePerWarp[1] * 2) / ret[1]) {
if (ret[0] < shape[0] / shapePerWarp[0]) {
ret[0] *= 2;
} else
ret[1] *= 2;
} else {
ret[1] *= 2;
} else {
ret[1] *= 2;
}
} while (true);
return ret;
}
} while (true);
return ret;
}
} // namespace
class BlockedToMMA : public mlir::RewritePattern {
int computeCapability;
@@ -646,13 +665,14 @@ public:
: mlir::RewritePattern(triton::DotOp::getOperationName(), 2, context),
computeCapability(computeCapability) {}
static SmallVector<unsigned, 2> getWarpsPerTile(const ArrayRef<int64_t> shape,
static SmallVector<unsigned, 2> getWarpsPerTile(triton::DotOp dotOp,
const ArrayRef<int64_t> shape,
int version, int numWarps) {
switch (version) {
case 1:
return warpsPerTile<1>(shape, numWarps);
return warpsPerTileV1(dotOp, shape, numWarps);
case 2:
return warpsPerTile<2>(shape, numWarps);
return warpsPerTileV2(dotOp, shape, numWarps);
default:
assert(false && "not supported version");
return {0, 0};
@@ -684,7 +704,7 @@ public:
retShape, oldRetType.getElementType(),
triton::gpu::MmaEncodingAttr::get(
oldRetType.getContext(), version,
getWarpsPerTile(retShape, version, numWarps)));
getWarpsPerTile(dotOp, retShape, version, numWarps)));
// convert accumulator
auto oldAcc = dotOp.getOperand(2);
auto newAcc = rewriter.create<triton::gpu::ConvertLayoutOp>(
@@ -732,7 +752,7 @@ public:
mlir::RewritePatternSet patterns(context);
patterns.add<SimplifyConversion>(context);
// patterns.add<DecomposeDotOperand>(context);
patterns.add<DecomposeDotOperand>(context);
patterns.add<RematerializeBackward>(context);
patterns.add<RematerializeForward>(context);
patterns.add<MoveConvertOutOfLoop>(context);

View File

@@ -130,6 +130,11 @@ LogicalResult Prefetcher::initialize() {
if (dotsInFor.empty())
return failure();
// TODO: segfault (original for still has uses)
// when used in flash attention that has 2 dots in the loop
if(dotsInFor.size() > 1)
return failure();
// returns source of cvt
auto getPrefetchSrc = [](Value v) -> Value {