[BACKEND] Added support for mma layouts in reductions (#863)

Validated hackily by manually modifying the reduction .ttgir in my local
cache. There will be a follow-up PR adding some better testing
infrastructure to test out conversions and reductions on arbitrary
layouts.
This commit is contained in:
Philippe Tillet
2022-11-10 09:58:07 -08:00
committed by GitHub
parent 57fd1864a7
commit 2aa538ec2e
6 changed files with 469 additions and 365 deletions

View File

@@ -42,15 +42,11 @@ static Type getPointeeType(Type type) {
namespace gpu {
// TODO: Inheritation of layout attributes
unsigned getElemsPerThread(Type type) {
if (type.isIntOrIndexOrFloat() ||
type.isa<triton::Float8Type>() ||
type.isa<triton::PointerType>())
return 1;
auto tensorType = type.cast<RankedTensorType>();
auto layout = tensorType.getEncoding();
auto shape = tensorType.getShape();
// TODO: Inheritance of layout attributes
// so that all distributed layouts implement
// these utilities
unsigned getElemsPerThread(Attribute layout, ArrayRef<int64_t> shape) {
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
return blockedLayout.getElemsPerThread(shape);
} else if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
@@ -67,6 +63,43 @@ unsigned getElemsPerThread(Type type) {
}
}
unsigned getElemsPerThread(Type type) {
if (type.isIntOrIndexOrFloat() ||
type.isa<triton::Float8Type>() ||
type.isa<triton::PointerType>())
return 1;
auto tensorType = type.cast<RankedTensorType>();
return getElemsPerThread(tensorType.getEncoding(), tensorType.getShape());
}
SmallVector<unsigned> getThreadsPerWarp(Attribute layout) {
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
return SmallVector<unsigned>(blockedLayout.getThreadsPerWarp().begin(),
blockedLayout.getThreadsPerWarp().end());
}
if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
if (mmaLayout.getVersion() == 1)
return SmallVector<unsigned>{4, 8};
if (mmaLayout.getVersion() == 2)
return SmallVector<unsigned>{8, 4};
}
assert(0 && "getThreadsPerWarp not implemented");
return {};
}
SmallVector<unsigned> getWarpsPerCTA(Attribute layout) {
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
return SmallVector<unsigned>(blockedLayout.getWarpsPerCTA().begin(),
blockedLayout.getWarpsPerCTA().end());
}
if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
return SmallVector<unsigned>(mmaLayout.getWarpsPerCTA().begin(),
mmaLayout.getWarpsPerCTA().end());
}
assert(0 && "getWarpsPerCTA not implemented");
return {};
}
SmallVector<unsigned> getSizePerThread(Attribute layout) {
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
return SmallVector<unsigned>(blockedLayout.getSizePerThread().begin(),
@@ -129,17 +162,11 @@ SmallVector<unsigned> getShapePerCTA(const Attribute &layout) {
} else if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
unsigned dim = sliceLayout.getDim();
auto parent = sliceLayout.getParent();
if (auto blockedParent = parent.dyn_cast<BlockedEncodingAttr>()) {
for (unsigned d = 0, n = blockedParent.getOrder().size(); d < n; ++d) {
if (d == dim)
continue;
shape.push_back(blockedParent.getSizePerThread()[d] *
blockedParent.getThreadsPerWarp()[d] *
blockedParent.getWarpsPerCTA()[d]);
}
} else {
assert(0 && "SliceEncodingAttr with parent other than "
"BlockedEncodingAttr not implemented");
for (unsigned d = 0, n = getOrder(parent).size(); d < n; ++d) {
if (d == dim)
continue;
shape.push_back(getSizePerThread(parent)[d] *
getThreadsPerWarp(parent)[d] * getWarpsPerCTA(parent)[d]);
}
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
if (mmaLayout.getVersion() == 2)
@@ -289,11 +316,11 @@ unsigned BlockedEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
return product<unsigned>(elemsPerThread);
}
SmallVector<int64_t>
SliceEncodingAttr::paddedShape(ArrayRef<int64_t> shape) const {
template <class T>
SmallVector<T> SliceEncodingAttr::paddedShape(ArrayRef<T> shape) const {
size_t rank = shape.size();
unsigned dim = getDim();
SmallVector<int64_t> retShape(rank + 1);
SmallVector<T> retShape(rank + 1);
for (unsigned d = 0; d < rank + 1; ++d) {
if (d < dim)
retShape[d] = shape[d];
@@ -304,18 +331,15 @@ SliceEncodingAttr::paddedShape(ArrayRef<int64_t> shape) const {
}
return retShape;
}
template SmallVector<unsigned>
SliceEncodingAttr::paddedShape<unsigned>(ArrayRef<unsigned> shape) const;
template SmallVector<int64_t>
SliceEncodingAttr::paddedShape<int64_t>(ArrayRef<int64_t> shape) const;
unsigned SliceEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
size_t rank = shape.size();
auto parent = getParent();
if (auto blockedParent = parent.dyn_cast<BlockedEncodingAttr>()) {
assert(rank == blockedParent.getSizePerThread().size() - 1 &&
"unexpected rank in SliceEncodingAttr::getElemsPerThread");
return blockedParent.getElemsPerThread(paddedShape(shape));
} else {
assert(0 && "getElemsPerThread not implemented");
return 0;
}
return ::getElemsPerThread(parent, paddedShape(shape));
}
unsigned MmaEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {