[BACKEND] Make flash attention forward pass work (#928)
This also simplifies BroadcastOp codegen
This commit is contained in:
@@ -50,10 +50,22 @@ public:
|
||||
auto dstType = convert.getType().cast<RankedTensorType>();
|
||||
if (srcType.getEncoding().isa<triton::gpu::BlockedEncodingAttr>() &&
|
||||
dstType.getEncoding().isa<triton::gpu::DotOperandEncodingAttr>()) {
|
||||
auto dstDotOperand = dstType.getEncoding().cast<triton::gpu::DotOperandEncodingAttr>();
|
||||
auto dstParent = dstDotOperand.getParent();
|
||||
if(dstDotOperand.getOpIdx()==1 ||
|
||||
!dstParent.isa<triton::gpu::MmaEncodingAttr>())
|
||||
return mlir::failure();
|
||||
auto dstParentMma = dstParent.cast<triton::gpu::MmaEncodingAttr>();
|
||||
if(dstParentMma.getVersion() == 1 ||
|
||||
dstParentMma.getWarpsPerCTA()[1] > 1)
|
||||
return mlir::failure();
|
||||
SetVector<Operation*> bwdSlices;
|
||||
mlir::getBackwardSlice(convert.getResult(), &bwdSlices);
|
||||
if(llvm::find_if(bwdSlices, [](Operation *op) { return isa<triton::DotOp>(op); }) == bwdSlices.end())
|
||||
return mlir::failure();
|
||||
|
||||
auto tmpType =
|
||||
RankedTensorType::get(dstType.getShape(), dstType.getElementType(),
|
||||
triton::gpu::SharedEncodingAttr::get(
|
||||
op->getContext(), 1, 1, 1, {1, 0}));
|
||||
RankedTensorType::get(dstType.getShape(), dstType.getElementType(), dstParentMma);
|
||||
auto tmp = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||
convert.getLoc(), tmpType, convert.getOperand());
|
||||
auto newConvert = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||
@@ -81,8 +93,11 @@ public:
|
||||
auto convert = llvm::cast<triton::gpu::ConvertLayoutOp>(op);
|
||||
// we don't handle conversions to DotOperandEncodingAttr
|
||||
// this is a heuristics to accommodate fused attention
|
||||
// if (dstType.getEncoding().isa<triton::gpu::DotOperandEncodingAttr>())
|
||||
// return mlir::failure();
|
||||
auto srcType = convert.getOperand().getType().cast<RankedTensorType>();
|
||||
auto dstType = convert.getType().cast<RankedTensorType>();
|
||||
if (dstType.getEncoding().isa<triton::gpu::DotOperandEncodingAttr>() &&
|
||||
srcType.getEncoding().isa<triton::gpu::MmaEncodingAttr>())
|
||||
return mlir::failure();
|
||||
// convert to the same layout -- we can delete
|
||||
if (op->getResultTypes() == op->getOperandTypes()) {
|
||||
rewriter.replaceOp(op, op->getOperands());
|
||||
@@ -586,12 +601,9 @@ mmaVersionToShapePerWarp(int version, const ArrayRef<int64_t> &shape,
|
||||
}
|
||||
}
|
||||
|
||||
template <int version>
|
||||
SmallVector<unsigned, 2> warpsPerTile(const ArrayRef<int64_t> shape,
|
||||
int numWarps);
|
||||
|
||||
template <>
|
||||
SmallVector<unsigned, 2> warpsPerTile<1>(const ArrayRef<int64_t> shape,
|
||||
SmallVector<unsigned, 2> warpsPerTileV1(triton::DotOp dotOp,
|
||||
const ArrayRef<int64_t> shape,
|
||||
int numWarps) {
|
||||
SmallVector<unsigned, 2> ret = {1, 1};
|
||||
SmallVector<int64_t, 2> shapePerWarp =
|
||||
@@ -611,33 +623,40 @@ SmallVector<unsigned, 2> warpsPerTile<1>(const ArrayRef<int64_t> shape,
|
||||
return ret;
|
||||
}
|
||||
|
||||
template <>
|
||||
SmallVector<unsigned, 2> warpsPerTile<2>(const ArrayRef<int64_t> shape,
|
||||
SmallVector<unsigned, 2> warpsPerTileV2(triton::DotOp dotOp,
|
||||
const ArrayRef<int64_t> shape,
|
||||
int numWarps) {
|
||||
SmallVector<unsigned, 2> ret = {1, 1};
|
||||
SmallVector<int64_t, 2> shapePerWarp =
|
||||
mmaVersionToShapePerWarp(2, shape, numWarps);
|
||||
// TODO (@daadaada): double-check.
|
||||
// original logic in
|
||||
// https://github.com/openai/triton/blob/master/lib/codegen/analysis/layout.cc#L252
|
||||
// seems buggy for shape = [32, 16] ?
|
||||
do {
|
||||
if (ret[0] * ret[1] >= numWarps)
|
||||
break;
|
||||
if (shape[0] / shapePerWarp[0] / ret[0] >=
|
||||
shape[1] / (shapePerWarp[1] * 2) / ret[1]) {
|
||||
if (ret[0] < shape[0] / shapePerWarp[0]) {
|
||||
ret[0] *= 2;
|
||||
} else
|
||||
SetVector<Operation*> slices;
|
||||
mlir::getForwardSlice(dotOp.getResult(), &slices);
|
||||
if(llvm::find_if(slices, [](Operation *op) { return isa<triton::DotOp>(op); }) != slices.end())
|
||||
return {(unsigned)numWarps, 1};
|
||||
|
||||
SmallVector<unsigned, 2> ret = {1, 1};
|
||||
SmallVector<int64_t, 2> shapePerWarp = {16, 8};
|
||||
bool changed = false;
|
||||
// TODO (@daadaada): double-check.
|
||||
// original logic in
|
||||
// https://github.com/openai/triton/blob/master/lib/codegen/analysis/layout.cc#L252
|
||||
// seems buggy for shape = [32, 16] ?
|
||||
do {
|
||||
changed = false;
|
||||
if (ret[0] * ret[1] >= numWarps)
|
||||
break;
|
||||
if (shape[0] / shapePerWarp[0] / ret[0] >=
|
||||
shape[1] / (shapePerWarp[1] * 2) / ret[1]) {
|
||||
if (ret[0] < shape[0] / shapePerWarp[0]) {
|
||||
ret[0] *= 2;
|
||||
} else
|
||||
ret[1] *= 2;
|
||||
} else {
|
||||
ret[1] *= 2;
|
||||
} else {
|
||||
ret[1] *= 2;
|
||||
}
|
||||
} while (true);
|
||||
return ret;
|
||||
}
|
||||
} while (true);
|
||||
return ret;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
class BlockedToMMA : public mlir::RewritePattern {
|
||||
int computeCapability;
|
||||
|
||||
@@ -646,13 +665,14 @@ public:
|
||||
: mlir::RewritePattern(triton::DotOp::getOperationName(), 2, context),
|
||||
computeCapability(computeCapability) {}
|
||||
|
||||
static SmallVector<unsigned, 2> getWarpsPerTile(const ArrayRef<int64_t> shape,
|
||||
static SmallVector<unsigned, 2> getWarpsPerTile(triton::DotOp dotOp,
|
||||
const ArrayRef<int64_t> shape,
|
||||
int version, int numWarps) {
|
||||
switch (version) {
|
||||
case 1:
|
||||
return warpsPerTile<1>(shape, numWarps);
|
||||
return warpsPerTileV1(dotOp, shape, numWarps);
|
||||
case 2:
|
||||
return warpsPerTile<2>(shape, numWarps);
|
||||
return warpsPerTileV2(dotOp, shape, numWarps);
|
||||
default:
|
||||
assert(false && "not supported version");
|
||||
return {0, 0};
|
||||
@@ -684,7 +704,7 @@ public:
|
||||
retShape, oldRetType.getElementType(),
|
||||
triton::gpu::MmaEncodingAttr::get(
|
||||
oldRetType.getContext(), version,
|
||||
getWarpsPerTile(retShape, version, numWarps)));
|
||||
getWarpsPerTile(dotOp, retShape, version, numWarps)));
|
||||
// convert accumulator
|
||||
auto oldAcc = dotOp.getOperand(2);
|
||||
auto newAcc = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||
@@ -732,7 +752,7 @@ public:
|
||||
mlir::RewritePatternSet patterns(context);
|
||||
|
||||
patterns.add<SimplifyConversion>(context);
|
||||
// patterns.add<DecomposeDotOperand>(context);
|
||||
patterns.add<DecomposeDotOperand>(context);
|
||||
patterns.add<RematerializeBackward>(context);
|
||||
patterns.add<RematerializeForward>(context);
|
||||
patterns.add<MoveConvertOutOfLoop>(context);
|
||||
|
@@ -130,6 +130,11 @@ LogicalResult Prefetcher::initialize() {
|
||||
|
||||
if (dotsInFor.empty())
|
||||
return failure();
|
||||
|
||||
// TODO: segfault (original for still has uses)
|
||||
// when used in flash attention that has 2 dots in the loop
|
||||
if(dotsInFor.size() > 1)
|
||||
return failure();
|
||||
|
||||
// returns source of cvt
|
||||
auto getPrefetchSrc = [](Value v) -> Value {
|
||||
|
Reference in New Issue
Block a user