[Triton-MLIR][Backend] Fix mma<v2> int8 precision error (#850)

Fix mma.16816 s8 precision error

Co-authored-by: ben-zhang-609 <benzh609@gmail.com>
This commit is contained in:
Yan Chunwei
2022-11-09 12:23:43 +08:00
committed by GitHub
parent e517b58d59
commit de5b84c476
4 changed files with 146 additions and 42 deletions

View File

@@ -117,10 +117,13 @@ SmallVector<unsigned> getShapePerCTA(const Attribute &layout) {
"BlockedEncodingAttr not implemented");
}
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
assert(mmaLayout.getVersion() == 2 &&
"mmaLayout version = 1 is not implemented yet");
return {16 * mmaLayout.getWarpsPerCTA()[0],
8 * mmaLayout.getWarpsPerCTA()[1]};
if (mmaLayout.getVersion() == 2)
return {16 * mmaLayout.getWarpsPerCTA()[0],
8 * mmaLayout.getWarpsPerCTA()[1]};
if (mmaLayout.getVersion() == 1)
return {16 * mmaLayout.getWarpsPerCTA()[0],
16 * mmaLayout.getWarpsPerCTA()[1]};
assert(0 && "Unexpected MMA layout version found");
} else {
assert(0 && "Unimplemented usage of getShapePerCTA");
}