[TESTING] Added infrastructure for executing TTGIR program and test for layout conversions (#885)

This commit is contained in:
Philippe Tillet
2022-11-18 07:46:45 +01:00
committed by GitHub
parent 9ea6135eb5
commit dab4855bdf
6 changed files with 243 additions and 67 deletions

View File

@@ -1399,37 +1399,35 @@ struct BroadcastOpConversion
Value result = op.result();
auto srcTy = op.src().getType().cast<RankedTensorType>();
auto resultTy = result.getType().cast<RankedTensorType>();
auto srcLayout = srcTy.getEncoding().dyn_cast<BlockedEncodingAttr>();
auto resultLayout = resultTy.getEncoding().dyn_cast<BlockedEncodingAttr>();
assert(srcLayout && (srcLayout == resultLayout) &&
"Unexpected layout of BroadcastOp");
auto srcLayout = srcTy.getEncoding();
auto resultLayout = resultTy.getEncoding();
auto srcShape = srcTy.getShape();
auto resultShape = resultTy.getShape();
unsigned rank = srcTy.getRank();
assert(rank == resultTy.getRank());
auto order = srcLayout.getOrder();
auto order = triton::gpu::getOrder(srcLayout);
SmallVector<int64_t> srcLogicalShape(2 * rank);
SmallVector<unsigned> srcLogicalOrder(2 * rank);
SmallVector<int64_t> resultLogicalShape(2 * rank);
SmallVector<unsigned> broadcastDims;
for (unsigned d = 0; d < rank; ++d) {
unsigned resultShapePerCTA = resultLayout.getSizePerThread()[d] *
resultLayout.getThreadsPerWarp()[d] *
resultLayout.getWarpsPerCTA()[d];
unsigned resultShapePerCTA = triton::gpu::getSizePerThread(resultLayout)[d] *
triton::gpu::getThreadsPerWarp(resultLayout)[d] *
triton::gpu::getWarpsPerCTA(resultLayout)[d];
int64_t numCtas = ceil<unsigned>(resultShape[d], resultShapePerCTA);
if (srcShape[d] != resultShape[d]) {
assert(srcShape[d] == 1);
broadcastDims.push_back(d);
srcLogicalShape[d] = 1;
srcLogicalShape[d + rank] =
std::max<unsigned>(1, srcLayout.getSizePerThread()[d]);
std::max<unsigned>(1, triton::gpu::getSizePerThread(srcLayout)[d]);
} else {
srcLogicalShape[d] = numCtas;
srcLogicalShape[d + rank] = resultLayout.getSizePerThread()[d];
srcLogicalShape[d + rank] = triton::gpu::getSizePerThread(resultLayout)[d];
}
resultLogicalShape[d] = numCtas;
resultLogicalShape[d + rank] = resultLayout.getSizePerThread()[d];
resultLogicalShape[d + rank] = triton::gpu::getSizePerThread(resultLayout)[d];
srcLogicalOrder[d] = order[d] + rank;
srcLogicalOrder[d + rank] = order[d];