diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index ad64e3ddd..982ba7d37 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -347,7 +347,8 @@ Value getStructFromElements(Location loc, ValueRange resultVals, return llvmStruct; } -// Delinearize on compile-time consts, assuming the order is [n, .. 2, 1, 0] +// TODO[goostavz]: to be deprecated +// delinearize supposing order is [n, .. , 2, 1, 0] template static SmallVector getMultiDimIndex(T linearIndex, ArrayRef shape) { // shape: {a, b, c, d} -> accMul: {b*c*d, c*d, d, 1} @@ -365,7 +366,40 @@ static SmallVector getMultiDimIndex(T linearIndex, ArrayRef shape) { return multiDimIndex; } -// Linearize on compile-time consts, assuming the order is [n, .. 2, 1, 0] +// delinearize supposing order is [0, 1, .. , n] +template +static SmallVector getMultiDimIndexImpl(T linearIndex, ArrayRef shape) { + // shape: {a, b, c, d} -> accMul: {1, a, a*b, a*b*c} + size_t rank = shape.size(); + T accMul = product(shape.drop_back()); + T linearRemain = linearIndex; + SmallVector multiDimIndex(rank); + for (int i = rank - 1; i >= 0; --i) { + multiDimIndex[i] = linearRemain / accMul; + linearRemain = linearRemain % accMul; + if (i != 0) { + accMul = accMul / shape[i - 1]; + } + } + return multiDimIndex; +} + +template +static SmallVector getMultiDimIndex(T linearIndex, ArrayRef shape, + ArrayRef order) { + size_t rank = shape.size(); + assert(rank == order.size()); + auto reordered = reorder(shape, order); + auto reorderedMultiDim = getMultiDimIndexImpl(linearIndex, reordered); + SmallVector multiDim(rank); + for (unsigned i = 0; i < rank; ++i) { + multiDim[order[i]] = reorderedMultiDim[i]; + } + return multiDim; +} + +// TODO[goostavz]: to be deprecated +// linearize supposing order is [n, .. , 2, 1, 0] template static T getLinearIndex(ArrayRef multiDimIndex, ArrayRef shape) { assert(multiDimIndex.size() == shape.size()); @@ -382,6 +416,30 @@ static T getLinearIndex(ArrayRef multiDimIndex, ArrayRef shape) { return linearIndex; } +template +static T getLinearIndexImpl(ArrayRef multiDimIndex, ArrayRef shape) { + assert(multiDimIndex.size() == shape.size()); + // shape: {a, b, c, d} -> accMul: {1, a, a*b, a*b*c} + size_t rank = shape.size(); + T accMul = product(shape.drop_back()); + T linearIndex = 0; + for (int i = rank - 1; i >= 0; --i) { + linearIndex += multiDimIndex[i] * accMul; + if (i != 0) { + accMul = accMul / shape[i - 1]; + } + } + return linearIndex; +} + +template +static T getLinearIndex(ArrayRef multiDimIndex, ArrayRef shape, + ArrayRef order) { + assert(shape.size() == order.size()); + return getLinearIndexImpl(reorder(multiDimIndex, order), + reorder(shape, order)); +} + static Value storeShared(ConversionPatternRewriter &rewriter, Location loc, Value ptr, Value val, Value pred) { MLIRContext *ctx = rewriter.getContext(); @@ -632,6 +690,7 @@ public: auto sizePerThread = blockedLayout.getSizePerThread(); auto threadsPerWarp = blockedLayout.getThreadsPerWarp(); auto warpsPerCTA = blockedLayout.getWarpsPerCTA(); + auto order = blockedLayout.getOrder(); unsigned rank = shape.size(); SmallVector shapePerCTA = getShapePerCTA(blockedLayout); @@ -663,9 +722,9 @@ public: unsigned linearNanoTileId = n / totalSizePerThread; unsigned linearNanoTileElemId = n % totalSizePerThread; SmallVector multiDimNanoTileId = - getMultiDimIndex(linearNanoTileId, tilesPerDim); - SmallVector multiDimNanoTileElemId = - getMultiDimIndex(linearNanoTileElemId, sizePerThread); + getMultiDimIndex(linearNanoTileId, tilesPerDim, order); + SmallVector multiDimNanoTileElemId = getMultiDimIndex( + linearNanoTileElemId, sizePerThread, order); for (unsigned k = 0; k < rank; ++k) { unsigned reorderedMultiDimId = multiDimNanoTileId[k] * @@ -1881,8 +1940,6 @@ struct PrintfOpConversion // currently support pointer, i8, i16, i32, i64, f16, bf16, f32, f64 std::string getFormatSubstr(Value value) const { Type type = value.getType(); - unsigned width = type.getIntOrFloatBitWidth(); - if (type.isa()) { return "%p"; } else if (type.isBF16() || type.isF16() || type.isF32() || type.isF64()) { @@ -1924,13 +1981,11 @@ struct PrintfOpConversion promoteValue(ConversionPatternRewriter &rewriter, Value value) { auto *context = rewriter.getContext(); auto type = value.getType(); - type.dump(); - unsigned width = type.getIntOrFloatBitWidth(); Value newOp = value; Type newType = type; bool bUnsigned = type.isUnsignedInteger(); - if (type.isIntOrIndex() && width < 32) { + if (type.isIntOrIndex() && type.getIntOrFloatBitWidth() < 32) { if (bUnsigned) { newType = ui32_ty; newOp = rewriter.create(UnknownLoc::get(context), newType, @@ -3057,23 +3112,24 @@ LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared( } unsigned linearIdxInNanoTile = i % srcAccumSizeInThreads; auto multiDimIdxInNanoTile = getMultiDimIndex( - linearIdxInNanoTile, srcBlockedLayout.getSizePerThread()); + linearIdxInNanoTile, srcBlockedLayout.getSizePerThread(), inOrd); unsigned pos = multiDimIdxInNanoTile[inOrd[0]] % minVec; multiDimIdxInNanoTile[inOrd[0]] /= minVec; unsigned wordVecIdx = - getLinearIndex(multiDimIdxInNanoTile, wordsInEachRep); + getLinearIndex(multiDimIdxInNanoTile, wordsInEachRep, inOrd); wordVecs[wordVecIdx] = insert_element(wordTy, wordVecs[wordVecIdx], inVals[i], idx_val(pos)); if (i % srcAccumSizeInThreads == srcAccumSizeInThreads - 1) { // end of replication, store the vectors into shared memory unsigned linearRepIdx = i / srcAccumSizeInThreads; - auto multiDimRepIdx = getMultiDimIndex(linearRepIdx, reps); + auto multiDimRepIdx = + getMultiDimIndex(linearRepIdx, reps, inOrd); for (unsigned linearWordIdx = 0; linearWordIdx < numWordsEachRep; ++linearWordIdx) { // step 1: recover the multidim_index from the index of input_elements auto multiDimWordIdx = - getMultiDimIndex(linearWordIdx, wordsInEachRep); + getMultiDimIndex(linearWordIdx, wordsInEachRep, inOrd); SmallVector multiDimIdx(2); auto wordOffset0 = multiDimRepIdx[0] * srcShapePerCTA[0] + multiDimWordIdx[0] * (inOrd[0] == 0 ? minVec : 1); @@ -3083,12 +3139,12 @@ LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared( multiDimIdx[1] = add(multiDimOffsetFirstElem[1], idx_val(wordOffset1)); // step 2: do swizzling - Value remained = urem(multiDimIdx[inOrd[0]], outVecVal); - multiDimIdx[inOrd[0]] = udiv(multiDimIdx[inOrd[0]], outVecVal); - Value off_1 = mul(multiDimIdx[inOrd[1]], idx_val(srcShape[inOrd[0]])); - Value phaseId = udiv(multiDimIdx[inOrd[1]], idx_val(perPhase)); + Value remained = urem(multiDimIdx[outOrd[0]], outVecVal); + multiDimIdx[outOrd[0]] = udiv(multiDimIdx[outOrd[0]], outVecVal); + Value off_1 = mul(multiDimIdx[outOrd[1]], idx_val(srcShape[outOrd[0]])); + Value phaseId = udiv(multiDimIdx[outOrd[1]], idx_val(perPhase)); phaseId = urem(phaseId, idx_val(maxPhase)); - Value off_0 = xor_(multiDimIdx[inOrd[0]], phaseId); + Value off_0 = xor_(multiDimIdx[outOrd[0]], phaseId); off_0 = mul(off_0, outVecVal); remained = udiv(remained, minVecVal); off_0 = add(off_0, mul(remained, minVecVal)); diff --git a/python/tests/test_gemm.py b/python/tests/test_gemm.py index 01738db1c..17fca9f6e 100644 --- a/python/tests/test_gemm.py +++ b/python/tests/test_gemm.py @@ -30,18 +30,32 @@ def matmul_no_scf_kernel( # TODO: num_warps could only be 4 for now -@pytest.mark.parametrize('SIZE_M,SIZE_N,SIZE_K,NUM_WARPS', [ - [128, 256, 32, 4], - [256, 128, 16, 4], - [128, 16, 32, 4], - [32, 128, 64, 4], - [128, 128, 64, 4], - [64, 128, 128, 4], - [64, 128, 128, 2], +@pytest.mark.parametrize('SHAPE,NUM_WARPS,TRANS_A,TRANS_B', [ + (shape, num_warps, trans_a, trans_b) + for shape in [ + [128, 256, 32], + [256, 128, 16], + [128, 16, 32], + [32, 128, 64], + [128, 128, 64], + [64, 128, 128], + ] + for num_warps in [2, 4] + for trans_a in [False, True] + for trans_b in [False, True] ]) -def test_gemm_no_scf(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS): - a = torch.randn((SIZE_M, SIZE_K), device='cuda', dtype=torch.float16) - b = torch.randn((SIZE_K, SIZE_N), device='cuda', dtype=torch.float16) +def test_gemm_no_scf(SHAPE, NUM_WARPS, TRANS_A, TRANS_B): + SIZE_M, SIZE_N, SIZE_K = SHAPE + if (TRANS_A): + a = torch.randn((SIZE_K, SIZE_M), device='cuda', dtype=torch.float16).T + else: + a = torch.randn((SIZE_M, SIZE_K), device='cuda', dtype=torch.float16) + + if (TRANS_B): + b = torch.randn((SIZE_N, SIZE_K), device='cuda', dtype=torch.float16).T + else: + b = torch.randn((SIZE_K, SIZE_N), device='cuda', dtype=torch.float16) + c = torch.empty((SIZE_M, SIZE_N), device=a.device, dtype=torch.float32) grid = lambda META: (1, ) matmul_no_scf_kernel[grid](a_ptr=a, b_ptr=b, c_ptr=c, @@ -55,16 +69,32 @@ def test_gemm_no_scf(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS): assert_close(c, golden, rtol=1e-3, atol=1e-3, check_dtype=False) -@pytest.mark.parametrize('SIZE_M,SIZE_N,SIZE_K,NUM_WARPS', [ - [64, 128, 128, 1], - [128, 128, 128, 4], - [16, 8, 32, 1], - [32, 16, 64, 2], - [32, 16, 64, 4], +@pytest.mark.parametrize('SHAPE,NUM_WARPS,TRANS_A,TRANS_B', [ + (shape, num_warps, trans_a, trans_b) + for shape in [ + [64, 128, 128], + [128, 128, 128], + [16, 8, 32], + [32, 16, 64], + [32, 16, 64], + ] + for num_warps in [1, 2, 4] + for trans_a in [False, True] + for trans_b in [False, True] ]) -def test_gemm_no_scf_int8(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS): - a = torch.randint(-5, 5, (SIZE_M, SIZE_K), device='cuda', dtype=torch.int8) - b = torch.randint(-5, 5, (SIZE_K, SIZE_N), device='cuda', dtype=torch.int8) +def test_gemm_no_scf_int8(SHAPE, NUM_WARPS, TRANS_A, TRANS_B): + SIZE_M, SIZE_N, SIZE_K = SHAPE + + if (TRANS_A): + a = torch.randint(-5, 5, (SIZE_K, SIZE_M), device='cuda', dtype=torch.int8).T + else: + a = torch.randint(-5, 5, (SIZE_M, SIZE_K), device='cuda', dtype=torch.int8) + + if (TRANS_B): + b = torch.randint(-5, 5, (SIZE_N, SIZE_K), device='cuda', dtype=torch.int8).T + else: + b = torch.randint(-5, 5, (SIZE_K, SIZE_N), device='cuda', dtype=torch.int8) + c = torch.empty((SIZE_M, SIZE_N), device=a.device, dtype=torch.int32) grid = lambda META: (1, ) @@ -125,28 +155,39 @@ def get_variant_golden(a, b): return c_padded[:SIZE_M, :SIZE_N] -@pytest.mark.parametrize('SIZE_M,SIZE_N,SIZE_K,NUM_WARPS,BLOCK_SIZE_M,BLOCK_SIZE_N,BLOCK_SIZE_K', [ +@pytest.mark.parametrize('SIZE_M,SIZE_N,SIZE_K,NUM_WARPS,BLOCK_SIZE_M,BLOCK_SIZE_N,BLOCK_SIZE_K,TRANS_A,TRANS_B', [ # Non-forloop - [64, 32, 64, 4, 64, 32, 64], - [128, 64, 128, 4, 128, 64, 128], + [64, 32, 64, 4, 64, 32, 64, False, False], + [128, 64, 128, 4, 128, 64, 128, False, False], # K-Forloop - [64, 32, 128, 4, 64, 32, 64], - [128, 16, 128, 4, 128, 16, 32], - [32, 16, 128, 4, 32, 16, 32], - [32, 64, 128, 4, 32, 64, 32], - [32, 128, 256, 4, 32, 128, 64], - [64, 128, 64, 4, 64, 128, 32], - [64, 64, 128, 4, 64, 64, 32], - [128, 128, 64, 4, 128, 128, 32], - [128, 128, 128, 4, 128, 128, 32], - [128, 128, 256, 4, 128, 128, 64], - [128, 256, 128, 4, 128, 256, 32], - [256, 128, 64, 4, 256, 128, 16], - [128, 64, 128, 4, 128, 64, 32], + [64, 32, 128, 4, 64, 32, 64, False, False], + [128, 16, 128, 4, 128, 16, 32, False, False], + [32, 16, 128, 4, 32, 16, 32, False, False], + [32, 64, 128, 4, 32, 64, 32, False, False], + [32, 128, 256, 4, 32, 128, 64, False, False], + [64, 128, 64, 4, 64, 128, 32, False, False], + [64, 64, 128, 4, 64, 64, 32, False, False], + [128, 128, 64, 4, 128, 128, 32, False, False], + [128, 128, 128, 4, 128, 128, 32, False, False], + [128, 128, 256, 4, 128, 128, 64, False, False], + [128, 256, 128, 4, 128, 256, 32, False, False], + [256, 128, 64, 4, 256, 128, 16, False, False], + [128, 64, 128, 4, 128, 64, 32, False, False], + # TODO[goostavz]: fix these cases + #[128, 64, 128, 4, 128, 64, 32, True, False], + #[128, 64, 128, 4, 128, 64, 32, False, True], ]) -def test_gemm(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K): - a = torch.randn((SIZE_M, SIZE_K), device='cuda', dtype=torch.float16) - b = torch.randn((SIZE_K, SIZE_N), device='cuda', dtype=torch.float16) +def test_gemm(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, TRANS_A, TRANS_B): + if (TRANS_A): + a = torch.randn((SIZE_K, SIZE_M), device='cuda', dtype=torch.float16).T + else: + a = torch.randn((SIZE_M, SIZE_K), device='cuda', dtype=torch.float16) + + if (TRANS_B): + b = torch.randn((SIZE_N, SIZE_K), device='cuda', dtype=torch.float16).T + else: + b = torch.randn((SIZE_K, SIZE_N), device='cuda', dtype=torch.float16) + c = torch.empty((SIZE_M, SIZE_N), device=a.device, dtype=torch.float32) grid = lambda META: (1, ) matmul_kernel[grid](a_ptr=a, b_ptr=b, c_ptr=c,