#include "triton/Analysis/AxisInfo.h" #include "triton/Analysis/Utility.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/Transforms/Passes.h" #include using namespace mlir; using namespace mlir::triton; #define GEN_PASS_CLASSES #include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" struct CoalescePass : public TritonGPUCoalesceBase { Attribute getCoalescedEncoding(AxisInfoAnalysis &axisInfo, Value ptr, int numWarps) { auto origType = ptr.getType().cast(); // Get the shape of the tensor. size_t rank = origType.getRank(); AxisInfo info = axisInfo.lookupLatticeElement(ptr)->getValue(); // Layout order in decreasing order of contiguity SmallVector order(rank); std::iota(order.begin(), order.end(), 0); auto contiguity = info.getContiguity(); std::sort(order.begin(), order.end(), [&](unsigned x, unsigned y) { return contiguity[x] > contiguity[y]; }); int numElems = product(origType.getShape()); int numThreads = numWarps * 32; int numElemsPerThread = std::max(numElems / numThreads, 1); // Thread tile size depends on memory alignment SmallVector sizePerThread(rank, 1); PointerType ptrType = origType.getElementType().cast(); auto pointeeType = ptrType.getPointeeType(); unsigned numBits = pointeeType.isa() ? 8 : pointeeType.getIntOrFloatBitWidth(); unsigned maxMultiple = info.getDivisibility(order[0]); unsigned maxContig = info.getContiguity(order[0]); unsigned alignment = std::min(maxMultiple, maxContig); unsigned perThread = std::min(alignment, 128 / numBits); sizePerThread[order[0]] = std::min(perThread, numElemsPerThread); SmallVector dims(rank); std::iota(dims.begin(), dims.end(), 0); // create encoding Attribute encoding = triton::gpu::BlockedEncodingAttr::get( &getContext(), origType.getShape(), sizePerThread, order, numWarps); return encoding; } std::function getTypeConverter(AxisInfoAnalysis &axisInfo, Value ptr, int numWarps) { Attribute encoding = getCoalescedEncoding(axisInfo, ptr, numWarps); return [encoding](Type _type) { RankedTensorType type = _type.cast(); return RankedTensorType::get(type.getShape(), type.getElementType(), encoding); }; } template void coalesceOp(AxisInfoAnalysis &axisInfo, Operation *op, Value ptr, OpBuilder builder) { RankedTensorType ty = ptr.getType().template dyn_cast(); if (!ty) return; auto mod = op->getParentOfType(); int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod); AxisInfo info = axisInfo.lookupLatticeElement(ptr)->getValue(); auto convertType = getTypeConverter(axisInfo, ptr, numWarps); // convert operands SmallVector newArgs; for (auto v : op->getOperands()) { auto vTy = v.getType().dyn_cast(); if (vTy && !vTy.getEncoding().isa()) newArgs.push_back(builder.create( op->getLoc(), convertType(v.getType()), v)); else newArgs.push_back(v); } // convert output types SmallVector newTypes; for (auto t : op->getResultTypes()) { bool is_async = std::is_same::value; newTypes.push_back(is_async ? t : convertType(t)); } // construct new op with the new encoding Operation *newOp = builder.create(op->getLoc(), newTypes, newArgs, op->getAttrs()); // cast the results back to the original layout for (size_t i = 0; i < op->getNumResults(); i++) { Value newResult = newOp->getResult(i); if (newTypes[i] != op->getResultTypes()[i]) { newResult = builder.create( op->getLoc(), op->getResult(i).getType(), newResult); } op->getResult(i).replaceAllUsesWith(newResult); } op->erase(); } void runOnOperation() override { Operation *op = getOperation(); // Run axis info analysis AxisInfoAnalysis axisInfo(&getContext()); axisInfo.run(op); OpBuilder builder(op); // For each memory op that has a layout L1: // 1. Create a coalesced memory layout L2 of the pointer operands // 2. Convert all operands from layout L1 to layout L2 // 3. Create a new memory op that consumes these operands and // produces a tensor with layout L2 // 4. Convert the output of this new memory op back to L1 // 5. Replace all the uses of the original memory op by the new one op->walk([&](Operation *curr) { OpBuilder::InsertionGuard g(builder); builder.setInsertionPoint(curr); if (auto load = dyn_cast(curr)) coalesceOp(axisInfo, curr, load.ptr(), builder); if (auto op = dyn_cast(curr)) coalesceOp(axisInfo, curr, op.ptr(), builder); if (auto op = dyn_cast(curr)) coalesceOp(axisInfo, curr, op.ptr(), builder); if (auto load = dyn_cast(curr)) coalesceOp(axisInfo, curr, load.src(), builder); if (auto store = dyn_cast(curr)) coalesceOp(axisInfo, curr, store.ptr(), builder); }); } }; std::unique_ptr mlir::createTritonGPUCoalescePass() { return std::make_unique(); }