[Triton-MLIR] Two fixes on allocation and backend related with MMA v1 (#930)
This commit is contained in:
@@ -31,6 +31,8 @@ SmallVector<unsigned> getWarpsPerCTA(Attribute layout);
|
|||||||
|
|
||||||
SmallVector<unsigned> getSizePerThread(Attribute layout);
|
SmallVector<unsigned> getSizePerThread(Attribute layout);
|
||||||
|
|
||||||
|
SmallVector<unsigned> getContigPerThread(Attribute layout);
|
||||||
|
|
||||||
SmallVector<unsigned> getThreadsPerCTA(const Attribute &layout);
|
SmallVector<unsigned> getThreadsPerCTA(const Attribute &layout);
|
||||||
|
|
||||||
SmallVector<unsigned> getShapePerCTA(const Attribute &layout);
|
SmallVector<unsigned> getShapePerCTA(const Attribute &layout);
|
||||||
|
@@ -13,6 +13,7 @@
|
|||||||
|
|
||||||
using ::mlir::triton::gpu::BlockedEncodingAttr;
|
using ::mlir::triton::gpu::BlockedEncodingAttr;
|
||||||
using ::mlir::triton::gpu::DotOperandEncodingAttr;
|
using ::mlir::triton::gpu::DotOperandEncodingAttr;
|
||||||
|
using ::mlir::triton::gpu::getContigPerThread;
|
||||||
using ::mlir::triton::gpu::getOrder;
|
using ::mlir::triton::gpu::getOrder;
|
||||||
using ::mlir::triton::gpu::getShapePerCTA;
|
using ::mlir::triton::gpu::getShapePerCTA;
|
||||||
using ::mlir::triton::gpu::getSizePerThread;
|
using ::mlir::triton::gpu::getSizePerThread;
|
||||||
@@ -60,8 +61,8 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
|
|||||||
assert(srcLayout && dstLayout &&
|
assert(srcLayout && dstLayout &&
|
||||||
"Unexpect layout in getScratchConfigForCvtLayout()");
|
"Unexpect layout in getScratchConfigForCvtLayout()");
|
||||||
auto [inOrd, outOrd] = getCvtOrder(srcLayout, dstLayout);
|
auto [inOrd, outOrd] = getCvtOrder(srcLayout, dstLayout);
|
||||||
unsigned srcContigPerThread = getSizePerThread(srcLayout)[inOrd[0]];
|
unsigned srcContigPerThread = getContigPerThread(srcLayout)[inOrd[0]];
|
||||||
unsigned dstContigPerThread = getSizePerThread(dstLayout)[outOrd[0]];
|
unsigned dstContigPerThread = getContigPerThread(dstLayout)[outOrd[0]];
|
||||||
// TODO: Fix the legacy issue that ourOrd[0] == 0 always means
|
// TODO: Fix the legacy issue that ourOrd[0] == 0 always means
|
||||||
// that we cannot do vectorization.
|
// that we cannot do vectorization.
|
||||||
inVec = outOrd[0] == 0 ? 1 : inOrd[0] == 0 ? 1 : srcContigPerThread;
|
inVec = outOrd[0] == 0 ? 1 : inOrd[0] == 0 ? 1 : srcContigPerThread;
|
||||||
|
@@ -2901,12 +2901,12 @@ private:
|
|||||||
Value mmaThreadIdInGrp = urem(laneId, _4);
|
Value mmaThreadIdInGrp = urem(laneId, _4);
|
||||||
Value mmaThreadIdInGrpM2 = mul(mmaThreadIdInGrp, _2);
|
Value mmaThreadIdInGrpM2 = mul(mmaThreadIdInGrp, _2);
|
||||||
Value mmaThreadIdInGrpM2P1 = add(mmaThreadIdInGrpM2, _1);
|
Value mmaThreadIdInGrpM2P1 = add(mmaThreadIdInGrpM2, _1);
|
||||||
Value colWarpOffset = mul(multiDimWarpId[0], _16);
|
Value rowWarpOffset = mul(multiDimWarpId[0], _16);
|
||||||
mmaColIdx[0] = add(mmaGrpId, colWarpOffset);
|
mmaRowIdx[0] = add(mmaGrpId, rowWarpOffset);
|
||||||
mmaColIdx[1] = add(mmaGrpIdP8, colWarpOffset);
|
mmaRowIdx[1] = add(mmaGrpIdP8, rowWarpOffset);
|
||||||
Value rowWarpOffset = mul(multiDimWarpId[1], _8);
|
Value colWarpOffset = mul(multiDimWarpId[1], _8);
|
||||||
mmaRowIdx[0] = add(mmaThreadIdInGrpM2, rowWarpOffset);
|
mmaColIdx[0] = add(mmaThreadIdInGrpM2, colWarpOffset);
|
||||||
mmaRowIdx[1] = add(mmaThreadIdInGrpM2P1, rowWarpOffset);
|
mmaColIdx[1] = add(mmaThreadIdInGrpM2P1, colWarpOffset);
|
||||||
} else if (mmaLayout.getVersion() == 1) {
|
} else if (mmaLayout.getVersion() == 1) {
|
||||||
multiDimWarpId[0] = urem(multiDimWarpId[0], idx_val(shape[0] / 16));
|
multiDimWarpId[0] = urem(multiDimWarpId[0], idx_val(shape[0] / 16));
|
||||||
multiDimWarpId[1] = urem(multiDimWarpId[1], idx_val(shape[1] / 16));
|
multiDimWarpId[1] = urem(multiDimWarpId[1], idx_val(shape[1] / 16));
|
||||||
@@ -2920,7 +2920,7 @@ private:
|
|||||||
Value rowOffset = add(mul(multiDimWarpId[1], _16), partRowOffset);
|
Value rowOffset = add(mul(multiDimWarpId[1], _16), partRowOffset);
|
||||||
mmaRowIdx[0] = add(urem(laneId, _2), rowOffset);
|
mmaRowIdx[0] = add(urem(laneId, _2), rowOffset);
|
||||||
mmaRowIdx[1] = add(mmaRowIdx[0], _2);
|
mmaRowIdx[1] = add(mmaRowIdx[0], _2);
|
||||||
mmaColIdx[0] = add(udiv(urem(laneId, _4), _2), colOffset);
|
mmaColIdx[0] = add(mul(udiv(urem(laneId, _4), _2), _2), colOffset);
|
||||||
mmaColIdx[1] = add(mmaColIdx[0], _1);
|
mmaColIdx[1] = add(mmaColIdx[0], _1);
|
||||||
mmaColIdx[2] = add(mmaColIdx[0], _4);
|
mmaColIdx[2] = add(mmaColIdx[0], _4);
|
||||||
mmaColIdx[3] = add(mmaColIdx[0], idx_val(5));
|
mmaColIdx[3] = add(mmaColIdx[0], idx_val(5));
|
||||||
@@ -2931,28 +2931,28 @@ private:
|
|||||||
assert(rank == 2);
|
assert(rank == 2);
|
||||||
SmallVector<Value> multiDimOffset(rank);
|
SmallVector<Value> multiDimOffset(rank);
|
||||||
if (mmaLayout.getVersion() == 2) {
|
if (mmaLayout.getVersion() == 2) {
|
||||||
multiDimOffset[0] = elemId < 2 ? mmaColIdx[0] : mmaColIdx[1];
|
multiDimOffset[0] = elemId < 2 ? mmaRowIdx[0] : mmaRowIdx[1];
|
||||||
multiDimOffset[1] = elemId % 2 == 0 ? mmaRowIdx[0] : mmaRowIdx[1];
|
multiDimOffset[1] = elemId % 2 == 0 ? mmaColIdx[0] : mmaColIdx[1];
|
||||||
multiDimOffset[0] = add(
|
multiDimOffset[0] = add(
|
||||||
multiDimOffset[0], idx_val(multiDimCTAInRepId[0] * shapePerCTA[0]));
|
multiDimOffset[0], idx_val(multiDimCTAInRepId[0] * shapePerCTA[0]));
|
||||||
multiDimOffset[1] = add(
|
multiDimOffset[1] = add(
|
||||||
multiDimOffset[1], idx_val(multiDimCTAInRepId[1] * shapePerCTA[1]));
|
multiDimOffset[1], idx_val(multiDimCTAInRepId[1] * shapePerCTA[1]));
|
||||||
} else if (mmaLayout.getVersion() == 1) {
|
} else if (mmaLayout.getVersion() == 1) {
|
||||||
// the order of elements in a thread:
|
// the order of elements in a thread:
|
||||||
// c0, c1, c4, c5
|
// c0, c1, ... c4, c5
|
||||||
// c2, c3, c6, c7
|
// c2, c3, ... c6, c7
|
||||||
if (elemId < 2) {
|
if (elemId < 2) {
|
||||||
multiDimOffset[0] = mmaColIdx[elemId % 2];
|
multiDimOffset[0] = mmaRowIdx[0];
|
||||||
multiDimOffset[1] = mmaRowIdx[0];
|
multiDimOffset[1] = mmaColIdx[elemId % 2];
|
||||||
} else if (elemId >= 2 && elemId < 4) {
|
} else if (elemId >= 2 && elemId < 4) {
|
||||||
multiDimOffset[0] = mmaColIdx[elemId % 2];
|
multiDimOffset[0] = mmaRowIdx[1];
|
||||||
multiDimOffset[1] = mmaRowIdx[1];
|
multiDimOffset[1] = mmaColIdx[elemId % 2];
|
||||||
} else if (elemId >= 4 && elemId < 6) {
|
} else if (elemId >= 4 && elemId < 6) {
|
||||||
multiDimOffset[0] = mmaColIdx[elemId % 2 + 2];
|
multiDimOffset[0] = mmaRowIdx[0];
|
||||||
multiDimOffset[1] = mmaRowIdx[0];
|
multiDimOffset[1] = mmaColIdx[elemId % 2 + 2];
|
||||||
} else if (elemId >= 6) {
|
} else if (elemId >= 6) {
|
||||||
multiDimOffset[0] = mmaColIdx[elemId % 2 + 2];
|
multiDimOffset[0] = mmaRowIdx[1];
|
||||||
multiDimOffset[1] = mmaRowIdx[1];
|
multiDimOffset[1] = mmaColIdx[elemId % 2 + 2];
|
||||||
}
|
}
|
||||||
multiDimOffset[0] = add(
|
multiDimOffset[0] = add(
|
||||||
multiDimOffset[0], idx_val(multiDimCTAInRepId[0] * shapePerCTA[0]));
|
multiDimOffset[0], idx_val(multiDimCTAInRepId[0] * shapePerCTA[0]));
|
||||||
@@ -3051,6 +3051,7 @@ void ConvertLayoutOpConversion::processReplica(
|
|||||||
multiDimCTAInRepId, shapePerCTA);
|
multiDimCTAInRepId, shapePerCTA);
|
||||||
Value offset =
|
Value offset =
|
||||||
linearize(rewriter, loc, multiDimOffset, paddedRepShape, outOrd);
|
linearize(rewriter, loc, multiDimOffset, paddedRepShape, outOrd);
|
||||||
|
|
||||||
auto elemPtrTy = ptr_ty(llvmElemTy, 3);
|
auto elemPtrTy = ptr_ty(llvmElemTy, 3);
|
||||||
Value ptr = gep(elemPtrTy, smemBase, offset);
|
Value ptr = gep(elemPtrTy, smemBase, offset);
|
||||||
auto vecTy = vec_ty(llvmElemTy, vec);
|
auto vecTy = vec_ty(llvmElemTy, vec);
|
||||||
@@ -3171,6 +3172,11 @@ LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared(
|
|||||||
triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
|
triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
auto loc = op.getLoc();
|
auto loc = op.getLoc();
|
||||||
|
|
||||||
|
// TODO[Keren]: A temporary workaround for an issue from membar pass.
|
||||||
|
// https://triton-lang.slack.com/archives/C042VBSQWNS/p1669796615860699?thread_ts=1669779203.526739&cid=C042VBSQWNS
|
||||||
|
barrier();
|
||||||
|
|
||||||
Value src = op.src();
|
Value src = op.src();
|
||||||
Value dst = op.result();
|
Value dst = op.result();
|
||||||
auto srcTy = src.getType().cast<RankedTensorType>();
|
auto srcTy = src.getType().cast<RankedTensorType>();
|
||||||
|
@@ -109,6 +109,8 @@ SmallVector<unsigned> getSizePerThread(Attribute layout) {
|
|||||||
if (mmaLayout.getVersion() == 2) {
|
if (mmaLayout.getVersion() == 2) {
|
||||||
return {2, 2};
|
return {2, 2};
|
||||||
} else if (mmaLayout.getVersion() == 1) {
|
} else if (mmaLayout.getVersion() == 1) {
|
||||||
|
// Note: here the definition of sizePerThread is obscure, which doesn't
|
||||||
|
// mean vecSize=4 can be supported in the last dimension.
|
||||||
return {2, 4};
|
return {2, 4};
|
||||||
} else {
|
} else {
|
||||||
llvm_unreachable("Unexpected mma version");
|
llvm_unreachable("Unexpected mma version");
|
||||||
@@ -140,6 +142,15 @@ SmallVector<unsigned> getSizePerThread(Attribute layout) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
SmallVector<unsigned> getContigPerThread(Attribute layout) {
|
||||||
|
if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
|
||||||
|
assert(mmaLayout.getVersion() == 1 || mmaLayout.getVersion() == 2);
|
||||||
|
return {1, 2};
|
||||||
|
} else {
|
||||||
|
return getSizePerThread(layout);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
SmallVector<unsigned> getThreadsPerCTA(const Attribute &layout) {
|
SmallVector<unsigned> getThreadsPerCTA(const Attribute &layout) {
|
||||||
SmallVector<unsigned> threads;
|
SmallVector<unsigned> threads;
|
||||||
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
|
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
|
||||||
|
@@ -735,9 +735,13 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
|||||||
// CHECK-LABEL: convert_layout_mmav1_block
|
// CHECK-LABEL: convert_layout_mmav1_block
|
||||||
func @convert_layout_mmav1_blocked(%arg0: tensor<32x16xf32, #mma>) {
|
func @convert_layout_mmav1_blocked(%arg0: tensor<32x16xf32, #mma>) {
|
||||||
// CHECK: llvm.store
|
// CHECK: llvm.store
|
||||||
// CHECK-SAME: !llvm.ptr<vector<4xf32>, 3>
|
// CHECK-SAME: !llvm.ptr<vector<2xf32>, 3>
|
||||||
// CHECK: llvm.store
|
// CHECK: llvm.store
|
||||||
// CHECK-SAME: !llvm.ptr<vector<4xf32>, 3>
|
// CHECK-SAME: !llvm.ptr<vector<2xf32>, 3>
|
||||||
|
// CHECK: llvm.store
|
||||||
|
// CHECK-SAME: !llvm.ptr<vector<2xf32>, 3>
|
||||||
|
// CHECK: llvm.store
|
||||||
|
// CHECK-SAME: !llvm.ptr<vector<2xf32>, 3>
|
||||||
// CHECK: nvvm.barrier0
|
// CHECK: nvvm.barrier0
|
||||||
// CHECK: llvm.load
|
// CHECK: llvm.load
|
||||||
// CHECK-SAME: !llvm.ptr<vector<4xf32>, 3>
|
// CHECK-SAME: !llvm.ptr<vector<4xf32>, 3>
|
||||||
|
Reference in New Issue
Block a user