[Triton-MLIR][Backend] support MMA v1 in ConvertLayout (#922)

The e2e verification of mma v1 is not done yet. 
Get this merged in advance just to prevent more conflicts.
This commit is contained in:
goostavz
2022-11-28 16:10:30 +08:00
committed by GitHub
parent 9d31998a9d
commit 0c1d4d764e
3 changed files with 116 additions and 32 deletions

View File

@@ -2926,7 +2926,7 @@ private:
return multiDimOffset;
}
if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
SmallVector<Value> mmaColIdx(2);
SmallVector<Value> mmaColIdx(4);
SmallVector<Value> mmaRowIdx(2);
Value threadId = getThreadId(rewriter, loc);
Value warpSize = idx_val(32);
@@ -2936,31 +2936,79 @@ private:
SmallVector<Value> multiDimWarpId(2);
multiDimWarpId[0] = urem(warpId, idx_val(mmaLayout.getWarpsPerCTA()[0]));
multiDimWarpId[1] = udiv(warpId, idx_val(mmaLayout.getWarpsPerCTA()[0]));
multiDimWarpId[0] = urem(multiDimWarpId[0], idx_val(shape[0] / 16));
multiDimWarpId[1] = urem(multiDimWarpId[1], idx_val(shape[1] / 8));
Value four = idx_val(4);
Value mmaGrpId = udiv(laneId, four);
Value mmaGrpIdP8 = add(mmaGrpId, idx_val(8));
Value mmaThreadIdInGrp = urem(laneId, four);
Value mmaThreadIdInGrpM2 = mul(mmaThreadIdInGrp, idx_val(2));
Value mmaThreadIdInGrpM2P1 = add(mmaThreadIdInGrpM2, idx_val(1));
Value colWarpOffset = mul(multiDimWarpId[0], idx_val(16));
mmaColIdx[0] = add(mmaGrpId, colWarpOffset);
mmaColIdx[1] = add(mmaGrpIdP8, colWarpOffset);
Value rowWarpOffset = mul(multiDimWarpId[1], idx_val(8));
mmaRowIdx[0] = add(mmaThreadIdInGrpM2, rowWarpOffset);
mmaRowIdx[1] = add(mmaThreadIdInGrpM2P1, rowWarpOffset);
Value _1 = idx_val(1);
Value _2 = idx_val(2);
Value _4 = idx_val(4);
Value _8 = idx_val(8);
Value _16 = idx_val(16);
if (mmaLayout.getVersion() == 2) {
multiDimWarpId[0] = urem(multiDimWarpId[0], idx_val(shape[0] / 16));
multiDimWarpId[1] = urem(multiDimWarpId[1], idx_val(shape[1] / 8));
Value mmaGrpId = udiv(laneId, _4);
Value mmaGrpIdP8 = add(mmaGrpId, _8);
Value mmaThreadIdInGrp = urem(laneId, _4);
Value mmaThreadIdInGrpM2 = mul(mmaThreadIdInGrp, _2);
Value mmaThreadIdInGrpM2P1 = add(mmaThreadIdInGrpM2, _1);
Value colWarpOffset = mul(multiDimWarpId[0], _16);
mmaColIdx[0] = add(mmaGrpId, colWarpOffset);
mmaColIdx[1] = add(mmaGrpIdP8, colWarpOffset);
Value rowWarpOffset = mul(multiDimWarpId[1], _8);
mmaRowIdx[0] = add(mmaThreadIdInGrpM2, rowWarpOffset);
mmaRowIdx[1] = add(mmaThreadIdInGrpM2P1, rowWarpOffset);
} else if (mmaLayout.getVersion() == 1) {
multiDimWarpId[0] = urem(multiDimWarpId[0], idx_val(shape[0] / 16));
multiDimWarpId[1] = urem(multiDimWarpId[1], idx_val(shape[1] / 16));
Value partId = udiv(laneId, _4);
Value partIdDiv4 = udiv(partId, _4);
Value partIdRem4 = urem(partId, _4);
Value partRowOffset = mul(udiv(partIdRem4, _2), _8);
partRowOffset = add(mul(partIdDiv4, _4), partRowOffset);
Value partColOffset = mul(urem(partIdRem4, _2), _8);
Value colOffset = add(mul(multiDimWarpId[0], _16), partColOffset);
Value rowOffset = add(mul(multiDimWarpId[1], _16), partRowOffset);
mmaRowIdx[0] = add(urem(laneId, _2), rowOffset);
mmaRowIdx[1] = add(mmaRowIdx[0], _2);
mmaColIdx[0] = add(udiv(urem(laneId, _4), _2), colOffset);
mmaColIdx[1] = add(mmaColIdx[0], _1);
mmaColIdx[2] = add(mmaColIdx[0], _4);
mmaColIdx[3] = add(mmaColIdx[0], idx_val(5));
} else {
llvm_unreachable("Unexpected MMALayout version");
}
assert(rank == 2);
assert(mmaLayout.getVersion() == 2 &&
"mmaLayout ver1 not implemented yet");
SmallVector<Value> multiDimOffset(rank);
multiDimOffset[0] = elemId < 2 ? mmaColIdx[0] : mmaColIdx[1];
multiDimOffset[1] = elemId % 2 == 0 ? mmaRowIdx[0] : mmaRowIdx[1];
multiDimOffset[0] = add(multiDimOffset[0],
idx_val(multiDimCTAInRepId[0] * shapePerCTA[0]));
multiDimOffset[1] = add(multiDimOffset[1],
idx_val(multiDimCTAInRepId[1] * shapePerCTA[1]));
if (mmaLayout.getVersion() == 2) {
multiDimOffset[0] = elemId < 2 ? mmaColIdx[0] : mmaColIdx[1];
multiDimOffset[1] = elemId % 2 == 0 ? mmaRowIdx[0] : mmaRowIdx[1];
multiDimOffset[0] = add(
multiDimOffset[0], idx_val(multiDimCTAInRepId[0] * shapePerCTA[0]));
multiDimOffset[1] = add(
multiDimOffset[1], idx_val(multiDimCTAInRepId[1] * shapePerCTA[1]));
} else if (mmaLayout.getVersion() == 1) {
// the order of elements in a thread:
// c0, c1, c4, c5
// c2, c3, c6, c7
if (elemId < 2) {
multiDimOffset[0] = mmaColIdx[elemId % 2];
multiDimOffset[1] = mmaRowIdx[0];
} else if (elemId >= 2 && elemId < 4) {
multiDimOffset[0] = mmaColIdx[elemId % 2];
multiDimOffset[1] = mmaRowIdx[1];
} else if (elemId >= 4 && elemId < 6) {
multiDimOffset[0] = mmaColIdx[elemId % 2 + 2];
multiDimOffset[1] = mmaRowIdx[0];
} else if (elemId >= 6) {
multiDimOffset[0] = mmaColIdx[elemId % 2 + 2];
multiDimOffset[1] = mmaRowIdx[1];
}
multiDimOffset[0] = add(
multiDimOffset[0], idx_val(multiDimCTAInRepId[0] * shapePerCTA[0]));
multiDimOffset[1] = add(
multiDimOffset[1], idx_val(multiDimCTAInRepId[1] * shapePerCTA[1]));
} else {
llvm_unreachable("Unexpected MMALayout version");
}
return multiDimOffset;
}
llvm_unreachable("unexpected layout in getMultiDimOffset");