[Triton-MLIR][BACKEND] Refine dot conversion (#710)
This PR does 1. Refine the dot conversion 2. some other tiny code refinement
This commit is contained in:
@@ -56,11 +56,14 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
|
||||
inVec = outOrd[0] == 0 ? 1 : inOrd[0] == 0 ? 1 : srcContigPerThread;
|
||||
outVec = outOrd[0] == 0 ? 1 : dstContigPerThread;
|
||||
|
||||
auto srcShapePerCTA = getShapePerCTA(srcLayout);
|
||||
auto dstShapePerCTA = getShapePerCTA(dstLayout);
|
||||
|
||||
unsigned pad = std::max(inVec, outVec);
|
||||
for (unsigned d = 0; d < rank; ++d) {
|
||||
paddedRepShape[d] = std::max(
|
||||
std::min<unsigned>(srcTy.getShape()[d], getShapePerCTA(srcLayout, d)),
|
||||
std::min<unsigned>(dstTy.getShape()[d], getShapePerCTA(dstLayout, d)));
|
||||
paddedRepShape[d] =
|
||||
std::max(std::min<unsigned>(srcTy.getShape()[d], srcShapePerCTA[d]),
|
||||
std::min<unsigned>(dstTy.getShape()[d], dstShapePerCTA[d]));
|
||||
}
|
||||
unsigned paddedDim = 1;
|
||||
if (auto dstBlockedLayout = dstLayout.dyn_cast<BlockedEncodingAttr>()) {
|
||||
|
Reference in New Issue
Block a user