[Triton-MLIR][Backend] Some cleanup in getMultiDimIndex/getLinearIndex (#880)

This commit is contained in:
goostavz
2022-11-18 09:19:21 +08:00
committed by GitHub
parent 5eee738df7
commit 9ea6135eb5
3 changed files with 76 additions and 85 deletions

View File

@@ -356,25 +356,6 @@ Value getStructFromElements(Location loc, ValueRange resultVals,
return llvmStruct;
}
// TODO[goostavz]: to be deprecated
// delinearize supposing order is [n, .. , 2, 1, 0]
template <typename T>
static SmallVector<T> getMultiDimIndex(T linearIndex, ArrayRef<T> shape) {
// shape: {a, b, c, d} -> accMul: {b*c*d, c*d, d, 1}
size_t rank = shape.size();
T accMul = product(shape.drop_front());
T linearRemain = linearIndex;
SmallVector<T> multiDimIndex(rank);
for (size_t i = 0; i < rank; ++i) {
multiDimIndex[i] = linearRemain / accMul;
linearRemain = linearRemain % accMul;
if (i != (rank - 1)) {
accMul = accMul / shape[i + 1];
}
}
return multiDimIndex;
}
// delinearize supposing order is [0, 1, .. , n]
template <typename T>
static SmallVector<T> getMultiDimIndexImpl(T linearIndex, ArrayRef<T> shape) {
@@ -407,24 +388,7 @@ static SmallVector<T> getMultiDimIndex(T linearIndex, ArrayRef<T> shape,
return multiDim;
}
// TODO[goostavz]: to be deprecated
// linearize supposing order is [n, .. , 2, 1, 0]
template <typename T>
static T getLinearIndex(ArrayRef<T> multiDimIndex, ArrayRef<T> shape) {
assert(multiDimIndex.size() == shape.size());
// shape: {a, b, c, d} -> accMul: {b*c*d, c*d, d, 1}
size_t rank = shape.size();
T accMul = product(shape.drop_front());
T linearIndex = 0;
for (size_t i = 0; i < rank; ++i) {
linearIndex += multiDimIndex[i] * accMul;
if (i != (rank - 1)) {
accMul = accMul / shape[i + 1];
}
}
return linearIndex;
}
// linearize supposing order is [0, 1, .. , n]
template <typename T>
static T getLinearIndexImpl(ArrayRef<T> multiDimIndex, ArrayRef<T> shape) {
assert(multiDimIndex.size() == shape.size());
@@ -621,6 +585,13 @@ public:
return multiDim;
}
Value linearize(ConversionPatternRewriter &rewriter, Location loc,
ArrayRef<Value> multiDim, ArrayRef<unsigned> shape,
ArrayRef<unsigned> order) const {
return linearize(rewriter, loc, reorder<Value>(multiDim, order),
reorder<unsigned>(shape, order));
}
Value linearize(ConversionPatternRewriter &rewriter, Location loc,
ArrayRef<Value> multiDim, ArrayRef<unsigned> shape) const {
int rank = multiDim.size();
@@ -1436,10 +1407,12 @@ struct BroadcastOpConversion
auto resultShape = resultTy.getShape();
unsigned rank = srcTy.getRank();
assert(rank == resultTy.getRank());
auto order = srcLayout.getOrder();
SmallVector<int64_t, 4> srcLogicalShape(2 * rank);
SmallVector<int64_t, 4> resultLogicalShape(2 * rank);
SmallVector<unsigned, 2> broadcastDims;
SmallVector<int64_t> srcLogicalShape(2 * rank);
SmallVector<unsigned> srcLogicalOrder(2 * rank);
SmallVector<int64_t> resultLogicalShape(2 * rank);
SmallVector<unsigned> broadcastDims;
for (unsigned d = 0; d < rank; ++d) {
unsigned resultShapePerCTA = resultLayout.getSizePerThread()[d] *
resultLayout.getThreadsPerWarp()[d] *
@@ -1457,9 +1430,13 @@ struct BroadcastOpConversion
}
resultLogicalShape[d] = numCtas;
resultLogicalShape[d + rank] = resultLayout.getSizePerThread()[d];
srcLogicalOrder[d] = order[d] + rank;
srcLogicalOrder[d + rank] = order[d];
}
int64_t duplicates = 1;
SmallVector<int64_t, 2> broadcastSizes(broadcastDims.size() * 2);
SmallVector<int64_t> broadcastSizes(broadcastDims.size() * 2);
SmallVector<unsigned> broadcastOrder(broadcastDims.size() * 2);
for (auto it : llvm::enumerate(broadcastDims)) {
// Incase there are multiple indices in the src that is actually
// calculating the same element, srcLogicalShape may not need to be 1.
@@ -1468,30 +1445,44 @@ struct BroadcastOpConversion
// [1, 2]
int64_t d = resultLogicalShape[it.value()] / srcLogicalShape[it.value()];
broadcastSizes[it.index()] = d;
broadcastOrder[it.index()] = srcLogicalOrder[it.value()];
duplicates *= d;
d = resultLogicalShape[it.value() + rank] /
srcLogicalShape[it.value() + rank];
broadcastSizes[it.index() + broadcastDims.size()] = d;
broadcastOrder[it.index() + broadcastDims.size()] =
srcLogicalOrder[it.value() + rank];
duplicates *= d;
}
auto argsort = [](SmallVector<unsigned> input) {
SmallVector<unsigned> idx(input.size());
std::iota(idx.begin(), idx.end(), 0);
std::sort(idx.begin(), idx.end(), [&input](unsigned a, unsigned b) {
return input[a] < input[b];
});
return idx;
};
broadcastOrder = argsort(broadcastOrder);
unsigned srcElems = getElemsPerThread(srcTy);
auto srcVals = getElementsFromStruct(loc, src, rewriter);
unsigned resultElems = getElemsPerThread(resultTy);
SmallVector<Value> resultVals(resultElems);
for (unsigned i = 0; i < srcElems; ++i) {
auto srcMultiDim = getMultiDimIndex<int64_t>(i, srcLogicalShape);
auto srcMultiDim =
getMultiDimIndex<int64_t>(i, srcLogicalShape, srcLogicalOrder);
for (int64_t j = 0; j < duplicates; ++j) {
auto resultMultiDim = srcMultiDim;
auto bcastMultiDim = getMultiDimIndex<int64_t>(j, broadcastSizes);
auto bcastMultiDim =
getMultiDimIndex<int64_t>(j, broadcastSizes, broadcastOrder);
for (auto bcastDim : llvm::enumerate(broadcastDims)) {
resultMultiDim[bcastDim.value()] += bcastMultiDim[bcastDim.index()];
resultMultiDim[bcastDim.value() + rank] +=
bcastMultiDim[bcastDim.index() + broadcastDims.size()] *
srcLogicalShape[bcastDim.index() + broadcastDims.size()];
}
auto resultLinearIndex =
getLinearIndex<int64_t>(resultMultiDim, resultLogicalShape);
auto resultLinearIndex = getLinearIndex<int64_t>(
resultMultiDim, resultLogicalShape, srcLogicalOrder);
resultVals[resultLinearIndex] = srcVals[i];
}
}
@@ -1665,9 +1656,7 @@ LogicalResult ReduceOpConversion::matchAndRewriteBasic(
SmallVector<Value> writeIdx = indices[key];
writeIdx[axis] = udiv(writeIdx[axis], sizePerThread);
Value writeOffset =
linearize(rewriter, loc, reorder<Value>(writeIdx, srcOrd),
reorder<unsigned>(smemShape, srcOrd));
Value writeOffset = linearize(rewriter, loc, writeIdx, smemShape, srcOrd);
Value writePtr = gep(elemPtrTy, smemBase, writeOffset);
store(acc, writePtr);
@@ -1676,9 +1665,7 @@ LogicalResult ReduceOpConversion::matchAndRewriteBasic(
readIdx[axis] = ints[N];
Value readMask = icmp_slt(writeIdx[axis], ints[N]);
Value readOffset =
select(readMask,
linearize(rewriter, loc, reorder<Value>(readIdx, srcOrd),
reorder<unsigned>(smemShape, srcOrd)),
select(readMask, linearize(rewriter, loc, readIdx, smemShape, srcOrd),
ints[0]);
Value readPtr = gep(elemPtrTy, writePtr, readOffset);
barrier();
@@ -1702,9 +1689,7 @@ LogicalResult ReduceOpConversion::matchAndRewriteBasic(
for (unsigned i = 0; i < resultElems; ++i) {
SmallVector<Value> readIdx = resultIndices[i];
readIdx.insert(readIdx.begin() + axis, ints[0]);
Value readOffset =
linearize(rewriter, loc, reorder<Value>(readIdx, srcOrd),
reorder<unsigned>(smemShape, srcOrd));
Value readOffset = linearize(rewriter, loc, readIdx, smemShape, srcOrd);
Value readPtr = gep(elemPtrTy, smemBase, readOffset);
resultVals[i] = load(readPtr);
}
@@ -1798,9 +1783,7 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast(
SmallVector<Value> writeIdx = indices[key];
writeIdx[axis] = (sizeInterWarps == 1) ? zero : warpIdAxis;
Value writeOffset =
linearize(rewriter, loc, reorder<Value>(writeIdx, order),
reorder<unsigned>(smemShape, order));
Value writeOffset = linearize(rewriter, loc, writeIdx, smemShape, order);
Value writePtr = gep(elemPtrTy, smemBase, writeOffset);
storeShared(rewriter, loc, writePtr, acc, laneZero);
}
@@ -1851,7 +1834,6 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast(
if (auto resultTy = op.getType().dyn_cast<RankedTensorType>()) {
// nd-tensor where n >= 1
auto resultLayout = resultTy.getEncoding().cast<SliceEncodingAttr>();
auto resultShape = resultTy.getShape();
SmallVector<unsigned> resultOrd;
for (auto ord : order) {
if (ord != 0)
@@ -1859,15 +1841,18 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast(
}
unsigned resultElems = getElemsPerThread(resultTy);
auto resultIndices = emitIndices(loc, rewriter, resultLayout, resultShape);
auto resultIndices =
emitIndices(loc, rewriter, resultLayout, resultTy.getShape());
assert(resultIndices.size() == resultElems);
SmallVector<Value> resultVals(resultElems);
SmallVector<unsigned> resultShape;
std::copy(resultTy.getShape().begin(), resultTy.getShape().end(),
std::back_inserter(resultShape));
for (size_t i = 0; i < resultElems; ++i) {
SmallVector<Value> readIdx = resultIndices[i];
Value readOffset =
linearize(rewriter, loc, reorder<Value>(readIdx, resultOrd),
reorder<int64_t, unsigned>(resultShape, resultOrd));
linearize(rewriter, loc, readIdx, resultShape, resultOrd);
Value readPtr = gep(elemPtrTy, smemBase, readOffset);
resultVals[i] = load(readPtr);
}
@@ -2818,8 +2803,8 @@ private:
auto multiDimOffsetFirstElem =
emitBaseIndexForBlockedLayout(loc, rewriter, blockedLayout, shape);
SmallVector<Value> multiDimOffset(rank);
SmallVector<unsigned> multiDimElemId =
getMultiDimIndex<unsigned>(elemId, blockedLayout.getSizePerThread());
SmallVector<unsigned> multiDimElemId = getMultiDimIndex<unsigned>(
elemId, blockedLayout.getSizePerThread(), blockedLayout.getOrder());
for (unsigned d = 0; d < rank; ++d) {
multiDimOffset[d] = add(multiDimOffsetFirstElem[d],
idx_val(multiDimCTAInRepId[d] * shapePerCTA[d] +
@@ -2850,9 +2835,7 @@ private:
Value warpSize = idx_val(32);
Value laneId = urem(threadId, warpSize);
Value warpId = udiv(threadId, warpSize);
// auto multiDimWarpId =
// delinearize(rewriter, loc, warpId, mmaLayout.getWarpsPerCTA());
// TODO: double confirm if its document bug or DotConversion's Bug
// TODO: fix the bug in MMAEncodingAttr document
SmallVector<Value> multiDimWarpId(2);
multiDimWarpId[0] = urem(warpId, idx_val(mmaLayout.getWarpsPerCTA()[0]));
multiDimWarpId[1] = udiv(warpId, idx_val(mmaLayout.getWarpsPerCTA()[0]));
@@ -2942,6 +2925,7 @@ void ConvertLayoutOpConversion::processReplica(
auto accumSizePerThread = product<unsigned>(sizePerThread);
SmallVector<unsigned> numCTAs(rank);
auto shapePerCTA = getShapePerCTA(layout);
auto order = getOrder(layout);
for (unsigned d = 0; d < rank; ++d) {
numCTAs[d] = ceil<unsigned>(type.getShape()[d], shapePerCTA[d]);
}
@@ -2957,14 +2941,16 @@ void ConvertLayoutOpConversion::processReplica(
auto llvmElemTy = getTypeConverter()->convertType(elemTy);
for (unsigned ctaId = 0; ctaId < accumNumCTAsEachRep; ++ctaId) {
auto multiDimCTAInRepId = getMultiDimIndex<unsigned>(ctaId, numCTAsEachRep);
auto multiDimCTAInRepId =
getMultiDimIndex<unsigned>(ctaId, numCTAsEachRep, order);
SmallVector<unsigned> multiDimCTAId(rank);
for (auto it : llvm::enumerate(multiDimCTAInRepId)) {
auto d = it.index();
multiDimCTAId[d] = multiDimRepId[d] * numCTAsEachRep[d] + it.value();
}
unsigned linearCTAId = getLinearIndex<unsigned>(multiDimCTAId, numCTAs);
unsigned linearCTAId =
getLinearIndex<unsigned>(multiDimCTAId, numCTAs, order);
// TODO: This is actually redundant index calculation, we should
// consider of caching the index calculation result in case
// of performance issue observed.
@@ -2973,8 +2959,7 @@ void ConvertLayoutOpConversion::processReplica(
getMultiDimOffset(layout, loc, rewriter, elemId, type.getShape(),
multiDimCTAInRepId, shapePerCTA);
Value offset =
linearize(rewriter, loc, reorder<Value>(multiDimOffset, outOrd),
reorder<unsigned>(paddedRepShape, outOrd));
linearize(rewriter, loc, multiDimOffset, paddedRepShape, outOrd);
auto elemPtrTy = ptr_ty(llvmElemTy, 3);
Value ptr = gep(elemPtrTy, smemBase, offset);
auto vecTy = vec_ty(llvmElemTy, vec);
@@ -3055,7 +3040,8 @@ LogicalResult ConvertLayoutOpConversion::lowerDistributedToDistributed(
SmallVector<Value> outVals(outElems);
for (unsigned repId = 0; repId < accumNumReplicates; ++repId) {
auto multiDimRepId = getMultiDimIndex<unsigned>(repId, numReplicates);
auto multiDimRepId =
getMultiDimIndex<unsigned>(repId, numReplicates, outOrd);
barrier();
if (srcLayout.isa<BlockedEncodingAttr>() ||
srcLayout.isa<SliceEncodingAttr>() ||

View File

@@ -154,6 +154,19 @@ def get_variant_golden(a, b):
c_padded = torch.matmul(a_padded, b_padded)
return c_padded[:SIZE_M, :SIZE_N]
# It's not easy to get a proper error threshold in different size
# Here the gemm calculation is padded to a different size in order to get
# a variant version of the golden result. And the error between golden and
# golden_variant provide reference on selecting the proper rtol / atol.
def get_proper_err(a, b, golden):
golden_variant = get_variant_golden(a, b)
golden_diff = golden - golden_variant
golden_abs_err = torch.max(torch.abs(golden_diff)).item()
golden_rel_err = torch.max(torch.abs(golden_diff / golden)).item()
return (golden_abs_err, golden_rel_err)
@pytest.mark.parametrize('SIZE_M,SIZE_N,SIZE_K,NUM_WARPS,BLOCK_SIZE_M,BLOCK_SIZE_N,BLOCK_SIZE_K,TRANS_A,TRANS_B', [
# Non-forloop
@@ -198,16 +211,7 @@ def test_gemm(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLO
BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K,
num_warps=NUM_WARPS)
golden = torch.matmul(a, b)
# It's not easy to get a proper error threshold in different size
# Here the gemm calculation is padded to a different size in order to get
# a variant version of the golden result. And the error between golden and
# golden_variant provide reference on selecting the proper rtol / atol.
golden_variant = get_variant_golden(a, b)
golden_diff = golden - golden_variant
golden_abs_err = torch.max(torch.abs(golden_diff)).item()
golden_rel_err = torch.max(torch.abs(golden_diff / golden)).item()
golden_abs_err, golden_rel_err = get_proper_err(a, b, golden)
torch.set_printoptions(profile="full")
assert_close(c, golden, rtol=max(1e-4, 1.5 * golden_rel_err), atol=max(1e-4, 1.5 * golden_abs_err), check_dtype=False)
@@ -272,4 +276,5 @@ def test_gemm_fmadot(M, N, K, num_warps, block_M, block_N, block_K):
BLOCK_SIZE_M=block_M, BLOCK_SIZE_N=block_N, BLOCK_SIZE_K=block_K)
golden = torch.matmul(a, b)
torch.testing.assert_close(c, golden)
golden_abs_err, golden_rel_err = get_proper_err(a, b, golden)
torch.testing.assert_close(c, golden, rtol=max(1e-4, 1.5 * golden_rel_err), atol=max(1e-4, 1.5 * golden_abs_err))

View File

@@ -245,12 +245,12 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
%0 = tt.view %arg : (tensor<256xf32, #blocked0>) -> tensor<256x1xf32,#blocked2>
// CHECK: llvm.mlir.undef
// CHECK: llvm.insertvalue %[[T0]]
// CHECK: llvm.insertvalue %[[T0]]
// CHECK: llvm.insertvalue %[[T0]]
// CHECK: llvm.insertvalue %[[T1]]
// CHECK: llvm.insertvalue %[[T0]]
// CHECK: llvm.insertvalue %[[T1]]
// CHECK: llvm.insertvalue %[[T0]]
// CHECK: llvm.insertvalue %[[T1]]
// CHECK: llvm.insertvalue %[[T1]]
// CHECK: llvm.insertvalue %[[T0]]
// CHECK: llvm.insertvalue %[[T1]]
%1 = tt.broadcast %0 : (tensor<256x1xf32,#blocked2>) -> tensor<256x4xf32, #blocked2>
return