[Triton-MLIR][Backend] Fix the definition of MmaEncodingAttr v1, and the output sequence of DotConversion in MMAv1 (#941)
This commit is contained in:
@@ -293,7 +293,7 @@ partitioned between warps.
|
||||
// -------------------------------- version = 1 --------------------------- //
|
||||
|
||||
For first-gen tensor cores, the implicit warpTileSize is [16, 16].
|
||||
Information about this layout can be found in the official PTX documentation
|
||||
Note: the layout is different from the recommended in PTX ISA
|
||||
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html
|
||||
(mma.884 section, FP32 accumulator).
|
||||
|
||||
@@ -301,29 +301,29 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is:
|
||||
|
||||
warp 0
|
||||
--------------------------------/\-------------------------------
|
||||
[ 0 0 2 2 0 0 2 2 4 4 6 6 4 4 6 6 ]
|
||||
[ 1 1 3 3 1 1 3 3 5 5 7 7 5 5 7 7 ]
|
||||
[ 0 0 2 2 0 0 2 2 4 4 6 6 4 4 6 6 ]
|
||||
[ 1 1 3 3 1 1 3 3 5 5 7 7 5 5 7 7 ]
|
||||
[ 16 16 18 18 16 16 18 18 20 20 22 22 20 20 22 22]
|
||||
[ 17 17 19 19 17 17 19 19 21 21 23 23 21 21 23 23]
|
||||
[ 16 16 18 18 16 16 18 18 20 20 22 22 20 20 22 22]
|
||||
[ 17 17 19 19 17 17 19 19 21 21 23 23 21 21 23 23]
|
||||
[ 8 8 10 10 8 8 10 10 12 12 14 14 12 12 14 14]
|
||||
[ 9 9 11 11 9 9 11 11 13 13 15 15 13 13 15 15]
|
||||
[ ..............................................................
|
||||
[ ..............................................................
|
||||
[ 24 24 26 26 24 24 26 26 28 28 30 30 28 28 30 30]
|
||||
[ 25 25 27 27 25 25 27 27 29 29 31 31 29 29 31 31]
|
||||
[ 0 0 2 2 8 8 10 10 0 0 2 2 8 8 10 10 ]
|
||||
[ 1 1 3 3 9 9 11 11 1 1 3 3 9 9 11 11 ]
|
||||
[ 0 0 2 2 8 8 10 10 0 0 2 2 8 8 10 10 ]
|
||||
[ 1 1 3 3 9 9 11 11 1 1 3 3 9 9 11 11 ]
|
||||
[ 4 4 6 6 12 12 14 14 4 4 6 6 12 12 14 14 ]
|
||||
[ 5 5 7 7 13 13 15 15 5 5 7 7 13 13 15 15 ]
|
||||
[ 4 4 6 6 12 12 14 14 4 4 6 6 12 12 14 14 ]
|
||||
[ 5 5 7 7 13 13 15 15 5 5 7 7 13 13 15 15 ]
|
||||
[ 16 16 18 18 20 20 22 22 16 16 18 18 20 20 22 22 ]
|
||||
[ 17 17 19 19 21 21 23 23 17 17 19 19 21 21 23 23 ]
|
||||
[ 16 16 18 18 20 20 22 22 16 16 18 18 20 20 22 22 ]
|
||||
[ 17 17 19 19 21 21 23 23 17 17 19 19 21 21 23 23 ]
|
||||
[ 24 24 26 26 28 28 30 30 24 24 26 26 28 28 30 30 ]
|
||||
[ 25 25 27 27 29 29 31 31 25 25 27 27 29 29 31 31 ]
|
||||
[ 24 24 26 26 28 28 30 30 24 24 26 26 28 28 30 30 ]
|
||||
[ 25 25 27 27 29 29 31 31 25 25 27 27 29 29 31 31 ]
|
||||
|
||||
warp 1 = warp0 + 32
|
||||
warp 1 = warp0 + 32
|
||||
--------------------------------/\-------------------------------
|
||||
[ 32 32 34 34 32 32 34 34 36 36 38 38 36 36 38 38]
|
||||
[ 33 33 35 35 33 33 35 35 37 37 39 39 37 37 39 39]
|
||||
[ ..............................................................
|
||||
[ ..............................................................
|
||||
[ 56 56 58 58 56 56 58 58 60 60 62 62 60 60 62 62]
|
||||
[ 57 57 59 59 57 57 59 59 61 61 63 63 61 61 63 63]
|
||||
[ 32 32 34 34 40 40 42 42 32 32 34 34 40 40 42 42 ]
|
||||
[ 33 33 35 35 41 41 43 43 33 33 35 35 41 41 43 43 ]
|
||||
[ ............................................................... ]
|
||||
|
||||
|
||||
// -------------------------------- version = 2 --------------------------- //
|
||||
|
||||
|
@@ -2876,20 +2876,21 @@ private:
|
||||
} else if (mmaLayout.getVersion() == 1) {
|
||||
multiDimWarpId[0] = urem(multiDimWarpId[0], idx_val(shape[0] / 16));
|
||||
multiDimWarpId[1] = urem(multiDimWarpId[1], idx_val(shape[1] / 16));
|
||||
Value partId = udiv(laneId, _4);
|
||||
Value partIdDiv4 = udiv(partId, _4);
|
||||
Value partIdRem4 = urem(partId, _4);
|
||||
Value partRowOffset = mul(udiv(partIdRem4, _2), _8);
|
||||
partRowOffset = add(mul(partIdDiv4, _4), partRowOffset);
|
||||
Value partColOffset = mul(urem(partIdRem4, _2), _8);
|
||||
Value colOffset = add(mul(multiDimWarpId[0], _16), partColOffset);
|
||||
Value rowOffset = add(mul(multiDimWarpId[1], _16), partRowOffset);
|
||||
mmaRowIdx[0] = add(urem(laneId, _2), rowOffset);
|
||||
Value laneIdDiv16 = udiv(laneId, _16);
|
||||
Value laneIdRem16 = urem(laneId, _16);
|
||||
Value laneIdRem2 = urem(laneId, _2);
|
||||
Value laneIdRem16Div8 = udiv(laneIdRem16, _8);
|
||||
Value laneIdRem16Div4 = udiv(laneIdRem16, _4);
|
||||
Value laneIdRem16Div4Rem2 = urem(laneIdRem16Div4, _2);
|
||||
Value laneIdRem4Div2 = udiv(urem(laneId, _4), _2);
|
||||
mmaRowIdx[0] =
|
||||
add(add(mul(laneIdDiv16, _8), mul(laneIdRem16Div4Rem2, _4)),
|
||||
laneIdRem2);
|
||||
mmaRowIdx[1] = add(mmaRowIdx[0], _2);
|
||||
mmaColIdx[0] = add(mul(udiv(urem(laneId, _4), _2), _2), colOffset);
|
||||
mmaColIdx[0] = add(mul(laneIdRem16Div8, _4), mul(laneIdRem4Div2, _2));
|
||||
mmaColIdx[1] = add(mmaColIdx[0], _1);
|
||||
mmaColIdx[2] = add(mmaColIdx[0], _4);
|
||||
mmaColIdx[3] = add(mmaColIdx[0], idx_val(5));
|
||||
mmaColIdx[2] = add(mmaColIdx[0], _8);
|
||||
mmaColIdx[3] = add(mmaColIdx[0], idx_val(9));
|
||||
} else {
|
||||
llvm_unreachable("Unexpected MMALayout version");
|
||||
}
|
||||
@@ -3543,6 +3544,8 @@ DotOpConversion::convertMMA884(triton::DotOp op, DotOpAdaptor adaptor,
|
||||
|
||||
// initialize accumulators
|
||||
SmallVector<Value> acc = getElementsFromStruct(loc, loadedC, rewriter);
|
||||
size_t resSize = acc.size();
|
||||
SmallVector<Value> resVals(resSize);
|
||||
|
||||
auto callMMA = [&](unsigned m, unsigned n, unsigned k) {
|
||||
auto ha = has[{m, k}];
|
||||
@@ -3586,8 +3589,14 @@ DotOpConversion::convertMMA884(triton::DotOp op, DotOpAdaptor adaptor,
|
||||
auto getIntAttr = [&](int v) {
|
||||
return ArrayAttr::get(ctx, {IntegerAttr::get(i32_ty, v)});
|
||||
};
|
||||
for (unsigned i = 0; i < 8; i++)
|
||||
acc[idx[i]] = extract_val(f32_ty, res, getIntAttr(i));
|
||||
|
||||
for (unsigned i = 0; i < 8; i++) {
|
||||
Value elem = extract_val(f32_ty, res, getIntAttr(i));
|
||||
acc[idx[i]] = elem;
|
||||
// TODO[goostavz]: double confirm this when m/n/k = [32, 32, x] has been
|
||||
// verified before MMA
|
||||
resVals[(m * numN / 2 + n) * 8 + i] = elem;
|
||||
}
|
||||
};
|
||||
|
||||
for (unsigned k = 0; k < NK; k += 4)
|
||||
@@ -3596,12 +3605,10 @@ DotOpConversion::convertMMA884(triton::DotOp op, DotOpAdaptor adaptor,
|
||||
callMMA(m, n, k);
|
||||
}
|
||||
|
||||
// replace with new packed result
|
||||
Type structTy = LLVM::LLVMStructType::getLiteral(
|
||||
ctx, SmallVector<Type>(acc.size(), type::f32Ty(ctx)));
|
||||
Value res = getStructFromElements(loc, acc, rewriter, structTy);
|
||||
ctx, SmallVector<Type>(resSize, type::f32Ty(ctx)));
|
||||
Value res = getStructFromElements(loc, resVals, rewriter, structTy);
|
||||
rewriter.replaceOp(op, res);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user