This PR merges the `triton-mlir` branch, in which we have been quietly rewriting the Triton backend from scratch to increase maintainability, stability and ultimately performance. Changes to the runtime are minimal, and this new version aims to remain backward-compatible with the previous commit. The legacy backend is now officially deprecated, but can still be accessed via the `legacy-backend` tag. Co-authored-by: Keren Zhou <kerenzhou@openai.com> Co-authored-by: Yan Chunwei <yanchunwei@outlook.com> Co-authored-by: goostavz <109190422+goostavz@users.noreply.github.com> Co-authored-by: Shintaro Iwasaki <siwasaki@fb.com> Co-authored-by: Yan Da <dyanab@connect.ust.hk> Co-authored-by: Jun Yang <yangjunpro@gmail.com> Co-authored-by: Ian Bearman <ianb@microsoft.com> Co-authored-by: Jason Ansel <jansel@jansel.net> Co-authored-by: Qingyi Liu <qingyil@nvidia.com> Co-authored-by: ben-zhang-609 <110140741+ben-zhang-609@users.noreply.github.com> Co-authored-by: Chenggang Zhao <lyricz@yeah.net> Co-authored-by: ben-zhang-609 <benzh609@gmail.com> Co-authored-by: dongdongl <dongdongl@nvidia.com>
140 lines
5.8 KiB
C++
140 lines
5.8 KiB
C++
#include "triton/Analysis/AxisInfo.h"
|
|
#include "triton/Analysis/Utility.h"
|
|
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
|
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
|
|
#include <numeric>
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::triton;
|
|
|
|
#define GEN_PASS_CLASSES
|
|
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
|
|
|
|
struct CoalescePass : public TritonGPUCoalesceBase<CoalescePass> {
|
|
Attribute getCoalescedEncoding(AxisInfoAnalysis &axisInfo, Value ptr,
|
|
int numWarps) {
|
|
auto origType = ptr.getType().cast<RankedTensorType>();
|
|
// Get the shape of the tensor.
|
|
size_t rank = origType.getRank();
|
|
AxisInfo info = axisInfo.lookupLatticeElement(ptr)->getValue();
|
|
// Layout order in decreasing order of contiguity
|
|
SmallVector<unsigned, 4> order(rank);
|
|
std::iota(order.begin(), order.end(), 0);
|
|
auto contiguity = info.getContiguity();
|
|
std::sort(order.begin(), order.end(), [&](unsigned x, unsigned y) {
|
|
return contiguity[x] > contiguity[y];
|
|
});
|
|
|
|
int numElems = product(origType.getShape());
|
|
int numThreads = numWarps * 32;
|
|
int numElemsPerThread = std::max(numElems / numThreads, 1);
|
|
|
|
// Thread tile size depends on memory alignment
|
|
SmallVector<unsigned, 4> sizePerThread(rank, 1);
|
|
PointerType ptrType = origType.getElementType().cast<PointerType>();
|
|
auto pointeeType = ptrType.getPointeeType();
|
|
unsigned numBits = pointeeType.isa<triton::Float8Type>()
|
|
? 8
|
|
: pointeeType.getIntOrFloatBitWidth();
|
|
unsigned maxMultiple = info.getDivisibility(order[0]);
|
|
unsigned maxContig = info.getContiguity(order[0]);
|
|
unsigned alignment = std::min(maxMultiple, maxContig);
|
|
unsigned perThread = std::min(alignment, 128 / numBits);
|
|
sizePerThread[order[0]] = std::min<int>(perThread, numElemsPerThread);
|
|
|
|
SmallVector<unsigned> dims(rank);
|
|
std::iota(dims.begin(), dims.end(), 0);
|
|
// create encoding
|
|
Attribute encoding = triton::gpu::BlockedEncodingAttr::get(
|
|
&getContext(), origType.getShape(), sizePerThread, order, numWarps);
|
|
return encoding;
|
|
}
|
|
|
|
std::function<Type(Type)> getTypeConverter(AxisInfoAnalysis &axisInfo,
|
|
Value ptr, int numWarps) {
|
|
Attribute encoding = getCoalescedEncoding(axisInfo, ptr, numWarps);
|
|
return [encoding](Type _type) {
|
|
RankedTensorType type = _type.cast<RankedTensorType>();
|
|
return RankedTensorType::get(type.getShape(), type.getElementType(),
|
|
encoding);
|
|
};
|
|
}
|
|
|
|
template <class T>
|
|
void coalesceOp(AxisInfoAnalysis &axisInfo, Operation *op, Value ptr,
|
|
OpBuilder builder) {
|
|
RankedTensorType ty = ptr.getType().template dyn_cast<RankedTensorType>();
|
|
if (!ty)
|
|
return;
|
|
auto mod = op->getParentOfType<ModuleOp>();
|
|
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
|
|
|
|
AxisInfo info = axisInfo.lookupLatticeElement(ptr)->getValue();
|
|
auto convertType = getTypeConverter(axisInfo, ptr, numWarps);
|
|
// convert operands
|
|
SmallVector<Value, 4> newArgs;
|
|
for (auto v : op->getOperands()) {
|
|
auto vTy = v.getType().dyn_cast<RankedTensorType>();
|
|
if (vTy && !vTy.getEncoding().isa<triton::gpu::SharedEncodingAttr>())
|
|
newArgs.push_back(builder.create<triton::gpu::ConvertLayoutOp>(
|
|
op->getLoc(), convertType(v.getType()), v));
|
|
else
|
|
newArgs.push_back(v);
|
|
}
|
|
// convert output types
|
|
SmallVector<Type, 4> newTypes;
|
|
for (auto t : op->getResultTypes()) {
|
|
bool is_async = std::is_same<T, triton::gpu::InsertSliceAsyncOp>::value;
|
|
newTypes.push_back(is_async ? t : convertType(t));
|
|
}
|
|
// construct new op with the new encoding
|
|
Operation *newOp =
|
|
builder.create<T>(op->getLoc(), newTypes, newArgs, op->getAttrs());
|
|
// cast the results back to the original layout
|
|
for (size_t i = 0; i < op->getNumResults(); i++) {
|
|
Value newResult = newOp->getResult(i);
|
|
if (newTypes[i] != op->getResultTypes()[i]) {
|
|
newResult = builder.create<triton::gpu::ConvertLayoutOp>(
|
|
op->getLoc(), op->getResult(i).getType(), newResult);
|
|
}
|
|
op->getResult(i).replaceAllUsesWith(newResult);
|
|
}
|
|
op->erase();
|
|
}
|
|
|
|
void runOnOperation() override {
|
|
Operation *op = getOperation();
|
|
// Run axis info analysis
|
|
AxisInfoAnalysis axisInfo(&getContext());
|
|
axisInfo.run(op);
|
|
OpBuilder builder(op);
|
|
|
|
// For each memory op that has a layout L1:
|
|
// 1. Create a coalesced memory layout L2 of the pointer operands
|
|
// 2. Convert all operands from layout L1 to layout L2
|
|
// 3. Create a new memory op that consumes these operands and
|
|
// produces a tensor with layout L2
|
|
// 4. Convert the output of this new memory op back to L1
|
|
// 5. Replace all the uses of the original memory op by the new one
|
|
op->walk([&](Operation *curr) {
|
|
OpBuilder::InsertionGuard g(builder);
|
|
builder.setInsertionPoint(curr);
|
|
if (auto load = dyn_cast<triton::LoadOp>(curr))
|
|
coalesceOp<triton::LoadOp>(axisInfo, curr, load.ptr(), builder);
|
|
if (auto op = dyn_cast<triton::AtomicRMWOp>(curr))
|
|
coalesceOp<triton::AtomicRMWOp>(axisInfo, curr, op.ptr(), builder);
|
|
if (auto op = dyn_cast<triton::AtomicCASOp>(curr))
|
|
coalesceOp<triton::AtomicCASOp>(axisInfo, curr, op.ptr(), builder);
|
|
if (auto load = dyn_cast<triton::gpu::InsertSliceAsyncOp>(curr))
|
|
coalesceOp<triton::gpu::InsertSliceAsyncOp>(axisInfo, curr, load.src(),
|
|
builder);
|
|
if (auto store = dyn_cast<triton::StoreOp>(curr))
|
|
coalesceOp<triton::StoreOp>(axisInfo, curr, store.ptr(), builder);
|
|
});
|
|
}
|
|
};
|
|
|
|
std::unique_ptr<Pass> mlir::createTritonGPUCoalescePass() {
|
|
return std::make_unique<CoalescePass>();
|
|
}
|