[Triton-MLIR] Generate LLVM/PTX code for async ops (#735)
This commit is contained in:
@@ -72,6 +72,21 @@ SmallVector<unsigned> getSizePerThread(Attribute layout) {
|
||||
}
|
||||
}
|
||||
|
||||
SmallVector<unsigned> getThreadsPerCTA(const Attribute &layout) {
|
||||
SmallVector<unsigned> threads;
|
||||
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
|
||||
for (int d = 0, n = blockedLayout.getOrder().size(); d < n; ++d)
|
||||
threads.push_back(blockedLayout.getThreadsPerWarp()[d] *
|
||||
blockedLayout.getWarpsPerCTA()[d]);
|
||||
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
|
||||
assert(0 && "Unimplemented usage of MmaEncodingAttr");
|
||||
} else {
|
||||
assert(0 && "Unimplemented usage of getShapePerCTA");
|
||||
}
|
||||
|
||||
return threads;
|
||||
}
|
||||
|
||||
SmallVector<unsigned> getShapePerCTA(const Attribute &layout) {
|
||||
SmallVector<unsigned> shape;
|
||||
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
|
||||
|
Reference in New Issue
Block a user