[BACKEND] Support of ConvertLayoutOp from blocked to blocked and SliceLayout with blocked parent (#658)
This commit is contained in:
@@ -39,6 +39,37 @@ static Type getPointeeType(Type type) {
|
||||
return Type();
|
||||
}
|
||||
|
||||
namespace gpu {
|
||||
|
||||
// TODO: Inheritation of layout attributes
|
||||
unsigned getElemsPerThread(Attribute layout, ArrayRef<int64_t> shape) {
|
||||
size_t rank = shape.size();
|
||||
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
|
||||
return blockedLayout.getElemsPerThread(shape);
|
||||
} else if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
|
||||
return sliceLayout.getElemsPerThread(shape);
|
||||
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
|
||||
return mmaLayout.getElemsPerThread(shape);
|
||||
} else if (auto sharedLayout = layout.dyn_cast<SharedEncodingAttr>()) {
|
||||
return sharedLayout.getElemsPerThread(shape);
|
||||
} else {
|
||||
assert(0 && "getElemsPerThread not implemented");
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
unsigned getShapePerCTA(const Attribute &layout, unsigned d) {
|
||||
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
|
||||
return blockedLayout.getSizePerThread()[d] *
|
||||
blockedLayout.getThreadsPerWarp()[d] *
|
||||
blockedLayout.getWarpsPerCTA()[d];
|
||||
} else {
|
||||
assert(0 && "Unimplemented usage of getShapePerCTA");
|
||||
return 0;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace triton
|
||||
} // namespace mlir
|
||||
|
||||
@@ -108,6 +139,55 @@ SliceEncodingAttr BlockedEncodingAttr::squeeze(int axis) {
|
||||
return SliceEncodingAttr::get(getContext(), axis, *this);
|
||||
}
|
||||
|
||||
unsigned BlockedEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
|
||||
size_t rank = shape.size();
|
||||
assert(rank == getSizePerThread().size() &&
|
||||
"unexpected rank in BlockedEncodingAttr::getElemsPerThread");
|
||||
SmallVector<unsigned> elemsPerThreadPerDim(rank);
|
||||
for (size_t i = 0; i < rank; ++i) {
|
||||
unsigned t =
|
||||
getSizePerThread()[i] * getThreadsPerWarp()[i] * getWarpsPerCTA()[i];
|
||||
elemsPerThreadPerDim[i] =
|
||||
ceil<unsigned>(shape[i], t) * getSizePerThread()[i];
|
||||
}
|
||||
return product<unsigned>(elemsPerThreadPerDim);
|
||||
}
|
||||
|
||||
unsigned SliceEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
|
||||
size_t rank = shape.size();
|
||||
auto parent = getParent();
|
||||
unsigned dim = getDim();
|
||||
if (auto blockedParent = parent.dyn_cast<BlockedEncodingAttr>()) {
|
||||
assert(rank == blockedParent.getSizePerThread().size() - 1 &&
|
||||
"unexpected rank in SliceEncodingAttr::getElemsPerThread");
|
||||
SmallVector<int64_t> paddedShape(rank + 1);
|
||||
for (unsigned d = 0; d < rank + 1; ++d) {
|
||||
if (d < dim)
|
||||
paddedShape[d] = shape[d];
|
||||
else if (d == dim)
|
||||
paddedShape[d] = 1;
|
||||
else
|
||||
paddedShape[d] = shape[d - 1];
|
||||
}
|
||||
return blockedParent.getElemsPerThread(paddedShape);
|
||||
} else {
|
||||
assert(0 && "getElemsPerThread not implemented");
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
unsigned MmaEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
|
||||
// TODO:
|
||||
assert(0 && "MmaEncodingAttr::getElemsPerThread not implemented");
|
||||
return 0;
|
||||
}
|
||||
|
||||
unsigned SharedEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
|
||||
// TODO:
|
||||
assert(0 && "SharedEncodingAttr::getElemsPerThread not implemented");
|
||||
return 0;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Blocked Encoding
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
Reference in New Issue
Block a user