[TESTING] Added infrastructure for executing TTGIR program and test for layout conversions (#885)
This commit is contained in:
@@ -1399,37 +1399,35 @@ struct BroadcastOpConversion
|
||||
Value result = op.result();
|
||||
auto srcTy = op.src().getType().cast<RankedTensorType>();
|
||||
auto resultTy = result.getType().cast<RankedTensorType>();
|
||||
auto srcLayout = srcTy.getEncoding().dyn_cast<BlockedEncodingAttr>();
|
||||
auto resultLayout = resultTy.getEncoding().dyn_cast<BlockedEncodingAttr>();
|
||||
assert(srcLayout && (srcLayout == resultLayout) &&
|
||||
"Unexpected layout of BroadcastOp");
|
||||
auto srcLayout = srcTy.getEncoding();
|
||||
auto resultLayout = resultTy.getEncoding();
|
||||
auto srcShape = srcTy.getShape();
|
||||
auto resultShape = resultTy.getShape();
|
||||
unsigned rank = srcTy.getRank();
|
||||
assert(rank == resultTy.getRank());
|
||||
auto order = srcLayout.getOrder();
|
||||
auto order = triton::gpu::getOrder(srcLayout);
|
||||
|
||||
SmallVector<int64_t> srcLogicalShape(2 * rank);
|
||||
SmallVector<unsigned> srcLogicalOrder(2 * rank);
|
||||
SmallVector<int64_t> resultLogicalShape(2 * rank);
|
||||
SmallVector<unsigned> broadcastDims;
|
||||
for (unsigned d = 0; d < rank; ++d) {
|
||||
unsigned resultShapePerCTA = resultLayout.getSizePerThread()[d] *
|
||||
resultLayout.getThreadsPerWarp()[d] *
|
||||
resultLayout.getWarpsPerCTA()[d];
|
||||
unsigned resultShapePerCTA = triton::gpu::getSizePerThread(resultLayout)[d] *
|
||||
triton::gpu::getThreadsPerWarp(resultLayout)[d] *
|
||||
triton::gpu::getWarpsPerCTA(resultLayout)[d];
|
||||
int64_t numCtas = ceil<unsigned>(resultShape[d], resultShapePerCTA);
|
||||
if (srcShape[d] != resultShape[d]) {
|
||||
assert(srcShape[d] == 1);
|
||||
broadcastDims.push_back(d);
|
||||
srcLogicalShape[d] = 1;
|
||||
srcLogicalShape[d + rank] =
|
||||
std::max<unsigned>(1, srcLayout.getSizePerThread()[d]);
|
||||
std::max<unsigned>(1, triton::gpu::getSizePerThread(srcLayout)[d]);
|
||||
} else {
|
||||
srcLogicalShape[d] = numCtas;
|
||||
srcLogicalShape[d + rank] = resultLayout.getSizePerThread()[d];
|
||||
srcLogicalShape[d + rank] = triton::gpu::getSizePerThread(resultLayout)[d];
|
||||
}
|
||||
resultLogicalShape[d] = numCtas;
|
||||
resultLogicalShape[d + rank] = resultLayout.getSizePerThread()[d];
|
||||
resultLogicalShape[d + rank] = triton::gpu::getSizePerThread(resultLayout)[d];
|
||||
|
||||
srcLogicalOrder[d] = order[d] + rank;
|
||||
srcLogicalOrder[d + rank] = order[d];
|
||||
|
Reference in New Issue
Block a user