[Triton-MLIR][Backend] Some cleanup in getMultiDimIndex/getLinearIndex (#880)
This commit is contained in:
@@ -356,25 +356,6 @@ Value getStructFromElements(Location loc, ValueRange resultVals,
|
|||||||
return llvmStruct;
|
return llvmStruct;
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO[goostavz]: to be deprecated
|
|
||||||
// delinearize supposing order is [n, .. , 2, 1, 0]
|
|
||||||
template <typename T>
|
|
||||||
static SmallVector<T> getMultiDimIndex(T linearIndex, ArrayRef<T> shape) {
|
|
||||||
// shape: {a, b, c, d} -> accMul: {b*c*d, c*d, d, 1}
|
|
||||||
size_t rank = shape.size();
|
|
||||||
T accMul = product(shape.drop_front());
|
|
||||||
T linearRemain = linearIndex;
|
|
||||||
SmallVector<T> multiDimIndex(rank);
|
|
||||||
for (size_t i = 0; i < rank; ++i) {
|
|
||||||
multiDimIndex[i] = linearRemain / accMul;
|
|
||||||
linearRemain = linearRemain % accMul;
|
|
||||||
if (i != (rank - 1)) {
|
|
||||||
accMul = accMul / shape[i + 1];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return multiDimIndex;
|
|
||||||
}
|
|
||||||
|
|
||||||
// delinearize supposing order is [0, 1, .. , n]
|
// delinearize supposing order is [0, 1, .. , n]
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static SmallVector<T> getMultiDimIndexImpl(T linearIndex, ArrayRef<T> shape) {
|
static SmallVector<T> getMultiDimIndexImpl(T linearIndex, ArrayRef<T> shape) {
|
||||||
@@ -407,24 +388,7 @@ static SmallVector<T> getMultiDimIndex(T linearIndex, ArrayRef<T> shape,
|
|||||||
return multiDim;
|
return multiDim;
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO[goostavz]: to be deprecated
|
// linearize supposing order is [0, 1, .. , n]
|
||||||
// linearize supposing order is [n, .. , 2, 1, 0]
|
|
||||||
template <typename T>
|
|
||||||
static T getLinearIndex(ArrayRef<T> multiDimIndex, ArrayRef<T> shape) {
|
|
||||||
assert(multiDimIndex.size() == shape.size());
|
|
||||||
// shape: {a, b, c, d} -> accMul: {b*c*d, c*d, d, 1}
|
|
||||||
size_t rank = shape.size();
|
|
||||||
T accMul = product(shape.drop_front());
|
|
||||||
T linearIndex = 0;
|
|
||||||
for (size_t i = 0; i < rank; ++i) {
|
|
||||||
linearIndex += multiDimIndex[i] * accMul;
|
|
||||||
if (i != (rank - 1)) {
|
|
||||||
accMul = accMul / shape[i + 1];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return linearIndex;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static T getLinearIndexImpl(ArrayRef<T> multiDimIndex, ArrayRef<T> shape) {
|
static T getLinearIndexImpl(ArrayRef<T> multiDimIndex, ArrayRef<T> shape) {
|
||||||
assert(multiDimIndex.size() == shape.size());
|
assert(multiDimIndex.size() == shape.size());
|
||||||
@@ -621,6 +585,13 @@ public:
|
|||||||
return multiDim;
|
return multiDim;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Value linearize(ConversionPatternRewriter &rewriter, Location loc,
|
||||||
|
ArrayRef<Value> multiDim, ArrayRef<unsigned> shape,
|
||||||
|
ArrayRef<unsigned> order) const {
|
||||||
|
return linearize(rewriter, loc, reorder<Value>(multiDim, order),
|
||||||
|
reorder<unsigned>(shape, order));
|
||||||
|
}
|
||||||
|
|
||||||
Value linearize(ConversionPatternRewriter &rewriter, Location loc,
|
Value linearize(ConversionPatternRewriter &rewriter, Location loc,
|
||||||
ArrayRef<Value> multiDim, ArrayRef<unsigned> shape) const {
|
ArrayRef<Value> multiDim, ArrayRef<unsigned> shape) const {
|
||||||
int rank = multiDim.size();
|
int rank = multiDim.size();
|
||||||
@@ -1436,10 +1407,12 @@ struct BroadcastOpConversion
|
|||||||
auto resultShape = resultTy.getShape();
|
auto resultShape = resultTy.getShape();
|
||||||
unsigned rank = srcTy.getRank();
|
unsigned rank = srcTy.getRank();
|
||||||
assert(rank == resultTy.getRank());
|
assert(rank == resultTy.getRank());
|
||||||
|
auto order = srcLayout.getOrder();
|
||||||
|
|
||||||
SmallVector<int64_t, 4> srcLogicalShape(2 * rank);
|
SmallVector<int64_t> srcLogicalShape(2 * rank);
|
||||||
SmallVector<int64_t, 4> resultLogicalShape(2 * rank);
|
SmallVector<unsigned> srcLogicalOrder(2 * rank);
|
||||||
SmallVector<unsigned, 2> broadcastDims;
|
SmallVector<int64_t> resultLogicalShape(2 * rank);
|
||||||
|
SmallVector<unsigned> broadcastDims;
|
||||||
for (unsigned d = 0; d < rank; ++d) {
|
for (unsigned d = 0; d < rank; ++d) {
|
||||||
unsigned resultShapePerCTA = resultLayout.getSizePerThread()[d] *
|
unsigned resultShapePerCTA = resultLayout.getSizePerThread()[d] *
|
||||||
resultLayout.getThreadsPerWarp()[d] *
|
resultLayout.getThreadsPerWarp()[d] *
|
||||||
@@ -1457,9 +1430,13 @@ struct BroadcastOpConversion
|
|||||||
}
|
}
|
||||||
resultLogicalShape[d] = numCtas;
|
resultLogicalShape[d] = numCtas;
|
||||||
resultLogicalShape[d + rank] = resultLayout.getSizePerThread()[d];
|
resultLogicalShape[d + rank] = resultLayout.getSizePerThread()[d];
|
||||||
|
|
||||||
|
srcLogicalOrder[d] = order[d] + rank;
|
||||||
|
srcLogicalOrder[d + rank] = order[d];
|
||||||
}
|
}
|
||||||
int64_t duplicates = 1;
|
int64_t duplicates = 1;
|
||||||
SmallVector<int64_t, 2> broadcastSizes(broadcastDims.size() * 2);
|
SmallVector<int64_t> broadcastSizes(broadcastDims.size() * 2);
|
||||||
|
SmallVector<unsigned> broadcastOrder(broadcastDims.size() * 2);
|
||||||
for (auto it : llvm::enumerate(broadcastDims)) {
|
for (auto it : llvm::enumerate(broadcastDims)) {
|
||||||
// Incase there are multiple indices in the src that is actually
|
// Incase there are multiple indices in the src that is actually
|
||||||
// calculating the same element, srcLogicalShape may not need to be 1.
|
// calculating the same element, srcLogicalShape may not need to be 1.
|
||||||
@@ -1468,30 +1445,44 @@ struct BroadcastOpConversion
|
|||||||
// [1, 2]
|
// [1, 2]
|
||||||
int64_t d = resultLogicalShape[it.value()] / srcLogicalShape[it.value()];
|
int64_t d = resultLogicalShape[it.value()] / srcLogicalShape[it.value()];
|
||||||
broadcastSizes[it.index()] = d;
|
broadcastSizes[it.index()] = d;
|
||||||
|
broadcastOrder[it.index()] = srcLogicalOrder[it.value()];
|
||||||
duplicates *= d;
|
duplicates *= d;
|
||||||
d = resultLogicalShape[it.value() + rank] /
|
d = resultLogicalShape[it.value() + rank] /
|
||||||
srcLogicalShape[it.value() + rank];
|
srcLogicalShape[it.value() + rank];
|
||||||
broadcastSizes[it.index() + broadcastDims.size()] = d;
|
broadcastSizes[it.index() + broadcastDims.size()] = d;
|
||||||
|
broadcastOrder[it.index() + broadcastDims.size()] =
|
||||||
|
srcLogicalOrder[it.value() + rank];
|
||||||
duplicates *= d;
|
duplicates *= d;
|
||||||
}
|
}
|
||||||
|
auto argsort = [](SmallVector<unsigned> input) {
|
||||||
|
SmallVector<unsigned> idx(input.size());
|
||||||
|
std::iota(idx.begin(), idx.end(), 0);
|
||||||
|
std::sort(idx.begin(), idx.end(), [&input](unsigned a, unsigned b) {
|
||||||
|
return input[a] < input[b];
|
||||||
|
});
|
||||||
|
return idx;
|
||||||
|
};
|
||||||
|
broadcastOrder = argsort(broadcastOrder);
|
||||||
|
|
||||||
unsigned srcElems = getElemsPerThread(srcTy);
|
unsigned srcElems = getElemsPerThread(srcTy);
|
||||||
auto srcVals = getElementsFromStruct(loc, src, rewriter);
|
auto srcVals = getElementsFromStruct(loc, src, rewriter);
|
||||||
unsigned resultElems = getElemsPerThread(resultTy);
|
unsigned resultElems = getElemsPerThread(resultTy);
|
||||||
SmallVector<Value> resultVals(resultElems);
|
SmallVector<Value> resultVals(resultElems);
|
||||||
for (unsigned i = 0; i < srcElems; ++i) {
|
for (unsigned i = 0; i < srcElems; ++i) {
|
||||||
auto srcMultiDim = getMultiDimIndex<int64_t>(i, srcLogicalShape);
|
auto srcMultiDim =
|
||||||
|
getMultiDimIndex<int64_t>(i, srcLogicalShape, srcLogicalOrder);
|
||||||
for (int64_t j = 0; j < duplicates; ++j) {
|
for (int64_t j = 0; j < duplicates; ++j) {
|
||||||
auto resultMultiDim = srcMultiDim;
|
auto resultMultiDim = srcMultiDim;
|
||||||
auto bcastMultiDim = getMultiDimIndex<int64_t>(j, broadcastSizes);
|
auto bcastMultiDim =
|
||||||
|
getMultiDimIndex<int64_t>(j, broadcastSizes, broadcastOrder);
|
||||||
for (auto bcastDim : llvm::enumerate(broadcastDims)) {
|
for (auto bcastDim : llvm::enumerate(broadcastDims)) {
|
||||||
resultMultiDim[bcastDim.value()] += bcastMultiDim[bcastDim.index()];
|
resultMultiDim[bcastDim.value()] += bcastMultiDim[bcastDim.index()];
|
||||||
resultMultiDim[bcastDim.value() + rank] +=
|
resultMultiDim[bcastDim.value() + rank] +=
|
||||||
bcastMultiDim[bcastDim.index() + broadcastDims.size()] *
|
bcastMultiDim[bcastDim.index() + broadcastDims.size()] *
|
||||||
srcLogicalShape[bcastDim.index() + broadcastDims.size()];
|
srcLogicalShape[bcastDim.index() + broadcastDims.size()];
|
||||||
}
|
}
|
||||||
auto resultLinearIndex =
|
auto resultLinearIndex = getLinearIndex<int64_t>(
|
||||||
getLinearIndex<int64_t>(resultMultiDim, resultLogicalShape);
|
resultMultiDim, resultLogicalShape, srcLogicalOrder);
|
||||||
resultVals[resultLinearIndex] = srcVals[i];
|
resultVals[resultLinearIndex] = srcVals[i];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1665,9 +1656,7 @@ LogicalResult ReduceOpConversion::matchAndRewriteBasic(
|
|||||||
SmallVector<Value> writeIdx = indices[key];
|
SmallVector<Value> writeIdx = indices[key];
|
||||||
|
|
||||||
writeIdx[axis] = udiv(writeIdx[axis], sizePerThread);
|
writeIdx[axis] = udiv(writeIdx[axis], sizePerThread);
|
||||||
Value writeOffset =
|
Value writeOffset = linearize(rewriter, loc, writeIdx, smemShape, srcOrd);
|
||||||
linearize(rewriter, loc, reorder<Value>(writeIdx, srcOrd),
|
|
||||||
reorder<unsigned>(smemShape, srcOrd));
|
|
||||||
Value writePtr = gep(elemPtrTy, smemBase, writeOffset);
|
Value writePtr = gep(elemPtrTy, smemBase, writeOffset);
|
||||||
store(acc, writePtr);
|
store(acc, writePtr);
|
||||||
|
|
||||||
@@ -1676,9 +1665,7 @@ LogicalResult ReduceOpConversion::matchAndRewriteBasic(
|
|||||||
readIdx[axis] = ints[N];
|
readIdx[axis] = ints[N];
|
||||||
Value readMask = icmp_slt(writeIdx[axis], ints[N]);
|
Value readMask = icmp_slt(writeIdx[axis], ints[N]);
|
||||||
Value readOffset =
|
Value readOffset =
|
||||||
select(readMask,
|
select(readMask, linearize(rewriter, loc, readIdx, smemShape, srcOrd),
|
||||||
linearize(rewriter, loc, reorder<Value>(readIdx, srcOrd),
|
|
||||||
reorder<unsigned>(smemShape, srcOrd)),
|
|
||||||
ints[0]);
|
ints[0]);
|
||||||
Value readPtr = gep(elemPtrTy, writePtr, readOffset);
|
Value readPtr = gep(elemPtrTy, writePtr, readOffset);
|
||||||
barrier();
|
barrier();
|
||||||
@@ -1702,9 +1689,7 @@ LogicalResult ReduceOpConversion::matchAndRewriteBasic(
|
|||||||
for (unsigned i = 0; i < resultElems; ++i) {
|
for (unsigned i = 0; i < resultElems; ++i) {
|
||||||
SmallVector<Value> readIdx = resultIndices[i];
|
SmallVector<Value> readIdx = resultIndices[i];
|
||||||
readIdx.insert(readIdx.begin() + axis, ints[0]);
|
readIdx.insert(readIdx.begin() + axis, ints[0]);
|
||||||
Value readOffset =
|
Value readOffset = linearize(rewriter, loc, readIdx, smemShape, srcOrd);
|
||||||
linearize(rewriter, loc, reorder<Value>(readIdx, srcOrd),
|
|
||||||
reorder<unsigned>(smemShape, srcOrd));
|
|
||||||
Value readPtr = gep(elemPtrTy, smemBase, readOffset);
|
Value readPtr = gep(elemPtrTy, smemBase, readOffset);
|
||||||
resultVals[i] = load(readPtr);
|
resultVals[i] = load(readPtr);
|
||||||
}
|
}
|
||||||
@@ -1798,9 +1783,7 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast(
|
|||||||
|
|
||||||
SmallVector<Value> writeIdx = indices[key];
|
SmallVector<Value> writeIdx = indices[key];
|
||||||
writeIdx[axis] = (sizeInterWarps == 1) ? zero : warpIdAxis;
|
writeIdx[axis] = (sizeInterWarps == 1) ? zero : warpIdAxis;
|
||||||
Value writeOffset =
|
Value writeOffset = linearize(rewriter, loc, writeIdx, smemShape, order);
|
||||||
linearize(rewriter, loc, reorder<Value>(writeIdx, order),
|
|
||||||
reorder<unsigned>(smemShape, order));
|
|
||||||
Value writePtr = gep(elemPtrTy, smemBase, writeOffset);
|
Value writePtr = gep(elemPtrTy, smemBase, writeOffset);
|
||||||
storeShared(rewriter, loc, writePtr, acc, laneZero);
|
storeShared(rewriter, loc, writePtr, acc, laneZero);
|
||||||
}
|
}
|
||||||
@@ -1851,7 +1834,6 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast(
|
|||||||
if (auto resultTy = op.getType().dyn_cast<RankedTensorType>()) {
|
if (auto resultTy = op.getType().dyn_cast<RankedTensorType>()) {
|
||||||
// nd-tensor where n >= 1
|
// nd-tensor where n >= 1
|
||||||
auto resultLayout = resultTy.getEncoding().cast<SliceEncodingAttr>();
|
auto resultLayout = resultTy.getEncoding().cast<SliceEncodingAttr>();
|
||||||
auto resultShape = resultTy.getShape();
|
|
||||||
SmallVector<unsigned> resultOrd;
|
SmallVector<unsigned> resultOrd;
|
||||||
for (auto ord : order) {
|
for (auto ord : order) {
|
||||||
if (ord != 0)
|
if (ord != 0)
|
||||||
@@ -1859,15 +1841,18 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast(
|
|||||||
}
|
}
|
||||||
|
|
||||||
unsigned resultElems = getElemsPerThread(resultTy);
|
unsigned resultElems = getElemsPerThread(resultTy);
|
||||||
auto resultIndices = emitIndices(loc, rewriter, resultLayout, resultShape);
|
auto resultIndices =
|
||||||
|
emitIndices(loc, rewriter, resultLayout, resultTy.getShape());
|
||||||
assert(resultIndices.size() == resultElems);
|
assert(resultIndices.size() == resultElems);
|
||||||
|
|
||||||
SmallVector<Value> resultVals(resultElems);
|
SmallVector<Value> resultVals(resultElems);
|
||||||
|
SmallVector<unsigned> resultShape;
|
||||||
|
std::copy(resultTy.getShape().begin(), resultTy.getShape().end(),
|
||||||
|
std::back_inserter(resultShape));
|
||||||
for (size_t i = 0; i < resultElems; ++i) {
|
for (size_t i = 0; i < resultElems; ++i) {
|
||||||
SmallVector<Value> readIdx = resultIndices[i];
|
SmallVector<Value> readIdx = resultIndices[i];
|
||||||
Value readOffset =
|
Value readOffset =
|
||||||
linearize(rewriter, loc, reorder<Value>(readIdx, resultOrd),
|
linearize(rewriter, loc, readIdx, resultShape, resultOrd);
|
||||||
reorder<int64_t, unsigned>(resultShape, resultOrd));
|
|
||||||
Value readPtr = gep(elemPtrTy, smemBase, readOffset);
|
Value readPtr = gep(elemPtrTy, smemBase, readOffset);
|
||||||
resultVals[i] = load(readPtr);
|
resultVals[i] = load(readPtr);
|
||||||
}
|
}
|
||||||
@@ -2818,8 +2803,8 @@ private:
|
|||||||
auto multiDimOffsetFirstElem =
|
auto multiDimOffsetFirstElem =
|
||||||
emitBaseIndexForBlockedLayout(loc, rewriter, blockedLayout, shape);
|
emitBaseIndexForBlockedLayout(loc, rewriter, blockedLayout, shape);
|
||||||
SmallVector<Value> multiDimOffset(rank);
|
SmallVector<Value> multiDimOffset(rank);
|
||||||
SmallVector<unsigned> multiDimElemId =
|
SmallVector<unsigned> multiDimElemId = getMultiDimIndex<unsigned>(
|
||||||
getMultiDimIndex<unsigned>(elemId, blockedLayout.getSizePerThread());
|
elemId, blockedLayout.getSizePerThread(), blockedLayout.getOrder());
|
||||||
for (unsigned d = 0; d < rank; ++d) {
|
for (unsigned d = 0; d < rank; ++d) {
|
||||||
multiDimOffset[d] = add(multiDimOffsetFirstElem[d],
|
multiDimOffset[d] = add(multiDimOffsetFirstElem[d],
|
||||||
idx_val(multiDimCTAInRepId[d] * shapePerCTA[d] +
|
idx_val(multiDimCTAInRepId[d] * shapePerCTA[d] +
|
||||||
@@ -2850,9 +2835,7 @@ private:
|
|||||||
Value warpSize = idx_val(32);
|
Value warpSize = idx_val(32);
|
||||||
Value laneId = urem(threadId, warpSize);
|
Value laneId = urem(threadId, warpSize);
|
||||||
Value warpId = udiv(threadId, warpSize);
|
Value warpId = udiv(threadId, warpSize);
|
||||||
// auto multiDimWarpId =
|
// TODO: fix the bug in MMAEncodingAttr document
|
||||||
// delinearize(rewriter, loc, warpId, mmaLayout.getWarpsPerCTA());
|
|
||||||
// TODO: double confirm if its document bug or DotConversion's Bug
|
|
||||||
SmallVector<Value> multiDimWarpId(2);
|
SmallVector<Value> multiDimWarpId(2);
|
||||||
multiDimWarpId[0] = urem(warpId, idx_val(mmaLayout.getWarpsPerCTA()[0]));
|
multiDimWarpId[0] = urem(warpId, idx_val(mmaLayout.getWarpsPerCTA()[0]));
|
||||||
multiDimWarpId[1] = udiv(warpId, idx_val(mmaLayout.getWarpsPerCTA()[0]));
|
multiDimWarpId[1] = udiv(warpId, idx_val(mmaLayout.getWarpsPerCTA()[0]));
|
||||||
@@ -2942,6 +2925,7 @@ void ConvertLayoutOpConversion::processReplica(
|
|||||||
auto accumSizePerThread = product<unsigned>(sizePerThread);
|
auto accumSizePerThread = product<unsigned>(sizePerThread);
|
||||||
SmallVector<unsigned> numCTAs(rank);
|
SmallVector<unsigned> numCTAs(rank);
|
||||||
auto shapePerCTA = getShapePerCTA(layout);
|
auto shapePerCTA = getShapePerCTA(layout);
|
||||||
|
auto order = getOrder(layout);
|
||||||
for (unsigned d = 0; d < rank; ++d) {
|
for (unsigned d = 0; d < rank; ++d) {
|
||||||
numCTAs[d] = ceil<unsigned>(type.getShape()[d], shapePerCTA[d]);
|
numCTAs[d] = ceil<unsigned>(type.getShape()[d], shapePerCTA[d]);
|
||||||
}
|
}
|
||||||
@@ -2957,14 +2941,16 @@ void ConvertLayoutOpConversion::processReplica(
|
|||||||
auto llvmElemTy = getTypeConverter()->convertType(elemTy);
|
auto llvmElemTy = getTypeConverter()->convertType(elemTy);
|
||||||
|
|
||||||
for (unsigned ctaId = 0; ctaId < accumNumCTAsEachRep; ++ctaId) {
|
for (unsigned ctaId = 0; ctaId < accumNumCTAsEachRep; ++ctaId) {
|
||||||
auto multiDimCTAInRepId = getMultiDimIndex<unsigned>(ctaId, numCTAsEachRep);
|
auto multiDimCTAInRepId =
|
||||||
|
getMultiDimIndex<unsigned>(ctaId, numCTAsEachRep, order);
|
||||||
SmallVector<unsigned> multiDimCTAId(rank);
|
SmallVector<unsigned> multiDimCTAId(rank);
|
||||||
for (auto it : llvm::enumerate(multiDimCTAInRepId)) {
|
for (auto it : llvm::enumerate(multiDimCTAInRepId)) {
|
||||||
auto d = it.index();
|
auto d = it.index();
|
||||||
multiDimCTAId[d] = multiDimRepId[d] * numCTAsEachRep[d] + it.value();
|
multiDimCTAId[d] = multiDimRepId[d] * numCTAsEachRep[d] + it.value();
|
||||||
}
|
}
|
||||||
|
|
||||||
unsigned linearCTAId = getLinearIndex<unsigned>(multiDimCTAId, numCTAs);
|
unsigned linearCTAId =
|
||||||
|
getLinearIndex<unsigned>(multiDimCTAId, numCTAs, order);
|
||||||
// TODO: This is actually redundant index calculation, we should
|
// TODO: This is actually redundant index calculation, we should
|
||||||
// consider of caching the index calculation result in case
|
// consider of caching the index calculation result in case
|
||||||
// of performance issue observed.
|
// of performance issue observed.
|
||||||
@@ -2973,8 +2959,7 @@ void ConvertLayoutOpConversion::processReplica(
|
|||||||
getMultiDimOffset(layout, loc, rewriter, elemId, type.getShape(),
|
getMultiDimOffset(layout, loc, rewriter, elemId, type.getShape(),
|
||||||
multiDimCTAInRepId, shapePerCTA);
|
multiDimCTAInRepId, shapePerCTA);
|
||||||
Value offset =
|
Value offset =
|
||||||
linearize(rewriter, loc, reorder<Value>(multiDimOffset, outOrd),
|
linearize(rewriter, loc, multiDimOffset, paddedRepShape, outOrd);
|
||||||
reorder<unsigned>(paddedRepShape, outOrd));
|
|
||||||
auto elemPtrTy = ptr_ty(llvmElemTy, 3);
|
auto elemPtrTy = ptr_ty(llvmElemTy, 3);
|
||||||
Value ptr = gep(elemPtrTy, smemBase, offset);
|
Value ptr = gep(elemPtrTy, smemBase, offset);
|
||||||
auto vecTy = vec_ty(llvmElemTy, vec);
|
auto vecTy = vec_ty(llvmElemTy, vec);
|
||||||
@@ -3055,7 +3040,8 @@ LogicalResult ConvertLayoutOpConversion::lowerDistributedToDistributed(
|
|||||||
SmallVector<Value> outVals(outElems);
|
SmallVector<Value> outVals(outElems);
|
||||||
|
|
||||||
for (unsigned repId = 0; repId < accumNumReplicates; ++repId) {
|
for (unsigned repId = 0; repId < accumNumReplicates; ++repId) {
|
||||||
auto multiDimRepId = getMultiDimIndex<unsigned>(repId, numReplicates);
|
auto multiDimRepId =
|
||||||
|
getMultiDimIndex<unsigned>(repId, numReplicates, outOrd);
|
||||||
barrier();
|
barrier();
|
||||||
if (srcLayout.isa<BlockedEncodingAttr>() ||
|
if (srcLayout.isa<BlockedEncodingAttr>() ||
|
||||||
srcLayout.isa<SliceEncodingAttr>() ||
|
srcLayout.isa<SliceEncodingAttr>() ||
|
||||||
|
@@ -154,6 +154,19 @@ def get_variant_golden(a, b):
|
|||||||
c_padded = torch.matmul(a_padded, b_padded)
|
c_padded = torch.matmul(a_padded, b_padded)
|
||||||
return c_padded[:SIZE_M, :SIZE_N]
|
return c_padded[:SIZE_M, :SIZE_N]
|
||||||
|
|
||||||
|
# It's not easy to get a proper error threshold in different size
|
||||||
|
# Here the gemm calculation is padded to a different size in order to get
|
||||||
|
# a variant version of the golden result. And the error between golden and
|
||||||
|
# golden_variant provide reference on selecting the proper rtol / atol.
|
||||||
|
|
||||||
|
|
||||||
|
def get_proper_err(a, b, golden):
|
||||||
|
golden_variant = get_variant_golden(a, b)
|
||||||
|
golden_diff = golden - golden_variant
|
||||||
|
golden_abs_err = torch.max(torch.abs(golden_diff)).item()
|
||||||
|
golden_rel_err = torch.max(torch.abs(golden_diff / golden)).item()
|
||||||
|
return (golden_abs_err, golden_rel_err)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('SIZE_M,SIZE_N,SIZE_K,NUM_WARPS,BLOCK_SIZE_M,BLOCK_SIZE_N,BLOCK_SIZE_K,TRANS_A,TRANS_B', [
|
@pytest.mark.parametrize('SIZE_M,SIZE_N,SIZE_K,NUM_WARPS,BLOCK_SIZE_M,BLOCK_SIZE_N,BLOCK_SIZE_K,TRANS_A,TRANS_B', [
|
||||||
# Non-forloop
|
# Non-forloop
|
||||||
@@ -198,16 +211,7 @@ def test_gemm(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLO
|
|||||||
BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K,
|
BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K,
|
||||||
num_warps=NUM_WARPS)
|
num_warps=NUM_WARPS)
|
||||||
golden = torch.matmul(a, b)
|
golden = torch.matmul(a, b)
|
||||||
|
golden_abs_err, golden_rel_err = get_proper_err(a, b, golden)
|
||||||
# It's not easy to get a proper error threshold in different size
|
|
||||||
# Here the gemm calculation is padded to a different size in order to get
|
|
||||||
# a variant version of the golden result. And the error between golden and
|
|
||||||
# golden_variant provide reference on selecting the proper rtol / atol.
|
|
||||||
golden_variant = get_variant_golden(a, b)
|
|
||||||
golden_diff = golden - golden_variant
|
|
||||||
golden_abs_err = torch.max(torch.abs(golden_diff)).item()
|
|
||||||
golden_rel_err = torch.max(torch.abs(golden_diff / golden)).item()
|
|
||||||
|
|
||||||
torch.set_printoptions(profile="full")
|
torch.set_printoptions(profile="full")
|
||||||
assert_close(c, golden, rtol=max(1e-4, 1.5 * golden_rel_err), atol=max(1e-4, 1.5 * golden_abs_err), check_dtype=False)
|
assert_close(c, golden, rtol=max(1e-4, 1.5 * golden_rel_err), atol=max(1e-4, 1.5 * golden_abs_err), check_dtype=False)
|
||||||
|
|
||||||
@@ -272,4 +276,5 @@ def test_gemm_fmadot(M, N, K, num_warps, block_M, block_N, block_K):
|
|||||||
BLOCK_SIZE_M=block_M, BLOCK_SIZE_N=block_N, BLOCK_SIZE_K=block_K)
|
BLOCK_SIZE_M=block_M, BLOCK_SIZE_N=block_N, BLOCK_SIZE_K=block_K)
|
||||||
|
|
||||||
golden = torch.matmul(a, b)
|
golden = torch.matmul(a, b)
|
||||||
torch.testing.assert_close(c, golden)
|
golden_abs_err, golden_rel_err = get_proper_err(a, b, golden)
|
||||||
|
torch.testing.assert_close(c, golden, rtol=max(1e-4, 1.5 * golden_rel_err), atol=max(1e-4, 1.5 * golden_abs_err))
|
||||||
|
@@ -245,12 +245,12 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
|||||||
%0 = tt.view %arg : (tensor<256xf32, #blocked0>) -> tensor<256x1xf32,#blocked2>
|
%0 = tt.view %arg : (tensor<256xf32, #blocked0>) -> tensor<256x1xf32,#blocked2>
|
||||||
// CHECK: llvm.mlir.undef
|
// CHECK: llvm.mlir.undef
|
||||||
// CHECK: llvm.insertvalue %[[T0]]
|
// CHECK: llvm.insertvalue %[[T0]]
|
||||||
// CHECK: llvm.insertvalue %[[T0]]
|
// CHECK: llvm.insertvalue %[[T1]]
|
||||||
// CHECK: llvm.insertvalue %[[T0]]
|
|
||||||
// CHECK: llvm.insertvalue %[[T0]]
|
// CHECK: llvm.insertvalue %[[T0]]
|
||||||
// CHECK: llvm.insertvalue %[[T1]]
|
// CHECK: llvm.insertvalue %[[T1]]
|
||||||
|
// CHECK: llvm.insertvalue %[[T0]]
|
||||||
// CHECK: llvm.insertvalue %[[T1]]
|
// CHECK: llvm.insertvalue %[[T1]]
|
||||||
// CHECK: llvm.insertvalue %[[T1]]
|
// CHECK: llvm.insertvalue %[[T0]]
|
||||||
// CHECK: llvm.insertvalue %[[T1]]
|
// CHECK: llvm.insertvalue %[[T1]]
|
||||||
%1 = tt.broadcast %0 : (tensor<256x1xf32,#blocked2>) -> tensor<256x4xf32, #blocked2>
|
%1 = tt.broadcast %0 : (tensor<256x1xf32,#blocked2>) -> tensor<256x4xf32, #blocked2>
|
||||||
return
|
return
|
||||||
|
Reference in New Issue
Block a user