[Triton-MLIR][Backend] Fix mma<v2> int8 precision error (#850)
Fix mma.16816 s8 precision error Co-authored-by: ben-zhang-609 <benzh609@gmail.com>
This commit is contained in:
@@ -117,10 +117,13 @@ SmallVector<unsigned> getShapePerCTA(const Attribute &layout) {
|
||||
"BlockedEncodingAttr not implemented");
|
||||
}
|
||||
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
|
||||
assert(mmaLayout.getVersion() == 2 &&
|
||||
"mmaLayout version = 1 is not implemented yet");
|
||||
return {16 * mmaLayout.getWarpsPerCTA()[0],
|
||||
8 * mmaLayout.getWarpsPerCTA()[1]};
|
||||
if (mmaLayout.getVersion() == 2)
|
||||
return {16 * mmaLayout.getWarpsPerCTA()[0],
|
||||
8 * mmaLayout.getWarpsPerCTA()[1]};
|
||||
if (mmaLayout.getVersion() == 1)
|
||||
return {16 * mmaLayout.getWarpsPerCTA()[0],
|
||||
16 * mmaLayout.getWarpsPerCTA()[1]};
|
||||
assert(0 && "Unexpected MMA layout version found");
|
||||
} else {
|
||||
assert(0 && "Unimplemented usage of getShapePerCTA");
|
||||
}
|
||||
|
Reference in New Issue
Block a user