[BACKEND] Added support for mma layouts in reductions (#863)
Validated hackily by manually modifying the reduction .ttgir in my local cache. There will be a follow-up PR adding some better testing infrastructure to test out conversions and reductions on arbitrary layouts.
This commit is contained in:
@@ -42,15 +42,11 @@ static Type getPointeeType(Type type) {
|
||||
|
||||
namespace gpu {
|
||||
|
||||
// TODO: Inheritation of layout attributes
|
||||
unsigned getElemsPerThread(Type type) {
|
||||
if (type.isIntOrIndexOrFloat() ||
|
||||
type.isa<triton::Float8Type>() ||
|
||||
type.isa<triton::PointerType>())
|
||||
return 1;
|
||||
auto tensorType = type.cast<RankedTensorType>();
|
||||
auto layout = tensorType.getEncoding();
|
||||
auto shape = tensorType.getShape();
|
||||
// TODO: Inheritance of layout attributes
|
||||
// so that all distributed layouts implement
|
||||
// these utilities
|
||||
|
||||
unsigned getElemsPerThread(Attribute layout, ArrayRef<int64_t> shape) {
|
||||
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
|
||||
return blockedLayout.getElemsPerThread(shape);
|
||||
} else if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
|
||||
@@ -67,6 +63,43 @@ unsigned getElemsPerThread(Type type) {
|
||||
}
|
||||
}
|
||||
|
||||
unsigned getElemsPerThread(Type type) {
|
||||
if (type.isIntOrIndexOrFloat() ||
|
||||
type.isa<triton::Float8Type>() ||
|
||||
type.isa<triton::PointerType>())
|
||||
return 1;
|
||||
auto tensorType = type.cast<RankedTensorType>();
|
||||
return getElemsPerThread(tensorType.getEncoding(), tensorType.getShape());
|
||||
}
|
||||
|
||||
SmallVector<unsigned> getThreadsPerWarp(Attribute layout) {
|
||||
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
|
||||
return SmallVector<unsigned>(blockedLayout.getThreadsPerWarp().begin(),
|
||||
blockedLayout.getThreadsPerWarp().end());
|
||||
}
|
||||
if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
|
||||
if (mmaLayout.getVersion() == 1)
|
||||
return SmallVector<unsigned>{4, 8};
|
||||
if (mmaLayout.getVersion() == 2)
|
||||
return SmallVector<unsigned>{8, 4};
|
||||
}
|
||||
assert(0 && "getThreadsPerWarp not implemented");
|
||||
return {};
|
||||
}
|
||||
|
||||
SmallVector<unsigned> getWarpsPerCTA(Attribute layout) {
|
||||
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
|
||||
return SmallVector<unsigned>(blockedLayout.getWarpsPerCTA().begin(),
|
||||
blockedLayout.getWarpsPerCTA().end());
|
||||
}
|
||||
if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
|
||||
return SmallVector<unsigned>(mmaLayout.getWarpsPerCTA().begin(),
|
||||
mmaLayout.getWarpsPerCTA().end());
|
||||
}
|
||||
assert(0 && "getWarpsPerCTA not implemented");
|
||||
return {};
|
||||
}
|
||||
|
||||
SmallVector<unsigned> getSizePerThread(Attribute layout) {
|
||||
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
|
||||
return SmallVector<unsigned>(blockedLayout.getSizePerThread().begin(),
|
||||
@@ -129,17 +162,11 @@ SmallVector<unsigned> getShapePerCTA(const Attribute &layout) {
|
||||
} else if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
|
||||
unsigned dim = sliceLayout.getDim();
|
||||
auto parent = sliceLayout.getParent();
|
||||
if (auto blockedParent = parent.dyn_cast<BlockedEncodingAttr>()) {
|
||||
for (unsigned d = 0, n = blockedParent.getOrder().size(); d < n; ++d) {
|
||||
if (d == dim)
|
||||
continue;
|
||||
shape.push_back(blockedParent.getSizePerThread()[d] *
|
||||
blockedParent.getThreadsPerWarp()[d] *
|
||||
blockedParent.getWarpsPerCTA()[d]);
|
||||
}
|
||||
} else {
|
||||
assert(0 && "SliceEncodingAttr with parent other than "
|
||||
"BlockedEncodingAttr not implemented");
|
||||
for (unsigned d = 0, n = getOrder(parent).size(); d < n; ++d) {
|
||||
if (d == dim)
|
||||
continue;
|
||||
shape.push_back(getSizePerThread(parent)[d] *
|
||||
getThreadsPerWarp(parent)[d] * getWarpsPerCTA(parent)[d]);
|
||||
}
|
||||
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
|
||||
if (mmaLayout.getVersion() == 2)
|
||||
@@ -289,11 +316,11 @@ unsigned BlockedEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
|
||||
return product<unsigned>(elemsPerThread);
|
||||
}
|
||||
|
||||
SmallVector<int64_t>
|
||||
SliceEncodingAttr::paddedShape(ArrayRef<int64_t> shape) const {
|
||||
template <class T>
|
||||
SmallVector<T> SliceEncodingAttr::paddedShape(ArrayRef<T> shape) const {
|
||||
size_t rank = shape.size();
|
||||
unsigned dim = getDim();
|
||||
SmallVector<int64_t> retShape(rank + 1);
|
||||
SmallVector<T> retShape(rank + 1);
|
||||
for (unsigned d = 0; d < rank + 1; ++d) {
|
||||
if (d < dim)
|
||||
retShape[d] = shape[d];
|
||||
@@ -304,18 +331,15 @@ SliceEncodingAttr::paddedShape(ArrayRef<int64_t> shape) const {
|
||||
}
|
||||
return retShape;
|
||||
}
|
||||
template SmallVector<unsigned>
|
||||
SliceEncodingAttr::paddedShape<unsigned>(ArrayRef<unsigned> shape) const;
|
||||
template SmallVector<int64_t>
|
||||
SliceEncodingAttr::paddedShape<int64_t>(ArrayRef<int64_t> shape) const;
|
||||
|
||||
unsigned SliceEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
|
||||
size_t rank = shape.size();
|
||||
auto parent = getParent();
|
||||
if (auto blockedParent = parent.dyn_cast<BlockedEncodingAttr>()) {
|
||||
assert(rank == blockedParent.getSizePerThread().size() - 1 &&
|
||||
"unexpected rank in SliceEncodingAttr::getElemsPerThread");
|
||||
return blockedParent.getElemsPerThread(paddedShape(shape));
|
||||
} else {
|
||||
assert(0 && "getElemsPerThread not implemented");
|
||||
return 0;
|
||||
}
|
||||
return ::getElemsPerThread(parent, paddedShape(shape));
|
||||
}
|
||||
|
||||
unsigned MmaEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
|
||||
|
Reference in New Issue
Block a user