[Triton-MLIR][BACKEND] Refine dot conversion (#710)

This PR does

1. Refine the dot conversion
2. some other tiny code refinement
This commit is contained in:
Yan Chunwei
2022-09-27 14:38:34 +08:00
committed by GitHub
parent 61b61755e5
commit 3a84278530
11 changed files with 439 additions and 291 deletions

View File

@@ -56,11 +56,14 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
inVec = outOrd[0] == 0 ? 1 : inOrd[0] == 0 ? 1 : srcContigPerThread;
outVec = outOrd[0] == 0 ? 1 : dstContigPerThread;
auto srcShapePerCTA = getShapePerCTA(srcLayout);
auto dstShapePerCTA = getShapePerCTA(dstLayout);
unsigned pad = std::max(inVec, outVec);
for (unsigned d = 0; d < rank; ++d) {
paddedRepShape[d] = std::max(
std::min<unsigned>(srcTy.getShape()[d], getShapePerCTA(srcLayout, d)),
std::min<unsigned>(dstTy.getShape()[d], getShapePerCTA(dstLayout, d)));
paddedRepShape[d] =
std::max(std::min<unsigned>(srcTy.getShape()[d], srcShapePerCTA[d]),
std::min<unsigned>(dstTy.getShape()[d], dstShapePerCTA[d]));
}
unsigned paddedDim = 1;
if (auto dstBlockedLayout = dstLayout.dyn_cast<BlockedEncodingAttr>()) {