[Triton-MLIR][Backend] support MMA v1 in ConvertLayout (#922)
The e2e verification of mma v1 is not done yet. Get this merged in advance just to prevent more conflicts.
This commit is contained in:
@@ -78,9 +78,9 @@ SmallVector<unsigned> getThreadsPerWarp(Attribute layout) {
|
||||
}
|
||||
if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
|
||||
if (mmaLayout.getVersion() == 1)
|
||||
return SmallVector<unsigned>{4, 8};
|
||||
return {4, 8};
|
||||
if (mmaLayout.getVersion() == 2)
|
||||
return SmallVector<unsigned>{8, 4};
|
||||
return {8, 4};
|
||||
}
|
||||
assert(0 && "getThreadsPerWarp not implemented");
|
||||
return {};
|
||||
@@ -106,9 +106,13 @@ SmallVector<unsigned> getSizePerThread(Attribute layout) {
|
||||
} else if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
|
||||
return getSizePerThread(sliceLayout.getParent());
|
||||
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
|
||||
assert(mmaLayout.getVersion() == 2 &&
|
||||
"mmaLayout version = 1 is not implemented yet");
|
||||
return SmallVector<unsigned>{2, 2};
|
||||
if (mmaLayout.getVersion() == 2) {
|
||||
return {2, 2};
|
||||
} else if (mmaLayout.getVersion() == 1) {
|
||||
return {2, 4};
|
||||
} else {
|
||||
llvm_unreachable("Unexpected mma version");
|
||||
}
|
||||
} else if (auto dotLayout = layout.dyn_cast<DotOperandEncodingAttr>()) {
|
||||
auto parentLayout = dotLayout.getParent();
|
||||
assert(parentLayout && "DotOperandEncodingAttr must have a parent");
|
||||
@@ -194,6 +198,16 @@ SmallVector<unsigned> getShapePerCTA(const Attribute &layout) {
|
||||
assert(0 && "DotOperandEncodingAttr non-MmaEncodingAttr parent not "
|
||||
"supported yet");
|
||||
}
|
||||
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
|
||||
if (mmaLayout.getVersion() == 2) {
|
||||
return {16 * mmaLayout.getWarpsPerCTA()[0],
|
||||
8 * mmaLayout.getWarpsPerCTA()[1]};
|
||||
} else if (mmaLayout.getVersion() == 1) {
|
||||
return {16 * mmaLayout.getWarpsPerCTA()[0],
|
||||
16 * mmaLayout.getWarpsPerCTA()[1]};
|
||||
} else {
|
||||
llvm_unreachable("Unexpected mma version");
|
||||
}
|
||||
} else {
|
||||
assert(0 && "Unimplemented usage of getShapePerCTA");
|
||||
}
|
||||
@@ -205,9 +219,9 @@ SmallVector<unsigned> getOrder(const Attribute &layout) {
|
||||
return SmallVector<unsigned>(blockedLayout.getOrder().begin(),
|
||||
blockedLayout.getOrder().end());
|
||||
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
|
||||
return SmallVector<unsigned>{1, 0};
|
||||
return {1, 0};
|
||||
} else if (auto dotLayout = layout.dyn_cast<DotOperandEncodingAttr>()) {
|
||||
return SmallVector<unsigned>{1, 0};
|
||||
return {1, 0};
|
||||
} else if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
|
||||
SmallVector<unsigned> parentOrder = getOrder(sliceLayout.getParent());
|
||||
unsigned dim = sliceLayout.getDim();
|
||||
@@ -358,6 +372,8 @@ unsigned MmaEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
|
||||
unsigned elemsCol = ceil<unsigned>(shape[0], 16 * getWarpsPerCTA()[0]) * 2;
|
||||
unsigned elemsRow = ceil<unsigned>(shape[1], 8 * getWarpsPerCTA()[1]) * 2;
|
||||
res = elemsCol * elemsRow;
|
||||
} else {
|
||||
llvm_unreachable("Unexpected mma version");
|
||||
}
|
||||
|
||||
return res;
|
||||
|
Reference in New Issue
Block a user