[Triton-MLIR][Backend] Add ReduceOpConversion into TritonGPUToLLVM conversion (#774)

What is done in this PR:
- [x] Add `ConvertLayout`, `getSizePerThread` and `getShapePerCTA`
implementation for `SliceEncodingAttr`
- [x] Split `emitIndices` into two phases:
`emitBaseIndexForBlockedLayout` and `emitOffsetForBlockedLayout`
- [x] Add `ReduceOpConversion::matchAndRewriteBasic` implementation
- [x] Add `ReduceOpConversion::matchAndRewriteFast` implementation with
ptx instruction `shfl.sync`
- [x] Add support for scalar value in `StoreOpConversion`
- [x] Add Reduce1d and Reduce2d unit tests and pass all unit tests

Co-authored-by: Qingyi Liu <liuqingyi1993@gmail.com>
This commit is contained in:
Qingyi Liu
2022-10-28 11:07:45 +08:00
committed by GitHub
parent 3e6cc6d66c
commit 42db3538e4
7 changed files with 680 additions and 57 deletions

View File

@@ -63,6 +63,19 @@ SmallVector<unsigned> getSizePerThread(Attribute layout) {
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
return SmallVector<unsigned>(blockedLayout.getSizePerThread().begin(),
blockedLayout.getSizePerThread().end());
} else if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
unsigned dim = sliceLayout.getDim();
auto parent = sliceLayout.getParent();
if (auto blockedParent = parent.dyn_cast<BlockedEncodingAttr>()) {
SmallVector<unsigned> sizePerThread(
blockedParent.getSizePerThread().begin(),
blockedParent.getSizePerThread().end());
sizePerThread.erase(sizePerThread.begin() + dim);
return sizePerThread;
} else {
assert(0 && "SliceEncodingAttr with parent other than "
"BlockedEncodingAttr not implemented");
}
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
assert(mmaLayout.getVersion() == 2 &&
"mmaLayout version = 1 is not implemented yet");
@@ -95,6 +108,21 @@ SmallVector<unsigned> getShapePerCTA(const Attribute &layout) {
shape.push_back(blockedLayout.getSizePerThread()[d] *
blockedLayout.getThreadsPerWarp()[d] *
blockedLayout.getWarpsPerCTA()[d]);
} else if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
unsigned dim = sliceLayout.getDim();
auto parent = sliceLayout.getParent();
if (auto blockedParent = parent.dyn_cast<BlockedEncodingAttr>()) {
for (int d = 0, n = blockedParent.getOrder().size(); d < n; ++d) {
if (d == dim)
continue;
shape.push_back(blockedParent.getSizePerThread()[d] *
blockedParent.getThreadsPerWarp()[d] *
blockedParent.getWarpsPerCTA()[d]);
}
} else {
assert(0 && "SliceEncodingAttr with parent other than "
"BlockedEncodingAttr not implemented");
}
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
assert(mmaLayout.getVersion() == 2 &&
"mmaLayout version = 1 is not implemented yet");
@@ -206,6 +234,22 @@ unsigned BlockedEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
return product<unsigned>(elemsPerThread);
}
SmallVector<int64_t>
SliceEncodingAttr::paddedShape(ArrayRef<int64_t> shape) const {
size_t rank = shape.size();
unsigned dim = getDim();
SmallVector<int64_t> retShape(rank + 1);
for (unsigned d = 0; d < rank + 1; ++d) {
if (d < dim)
retShape[d] = shape[d];
else if (d == dim)
retShape[d] = 1;
else
retShape[d] = shape[d - 1];
}
return retShape;
}
unsigned SliceEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
size_t rank = shape.size();
auto parent = getParent();
@@ -213,16 +257,7 @@ unsigned SliceEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
if (auto blockedParent = parent.dyn_cast<BlockedEncodingAttr>()) {
assert(rank == blockedParent.getSizePerThread().size() - 1 &&
"unexpected rank in SliceEncodingAttr::getElemsPerThread");
SmallVector<int64_t> paddedShape(rank + 1);
for (unsigned d = 0; d < rank + 1; ++d) {
if (d < dim)
paddedShape[d] = shape[d];
else if (d == dim)
paddedShape[d] = 1;
else
paddedShape[d] = shape[d - 1];
}
return blockedParent.getElemsPerThread(paddedShape);
return blockedParent.getElemsPerThread(paddedShape(shape));
} else {
assert(0 && "getElemsPerThread not implemented");
return 0;