[Triton-MLIR][Backend] Support layout conversion between mmaLayout and blockedLayout (#693)
This commit is contained in:
@@ -58,17 +58,56 @@ unsigned getElemsPerThread(Attribute layout, ArrayRef<int64_t> shape) {
|
||||
}
|
||||
}
|
||||
|
||||
SmallVector<unsigned> getSizePerThread(Attribute layout) {
|
||||
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
|
||||
return SmallVector<unsigned>(blockedLayout.getSizePerThread().begin(),
|
||||
blockedLayout.getSizePerThread().end());
|
||||
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
|
||||
assert(mmaLayout.getVersion() == 2 &&
|
||||
"mmaLayout version = 1 is not implemented yet");
|
||||
return SmallVector<unsigned>{2, 2};
|
||||
} else {
|
||||
assert(0 && "getSizePerThread not implemented");
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
unsigned getShapePerCTA(const Attribute &layout, unsigned d) {
|
||||
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
|
||||
return blockedLayout.getSizePerThread()[d] *
|
||||
blockedLayout.getThreadsPerWarp()[d] *
|
||||
blockedLayout.getWarpsPerCTA()[d];
|
||||
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
|
||||
assert(mmaLayout.getVersion() == 2 &&
|
||||
"mmaLayout version = 1 is not implemented yet");
|
||||
assert(d < 2 && "Unexpected usage of getShapePerCTA");
|
||||
if (d == 0) {
|
||||
return 16 * mmaLayout.getWarpsPerCTA()[0];
|
||||
} else {
|
||||
// d == 1
|
||||
return 8 * mmaLayout.getWarpsPerCTA()[1];
|
||||
}
|
||||
} else {
|
||||
assert(0 && "Unimplemented usage of getShapePerCTA");
|
||||
return 0;
|
||||
}
|
||||
};
|
||||
|
||||
SmallVector<unsigned> getOrder(const Attribute &layout) {
|
||||
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
|
||||
return SmallVector<unsigned>(blockedLayout.getOrder().begin(),
|
||||
blockedLayout.getOrder().end());
|
||||
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
|
||||
return SmallVector<unsigned>{1, 0};
|
||||
} else if (auto sharedLayout = layout.dyn_cast<SharedEncodingAttr>()) {
|
||||
return SmallVector<unsigned>(sharedLayout.getOrder().begin(),
|
||||
sharedLayout.getOrder().end());
|
||||
} else {
|
||||
assert(0 && "Unimplemented usage of getOrder");
|
||||
return {};
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace triton
|
||||
} // namespace mlir
|
||||
@@ -177,9 +216,11 @@ unsigned SliceEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
|
||||
}
|
||||
|
||||
unsigned MmaEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
|
||||
int threads = product(getWarpsPerCTA());
|
||||
int numElem = product(shape);
|
||||
return numElem / threads;
|
||||
size_t rank = shape.size();
|
||||
assert(rank == 2 && "Unexpected rank of mma layout");
|
||||
unsigned elemsCol = ceil<unsigned>(shape[0], 16 * getWarpsPerCTA()[0]) * 2;
|
||||
unsigned elemsRow = ceil<unsigned>(shape[1], 8 * getWarpsPerCTA()[1]) * 2;
|
||||
return elemsCol * elemsRow;
|
||||
}
|
||||
|
||||
unsigned SharedEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
|
||||
|
Reference in New Issue
Block a user