diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index 2cbb47e85..b90a8d933 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -356,25 +356,6 @@ Value getStructFromElements(Location loc, ValueRange resultVals, return llvmStruct; } -// TODO[goostavz]: to be deprecated -// delinearize supposing order is [n, .. , 2, 1, 0] -template -static SmallVector getMultiDimIndex(T linearIndex, ArrayRef shape) { - // shape: {a, b, c, d} -> accMul: {b*c*d, c*d, d, 1} - size_t rank = shape.size(); - T accMul = product(shape.drop_front()); - T linearRemain = linearIndex; - SmallVector multiDimIndex(rank); - for (size_t i = 0; i < rank; ++i) { - multiDimIndex[i] = linearRemain / accMul; - linearRemain = linearRemain % accMul; - if (i != (rank - 1)) { - accMul = accMul / shape[i + 1]; - } - } - return multiDimIndex; -} - // delinearize supposing order is [0, 1, .. , n] template static SmallVector getMultiDimIndexImpl(T linearIndex, ArrayRef shape) { @@ -407,24 +388,7 @@ static SmallVector getMultiDimIndex(T linearIndex, ArrayRef shape, return multiDim; } -// TODO[goostavz]: to be deprecated -// linearize supposing order is [n, .. , 2, 1, 0] -template -static T getLinearIndex(ArrayRef multiDimIndex, ArrayRef shape) { - assert(multiDimIndex.size() == shape.size()); - // shape: {a, b, c, d} -> accMul: {b*c*d, c*d, d, 1} - size_t rank = shape.size(); - T accMul = product(shape.drop_front()); - T linearIndex = 0; - for (size_t i = 0; i < rank; ++i) { - linearIndex += multiDimIndex[i] * accMul; - if (i != (rank - 1)) { - accMul = accMul / shape[i + 1]; - } - } - return linearIndex; -} - +// linearize supposing order is [0, 1, .. , n] template static T getLinearIndexImpl(ArrayRef multiDimIndex, ArrayRef shape) { assert(multiDimIndex.size() == shape.size()); @@ -621,6 +585,13 @@ public: return multiDim; } + Value linearize(ConversionPatternRewriter &rewriter, Location loc, + ArrayRef multiDim, ArrayRef shape, + ArrayRef order) const { + return linearize(rewriter, loc, reorder(multiDim, order), + reorder(shape, order)); + } + Value linearize(ConversionPatternRewriter &rewriter, Location loc, ArrayRef multiDim, ArrayRef shape) const { int rank = multiDim.size(); @@ -1436,10 +1407,12 @@ struct BroadcastOpConversion auto resultShape = resultTy.getShape(); unsigned rank = srcTy.getRank(); assert(rank == resultTy.getRank()); + auto order = srcLayout.getOrder(); - SmallVector srcLogicalShape(2 * rank); - SmallVector resultLogicalShape(2 * rank); - SmallVector broadcastDims; + SmallVector srcLogicalShape(2 * rank); + SmallVector srcLogicalOrder(2 * rank); + SmallVector resultLogicalShape(2 * rank); + SmallVector broadcastDims; for (unsigned d = 0; d < rank; ++d) { unsigned resultShapePerCTA = resultLayout.getSizePerThread()[d] * resultLayout.getThreadsPerWarp()[d] * @@ -1457,9 +1430,13 @@ struct BroadcastOpConversion } resultLogicalShape[d] = numCtas; resultLogicalShape[d + rank] = resultLayout.getSizePerThread()[d]; + + srcLogicalOrder[d] = order[d] + rank; + srcLogicalOrder[d + rank] = order[d]; } int64_t duplicates = 1; - SmallVector broadcastSizes(broadcastDims.size() * 2); + SmallVector broadcastSizes(broadcastDims.size() * 2); + SmallVector broadcastOrder(broadcastDims.size() * 2); for (auto it : llvm::enumerate(broadcastDims)) { // Incase there are multiple indices in the src that is actually // calculating the same element, srcLogicalShape may not need to be 1. @@ -1468,30 +1445,44 @@ struct BroadcastOpConversion // [1, 2] int64_t d = resultLogicalShape[it.value()] / srcLogicalShape[it.value()]; broadcastSizes[it.index()] = d; + broadcastOrder[it.index()] = srcLogicalOrder[it.value()]; duplicates *= d; d = resultLogicalShape[it.value() + rank] / srcLogicalShape[it.value() + rank]; broadcastSizes[it.index() + broadcastDims.size()] = d; + broadcastOrder[it.index() + broadcastDims.size()] = + srcLogicalOrder[it.value() + rank]; duplicates *= d; } + auto argsort = [](SmallVector input) { + SmallVector idx(input.size()); + std::iota(idx.begin(), idx.end(), 0); + std::sort(idx.begin(), idx.end(), [&input](unsigned a, unsigned b) { + return input[a] < input[b]; + }); + return idx; + }; + broadcastOrder = argsort(broadcastOrder); unsigned srcElems = getElemsPerThread(srcTy); auto srcVals = getElementsFromStruct(loc, src, rewriter); unsigned resultElems = getElemsPerThread(resultTy); SmallVector resultVals(resultElems); for (unsigned i = 0; i < srcElems; ++i) { - auto srcMultiDim = getMultiDimIndex(i, srcLogicalShape); + auto srcMultiDim = + getMultiDimIndex(i, srcLogicalShape, srcLogicalOrder); for (int64_t j = 0; j < duplicates; ++j) { auto resultMultiDim = srcMultiDim; - auto bcastMultiDim = getMultiDimIndex(j, broadcastSizes); + auto bcastMultiDim = + getMultiDimIndex(j, broadcastSizes, broadcastOrder); for (auto bcastDim : llvm::enumerate(broadcastDims)) { resultMultiDim[bcastDim.value()] += bcastMultiDim[bcastDim.index()]; resultMultiDim[bcastDim.value() + rank] += bcastMultiDim[bcastDim.index() + broadcastDims.size()] * srcLogicalShape[bcastDim.index() + broadcastDims.size()]; } - auto resultLinearIndex = - getLinearIndex(resultMultiDim, resultLogicalShape); + auto resultLinearIndex = getLinearIndex( + resultMultiDim, resultLogicalShape, srcLogicalOrder); resultVals[resultLinearIndex] = srcVals[i]; } } @@ -1665,9 +1656,7 @@ LogicalResult ReduceOpConversion::matchAndRewriteBasic( SmallVector writeIdx = indices[key]; writeIdx[axis] = udiv(writeIdx[axis], sizePerThread); - Value writeOffset = - linearize(rewriter, loc, reorder(writeIdx, srcOrd), - reorder(smemShape, srcOrd)); + Value writeOffset = linearize(rewriter, loc, writeIdx, smemShape, srcOrd); Value writePtr = gep(elemPtrTy, smemBase, writeOffset); store(acc, writePtr); @@ -1676,9 +1665,7 @@ LogicalResult ReduceOpConversion::matchAndRewriteBasic( readIdx[axis] = ints[N]; Value readMask = icmp_slt(writeIdx[axis], ints[N]); Value readOffset = - select(readMask, - linearize(rewriter, loc, reorder(readIdx, srcOrd), - reorder(smemShape, srcOrd)), + select(readMask, linearize(rewriter, loc, readIdx, smemShape, srcOrd), ints[0]); Value readPtr = gep(elemPtrTy, writePtr, readOffset); barrier(); @@ -1702,9 +1689,7 @@ LogicalResult ReduceOpConversion::matchAndRewriteBasic( for (unsigned i = 0; i < resultElems; ++i) { SmallVector readIdx = resultIndices[i]; readIdx.insert(readIdx.begin() + axis, ints[0]); - Value readOffset = - linearize(rewriter, loc, reorder(readIdx, srcOrd), - reorder(smemShape, srcOrd)); + Value readOffset = linearize(rewriter, loc, readIdx, smemShape, srcOrd); Value readPtr = gep(elemPtrTy, smemBase, readOffset); resultVals[i] = load(readPtr); } @@ -1798,9 +1783,7 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast( SmallVector writeIdx = indices[key]; writeIdx[axis] = (sizeInterWarps == 1) ? zero : warpIdAxis; - Value writeOffset = - linearize(rewriter, loc, reorder(writeIdx, order), - reorder(smemShape, order)); + Value writeOffset = linearize(rewriter, loc, writeIdx, smemShape, order); Value writePtr = gep(elemPtrTy, smemBase, writeOffset); storeShared(rewriter, loc, writePtr, acc, laneZero); } @@ -1851,7 +1834,6 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast( if (auto resultTy = op.getType().dyn_cast()) { // nd-tensor where n >= 1 auto resultLayout = resultTy.getEncoding().cast(); - auto resultShape = resultTy.getShape(); SmallVector resultOrd; for (auto ord : order) { if (ord != 0) @@ -1859,15 +1841,18 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast( } unsigned resultElems = getElemsPerThread(resultTy); - auto resultIndices = emitIndices(loc, rewriter, resultLayout, resultShape); + auto resultIndices = + emitIndices(loc, rewriter, resultLayout, resultTy.getShape()); assert(resultIndices.size() == resultElems); SmallVector resultVals(resultElems); + SmallVector resultShape; + std::copy(resultTy.getShape().begin(), resultTy.getShape().end(), + std::back_inserter(resultShape)); for (size_t i = 0; i < resultElems; ++i) { SmallVector readIdx = resultIndices[i]; Value readOffset = - linearize(rewriter, loc, reorder(readIdx, resultOrd), - reorder(resultShape, resultOrd)); + linearize(rewriter, loc, readIdx, resultShape, resultOrd); Value readPtr = gep(elemPtrTy, smemBase, readOffset); resultVals[i] = load(readPtr); } @@ -2818,8 +2803,8 @@ private: auto multiDimOffsetFirstElem = emitBaseIndexForBlockedLayout(loc, rewriter, blockedLayout, shape); SmallVector multiDimOffset(rank); - SmallVector multiDimElemId = - getMultiDimIndex(elemId, blockedLayout.getSizePerThread()); + SmallVector multiDimElemId = getMultiDimIndex( + elemId, blockedLayout.getSizePerThread(), blockedLayout.getOrder()); for (unsigned d = 0; d < rank; ++d) { multiDimOffset[d] = add(multiDimOffsetFirstElem[d], idx_val(multiDimCTAInRepId[d] * shapePerCTA[d] + @@ -2850,9 +2835,7 @@ private: Value warpSize = idx_val(32); Value laneId = urem(threadId, warpSize); Value warpId = udiv(threadId, warpSize); - // auto multiDimWarpId = - // delinearize(rewriter, loc, warpId, mmaLayout.getWarpsPerCTA()); - // TODO: double confirm if its document bug or DotConversion's Bug + // TODO: fix the bug in MMAEncodingAttr document SmallVector multiDimWarpId(2); multiDimWarpId[0] = urem(warpId, idx_val(mmaLayout.getWarpsPerCTA()[0])); multiDimWarpId[1] = udiv(warpId, idx_val(mmaLayout.getWarpsPerCTA()[0])); @@ -2942,6 +2925,7 @@ void ConvertLayoutOpConversion::processReplica( auto accumSizePerThread = product(sizePerThread); SmallVector numCTAs(rank); auto shapePerCTA = getShapePerCTA(layout); + auto order = getOrder(layout); for (unsigned d = 0; d < rank; ++d) { numCTAs[d] = ceil(type.getShape()[d], shapePerCTA[d]); } @@ -2957,14 +2941,16 @@ void ConvertLayoutOpConversion::processReplica( auto llvmElemTy = getTypeConverter()->convertType(elemTy); for (unsigned ctaId = 0; ctaId < accumNumCTAsEachRep; ++ctaId) { - auto multiDimCTAInRepId = getMultiDimIndex(ctaId, numCTAsEachRep); + auto multiDimCTAInRepId = + getMultiDimIndex(ctaId, numCTAsEachRep, order); SmallVector multiDimCTAId(rank); for (auto it : llvm::enumerate(multiDimCTAInRepId)) { auto d = it.index(); multiDimCTAId[d] = multiDimRepId[d] * numCTAsEachRep[d] + it.value(); } - unsigned linearCTAId = getLinearIndex(multiDimCTAId, numCTAs); + unsigned linearCTAId = + getLinearIndex(multiDimCTAId, numCTAs, order); // TODO: This is actually redundant index calculation, we should // consider of caching the index calculation result in case // of performance issue observed. @@ -2973,8 +2959,7 @@ void ConvertLayoutOpConversion::processReplica( getMultiDimOffset(layout, loc, rewriter, elemId, type.getShape(), multiDimCTAInRepId, shapePerCTA); Value offset = - linearize(rewriter, loc, reorder(multiDimOffset, outOrd), - reorder(paddedRepShape, outOrd)); + linearize(rewriter, loc, multiDimOffset, paddedRepShape, outOrd); auto elemPtrTy = ptr_ty(llvmElemTy, 3); Value ptr = gep(elemPtrTy, smemBase, offset); auto vecTy = vec_ty(llvmElemTy, vec); @@ -3055,7 +3040,8 @@ LogicalResult ConvertLayoutOpConversion::lowerDistributedToDistributed( SmallVector outVals(outElems); for (unsigned repId = 0; repId < accumNumReplicates; ++repId) { - auto multiDimRepId = getMultiDimIndex(repId, numReplicates); + auto multiDimRepId = + getMultiDimIndex(repId, numReplicates, outOrd); barrier(); if (srcLayout.isa() || srcLayout.isa() || diff --git a/python/tests/test_gemm.py b/python/tests/test_gemm.py index 17fca9f6e..bd8ddd27c 100644 --- a/python/tests/test_gemm.py +++ b/python/tests/test_gemm.py @@ -154,6 +154,19 @@ def get_variant_golden(a, b): c_padded = torch.matmul(a_padded, b_padded) return c_padded[:SIZE_M, :SIZE_N] +# It's not easy to get a proper error threshold in different size +# Here the gemm calculation is padded to a different size in order to get +# a variant version of the golden result. And the error between golden and +# golden_variant provide reference on selecting the proper rtol / atol. + + +def get_proper_err(a, b, golden): + golden_variant = get_variant_golden(a, b) + golden_diff = golden - golden_variant + golden_abs_err = torch.max(torch.abs(golden_diff)).item() + golden_rel_err = torch.max(torch.abs(golden_diff / golden)).item() + return (golden_abs_err, golden_rel_err) + @pytest.mark.parametrize('SIZE_M,SIZE_N,SIZE_K,NUM_WARPS,BLOCK_SIZE_M,BLOCK_SIZE_N,BLOCK_SIZE_K,TRANS_A,TRANS_B', [ # Non-forloop @@ -198,16 +211,7 @@ def test_gemm(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLO BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K, num_warps=NUM_WARPS) golden = torch.matmul(a, b) - - # It's not easy to get a proper error threshold in different size - # Here the gemm calculation is padded to a different size in order to get - # a variant version of the golden result. And the error between golden and - # golden_variant provide reference on selecting the proper rtol / atol. - golden_variant = get_variant_golden(a, b) - golden_diff = golden - golden_variant - golden_abs_err = torch.max(torch.abs(golden_diff)).item() - golden_rel_err = torch.max(torch.abs(golden_diff / golden)).item() - + golden_abs_err, golden_rel_err = get_proper_err(a, b, golden) torch.set_printoptions(profile="full") assert_close(c, golden, rtol=max(1e-4, 1.5 * golden_rel_err), atol=max(1e-4, 1.5 * golden_abs_err), check_dtype=False) @@ -272,4 +276,5 @@ def test_gemm_fmadot(M, N, K, num_warps, block_M, block_N, block_K): BLOCK_SIZE_M=block_M, BLOCK_SIZE_N=block_N, BLOCK_SIZE_K=block_K) golden = torch.matmul(a, b) - torch.testing.assert_close(c, golden) + golden_abs_err, golden_rel_err = get_proper_err(a, b, golden) + torch.testing.assert_close(c, golden, rtol=max(1e-4, 1.5 * golden_rel_err), atol=max(1e-4, 1.5 * golden_abs_err)) diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index bcb729ef0..62286b21d 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -245,12 +245,12 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { %0 = tt.view %arg : (tensor<256xf32, #blocked0>) -> tensor<256x1xf32,#blocked2> // CHECK: llvm.mlir.undef // CHECK: llvm.insertvalue %[[T0]] - // CHECK: llvm.insertvalue %[[T0]] - // CHECK: llvm.insertvalue %[[T0]] + // CHECK: llvm.insertvalue %[[T1]] // CHECK: llvm.insertvalue %[[T0]] // CHECK: llvm.insertvalue %[[T1]] + // CHECK: llvm.insertvalue %[[T0]] // CHECK: llvm.insertvalue %[[T1]] - // CHECK: llvm.insertvalue %[[T1]] + // CHECK: llvm.insertvalue %[[T0]] // CHECK: llvm.insertvalue %[[T1]] %1 = tt.broadcast %0 : (tensor<256x1xf32,#blocked2>) -> tensor<256x4xf32, #blocked2> return