[BACKEND] Added support for mma layouts in reductions (#863)

Validated hackily by manually modifying the reduction .ttgir in my local
cache. There will be a follow-up PR adding some better testing
infrastructure to test out conversions and reductions on arbitrary
layouts.
This commit is contained in:
Philippe Tillet
2022-11-10 09:58:07 -08:00
committed by GitHub
parent 57fd1864a7
commit 2aa538ec2e
6 changed files with 469 additions and 365 deletions

View File

@@ -87,22 +87,22 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
SmallVector<unsigned> getScratchConfigForReduce(triton::ReduceOp op) {
auto srcTy = op.operand().getType().cast<RankedTensorType>();
auto srcLayout = srcTy.getEncoding().cast<BlockedEncodingAttr>();
auto srcLayout = srcTy.getEncoding();
auto srcShape = srcTy.getShape();
auto axis = op.axis();
bool fastReduce = axis == srcLayout.getOrder()[0];
bool fastReduce = axis == getOrder(srcLayout)[0];
SmallVector<unsigned> smemShape;
for (auto d : srcShape)
smemShape.push_back(d);
if (fastReduce) {
unsigned sizeInterWarps = srcLayout.getWarpsPerCTA()[axis];
unsigned sizeInterWarps = gpu::getWarpsPerCTA(srcLayout)[axis];
smemShape[axis] = sizeInterWarps;
} else {
unsigned threadsPerCTAAxis =
srcLayout.getThreadsPerWarp()[axis] * srcLayout.getWarpsPerCTA()[axis];
unsigned threadsPerCTAAxis = gpu::getThreadsPerWarp(srcLayout)[axis] *
gpu::getWarpsPerCTA(srcLayout)[axis];
smemShape[axis] = threadsPerCTAAxis;
}
@@ -161,16 +161,11 @@ private:
// TODO(Keren): Reduce with index is not supported yet.
auto value = op->getOperand(0);
if (auto tensorType = value.getType().dyn_cast<RankedTensorType>()) {
if (tensorType.getEncoding().isa<BlockedEncodingAttr>()) {
auto smemShape = getScratchConfigForReduce(reduceOp);
unsigned elems = std::accumulate(smemShape.begin(), smemShape.end(),
1, std::multiplies{});
auto bytes = elems * tensorType.getElementTypeBitWidth() / 8;
allocation->addBuffer<BufferT::BufferKind::Scratch>(op, bytes);
} else {
assert(0 && "ReduceOp with input layout other than blocked layout is "
"not implemented yet");
}
auto smemShape = getScratchConfigForReduce(reduceOp);
unsigned elems = std::accumulate(smemShape.begin(), smemShape.end(), 1,
std::multiplies{});
auto bytes = elems * tensorType.getElementTypeBitWidth() / 8;
allocation->addBuffer<BufferT::BufferKind::Scratch>(op, bytes);
}
} else if (auto cvtLayout = dyn_cast<triton::gpu::ConvertLayoutOp>(op)) {
auto srcTy = cvtLayout.src().getType().cast<RankedTensorType>();