[Triton-MLIR][Backend] support MMA v1 in ConvertLayout (#922)

The e2e verification of mma v1 is not done yet. 
Get this merged in advance just to prevent more conflicts.
This commit is contained in:
goostavz
2022-11-28 16:10:30 +08:00
committed by GitHub
parent 9d31998a9d
commit 0c1d4d764e
3 changed files with 116 additions and 32 deletions

View File

@@ -78,9 +78,9 @@ SmallVector<unsigned> getThreadsPerWarp(Attribute layout) {
}
if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
if (mmaLayout.getVersion() == 1)
return SmallVector<unsigned>{4, 8};
return {4, 8};
if (mmaLayout.getVersion() == 2)
return SmallVector<unsigned>{8, 4};
return {8, 4};
}
assert(0 && "getThreadsPerWarp not implemented");
return {};
@@ -106,9 +106,13 @@ SmallVector<unsigned> getSizePerThread(Attribute layout) {
} else if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
return getSizePerThread(sliceLayout.getParent());
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
assert(mmaLayout.getVersion() == 2 &&
"mmaLayout version = 1 is not implemented yet");
return SmallVector<unsigned>{2, 2};
if (mmaLayout.getVersion() == 2) {
return {2, 2};
} else if (mmaLayout.getVersion() == 1) {
return {2, 4};
} else {
llvm_unreachable("Unexpected mma version");
}
} else if (auto dotLayout = layout.dyn_cast<DotOperandEncodingAttr>()) {
auto parentLayout = dotLayout.getParent();
assert(parentLayout && "DotOperandEncodingAttr must have a parent");
@@ -194,6 +198,16 @@ SmallVector<unsigned> getShapePerCTA(const Attribute &layout) {
assert(0 && "DotOperandEncodingAttr non-MmaEncodingAttr parent not "
"supported yet");
}
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
if (mmaLayout.getVersion() == 2) {
return {16 * mmaLayout.getWarpsPerCTA()[0],
8 * mmaLayout.getWarpsPerCTA()[1]};
} else if (mmaLayout.getVersion() == 1) {
return {16 * mmaLayout.getWarpsPerCTA()[0],
16 * mmaLayout.getWarpsPerCTA()[1]};
} else {
llvm_unreachable("Unexpected mma version");
}
} else {
assert(0 && "Unimplemented usage of getShapePerCTA");
}
@@ -205,9 +219,9 @@ SmallVector<unsigned> getOrder(const Attribute &layout) {
return SmallVector<unsigned>(blockedLayout.getOrder().begin(),
blockedLayout.getOrder().end());
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
return SmallVector<unsigned>{1, 0};
return {1, 0};
} else if (auto dotLayout = layout.dyn_cast<DotOperandEncodingAttr>()) {
return SmallVector<unsigned>{1, 0};
return {1, 0};
} else if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
SmallVector<unsigned> parentOrder = getOrder(sliceLayout.getParent());
unsigned dim = sliceLayout.getDim();
@@ -358,6 +372,8 @@ unsigned MmaEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
unsigned elemsCol = ceil<unsigned>(shape[0], 16 * getWarpsPerCTA()[0]) * 2;
unsigned elemsRow = ceil<unsigned>(shape[1], 8 * getWarpsPerCTA()[1]) * 2;
res = elemsCol * elemsRow;
} else {
llvm_unreachable("Unexpected mma version");
}
return res;