[Triton-MLIR] Two fixes on allocation and backend related with MMA v1 (#930)

This commit is contained in:
goostavz
2022-11-30 17:27:26 +08:00
committed by GitHub
parent 9bb54402b3
commit 4e6a8209ed
5 changed files with 47 additions and 23 deletions

View File

@@ -2901,12 +2901,12 @@ private:
Value mmaThreadIdInGrp = urem(laneId, _4);
Value mmaThreadIdInGrpM2 = mul(mmaThreadIdInGrp, _2);
Value mmaThreadIdInGrpM2P1 = add(mmaThreadIdInGrpM2, _1);
Value colWarpOffset = mul(multiDimWarpId[0], _16);
mmaColIdx[0] = add(mmaGrpId, colWarpOffset);
mmaColIdx[1] = add(mmaGrpIdP8, colWarpOffset);
Value rowWarpOffset = mul(multiDimWarpId[1], _8);
mmaRowIdx[0] = add(mmaThreadIdInGrpM2, rowWarpOffset);
mmaRowIdx[1] = add(mmaThreadIdInGrpM2P1, rowWarpOffset);
Value rowWarpOffset = mul(multiDimWarpId[0], _16);
mmaRowIdx[0] = add(mmaGrpId, rowWarpOffset);
mmaRowIdx[1] = add(mmaGrpIdP8, rowWarpOffset);
Value colWarpOffset = mul(multiDimWarpId[1], _8);
mmaColIdx[0] = add(mmaThreadIdInGrpM2, colWarpOffset);
mmaColIdx[1] = add(mmaThreadIdInGrpM2P1, colWarpOffset);
} else if (mmaLayout.getVersion() == 1) {
multiDimWarpId[0] = urem(multiDimWarpId[0], idx_val(shape[0] / 16));
multiDimWarpId[1] = urem(multiDimWarpId[1], idx_val(shape[1] / 16));
@@ -2920,7 +2920,7 @@ private:
Value rowOffset = add(mul(multiDimWarpId[1], _16), partRowOffset);
mmaRowIdx[0] = add(urem(laneId, _2), rowOffset);
mmaRowIdx[1] = add(mmaRowIdx[0], _2);
mmaColIdx[0] = add(udiv(urem(laneId, _4), _2), colOffset);
mmaColIdx[0] = add(mul(udiv(urem(laneId, _4), _2), _2), colOffset);
mmaColIdx[1] = add(mmaColIdx[0], _1);
mmaColIdx[2] = add(mmaColIdx[0], _4);
mmaColIdx[3] = add(mmaColIdx[0], idx_val(5));
@@ -2931,28 +2931,28 @@ private:
assert(rank == 2);
SmallVector<Value> multiDimOffset(rank);
if (mmaLayout.getVersion() == 2) {
multiDimOffset[0] = elemId < 2 ? mmaColIdx[0] : mmaColIdx[1];
multiDimOffset[1] = elemId % 2 == 0 ? mmaRowIdx[0] : mmaRowIdx[1];
multiDimOffset[0] = elemId < 2 ? mmaRowIdx[0] : mmaRowIdx[1];
multiDimOffset[1] = elemId % 2 == 0 ? mmaColIdx[0] : mmaColIdx[1];
multiDimOffset[0] = add(
multiDimOffset[0], idx_val(multiDimCTAInRepId[0] * shapePerCTA[0]));
multiDimOffset[1] = add(
multiDimOffset[1], idx_val(multiDimCTAInRepId[1] * shapePerCTA[1]));
} else if (mmaLayout.getVersion() == 1) {
// the order of elements in a thread:
// c0, c1, c4, c5
// c2, c3, c6, c7
// c0, c1, ... c4, c5
// c2, c3, ... c6, c7
if (elemId < 2) {
multiDimOffset[0] = mmaColIdx[elemId % 2];
multiDimOffset[1] = mmaRowIdx[0];
multiDimOffset[0] = mmaRowIdx[0];
multiDimOffset[1] = mmaColIdx[elemId % 2];
} else if (elemId >= 2 && elemId < 4) {
multiDimOffset[0] = mmaColIdx[elemId % 2];
multiDimOffset[1] = mmaRowIdx[1];
multiDimOffset[0] = mmaRowIdx[1];
multiDimOffset[1] = mmaColIdx[elemId % 2];
} else if (elemId >= 4 && elemId < 6) {
multiDimOffset[0] = mmaColIdx[elemId % 2 + 2];
multiDimOffset[1] = mmaRowIdx[0];
multiDimOffset[0] = mmaRowIdx[0];
multiDimOffset[1] = mmaColIdx[elemId % 2 + 2];
} else if (elemId >= 6) {
multiDimOffset[0] = mmaColIdx[elemId % 2 + 2];
multiDimOffset[1] = mmaRowIdx[1];
multiDimOffset[0] = mmaRowIdx[1];
multiDimOffset[1] = mmaColIdx[elemId % 2 + 2];
}
multiDimOffset[0] = add(
multiDimOffset[0], idx_val(multiDimCTAInRepId[0] * shapePerCTA[0]));
@@ -3051,6 +3051,7 @@ void ConvertLayoutOpConversion::processReplica(
multiDimCTAInRepId, shapePerCTA);
Value offset =
linearize(rewriter, loc, multiDimOffset, paddedRepShape, outOrd);
auto elemPtrTy = ptr_ty(llvmElemTy, 3);
Value ptr = gep(elemPtrTy, smemBase, offset);
auto vecTy = vec_ty(llvmElemTy, vec);
@@ -3171,6 +3172,11 @@ LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared(
triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto loc = op.getLoc();
// TODO[Keren]: A temporary workaround for an issue from membar pass.
// https://triton-lang.slack.com/archives/C042VBSQWNS/p1669796615860699?thread_ts=1669779203.526739&cid=C042VBSQWNS
barrier();
Value src = op.src();
Value dst = op.result();
auto srcTy = src.getType().cast<RankedTensorType>();