[Triton-MLIR] Generate LLVM/PTX code for async ops (#735)

This commit is contained in:
Keren Zhou
2022-10-04 09:37:00 -07:00
committed by GitHub
parent f9d7f2f126
commit 289ff293cc
9 changed files with 412 additions and 57 deletions

View File

@@ -72,6 +72,21 @@ SmallVector<unsigned> getSizePerThread(Attribute layout) {
}
}
SmallVector<unsigned> getThreadsPerCTA(const Attribute &layout) {
SmallVector<unsigned> threads;
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
for (int d = 0, n = blockedLayout.getOrder().size(); d < n; ++d)
threads.push_back(blockedLayout.getThreadsPerWarp()[d] *
blockedLayout.getWarpsPerCTA()[d]);
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
assert(0 && "Unimplemented usage of MmaEncodingAttr");
} else {
assert(0 && "Unimplemented usage of getShapePerCTA");
}
return threads;
}
SmallVector<unsigned> getShapePerCTA(const Attribute &layout) {
SmallVector<unsigned> shape;
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {