A (potential) problem by directly adopting `tensor.extract_slice`. Long story short, `tensor.extract_slice` is not aware of swizzling. Consider the following shared memory tensor and its first three slices, where each slice includes two tile (the loading unit of LDGSTS) of elements. Currently, the tiles haven't been swizzled yet, so slicing seems to work. <img width="1219" alt="image" src="https://user-images.githubusercontent.com/2306281/201833023-a7950705-2d50-4c0a-8527-7505261c3a3c.png"> However, now consider the following figure, which is the layout after applying swizzling on the first figure. <img width="1244" alt="image" src="https://user-images.githubusercontent.com/2306281/201834824-7daae360-f5bc-4e6b-a921-20be3f294b78.png"> Note that on phase 2, all tiles have been swizzled out of their originally slices. This implies that if we use the tile index after slicing, we can no longer locate the correct tiles. For example, T3 was in slice 1 but got swapped to slice 0 after swizzling. Here's a more detailed explanation. In the current `triton-mlir` branch, we only compute the relative offset of each tile. So T3's index in Slice 1 is *1*, and it will be swizzled using *1* and *phase id*. Whereas the correct index of T3 should be *3*, which is the relative offset to the beginning of the shared memory tensor being swizzled, and T3 should be swizzled using *3* and *phase id*. This PR proposes a hacky solution for this problem. We restore the "correct" offset of each tile by **assuming that slicing on a specific dim only happens at most once on the output of insert_slice_async**. I admit it's risky and fragile. The other possible solution is adopting cutlass' swizzling logic that limits the indices being swizzled in a "bounding box" that matches the mma instruction executes. For example, in the following tensor layout, each 4x4 submatrix is a minimum swizzling unit, and the entire tensor represents the tensor layout of operand A in `mma.16816`. <img width="565" alt="image" src="https://user-images.githubusercontent.com/2306281/201836879-4ca7824b-530c-4a06-a3d5-1e74a2de1b42.png"> Co-authored-by: Phil Tillet <phil@openai.com>
140 lines
5.8 KiB
C++
140 lines
5.8 KiB
C++
#include "triton/Analysis/AxisInfo.h"
|
|
#include "triton/Analysis/Utility.h"
|
|
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
|
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
|
|
#include <numeric>
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::triton;
|
|
|
|
#define GEN_PASS_CLASSES
|
|
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
|
|
|
|
struct CoalescePass : public TritonGPUCoalesceBase<CoalescePass> {
|
|
Attribute getCoalescedEncoding(AxisInfoAnalysis &axisInfo, Value ptr,
|
|
int numWarps) {
|
|
auto origType = ptr.getType().cast<RankedTensorType>();
|
|
// Get the shape of the tensor.
|
|
size_t rank = origType.getRank();
|
|
AxisInfo info = axisInfo.lookupLatticeElement(ptr)->getValue();
|
|
// Layout order in decreasing order of contiguity
|
|
SmallVector<unsigned, 4> order(rank);
|
|
std::iota(order.begin(), order.end(), 0);
|
|
auto contiguity = info.getContiguity();
|
|
std::sort(order.begin(), order.end(), [&](unsigned x, unsigned y) {
|
|
return contiguity[x] > contiguity[y];
|
|
});
|
|
|
|
int numElems = product(origType.getShape());
|
|
int numThreads = numWarps * 32;
|
|
int numElemsPerThread = std::max(numElems / numThreads, 1);
|
|
|
|
// Thread tile size depends on memory alignment
|
|
SmallVector<unsigned, 4> sizePerThread(rank, 1);
|
|
PointerType ptrType = origType.getElementType().cast<PointerType>();
|
|
auto pointeeType = ptrType.getPointeeType();
|
|
unsigned numBits = pointeeType.isa<triton::Float8Type>()
|
|
? 8
|
|
: pointeeType.getIntOrFloatBitWidth();
|
|
unsigned maxMultiple = info.getDivisibility(order[0]);
|
|
unsigned maxContig = info.getContiguity(order[0]);
|
|
unsigned alignment = std::min(maxMultiple, maxContig);
|
|
unsigned perThread = std::min(alignment, 128 / numBits);
|
|
sizePerThread[order[0]] = std::min<int>(perThread, numElemsPerThread);
|
|
|
|
SmallVector<unsigned> dims(rank);
|
|
std::iota(dims.begin(), dims.end(), 0);
|
|
// create encoding
|
|
Attribute encoding = triton::gpu::BlockedEncodingAttr::get(
|
|
&getContext(), origType.getShape(), sizePerThread, order, numWarps);
|
|
return encoding;
|
|
}
|
|
|
|
std::function<Type(Type)> getTypeConverter(AxisInfoAnalysis &axisInfo,
|
|
Value ptr, int numWarps) {
|
|
Attribute encoding = getCoalescedEncoding(axisInfo, ptr, numWarps);
|
|
return [encoding](Type _type) {
|
|
RankedTensorType type = _type.cast<RankedTensorType>();
|
|
return RankedTensorType::get(type.getShape(), type.getElementType(),
|
|
encoding);
|
|
};
|
|
}
|
|
|
|
template <class T>
|
|
void coalesceOp(AxisInfoAnalysis &axisInfo, Operation *op, Value ptr,
|
|
OpBuilder builder) {
|
|
RankedTensorType ty = ptr.getType().template dyn_cast<RankedTensorType>();
|
|
if (!ty)
|
|
return;
|
|
auto mod = op->getParentOfType<ModuleOp>();
|
|
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
|
|
|
|
AxisInfo info = axisInfo.lookupLatticeElement(ptr)->getValue();
|
|
auto convertType = getTypeConverter(axisInfo, ptr, numWarps);
|
|
// convert operands
|
|
SmallVector<Value, 4> newArgs;
|
|
for (auto v : op->getOperands()) {
|
|
auto vTy = v.getType().dyn_cast<RankedTensorType>();
|
|
if (vTy && !vTy.getEncoding().isa<triton::gpu::SharedEncodingAttr>())
|
|
newArgs.push_back(builder.create<triton::gpu::ConvertLayoutOp>(
|
|
op->getLoc(), convertType(v.getType()), v));
|
|
else
|
|
newArgs.push_back(v);
|
|
}
|
|
// convert output types
|
|
SmallVector<Type, 4> newTypes;
|
|
for (auto t : op->getResultTypes()) {
|
|
bool is_async = std::is_same<T, triton::gpu::InsertSliceAsyncOp>::value;
|
|
newTypes.push_back(is_async ? t : convertType(t));
|
|
}
|
|
// construct new op with the new encoding
|
|
Operation *newOp =
|
|
builder.create<T>(op->getLoc(), newTypes, newArgs, op->getAttrs());
|
|
// cast the results back to the original layout
|
|
for (size_t i = 0; i < op->getNumResults(); i++) {
|
|
Value newResult = newOp->getResult(i);
|
|
if (newTypes[i] != op->getResultTypes()[i]) {
|
|
newResult = builder.create<triton::gpu::ConvertLayoutOp>(
|
|
op->getLoc(), op->getResult(i).getType(), newResult);
|
|
}
|
|
op->getResult(i).replaceAllUsesWith(newResult);
|
|
}
|
|
op->erase();
|
|
}
|
|
|
|
void runOnOperation() override {
|
|
Operation *op = getOperation();
|
|
// Run axis info analysis
|
|
AxisInfoAnalysis axisInfo(&getContext());
|
|
axisInfo.run(op);
|
|
OpBuilder builder(op);
|
|
|
|
// For each memory op that has a layout L1:
|
|
// 1. Create a coalesced memory layout L2 of the pointer operands
|
|
// 2. Convert all operands from layout L1 to layout L2
|
|
// 3. Create a new memory op that consumes these operands and
|
|
// produces a tensor with layout L2
|
|
// 4. Convert the output of this new memory op back to L1
|
|
// 5. Replace all the uses of the original memory op by the new one
|
|
op->walk([&](Operation *curr) {
|
|
OpBuilder::InsertionGuard g(builder);
|
|
builder.setInsertionPoint(curr);
|
|
if (auto load = dyn_cast<triton::LoadOp>(curr))
|
|
coalesceOp<triton::LoadOp>(axisInfo, curr, load.ptr(), builder);
|
|
if (auto op = dyn_cast<triton::AtomicRMWOp>(curr))
|
|
coalesceOp<triton::AtomicRMWOp>(axisInfo, curr, op.ptr(), builder);
|
|
if (auto op = dyn_cast<triton::AtomicCASOp>(curr))
|
|
coalesceOp<triton::AtomicCASOp>(axisInfo, curr, op.ptr(), builder);
|
|
if (auto load = dyn_cast<triton::gpu::InsertSliceAsyncOp>(curr))
|
|
coalesceOp<triton::gpu::InsertSliceAsyncOp>(axisInfo, curr, load.src(),
|
|
builder);
|
|
if (auto store = dyn_cast<triton::StoreOp>(curr))
|
|
coalesceOp<triton::StoreOp>(axisInfo, curr, store.ptr(), builder);
|
|
});
|
|
}
|
|
};
|
|
|
|
std::unique_ptr<Pass> mlir::createTritonGPUCoalescePass() {
|
|
return std::make_unique<CoalescePass>();
|
|
}
|