diff --git a/include/triton/Dialect/TritonGPU/IR/Dialect.h b/include/triton/Dialect/TritonGPU/IR/Dialect.h index 5985536eb..c4f169935 100644 --- a/include/triton/Dialect/TritonGPU/IR/Dialect.h +++ b/include/triton/Dialect/TritonGPU/IR/Dialect.h @@ -31,6 +31,8 @@ SmallVector getWarpsPerCTA(Attribute layout); SmallVector getSizePerThread(Attribute layout); +SmallVector getContigPerThread(Attribute layout); + SmallVector getThreadsPerCTA(const Attribute &layout); SmallVector getShapePerCTA(const Attribute &layout); diff --git a/lib/Analysis/Allocation.cpp b/lib/Analysis/Allocation.cpp index 068956697..a2262e75f 100644 --- a/lib/Analysis/Allocation.cpp +++ b/lib/Analysis/Allocation.cpp @@ -13,6 +13,7 @@ using ::mlir::triton::gpu::BlockedEncodingAttr; using ::mlir::triton::gpu::DotOperandEncodingAttr; +using ::mlir::triton::gpu::getContigPerThread; using ::mlir::triton::gpu::getOrder; using ::mlir::triton::gpu::getShapePerCTA; using ::mlir::triton::gpu::getSizePerThread; @@ -60,8 +61,8 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec, assert(srcLayout && dstLayout && "Unexpect layout in getScratchConfigForCvtLayout()"); auto [inOrd, outOrd] = getCvtOrder(srcLayout, dstLayout); - unsigned srcContigPerThread = getSizePerThread(srcLayout)[inOrd[0]]; - unsigned dstContigPerThread = getSizePerThread(dstLayout)[outOrd[0]]; + unsigned srcContigPerThread = getContigPerThread(srcLayout)[inOrd[0]]; + unsigned dstContigPerThread = getContigPerThread(dstLayout)[outOrd[0]]; // TODO: Fix the legacy issue that ourOrd[0] == 0 always means // that we cannot do vectorization. inVec = outOrd[0] == 0 ? 1 : inOrd[0] == 0 ? 1 : srcContigPerThread; diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index 138e0bd2d..f7b5e5c85 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -2901,12 +2901,12 @@ private: Value mmaThreadIdInGrp = urem(laneId, _4); Value mmaThreadIdInGrpM2 = mul(mmaThreadIdInGrp, _2); Value mmaThreadIdInGrpM2P1 = add(mmaThreadIdInGrpM2, _1); - Value colWarpOffset = mul(multiDimWarpId[0], _16); - mmaColIdx[0] = add(mmaGrpId, colWarpOffset); - mmaColIdx[1] = add(mmaGrpIdP8, colWarpOffset); - Value rowWarpOffset = mul(multiDimWarpId[1], _8); - mmaRowIdx[0] = add(mmaThreadIdInGrpM2, rowWarpOffset); - mmaRowIdx[1] = add(mmaThreadIdInGrpM2P1, rowWarpOffset); + Value rowWarpOffset = mul(multiDimWarpId[0], _16); + mmaRowIdx[0] = add(mmaGrpId, rowWarpOffset); + mmaRowIdx[1] = add(mmaGrpIdP8, rowWarpOffset); + Value colWarpOffset = mul(multiDimWarpId[1], _8); + mmaColIdx[0] = add(mmaThreadIdInGrpM2, colWarpOffset); + mmaColIdx[1] = add(mmaThreadIdInGrpM2P1, colWarpOffset); } else if (mmaLayout.getVersion() == 1) { multiDimWarpId[0] = urem(multiDimWarpId[0], idx_val(shape[0] / 16)); multiDimWarpId[1] = urem(multiDimWarpId[1], idx_val(shape[1] / 16)); @@ -2920,7 +2920,7 @@ private: Value rowOffset = add(mul(multiDimWarpId[1], _16), partRowOffset); mmaRowIdx[0] = add(urem(laneId, _2), rowOffset); mmaRowIdx[1] = add(mmaRowIdx[0], _2); - mmaColIdx[0] = add(udiv(urem(laneId, _4), _2), colOffset); + mmaColIdx[0] = add(mul(udiv(urem(laneId, _4), _2), _2), colOffset); mmaColIdx[1] = add(mmaColIdx[0], _1); mmaColIdx[2] = add(mmaColIdx[0], _4); mmaColIdx[3] = add(mmaColIdx[0], idx_val(5)); @@ -2931,28 +2931,28 @@ private: assert(rank == 2); SmallVector multiDimOffset(rank); if (mmaLayout.getVersion() == 2) { - multiDimOffset[0] = elemId < 2 ? mmaColIdx[0] : mmaColIdx[1]; - multiDimOffset[1] = elemId % 2 == 0 ? mmaRowIdx[0] : mmaRowIdx[1]; + multiDimOffset[0] = elemId < 2 ? mmaRowIdx[0] : mmaRowIdx[1]; + multiDimOffset[1] = elemId % 2 == 0 ? mmaColIdx[0] : mmaColIdx[1]; multiDimOffset[0] = add( multiDimOffset[0], idx_val(multiDimCTAInRepId[0] * shapePerCTA[0])); multiDimOffset[1] = add( multiDimOffset[1], idx_val(multiDimCTAInRepId[1] * shapePerCTA[1])); } else if (mmaLayout.getVersion() == 1) { // the order of elements in a thread: - // c0, c1, c4, c5 - // c2, c3, c6, c7 + // c0, c1, ... c4, c5 + // c2, c3, ... c6, c7 if (elemId < 2) { - multiDimOffset[0] = mmaColIdx[elemId % 2]; - multiDimOffset[1] = mmaRowIdx[0]; + multiDimOffset[0] = mmaRowIdx[0]; + multiDimOffset[1] = mmaColIdx[elemId % 2]; } else if (elemId >= 2 && elemId < 4) { - multiDimOffset[0] = mmaColIdx[elemId % 2]; - multiDimOffset[1] = mmaRowIdx[1]; + multiDimOffset[0] = mmaRowIdx[1]; + multiDimOffset[1] = mmaColIdx[elemId % 2]; } else if (elemId >= 4 && elemId < 6) { - multiDimOffset[0] = mmaColIdx[elemId % 2 + 2]; - multiDimOffset[1] = mmaRowIdx[0]; + multiDimOffset[0] = mmaRowIdx[0]; + multiDimOffset[1] = mmaColIdx[elemId % 2 + 2]; } else if (elemId >= 6) { - multiDimOffset[0] = mmaColIdx[elemId % 2 + 2]; - multiDimOffset[1] = mmaRowIdx[1]; + multiDimOffset[0] = mmaRowIdx[1]; + multiDimOffset[1] = mmaColIdx[elemId % 2 + 2]; } multiDimOffset[0] = add( multiDimOffset[0], idx_val(multiDimCTAInRepId[0] * shapePerCTA[0])); @@ -3051,6 +3051,7 @@ void ConvertLayoutOpConversion::processReplica( multiDimCTAInRepId, shapePerCTA); Value offset = linearize(rewriter, loc, multiDimOffset, paddedRepShape, outOrd); + auto elemPtrTy = ptr_ty(llvmElemTy, 3); Value ptr = gep(elemPtrTy, smemBase, offset); auto vecTy = vec_ty(llvmElemTy, vec); @@ -3171,6 +3172,11 @@ LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared( triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { auto loc = op.getLoc(); + + // TODO[Keren]: A temporary workaround for an issue from membar pass. + // https://triton-lang.slack.com/archives/C042VBSQWNS/p1669796615860699?thread_ts=1669779203.526739&cid=C042VBSQWNS + barrier(); + Value src = op.src(); Value dst = op.result(); auto srcTy = src.getType().cast(); diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 68c7f48a2..a178470d9 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -109,6 +109,8 @@ SmallVector getSizePerThread(Attribute layout) { if (mmaLayout.getVersion() == 2) { return {2, 2}; } else if (mmaLayout.getVersion() == 1) { + // Note: here the definition of sizePerThread is obscure, which doesn't + // mean vecSize=4 can be supported in the last dimension. return {2, 4}; } else { llvm_unreachable("Unexpected mma version"); @@ -140,6 +142,15 @@ SmallVector getSizePerThread(Attribute layout) { } } +SmallVector getContigPerThread(Attribute layout) { + if (auto mmaLayout = layout.dyn_cast()) { + assert(mmaLayout.getVersion() == 1 || mmaLayout.getVersion() == 2); + return {1, 2}; + } else { + return getSizePerThread(layout); + } +} + SmallVector getThreadsPerCTA(const Attribute &layout) { SmallVector threads; if (auto blockedLayout = layout.dyn_cast()) { diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index dc4acb183..e7f9a4b90 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -735,9 +735,13 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} { // CHECK-LABEL: convert_layout_mmav1_block func @convert_layout_mmav1_blocked(%arg0: tensor<32x16xf32, #mma>) { // CHECK: llvm.store - // CHECK-SAME: !llvm.ptr, 3> + // CHECK-SAME: !llvm.ptr, 3> // CHECK: llvm.store - // CHECK-SAME: !llvm.ptr, 3> + // CHECK-SAME: !llvm.ptr, 3> + // CHECK: llvm.store + // CHECK-SAME: !llvm.ptr, 3> + // CHECK: llvm.store + // CHECK-SAME: !llvm.ptr, 3> // CHECK: nvvm.barrier0 // CHECK: llvm.load // CHECK-SAME: !llvm.ptr, 3>