[WIP][Triton-MLIR] Prefetch pass fixup (#873)

A (potential) problem by directly adopting `tensor.extract_slice`.

Long story short, `tensor.extract_slice` is not aware of swizzling.
Consider the following shared memory tensor and its first three slices,
where each slice includes two tile (the loading unit of LDGSTS) of
elements. Currently, the tiles haven't been swizzled yet, so slicing
seems to work.

<img width="1219" alt="image"
src="https://user-images.githubusercontent.com/2306281/201833023-a7950705-2d50-4c0a-8527-7505261c3a3c.png">

However, now consider the following figure, which is the layout after
applying swizzling on the first figure.

<img width="1244" alt="image"
src="https://user-images.githubusercontent.com/2306281/201834824-7daae360-f5bc-4e6b-a921-20be3f294b78.png">

Note that on phase 2, all tiles have been swizzled out of their
originally slices. This implies that if we use the tile index after
slicing, we can no longer locate the correct tiles. For example, T3 was
in slice 1 but got swapped to slice 0 after swizzling.

Here's a more detailed explanation. In the current `triton-mlir` branch,
we only compute the relative offset of each tile. So T3's index in Slice
1 is *1*, and it will be swizzled using *1* and *phase id*. Whereas the
correct index of T3 should be *3*, which is the relative offset to the
beginning of the shared memory tensor being swizzled, and T3 should be
swizzled using *3* and *phase id*.

This PR proposes a hacky solution for this problem. We restore the
"correct" offset of each tile by **assuming that slicing on a specific
dim only happens at most once on the output of insert_slice_async**. I
admit it's risky and fragile.

The other possible solution is adopting cutlass' swizzling logic that
limits the indices being swizzled in a "bounding box" that matches the
mma instruction executes. For example, in the following tensor layout,
each 4x4 submatrix is a minimum swizzling unit, and the entire tensor
represents the tensor layout of operand A in `mma.16816`.

<img width="565" alt="image"
src="https://user-images.githubusercontent.com/2306281/201836879-4ca7824b-530c-4a06-a3d5-1e74a2de1b42.png">

Co-authored-by: Phil Tillet <phil@openai.com>
This commit is contained in:
Keren Zhou
2022-11-19 19:57:16 -08:00
committed by GitHub
parent e8994209f4
commit 6c5f646f4e
7 changed files with 146 additions and 70 deletions

View File

@@ -33,9 +33,9 @@ struct CoalescePass : public TritonGPUCoalesceBase<CoalescePass> {
SmallVector<unsigned, 4> sizePerThread(rank, 1);
PointerType ptrType = origType.getElementType().cast<PointerType>();
auto pointeeType = ptrType.getPointeeType();
unsigned numBits =
pointeeType.isa<triton::Float8Type>() ?
8 : pointeeType.getIntOrFloatBitWidth();
unsigned numBits = pointeeType.isa<triton::Float8Type>()
? 8
: pointeeType.getIntOrFloatBitWidth();
unsigned maxMultiple = info.getDivisibility(order[0]);
unsigned maxContig = info.getContiguity(order[0]);
unsigned alignment = std::min(maxMultiple, maxContig);

View File

@@ -78,8 +78,6 @@ public:
if (!llvm::isa<triton::gpu::ConvertLayoutOp>(op))
return mlir::failure();
auto convert = llvm::cast<triton::gpu::ConvertLayoutOp>(op);
auto srcType = convert.getOperand().getType().cast<RankedTensorType>();
auto dstType = convert.getType().cast<RankedTensorType>();
// we don't handle conversions to DotOperandEncodingAttr
// this is a heuristics to accommodate fused attention
// if (dstType.getEncoding().isa<triton::gpu::DotOperandEncodingAttr>())
@@ -96,6 +94,9 @@ public:
// cvt(alloc_tensor(x), type2) -> alloc_tensor(x, type2)
auto alloc_tensor = dyn_cast<triton::gpu::AllocTensorOp>(arg);
if (alloc_tensor) {
if (!isSharedEncoding(op->getResult(0))) {
return mlir::failure();
}
rewriter.replaceOpWithNewOp<triton::gpu::AllocTensorOp>(
op, op->getResult(0).getType());
return mlir::success();
@@ -103,6 +104,9 @@ public:
// cvt(insert_slice(x), type2) -> insert_slice(cvt(x, type2))
auto insert_slice = dyn_cast<triton::gpu::InsertSliceAsyncOp>(arg);
if (insert_slice) {
if (!isSharedEncoding(op->getResult(0))) {
return mlir::failure();
}
auto newType = op->getResult(0).getType().cast<RankedTensorType>();
// Ensure that the new insert_slice op is placed in the same place as the
// old insert_slice op. Otherwise, the new insert_slice op may be placed
@@ -121,6 +125,9 @@ public:
// cvt(extract_slice(x), type2) -> extract_slice(cvt(x, type2))
auto extract_slice = dyn_cast<tensor::ExtractSliceOp>(arg);
if (extract_slice) {
if (!isSharedEncoding(op->getResult(0))) {
return mlir::failure();
}
auto origType = extract_slice.source().getType().cast<RankedTensorType>();
auto newType = RankedTensorType::get(
origType.getShape(), origType.getElementType(),
@@ -144,16 +151,15 @@ public:
return mlir::success();
}
// cvt(type2, x)
// cvt(cvt(x, type1), type2) -> cvt(x, type2)
if (llvm::isa<triton::gpu::ConvertLayoutOp>(arg)) {
auto argType = arg->getOperand(0).getType().cast<RankedTensorType>();
if (arg->getOperand(0).getDefiningOp() &&
!argType.getEncoding().isa<triton::gpu::SharedEncodingAttr>() &&
srcType.getEncoding().isa<triton::gpu::SharedEncodingAttr>() &&
!dstType.getEncoding().isa<triton::gpu::SharedEncodingAttr>()) {
!isSharedEncoding(arg->getOperand(0)) &&
isSharedEncoding(convert.getOperand()) &&
!isSharedEncoding(convert.getResult())) {
return mlir::failure();
}
auto srcType = convert.getOperand().getType().cast<RankedTensorType>();
auto srcShared =
srcType.getEncoding().dyn_cast<triton::gpu::SharedEncodingAttr>();
if (srcShared && srcShared.getVec() > 1)

View File

@@ -27,6 +27,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/IR/BlockAndValueMapping.h"
#include "triton/Analysis/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
@@ -60,7 +61,7 @@ class Prefetcher {
LogicalResult isForOpOperand(Value v);
Value generatePrefetch(Value v, unsigned opIdx, bool isPrefetch,
Value generatePrefetch(Value v, unsigned opIdx, bool isPrologue,
Attribute dotEncoding, OpBuilder &builder,
llvm::Optional<int64_t> offsetK = llvm::None,
llvm::Optional<int64_t> shapeK = llvm::None);
@@ -79,7 +80,7 @@ public:
scf::ForOp createNewForOp();
};
Value Prefetcher::generatePrefetch(Value v, unsigned opIdx, bool isPrefetch,
Value Prefetcher::generatePrefetch(Value v, unsigned opIdx, bool isPrologue,
Attribute dotEncoding, OpBuilder &builder,
llvm::Optional<int64_t> offsetK,
llvm::Optional<int64_t> shapeK) {
@@ -94,8 +95,8 @@ Value Prefetcher::generatePrefetch(Value v, unsigned opIdx, bool isPrefetch,
// k => (prefetchWidth, k - prefetchWidth)
int64_t kIdx = opIdx == 0 ? 1 : 0;
offset[kIdx] = isPrefetch ? 0 : prefetchWidth;
shape[kIdx] = isPrefetch ? prefetchWidth : (shape[kIdx] - prefetchWidth);
offset[kIdx] = isPrologue ? 0 : prefetchWidth;
shape[kIdx] = isPrologue ? prefetchWidth : (shape[kIdx] - prefetchWidth);
if (shapeK)
shape[kIdx] = *shapeK;
@@ -132,9 +133,9 @@ LogicalResult Prefetcher::initialize() {
// returns source of cvt
auto getPrefetchSrc = [](Value v) -> Value {
// TODO: Check if the layout of src is SharedEncodingAttr
if (auto cvt = v.getDefiningOp<triton::gpu::ConvertLayoutOp>())
return cvt.src();
if (isSharedEncoding(cvt.getOperand()))
return cvt.src();
return Value();
};
@@ -152,6 +153,10 @@ LogicalResult Prefetcher::initialize() {
};
for (triton::DotOp dot : dotsInFor) {
auto kSize = dot.a().getType().cast<RankedTensorType>().getShape()[1];
// Skip prefetching if kSize is less than prefetchWidth
if (kSize < prefetchWidth)
continue;
Value aSmem = getPrefetchSrc(dot.a());
Value bSmem = getPrefetchSrc(dot.b());
if (aSmem && bSmem) {
@@ -217,7 +222,7 @@ scf::ForOp Prefetcher::createNewForOp() {
mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]);
for (Operation &op : forOp.getBody()->without_terminator()) {
Operation *newOp = nullptr;
Operation *newOp = builder.clone(op, mapping);
auto dot = dyn_cast<triton::DotOp>(&op);
if (dots.contains(dot)) {
Attribute dotEncoding =
@@ -252,8 +257,6 @@ scf::ForOp Prefetcher::createNewForOp() {
kOff += kShape;
kRem -= kShape;
}
} else {
newOp = builder.clone(op, mapping);
}
// update mapping of results
for (unsigned dstIdx : llvm::seq(unsigned(0), op.getNumResults()))