[Triton-MLIR][Backend] Fix convert_layout blocked->shared in non-default order (#876)

This PR fix the problem of TN/NT GEMM correctness when no SCF involved.
I'll continue to clean up getLinearIndex/getMultiDimIndex in a uniformed
way which should be benifical to avoid different kinds of order issues.
This is not fully done yet, just merge to sync the code.
This commit is contained in:
goostavz
2022-11-15 09:02:46 +08:00
committed by GitHub
parent 1eedaf7bec
commit c28cfd821b
2 changed files with 155 additions and 58 deletions

View File

@@ -347,7 +347,8 @@ Value getStructFromElements(Location loc, ValueRange resultVals,
return llvmStruct;
}
// Delinearize on compile-time consts, assuming the order is [n, .. 2, 1, 0]
// TODO[goostavz]: to be deprecated
// delinearize supposing order is [n, .. , 2, 1, 0]
template <typename T>
static SmallVector<T> getMultiDimIndex(T linearIndex, ArrayRef<T> shape) {
// shape: {a, b, c, d} -> accMul: {b*c*d, c*d, d, 1}
@@ -365,7 +366,40 @@ static SmallVector<T> getMultiDimIndex(T linearIndex, ArrayRef<T> shape) {
return multiDimIndex;
}
// Linearize on compile-time consts, assuming the order is [n, .. 2, 1, 0]
// delinearize supposing order is [0, 1, .. , n]
template <typename T>
static SmallVector<T> getMultiDimIndexImpl(T linearIndex, ArrayRef<T> shape) {
// shape: {a, b, c, d} -> accMul: {1, a, a*b, a*b*c}
size_t rank = shape.size();
T accMul = product(shape.drop_back());
T linearRemain = linearIndex;
SmallVector<T> multiDimIndex(rank);
for (int i = rank - 1; i >= 0; --i) {
multiDimIndex[i] = linearRemain / accMul;
linearRemain = linearRemain % accMul;
if (i != 0) {
accMul = accMul / shape[i - 1];
}
}
return multiDimIndex;
}
template <typename T>
static SmallVector<T> getMultiDimIndex(T linearIndex, ArrayRef<T> shape,
ArrayRef<unsigned> order) {
size_t rank = shape.size();
assert(rank == order.size());
auto reordered = reorder(shape, order);
auto reorderedMultiDim = getMultiDimIndexImpl<T>(linearIndex, reordered);
SmallVector<T> multiDim(rank);
for (unsigned i = 0; i < rank; ++i) {
multiDim[order[i]] = reorderedMultiDim[i];
}
return multiDim;
}
// TODO[goostavz]: to be deprecated
// linearize supposing order is [n, .. , 2, 1, 0]
template <typename T>
static T getLinearIndex(ArrayRef<T> multiDimIndex, ArrayRef<T> shape) {
assert(multiDimIndex.size() == shape.size());
@@ -382,6 +416,30 @@ static T getLinearIndex(ArrayRef<T> multiDimIndex, ArrayRef<T> shape) {
return linearIndex;
}
template <typename T>
static T getLinearIndexImpl(ArrayRef<T> multiDimIndex, ArrayRef<T> shape) {
assert(multiDimIndex.size() == shape.size());
// shape: {a, b, c, d} -> accMul: {1, a, a*b, a*b*c}
size_t rank = shape.size();
T accMul = product(shape.drop_back());
T linearIndex = 0;
for (int i = rank - 1; i >= 0; --i) {
linearIndex += multiDimIndex[i] * accMul;
if (i != 0) {
accMul = accMul / shape[i - 1];
}
}
return linearIndex;
}
template <typename T>
static T getLinearIndex(ArrayRef<T> multiDimIndex, ArrayRef<T> shape,
ArrayRef<unsigned> order) {
assert(shape.size() == order.size());
return getLinearIndexImpl<T>(reorder(multiDimIndex, order),
reorder(shape, order));
}
static Value storeShared(ConversionPatternRewriter &rewriter, Location loc,
Value ptr, Value val, Value pred) {
MLIRContext *ctx = rewriter.getContext();
@@ -632,6 +690,7 @@ public:
auto sizePerThread = blockedLayout.getSizePerThread();
auto threadsPerWarp = blockedLayout.getThreadsPerWarp();
auto warpsPerCTA = blockedLayout.getWarpsPerCTA();
auto order = blockedLayout.getOrder();
unsigned rank = shape.size();
SmallVector<unsigned> shapePerCTA = getShapePerCTA(blockedLayout);
@@ -663,9 +722,9 @@ public:
unsigned linearNanoTileId = n / totalSizePerThread;
unsigned linearNanoTileElemId = n % totalSizePerThread;
SmallVector<unsigned> multiDimNanoTileId =
getMultiDimIndex<unsigned>(linearNanoTileId, tilesPerDim);
SmallVector<unsigned> multiDimNanoTileElemId =
getMultiDimIndex<unsigned>(linearNanoTileElemId, sizePerThread);
getMultiDimIndex<unsigned>(linearNanoTileId, tilesPerDim, order);
SmallVector<unsigned> multiDimNanoTileElemId = getMultiDimIndex<unsigned>(
linearNanoTileElemId, sizePerThread, order);
for (unsigned k = 0; k < rank; ++k) {
unsigned reorderedMultiDimId =
multiDimNanoTileId[k] *
@@ -1881,8 +1940,6 @@ struct PrintfOpConversion
// currently support pointer, i8, i16, i32, i64, f16, bf16, f32, f64
std::string getFormatSubstr(Value value) const {
Type type = value.getType();
unsigned width = type.getIntOrFloatBitWidth();
if (type.isa<LLVM::LLVMPointerType>()) {
return "%p";
} else if (type.isBF16() || type.isF16() || type.isF32() || type.isF64()) {
@@ -1924,13 +1981,11 @@ struct PrintfOpConversion
promoteValue(ConversionPatternRewriter &rewriter, Value value) {
auto *context = rewriter.getContext();
auto type = value.getType();
type.dump();
unsigned width = type.getIntOrFloatBitWidth();
Value newOp = value;
Type newType = type;
bool bUnsigned = type.isUnsignedInteger();
if (type.isIntOrIndex() && width < 32) {
if (type.isIntOrIndex() && type.getIntOrFloatBitWidth() < 32) {
if (bUnsigned) {
newType = ui32_ty;
newOp = rewriter.create<LLVM::ZExtOp>(UnknownLoc::get(context), newType,
@@ -3057,23 +3112,24 @@ LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared(
}
unsigned linearIdxInNanoTile = i % srcAccumSizeInThreads;
auto multiDimIdxInNanoTile = getMultiDimIndex<unsigned>(
linearIdxInNanoTile, srcBlockedLayout.getSizePerThread());
linearIdxInNanoTile, srcBlockedLayout.getSizePerThread(), inOrd);
unsigned pos = multiDimIdxInNanoTile[inOrd[0]] % minVec;
multiDimIdxInNanoTile[inOrd[0]] /= minVec;
unsigned wordVecIdx =
getLinearIndex<unsigned>(multiDimIdxInNanoTile, wordsInEachRep);
getLinearIndex<unsigned>(multiDimIdxInNanoTile, wordsInEachRep, inOrd);
wordVecs[wordVecIdx] =
insert_element(wordTy, wordVecs[wordVecIdx], inVals[i], idx_val(pos));
if (i % srcAccumSizeInThreads == srcAccumSizeInThreads - 1) {
// end of replication, store the vectors into shared memory
unsigned linearRepIdx = i / srcAccumSizeInThreads;
auto multiDimRepIdx = getMultiDimIndex<unsigned>(linearRepIdx, reps);
auto multiDimRepIdx =
getMultiDimIndex<unsigned>(linearRepIdx, reps, inOrd);
for (unsigned linearWordIdx = 0; linearWordIdx < numWordsEachRep;
++linearWordIdx) {
// step 1: recover the multidim_index from the index of input_elements
auto multiDimWordIdx =
getMultiDimIndex<unsigned>(linearWordIdx, wordsInEachRep);
getMultiDimIndex<unsigned>(linearWordIdx, wordsInEachRep, inOrd);
SmallVector<Value> multiDimIdx(2);
auto wordOffset0 = multiDimRepIdx[0] * srcShapePerCTA[0] +
multiDimWordIdx[0] * (inOrd[0] == 0 ? minVec : 1);
@@ -3083,12 +3139,12 @@ LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared(
multiDimIdx[1] = add(multiDimOffsetFirstElem[1], idx_val(wordOffset1));
// step 2: do swizzling
Value remained = urem(multiDimIdx[inOrd[0]], outVecVal);
multiDimIdx[inOrd[0]] = udiv(multiDimIdx[inOrd[0]], outVecVal);
Value off_1 = mul(multiDimIdx[inOrd[1]], idx_val(srcShape[inOrd[0]]));
Value phaseId = udiv(multiDimIdx[inOrd[1]], idx_val(perPhase));
Value remained = urem(multiDimIdx[outOrd[0]], outVecVal);
multiDimIdx[outOrd[0]] = udiv(multiDimIdx[outOrd[0]], outVecVal);
Value off_1 = mul(multiDimIdx[outOrd[1]], idx_val(srcShape[outOrd[0]]));
Value phaseId = udiv(multiDimIdx[outOrd[1]], idx_val(perPhase));
phaseId = urem(phaseId, idx_val(maxPhase));
Value off_0 = xor_(multiDimIdx[inOrd[0]], phaseId);
Value off_0 = xor_(multiDimIdx[outOrd[0]], phaseId);
off_0 = mul(off_0, outVecVal);
remained = udiv(remained, minVecVal);
off_0 = add(off_0, mul(remained, minVecVal));