[Triton-MLIR][Backend] Add ReduceOpConversion into TritonGPUToLLVM conversion (#774)

What is done in this PR:
- [x] Add `ConvertLayout`, `getSizePerThread` and `getShapePerCTA`
implementation for `SliceEncodingAttr`
- [x] Split `emitIndices` into two phases:
`emitBaseIndexForBlockedLayout` and `emitOffsetForBlockedLayout`
- [x] Add `ReduceOpConversion::matchAndRewriteBasic` implementation
- [x] Add `ReduceOpConversion::matchAndRewriteFast` implementation with
ptx instruction `shfl.sync`
- [x] Add support for scalar value in `StoreOpConversion`
- [x] Add Reduce1d and Reduce2d unit tests and pass all unit tests

Co-authored-by: Qingyi Liu <liuqingyi1993@gmail.com>
This commit is contained in:
Qingyi Liu
2022-10-28 11:07:45 +08:00
committed by GitHub
parent 3e6cc6d66c
commit 42db3538e4
7 changed files with 680 additions and 57 deletions

View File

@@ -14,6 +14,7 @@ using ::mlir::triton::gpu::BlockedEncodingAttr;
using ::mlir::triton::gpu::getShapePerCTA;
using ::mlir::triton::gpu::MmaEncodingAttr;
using ::mlir::triton::gpu::SharedEncodingAttr;
using ::mlir::triton::gpu::SliceEncodingAttr;
namespace mlir {
@@ -33,6 +34,10 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
"Unexpect layout in getScratchConfigForCvtLayout()");
unsigned rank = dstTy.getRank();
SmallVector<unsigned> paddedRepShape(rank);
if (auto srcSliceLayout = srcLayout.dyn_cast<SliceEncodingAttr>())
srcLayout = srcSliceLayout.getParent();
if (auto dstSliceLayout = dstLayout.dyn_cast<SliceEncodingAttr>())
dstLayout = dstSliceLayout.getParent();
auto srcBlockedLayout = srcLayout.dyn_cast<BlockedEncodingAttr>();
auto srcMmaLayout = srcLayout.dyn_cast<MmaEncodingAttr>();
auto dstBlockedLayout = dstLayout.dyn_cast<BlockedEncodingAttr>();
@@ -73,6 +78,31 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
return paddedRepShape;
}
SmallVector<unsigned> getScratchConfigForReduce(triton::ReduceOp op) {
auto srcTy = op.operand().getType().cast<RankedTensorType>();
auto srcLayout = srcTy.getEncoding().cast<BlockedEncodingAttr>();
auto srcShape = srcTy.getShape();
auto rank = srcShape.size();
auto axis = op.axis();
bool fast_reduce = axis == 1; // FIXME(Qingyi): The fastest-changing dimension
SmallVector<unsigned> smemShape;
for (auto d : srcShape)
smemShape.push_back(d);
if (fast_reduce) {
unsigned sizeInterWarps = srcLayout.getWarpsPerCTA()[axis];
smemShape[axis] = sizeInterWarps;
} else {
unsigned threadsPerCTAAxis =
srcLayout.getThreadsPerWarp()[axis] * srcLayout.getWarpsPerCTA()[axis];
smemShape[axis] = threadsPerCTAAxis;
}
return smemShape;
}
class AllocationAnalysis {
public:
AllocationAnalysis(Operation *operation, Allocation *allocation)
@@ -127,9 +157,16 @@ private:
// TODO(Keren): Reduce with index is not supported yet.
auto value = op->getOperand(0);
if (auto tensorType = value.getType().dyn_cast<RankedTensorType>()) {
auto bytes = tensorType.getNumElements() *
tensorType.getElementTypeBitWidth() / 8;
allocation->addBuffer<BufferT::BufferKind::Scratch>(op, bytes);
if (tensorType.getEncoding().isa<BlockedEncodingAttr>()) {
auto smemShape = getScratchConfigForReduce(reduceOp);
unsigned elems = std::accumulate(smemShape.begin(), smemShape.end(),
1, std::multiplies{});
auto bytes = elems * tensorType.getElementTypeBitWidth() / 8;
allocation->addBuffer<BufferT::BufferKind::Scratch>(op, bytes);
} else {
assert(0 && "ReduceOp with input layout other than blocked layout is "
"not implemented yet");
}
}
} else if (auto cvtLayout = dyn_cast<triton::gpu::ConvertLayoutOp>(op)) {
auto srcTy = cvtLayout.src().getType().cast<RankedTensorType>();