[Triton-MLIR][Backend] Fix mmav1 in case of numWarps > 1 (#972)
This commit is contained in:
@@ -2886,11 +2886,15 @@ private:
|
||||
Value laneIdRem16Div4 = udiv(laneIdRem16, _4);
|
||||
Value laneIdRem16Div4Rem2 = urem(laneIdRem16Div4, _2);
|
||||
Value laneIdRem4Div2 = udiv(urem(laneId, _4), _2);
|
||||
Value rowWarpOffset = mul(multiDimWarpId[0], _16);
|
||||
Value colWarpOffset = mul(multiDimWarpId[1], _16);
|
||||
mmaRowIdx[0] =
|
||||
add(add(mul(laneIdDiv16, _8), mul(laneIdRem16Div4Rem2, _4)),
|
||||
laneIdRem2);
|
||||
mmaRowIdx[0] = add(mmaRowIdx[0], rowWarpOffset);
|
||||
mmaRowIdx[1] = add(mmaRowIdx[0], _2);
|
||||
mmaColIdx[0] = add(mul(laneIdRem16Div8, _4), mul(laneIdRem4Div2, _2));
|
||||
mmaColIdx[0] = add(mmaColIdx[0], colWarpOffset);
|
||||
mmaColIdx[1] = add(mmaColIdx[0], _1);
|
||||
mmaColIdx[2] = add(mmaColIdx[0], _8);
|
||||
mmaColIdx[3] = add(mmaColIdx[0], idx_val(9));
|
||||
|
Reference in New Issue
Block a user