[OPTIMIZER] Added memory coalescing pass (#31)

This commit is contained in:
Philippe Tillet
2022-07-31 20:59:31 -07:00
committed by GitHub
parent d1593e6ca8
commit 3236642e8f
7 changed files with 180 additions and 4 deletions

View File

@@ -3,6 +3,7 @@ mlir_tablegen(TritonGPUCombine.inc -gen-rewriters)
add_public_tablegen_target(TritonGPUCombineIncGen)
add_mlir_dialect_library(TritonGPUTransforms
Coalesce.cpp
Combine.cpp
Pipeline.cpp
Verifier.cpp

View File

@@ -0,0 +1,108 @@
#include "triton/Analysis/AxisInfo.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
#include <numeric>
using namespace mlir;
using namespace mlir::triton;
#define GEN_PASS_CLASSES
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
struct CoalescePass : public TritonGPUCoalesceBase<CoalescePass> {
Attribute getCoalescedEncoding(AxisInfoAnalysis &axisInfo, Value ptr) {
auto origType = ptr.getType().cast<RankedTensorType>();
// Get the shape of the tensor.
size_t rank = origType.getRank();
AxisInfo info = axisInfo.lookupLatticeElement(ptr)->getValue();
// Layout order in decreasing order of contiguity
SmallVector<unsigned, 4> order(rank);
std::iota(order.begin(), order.end(), 0);
auto contiguity = info.getContiguity();
std::sort(order.begin(), order.end(), [&](unsigned x, unsigned y) {
return contiguity[x] > contiguity[y];
});
// Thread tile size depends on memory alignment
SmallVector<unsigned, 4> sizePerThread(rank, 1);
PointerType ptrType = origType.getElementType().cast<PointerType>();
unsigned numBits = ptrType.getPointeeType().getIntOrFloatBitWidth();
unsigned maxMultiple = info.getDivisibility(order[0]);
unsigned maxContig = info.getContiguity(order[0]);
unsigned alignment = std::min(maxMultiple, maxContig);
unsigned perThread = std::min(alignment, 128 / numBits);
sizePerThread[order[0]] = perThread;
// create encoding
Attribute encoding = triton::gpu::TritonGPUBlockedEncodingAttr::get(
&getContext(), origType.getShape(), sizePerThread, order,
this->numWarps);
return encoding;
}
std::function<Type(Type)> getTypeConverter(AxisInfoAnalysis &axisInfo,
Value ptr) {
Attribute encoding = getCoalescedEncoding(axisInfo, ptr);
return [encoding](Type _type) {
RankedTensorType type = _type.cast<RankedTensorType>();
return RankedTensorType::get(type.getShape(), type.getElementType(),
encoding);
};
}
template <class T>
void coalesceOp(AxisInfoAnalysis &axisInfo, Operation *op, Value ptr,
OpBuilder builder) {
RankedTensorType ty = ptr.getType().template dyn_cast<RankedTensorType>();
if (!ty)
return;
AxisInfo info = axisInfo.lookupLatticeElement(ptr)->getValue();
auto convertType = getTypeConverter(axisInfo, ptr);
// convert operands
SmallVector<Value, 4> newArgs;
for (auto v : op->getOperands())
newArgs.push_back(builder.create<triton::gpu::ConvertLayoutOp>(
op->getLoc(), convertType(v.getType()), v));
// convert output types
SmallVector<Type, 4> newTypes;
for (auto t : op->getResultTypes())
newTypes.push_back(convertType(t));
// construct new op with the new encoding
Operation *newOp =
builder.create<T>(op->getLoc(), newTypes, newArgs, op->getAttrs());
// cast the results back to the original layout
for (size_t i = 0; i < op->getNumResults(); i++) {
auto newResult = builder.create<triton::gpu::ConvertLayoutOp>(
op->getLoc(), op->getResult(i).getType(), newOp->getResult(i));
op->getResult(i).replaceAllUsesWith(newResult);
}
op->erase();
}
void runOnOperation() override {
Operation *op = getOperation();
// Run axis info analysis
AxisInfoAnalysis axisInfo(&getContext());
axisInfo.run(op);
OpBuilder builder(op);
// For each memory op that has a layout L1:
// 1. Create a coalesced memory layout L2 of the pointer operands
// 2. Convert all operands from layout L1 to layout L2
// 3. Create a new memory op that consumes these operands and
// produces a tensor with layout L2
// 4. Convert the output of this new memory op back to L1
// 5. Replace all the uses of the original memory op by the new one
op->walk([&](Operation *curr) {
OpBuilder::InsertionGuard g(builder);
builder.setInsertionPoint(curr);
if (auto load = dyn_cast<triton::LoadOp>(curr))
coalesceOp<triton::LoadOp>(axisInfo, curr, load.ptr(), builder);
if (auto store = dyn_cast<triton::StoreOp>(curr))
coalesceOp<triton::StoreOp>(axisInfo, curr, store.ptr(), builder);
});
}
};
std::unique_ptr<Pass> mlir::createTritonGPUCoalescePass() {
return std::make_unique<CoalescePass>();
}