[Triton-MLIR][Backend] Some cleanup in getMultiDimIndex/getLinearIndex (#880)
This commit is contained in:
@@ -356,25 +356,6 @@ Value getStructFromElements(Location loc, ValueRange resultVals,
|
||||
return llvmStruct;
|
||||
}
|
||||
|
||||
// TODO[goostavz]: to be deprecated
|
||||
// delinearize supposing order is [n, .. , 2, 1, 0]
|
||||
template <typename T>
|
||||
static SmallVector<T> getMultiDimIndex(T linearIndex, ArrayRef<T> shape) {
|
||||
// shape: {a, b, c, d} -> accMul: {b*c*d, c*d, d, 1}
|
||||
size_t rank = shape.size();
|
||||
T accMul = product(shape.drop_front());
|
||||
T linearRemain = linearIndex;
|
||||
SmallVector<T> multiDimIndex(rank);
|
||||
for (size_t i = 0; i < rank; ++i) {
|
||||
multiDimIndex[i] = linearRemain / accMul;
|
||||
linearRemain = linearRemain % accMul;
|
||||
if (i != (rank - 1)) {
|
||||
accMul = accMul / shape[i + 1];
|
||||
}
|
||||
}
|
||||
return multiDimIndex;
|
||||
}
|
||||
|
||||
// delinearize supposing order is [0, 1, .. , n]
|
||||
template <typename T>
|
||||
static SmallVector<T> getMultiDimIndexImpl(T linearIndex, ArrayRef<T> shape) {
|
||||
@@ -407,24 +388,7 @@ static SmallVector<T> getMultiDimIndex(T linearIndex, ArrayRef<T> shape,
|
||||
return multiDim;
|
||||
}
|
||||
|
||||
// TODO[goostavz]: to be deprecated
|
||||
// linearize supposing order is [n, .. , 2, 1, 0]
|
||||
template <typename T>
|
||||
static T getLinearIndex(ArrayRef<T> multiDimIndex, ArrayRef<T> shape) {
|
||||
assert(multiDimIndex.size() == shape.size());
|
||||
// shape: {a, b, c, d} -> accMul: {b*c*d, c*d, d, 1}
|
||||
size_t rank = shape.size();
|
||||
T accMul = product(shape.drop_front());
|
||||
T linearIndex = 0;
|
||||
for (size_t i = 0; i < rank; ++i) {
|
||||
linearIndex += multiDimIndex[i] * accMul;
|
||||
if (i != (rank - 1)) {
|
||||
accMul = accMul / shape[i + 1];
|
||||
}
|
||||
}
|
||||
return linearIndex;
|
||||
}
|
||||
|
||||
// linearize supposing order is [0, 1, .. , n]
|
||||
template <typename T>
|
||||
static T getLinearIndexImpl(ArrayRef<T> multiDimIndex, ArrayRef<T> shape) {
|
||||
assert(multiDimIndex.size() == shape.size());
|
||||
@@ -621,6 +585,13 @@ public:
|
||||
return multiDim;
|
||||
}
|
||||
|
||||
Value linearize(ConversionPatternRewriter &rewriter, Location loc,
|
||||
ArrayRef<Value> multiDim, ArrayRef<unsigned> shape,
|
||||
ArrayRef<unsigned> order) const {
|
||||
return linearize(rewriter, loc, reorder<Value>(multiDim, order),
|
||||
reorder<unsigned>(shape, order));
|
||||
}
|
||||
|
||||
Value linearize(ConversionPatternRewriter &rewriter, Location loc,
|
||||
ArrayRef<Value> multiDim, ArrayRef<unsigned> shape) const {
|
||||
int rank = multiDim.size();
|
||||
@@ -1436,10 +1407,12 @@ struct BroadcastOpConversion
|
||||
auto resultShape = resultTy.getShape();
|
||||
unsigned rank = srcTy.getRank();
|
||||
assert(rank == resultTy.getRank());
|
||||
auto order = srcLayout.getOrder();
|
||||
|
||||
SmallVector<int64_t, 4> srcLogicalShape(2 * rank);
|
||||
SmallVector<int64_t, 4> resultLogicalShape(2 * rank);
|
||||
SmallVector<unsigned, 2> broadcastDims;
|
||||
SmallVector<int64_t> srcLogicalShape(2 * rank);
|
||||
SmallVector<unsigned> srcLogicalOrder(2 * rank);
|
||||
SmallVector<int64_t> resultLogicalShape(2 * rank);
|
||||
SmallVector<unsigned> broadcastDims;
|
||||
for (unsigned d = 0; d < rank; ++d) {
|
||||
unsigned resultShapePerCTA = resultLayout.getSizePerThread()[d] *
|
||||
resultLayout.getThreadsPerWarp()[d] *
|
||||
@@ -1457,9 +1430,13 @@ struct BroadcastOpConversion
|
||||
}
|
||||
resultLogicalShape[d] = numCtas;
|
||||
resultLogicalShape[d + rank] = resultLayout.getSizePerThread()[d];
|
||||
|
||||
srcLogicalOrder[d] = order[d] + rank;
|
||||
srcLogicalOrder[d + rank] = order[d];
|
||||
}
|
||||
int64_t duplicates = 1;
|
||||
SmallVector<int64_t, 2> broadcastSizes(broadcastDims.size() * 2);
|
||||
SmallVector<int64_t> broadcastSizes(broadcastDims.size() * 2);
|
||||
SmallVector<unsigned> broadcastOrder(broadcastDims.size() * 2);
|
||||
for (auto it : llvm::enumerate(broadcastDims)) {
|
||||
// Incase there are multiple indices in the src that is actually
|
||||
// calculating the same element, srcLogicalShape may not need to be 1.
|
||||
@@ -1468,30 +1445,44 @@ struct BroadcastOpConversion
|
||||
// [1, 2]
|
||||
int64_t d = resultLogicalShape[it.value()] / srcLogicalShape[it.value()];
|
||||
broadcastSizes[it.index()] = d;
|
||||
broadcastOrder[it.index()] = srcLogicalOrder[it.value()];
|
||||
duplicates *= d;
|
||||
d = resultLogicalShape[it.value() + rank] /
|
||||
srcLogicalShape[it.value() + rank];
|
||||
broadcastSizes[it.index() + broadcastDims.size()] = d;
|
||||
broadcastOrder[it.index() + broadcastDims.size()] =
|
||||
srcLogicalOrder[it.value() + rank];
|
||||
duplicates *= d;
|
||||
}
|
||||
auto argsort = [](SmallVector<unsigned> input) {
|
||||
SmallVector<unsigned> idx(input.size());
|
||||
std::iota(idx.begin(), idx.end(), 0);
|
||||
std::sort(idx.begin(), idx.end(), [&input](unsigned a, unsigned b) {
|
||||
return input[a] < input[b];
|
||||
});
|
||||
return idx;
|
||||
};
|
||||
broadcastOrder = argsort(broadcastOrder);
|
||||
|
||||
unsigned srcElems = getElemsPerThread(srcTy);
|
||||
auto srcVals = getElementsFromStruct(loc, src, rewriter);
|
||||
unsigned resultElems = getElemsPerThread(resultTy);
|
||||
SmallVector<Value> resultVals(resultElems);
|
||||
for (unsigned i = 0; i < srcElems; ++i) {
|
||||
auto srcMultiDim = getMultiDimIndex<int64_t>(i, srcLogicalShape);
|
||||
auto srcMultiDim =
|
||||
getMultiDimIndex<int64_t>(i, srcLogicalShape, srcLogicalOrder);
|
||||
for (int64_t j = 0; j < duplicates; ++j) {
|
||||
auto resultMultiDim = srcMultiDim;
|
||||
auto bcastMultiDim = getMultiDimIndex<int64_t>(j, broadcastSizes);
|
||||
auto bcastMultiDim =
|
||||
getMultiDimIndex<int64_t>(j, broadcastSizes, broadcastOrder);
|
||||
for (auto bcastDim : llvm::enumerate(broadcastDims)) {
|
||||
resultMultiDim[bcastDim.value()] += bcastMultiDim[bcastDim.index()];
|
||||
resultMultiDim[bcastDim.value() + rank] +=
|
||||
bcastMultiDim[bcastDim.index() + broadcastDims.size()] *
|
||||
srcLogicalShape[bcastDim.index() + broadcastDims.size()];
|
||||
}
|
||||
auto resultLinearIndex =
|
||||
getLinearIndex<int64_t>(resultMultiDim, resultLogicalShape);
|
||||
auto resultLinearIndex = getLinearIndex<int64_t>(
|
||||
resultMultiDim, resultLogicalShape, srcLogicalOrder);
|
||||
resultVals[resultLinearIndex] = srcVals[i];
|
||||
}
|
||||
}
|
||||
@@ -1665,9 +1656,7 @@ LogicalResult ReduceOpConversion::matchAndRewriteBasic(
|
||||
SmallVector<Value> writeIdx = indices[key];
|
||||
|
||||
writeIdx[axis] = udiv(writeIdx[axis], sizePerThread);
|
||||
Value writeOffset =
|
||||
linearize(rewriter, loc, reorder<Value>(writeIdx, srcOrd),
|
||||
reorder<unsigned>(smemShape, srcOrd));
|
||||
Value writeOffset = linearize(rewriter, loc, writeIdx, smemShape, srcOrd);
|
||||
Value writePtr = gep(elemPtrTy, smemBase, writeOffset);
|
||||
store(acc, writePtr);
|
||||
|
||||
@@ -1676,9 +1665,7 @@ LogicalResult ReduceOpConversion::matchAndRewriteBasic(
|
||||
readIdx[axis] = ints[N];
|
||||
Value readMask = icmp_slt(writeIdx[axis], ints[N]);
|
||||
Value readOffset =
|
||||
select(readMask,
|
||||
linearize(rewriter, loc, reorder<Value>(readIdx, srcOrd),
|
||||
reorder<unsigned>(smemShape, srcOrd)),
|
||||
select(readMask, linearize(rewriter, loc, readIdx, smemShape, srcOrd),
|
||||
ints[0]);
|
||||
Value readPtr = gep(elemPtrTy, writePtr, readOffset);
|
||||
barrier();
|
||||
@@ -1702,9 +1689,7 @@ LogicalResult ReduceOpConversion::matchAndRewriteBasic(
|
||||
for (unsigned i = 0; i < resultElems; ++i) {
|
||||
SmallVector<Value> readIdx = resultIndices[i];
|
||||
readIdx.insert(readIdx.begin() + axis, ints[0]);
|
||||
Value readOffset =
|
||||
linearize(rewriter, loc, reorder<Value>(readIdx, srcOrd),
|
||||
reorder<unsigned>(smemShape, srcOrd));
|
||||
Value readOffset = linearize(rewriter, loc, readIdx, smemShape, srcOrd);
|
||||
Value readPtr = gep(elemPtrTy, smemBase, readOffset);
|
||||
resultVals[i] = load(readPtr);
|
||||
}
|
||||
@@ -1798,9 +1783,7 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast(
|
||||
|
||||
SmallVector<Value> writeIdx = indices[key];
|
||||
writeIdx[axis] = (sizeInterWarps == 1) ? zero : warpIdAxis;
|
||||
Value writeOffset =
|
||||
linearize(rewriter, loc, reorder<Value>(writeIdx, order),
|
||||
reorder<unsigned>(smemShape, order));
|
||||
Value writeOffset = linearize(rewriter, loc, writeIdx, smemShape, order);
|
||||
Value writePtr = gep(elemPtrTy, smemBase, writeOffset);
|
||||
storeShared(rewriter, loc, writePtr, acc, laneZero);
|
||||
}
|
||||
@@ -1851,7 +1834,6 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast(
|
||||
if (auto resultTy = op.getType().dyn_cast<RankedTensorType>()) {
|
||||
// nd-tensor where n >= 1
|
||||
auto resultLayout = resultTy.getEncoding().cast<SliceEncodingAttr>();
|
||||
auto resultShape = resultTy.getShape();
|
||||
SmallVector<unsigned> resultOrd;
|
||||
for (auto ord : order) {
|
||||
if (ord != 0)
|
||||
@@ -1859,15 +1841,18 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast(
|
||||
}
|
||||
|
||||
unsigned resultElems = getElemsPerThread(resultTy);
|
||||
auto resultIndices = emitIndices(loc, rewriter, resultLayout, resultShape);
|
||||
auto resultIndices =
|
||||
emitIndices(loc, rewriter, resultLayout, resultTy.getShape());
|
||||
assert(resultIndices.size() == resultElems);
|
||||
|
||||
SmallVector<Value> resultVals(resultElems);
|
||||
SmallVector<unsigned> resultShape;
|
||||
std::copy(resultTy.getShape().begin(), resultTy.getShape().end(),
|
||||
std::back_inserter(resultShape));
|
||||
for (size_t i = 0; i < resultElems; ++i) {
|
||||
SmallVector<Value> readIdx = resultIndices[i];
|
||||
Value readOffset =
|
||||
linearize(rewriter, loc, reorder<Value>(readIdx, resultOrd),
|
||||
reorder<int64_t, unsigned>(resultShape, resultOrd));
|
||||
linearize(rewriter, loc, readIdx, resultShape, resultOrd);
|
||||
Value readPtr = gep(elemPtrTy, smemBase, readOffset);
|
||||
resultVals[i] = load(readPtr);
|
||||
}
|
||||
@@ -2818,8 +2803,8 @@ private:
|
||||
auto multiDimOffsetFirstElem =
|
||||
emitBaseIndexForBlockedLayout(loc, rewriter, blockedLayout, shape);
|
||||
SmallVector<Value> multiDimOffset(rank);
|
||||
SmallVector<unsigned> multiDimElemId =
|
||||
getMultiDimIndex<unsigned>(elemId, blockedLayout.getSizePerThread());
|
||||
SmallVector<unsigned> multiDimElemId = getMultiDimIndex<unsigned>(
|
||||
elemId, blockedLayout.getSizePerThread(), blockedLayout.getOrder());
|
||||
for (unsigned d = 0; d < rank; ++d) {
|
||||
multiDimOffset[d] = add(multiDimOffsetFirstElem[d],
|
||||
idx_val(multiDimCTAInRepId[d] * shapePerCTA[d] +
|
||||
@@ -2850,9 +2835,7 @@ private:
|
||||
Value warpSize = idx_val(32);
|
||||
Value laneId = urem(threadId, warpSize);
|
||||
Value warpId = udiv(threadId, warpSize);
|
||||
// auto multiDimWarpId =
|
||||
// delinearize(rewriter, loc, warpId, mmaLayout.getWarpsPerCTA());
|
||||
// TODO: double confirm if its document bug or DotConversion's Bug
|
||||
// TODO: fix the bug in MMAEncodingAttr document
|
||||
SmallVector<Value> multiDimWarpId(2);
|
||||
multiDimWarpId[0] = urem(warpId, idx_val(mmaLayout.getWarpsPerCTA()[0]));
|
||||
multiDimWarpId[1] = udiv(warpId, idx_val(mmaLayout.getWarpsPerCTA()[0]));
|
||||
@@ -2942,6 +2925,7 @@ void ConvertLayoutOpConversion::processReplica(
|
||||
auto accumSizePerThread = product<unsigned>(sizePerThread);
|
||||
SmallVector<unsigned> numCTAs(rank);
|
||||
auto shapePerCTA = getShapePerCTA(layout);
|
||||
auto order = getOrder(layout);
|
||||
for (unsigned d = 0; d < rank; ++d) {
|
||||
numCTAs[d] = ceil<unsigned>(type.getShape()[d], shapePerCTA[d]);
|
||||
}
|
||||
@@ -2957,14 +2941,16 @@ void ConvertLayoutOpConversion::processReplica(
|
||||
auto llvmElemTy = getTypeConverter()->convertType(elemTy);
|
||||
|
||||
for (unsigned ctaId = 0; ctaId < accumNumCTAsEachRep; ++ctaId) {
|
||||
auto multiDimCTAInRepId = getMultiDimIndex<unsigned>(ctaId, numCTAsEachRep);
|
||||
auto multiDimCTAInRepId =
|
||||
getMultiDimIndex<unsigned>(ctaId, numCTAsEachRep, order);
|
||||
SmallVector<unsigned> multiDimCTAId(rank);
|
||||
for (auto it : llvm::enumerate(multiDimCTAInRepId)) {
|
||||
auto d = it.index();
|
||||
multiDimCTAId[d] = multiDimRepId[d] * numCTAsEachRep[d] + it.value();
|
||||
}
|
||||
|
||||
unsigned linearCTAId = getLinearIndex<unsigned>(multiDimCTAId, numCTAs);
|
||||
unsigned linearCTAId =
|
||||
getLinearIndex<unsigned>(multiDimCTAId, numCTAs, order);
|
||||
// TODO: This is actually redundant index calculation, we should
|
||||
// consider of caching the index calculation result in case
|
||||
// of performance issue observed.
|
||||
@@ -2973,8 +2959,7 @@ void ConvertLayoutOpConversion::processReplica(
|
||||
getMultiDimOffset(layout, loc, rewriter, elemId, type.getShape(),
|
||||
multiDimCTAInRepId, shapePerCTA);
|
||||
Value offset =
|
||||
linearize(rewriter, loc, reorder<Value>(multiDimOffset, outOrd),
|
||||
reorder<unsigned>(paddedRepShape, outOrd));
|
||||
linearize(rewriter, loc, multiDimOffset, paddedRepShape, outOrd);
|
||||
auto elemPtrTy = ptr_ty(llvmElemTy, 3);
|
||||
Value ptr = gep(elemPtrTy, smemBase, offset);
|
||||
auto vecTy = vec_ty(llvmElemTy, vec);
|
||||
@@ -3055,7 +3040,8 @@ LogicalResult ConvertLayoutOpConversion::lowerDistributedToDistributed(
|
||||
SmallVector<Value> outVals(outElems);
|
||||
|
||||
for (unsigned repId = 0; repId < accumNumReplicates; ++repId) {
|
||||
auto multiDimRepId = getMultiDimIndex<unsigned>(repId, numReplicates);
|
||||
auto multiDimRepId =
|
||||
getMultiDimIndex<unsigned>(repId, numReplicates, outOrd);
|
||||
barrier();
|
||||
if (srcLayout.isa<BlockedEncodingAttr>() ||
|
||||
srcLayout.isa<SliceEncodingAttr>() ||
|
||||
|
@@ -154,6 +154,19 @@ def get_variant_golden(a, b):
|
||||
c_padded = torch.matmul(a_padded, b_padded)
|
||||
return c_padded[:SIZE_M, :SIZE_N]
|
||||
|
||||
# It's not easy to get a proper error threshold in different size
|
||||
# Here the gemm calculation is padded to a different size in order to get
|
||||
# a variant version of the golden result. And the error between golden and
|
||||
# golden_variant provide reference on selecting the proper rtol / atol.
|
||||
|
||||
|
||||
def get_proper_err(a, b, golden):
|
||||
golden_variant = get_variant_golden(a, b)
|
||||
golden_diff = golden - golden_variant
|
||||
golden_abs_err = torch.max(torch.abs(golden_diff)).item()
|
||||
golden_rel_err = torch.max(torch.abs(golden_diff / golden)).item()
|
||||
return (golden_abs_err, golden_rel_err)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('SIZE_M,SIZE_N,SIZE_K,NUM_WARPS,BLOCK_SIZE_M,BLOCK_SIZE_N,BLOCK_SIZE_K,TRANS_A,TRANS_B', [
|
||||
# Non-forloop
|
||||
@@ -198,16 +211,7 @@ def test_gemm(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLO
|
||||
BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K,
|
||||
num_warps=NUM_WARPS)
|
||||
golden = torch.matmul(a, b)
|
||||
|
||||
# It's not easy to get a proper error threshold in different size
|
||||
# Here the gemm calculation is padded to a different size in order to get
|
||||
# a variant version of the golden result. And the error between golden and
|
||||
# golden_variant provide reference on selecting the proper rtol / atol.
|
||||
golden_variant = get_variant_golden(a, b)
|
||||
golden_diff = golden - golden_variant
|
||||
golden_abs_err = torch.max(torch.abs(golden_diff)).item()
|
||||
golden_rel_err = torch.max(torch.abs(golden_diff / golden)).item()
|
||||
|
||||
golden_abs_err, golden_rel_err = get_proper_err(a, b, golden)
|
||||
torch.set_printoptions(profile="full")
|
||||
assert_close(c, golden, rtol=max(1e-4, 1.5 * golden_rel_err), atol=max(1e-4, 1.5 * golden_abs_err), check_dtype=False)
|
||||
|
||||
@@ -272,4 +276,5 @@ def test_gemm_fmadot(M, N, K, num_warps, block_M, block_N, block_K):
|
||||
BLOCK_SIZE_M=block_M, BLOCK_SIZE_N=block_N, BLOCK_SIZE_K=block_K)
|
||||
|
||||
golden = torch.matmul(a, b)
|
||||
torch.testing.assert_close(c, golden)
|
||||
golden_abs_err, golden_rel_err = get_proper_err(a, b, golden)
|
||||
torch.testing.assert_close(c, golden, rtol=max(1e-4, 1.5 * golden_rel_err), atol=max(1e-4, 1.5 * golden_abs_err))
|
||||
|
@@ -245,12 +245,12 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
%0 = tt.view %arg : (tensor<256xf32, #blocked0>) -> tensor<256x1xf32,#blocked2>
|
||||
// CHECK: llvm.mlir.undef
|
||||
// CHECK: llvm.insertvalue %[[T0]]
|
||||
// CHECK: llvm.insertvalue %[[T0]]
|
||||
// CHECK: llvm.insertvalue %[[T0]]
|
||||
// CHECK: llvm.insertvalue %[[T1]]
|
||||
// CHECK: llvm.insertvalue %[[T0]]
|
||||
// CHECK: llvm.insertvalue %[[T1]]
|
||||
// CHECK: llvm.insertvalue %[[T0]]
|
||||
// CHECK: llvm.insertvalue %[[T1]]
|
||||
// CHECK: llvm.insertvalue %[[T1]]
|
||||
// CHECK: llvm.insertvalue %[[T0]]
|
||||
// CHECK: llvm.insertvalue %[[T1]]
|
||||
%1 = tt.broadcast %0 : (tensor<256x1xf32,#blocked2>) -> tensor<256x4xf32, #blocked2>
|
||||
return
|
||||
|
Reference in New Issue
Block a user