[Triton-MLIR][Backend] support MMA v1 in ConvertLayout (#922)
The e2e verification of mma v1 is not done yet. Get this merged in advance just to prevent more conflicts.
This commit is contained in:
@@ -2926,7 +2926,7 @@ private:
|
|||||||
return multiDimOffset;
|
return multiDimOffset;
|
||||||
}
|
}
|
||||||
if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
|
if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
|
||||||
SmallVector<Value> mmaColIdx(2);
|
SmallVector<Value> mmaColIdx(4);
|
||||||
SmallVector<Value> mmaRowIdx(2);
|
SmallVector<Value> mmaRowIdx(2);
|
||||||
Value threadId = getThreadId(rewriter, loc);
|
Value threadId = getThreadId(rewriter, loc);
|
||||||
Value warpSize = idx_val(32);
|
Value warpSize = idx_val(32);
|
||||||
@@ -2936,31 +2936,79 @@ private:
|
|||||||
SmallVector<Value> multiDimWarpId(2);
|
SmallVector<Value> multiDimWarpId(2);
|
||||||
multiDimWarpId[0] = urem(warpId, idx_val(mmaLayout.getWarpsPerCTA()[0]));
|
multiDimWarpId[0] = urem(warpId, idx_val(mmaLayout.getWarpsPerCTA()[0]));
|
||||||
multiDimWarpId[1] = udiv(warpId, idx_val(mmaLayout.getWarpsPerCTA()[0]));
|
multiDimWarpId[1] = udiv(warpId, idx_val(mmaLayout.getWarpsPerCTA()[0]));
|
||||||
multiDimWarpId[0] = urem(multiDimWarpId[0], idx_val(shape[0] / 16));
|
Value _1 = idx_val(1);
|
||||||
multiDimWarpId[1] = urem(multiDimWarpId[1], idx_val(shape[1] / 8));
|
Value _2 = idx_val(2);
|
||||||
Value four = idx_val(4);
|
Value _4 = idx_val(4);
|
||||||
Value mmaGrpId = udiv(laneId, four);
|
Value _8 = idx_val(8);
|
||||||
Value mmaGrpIdP8 = add(mmaGrpId, idx_val(8));
|
Value _16 = idx_val(16);
|
||||||
Value mmaThreadIdInGrp = urem(laneId, four);
|
if (mmaLayout.getVersion() == 2) {
|
||||||
Value mmaThreadIdInGrpM2 = mul(mmaThreadIdInGrp, idx_val(2));
|
multiDimWarpId[0] = urem(multiDimWarpId[0], idx_val(shape[0] / 16));
|
||||||
Value mmaThreadIdInGrpM2P1 = add(mmaThreadIdInGrpM2, idx_val(1));
|
multiDimWarpId[1] = urem(multiDimWarpId[1], idx_val(shape[1] / 8));
|
||||||
Value colWarpOffset = mul(multiDimWarpId[0], idx_val(16));
|
Value mmaGrpId = udiv(laneId, _4);
|
||||||
mmaColIdx[0] = add(mmaGrpId, colWarpOffset);
|
Value mmaGrpIdP8 = add(mmaGrpId, _8);
|
||||||
mmaColIdx[1] = add(mmaGrpIdP8, colWarpOffset);
|
Value mmaThreadIdInGrp = urem(laneId, _4);
|
||||||
Value rowWarpOffset = mul(multiDimWarpId[1], idx_val(8));
|
Value mmaThreadIdInGrpM2 = mul(mmaThreadIdInGrp, _2);
|
||||||
mmaRowIdx[0] = add(mmaThreadIdInGrpM2, rowWarpOffset);
|
Value mmaThreadIdInGrpM2P1 = add(mmaThreadIdInGrpM2, _1);
|
||||||
mmaRowIdx[1] = add(mmaThreadIdInGrpM2P1, rowWarpOffset);
|
Value colWarpOffset = mul(multiDimWarpId[0], _16);
|
||||||
|
mmaColIdx[0] = add(mmaGrpId, colWarpOffset);
|
||||||
|
mmaColIdx[1] = add(mmaGrpIdP8, colWarpOffset);
|
||||||
|
Value rowWarpOffset = mul(multiDimWarpId[1], _8);
|
||||||
|
mmaRowIdx[0] = add(mmaThreadIdInGrpM2, rowWarpOffset);
|
||||||
|
mmaRowIdx[1] = add(mmaThreadIdInGrpM2P1, rowWarpOffset);
|
||||||
|
} else if (mmaLayout.getVersion() == 1) {
|
||||||
|
multiDimWarpId[0] = urem(multiDimWarpId[0], idx_val(shape[0] / 16));
|
||||||
|
multiDimWarpId[1] = urem(multiDimWarpId[1], idx_val(shape[1] / 16));
|
||||||
|
Value partId = udiv(laneId, _4);
|
||||||
|
Value partIdDiv4 = udiv(partId, _4);
|
||||||
|
Value partIdRem4 = urem(partId, _4);
|
||||||
|
Value partRowOffset = mul(udiv(partIdRem4, _2), _8);
|
||||||
|
partRowOffset = add(mul(partIdDiv4, _4), partRowOffset);
|
||||||
|
Value partColOffset = mul(urem(partIdRem4, _2), _8);
|
||||||
|
Value colOffset = add(mul(multiDimWarpId[0], _16), partColOffset);
|
||||||
|
Value rowOffset = add(mul(multiDimWarpId[1], _16), partRowOffset);
|
||||||
|
mmaRowIdx[0] = add(urem(laneId, _2), rowOffset);
|
||||||
|
mmaRowIdx[1] = add(mmaRowIdx[0], _2);
|
||||||
|
mmaColIdx[0] = add(udiv(urem(laneId, _4), _2), colOffset);
|
||||||
|
mmaColIdx[1] = add(mmaColIdx[0], _1);
|
||||||
|
mmaColIdx[2] = add(mmaColIdx[0], _4);
|
||||||
|
mmaColIdx[3] = add(mmaColIdx[0], idx_val(5));
|
||||||
|
} else {
|
||||||
|
llvm_unreachable("Unexpected MMALayout version");
|
||||||
|
}
|
||||||
|
|
||||||
assert(rank == 2);
|
assert(rank == 2);
|
||||||
assert(mmaLayout.getVersion() == 2 &&
|
|
||||||
"mmaLayout ver1 not implemented yet");
|
|
||||||
SmallVector<Value> multiDimOffset(rank);
|
SmallVector<Value> multiDimOffset(rank);
|
||||||
multiDimOffset[0] = elemId < 2 ? mmaColIdx[0] : mmaColIdx[1];
|
if (mmaLayout.getVersion() == 2) {
|
||||||
multiDimOffset[1] = elemId % 2 == 0 ? mmaRowIdx[0] : mmaRowIdx[1];
|
multiDimOffset[0] = elemId < 2 ? mmaColIdx[0] : mmaColIdx[1];
|
||||||
multiDimOffset[0] = add(multiDimOffset[0],
|
multiDimOffset[1] = elemId % 2 == 0 ? mmaRowIdx[0] : mmaRowIdx[1];
|
||||||
idx_val(multiDimCTAInRepId[0] * shapePerCTA[0]));
|
multiDimOffset[0] = add(
|
||||||
multiDimOffset[1] = add(multiDimOffset[1],
|
multiDimOffset[0], idx_val(multiDimCTAInRepId[0] * shapePerCTA[0]));
|
||||||
idx_val(multiDimCTAInRepId[1] * shapePerCTA[1]));
|
multiDimOffset[1] = add(
|
||||||
|
multiDimOffset[1], idx_val(multiDimCTAInRepId[1] * shapePerCTA[1]));
|
||||||
|
} else if (mmaLayout.getVersion() == 1) {
|
||||||
|
// the order of elements in a thread:
|
||||||
|
// c0, c1, c4, c5
|
||||||
|
// c2, c3, c6, c7
|
||||||
|
if (elemId < 2) {
|
||||||
|
multiDimOffset[0] = mmaColIdx[elemId % 2];
|
||||||
|
multiDimOffset[1] = mmaRowIdx[0];
|
||||||
|
} else if (elemId >= 2 && elemId < 4) {
|
||||||
|
multiDimOffset[0] = mmaColIdx[elemId % 2];
|
||||||
|
multiDimOffset[1] = mmaRowIdx[1];
|
||||||
|
} else if (elemId >= 4 && elemId < 6) {
|
||||||
|
multiDimOffset[0] = mmaColIdx[elemId % 2 + 2];
|
||||||
|
multiDimOffset[1] = mmaRowIdx[0];
|
||||||
|
} else if (elemId >= 6) {
|
||||||
|
multiDimOffset[0] = mmaColIdx[elemId % 2 + 2];
|
||||||
|
multiDimOffset[1] = mmaRowIdx[1];
|
||||||
|
}
|
||||||
|
multiDimOffset[0] = add(
|
||||||
|
multiDimOffset[0], idx_val(multiDimCTAInRepId[0] * shapePerCTA[0]));
|
||||||
|
multiDimOffset[1] = add(
|
||||||
|
multiDimOffset[1], idx_val(multiDimCTAInRepId[1] * shapePerCTA[1]));
|
||||||
|
} else {
|
||||||
|
llvm_unreachable("Unexpected MMALayout version");
|
||||||
|
}
|
||||||
return multiDimOffset;
|
return multiDimOffset;
|
||||||
}
|
}
|
||||||
llvm_unreachable("unexpected layout in getMultiDimOffset");
|
llvm_unreachable("unexpected layout in getMultiDimOffset");
|
||||||
|
@@ -78,9 +78,9 @@ SmallVector<unsigned> getThreadsPerWarp(Attribute layout) {
|
|||||||
}
|
}
|
||||||
if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
|
if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
|
||||||
if (mmaLayout.getVersion() == 1)
|
if (mmaLayout.getVersion() == 1)
|
||||||
return SmallVector<unsigned>{4, 8};
|
return {4, 8};
|
||||||
if (mmaLayout.getVersion() == 2)
|
if (mmaLayout.getVersion() == 2)
|
||||||
return SmallVector<unsigned>{8, 4};
|
return {8, 4};
|
||||||
}
|
}
|
||||||
assert(0 && "getThreadsPerWarp not implemented");
|
assert(0 && "getThreadsPerWarp not implemented");
|
||||||
return {};
|
return {};
|
||||||
@@ -106,9 +106,13 @@ SmallVector<unsigned> getSizePerThread(Attribute layout) {
|
|||||||
} else if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
|
} else if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
|
||||||
return getSizePerThread(sliceLayout.getParent());
|
return getSizePerThread(sliceLayout.getParent());
|
||||||
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
|
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
|
||||||
assert(mmaLayout.getVersion() == 2 &&
|
if (mmaLayout.getVersion() == 2) {
|
||||||
"mmaLayout version = 1 is not implemented yet");
|
return {2, 2};
|
||||||
return SmallVector<unsigned>{2, 2};
|
} else if (mmaLayout.getVersion() == 1) {
|
||||||
|
return {2, 4};
|
||||||
|
} else {
|
||||||
|
llvm_unreachable("Unexpected mma version");
|
||||||
|
}
|
||||||
} else if (auto dotLayout = layout.dyn_cast<DotOperandEncodingAttr>()) {
|
} else if (auto dotLayout = layout.dyn_cast<DotOperandEncodingAttr>()) {
|
||||||
auto parentLayout = dotLayout.getParent();
|
auto parentLayout = dotLayout.getParent();
|
||||||
assert(parentLayout && "DotOperandEncodingAttr must have a parent");
|
assert(parentLayout && "DotOperandEncodingAttr must have a parent");
|
||||||
@@ -194,6 +198,16 @@ SmallVector<unsigned> getShapePerCTA(const Attribute &layout) {
|
|||||||
assert(0 && "DotOperandEncodingAttr non-MmaEncodingAttr parent not "
|
assert(0 && "DotOperandEncodingAttr non-MmaEncodingAttr parent not "
|
||||||
"supported yet");
|
"supported yet");
|
||||||
}
|
}
|
||||||
|
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
|
||||||
|
if (mmaLayout.getVersion() == 2) {
|
||||||
|
return {16 * mmaLayout.getWarpsPerCTA()[0],
|
||||||
|
8 * mmaLayout.getWarpsPerCTA()[1]};
|
||||||
|
} else if (mmaLayout.getVersion() == 1) {
|
||||||
|
return {16 * mmaLayout.getWarpsPerCTA()[0],
|
||||||
|
16 * mmaLayout.getWarpsPerCTA()[1]};
|
||||||
|
} else {
|
||||||
|
llvm_unreachable("Unexpected mma version");
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
assert(0 && "Unimplemented usage of getShapePerCTA");
|
assert(0 && "Unimplemented usage of getShapePerCTA");
|
||||||
}
|
}
|
||||||
@@ -205,9 +219,9 @@ SmallVector<unsigned> getOrder(const Attribute &layout) {
|
|||||||
return SmallVector<unsigned>(blockedLayout.getOrder().begin(),
|
return SmallVector<unsigned>(blockedLayout.getOrder().begin(),
|
||||||
blockedLayout.getOrder().end());
|
blockedLayout.getOrder().end());
|
||||||
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
|
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
|
||||||
return SmallVector<unsigned>{1, 0};
|
return {1, 0};
|
||||||
} else if (auto dotLayout = layout.dyn_cast<DotOperandEncodingAttr>()) {
|
} else if (auto dotLayout = layout.dyn_cast<DotOperandEncodingAttr>()) {
|
||||||
return SmallVector<unsigned>{1, 0};
|
return {1, 0};
|
||||||
} else if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
|
} else if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
|
||||||
SmallVector<unsigned> parentOrder = getOrder(sliceLayout.getParent());
|
SmallVector<unsigned> parentOrder = getOrder(sliceLayout.getParent());
|
||||||
unsigned dim = sliceLayout.getDim();
|
unsigned dim = sliceLayout.getDim();
|
||||||
@@ -358,6 +372,8 @@ unsigned MmaEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
|
|||||||
unsigned elemsCol = ceil<unsigned>(shape[0], 16 * getWarpsPerCTA()[0]) * 2;
|
unsigned elemsCol = ceil<unsigned>(shape[0], 16 * getWarpsPerCTA()[0]) * 2;
|
||||||
unsigned elemsRow = ceil<unsigned>(shape[1], 8 * getWarpsPerCTA()[1]) * 2;
|
unsigned elemsRow = ceil<unsigned>(shape[1], 8 * getWarpsPerCTA()[1]) * 2;
|
||||||
res = elemsCol * elemsRow;
|
res = elemsCol * elemsRow;
|
||||||
|
} else {
|
||||||
|
llvm_unreachable("Unexpected mma version");
|
||||||
}
|
}
|
||||||
|
|
||||||
return res;
|
return res;
|
||||||
|
@@ -712,8 +712,8 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
|||||||
#mma = #triton_gpu.mma<{version = 2, warpsPerCTA = [2, 2]}>
|
#mma = #triton_gpu.mma<{version = 2, warpsPerCTA = [2, 2]}>
|
||||||
module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||||
// CHECK: llvm.mlir.global external @global_smem() {addr_space = 3 : i32} : !llvm.array<0 x i8>
|
// CHECK: llvm.mlir.global external @global_smem() {addr_space = 3 : i32} : !llvm.array<0 x i8>
|
||||||
// CHECK-LABEL: convert_layout_mma_block
|
// CHECK-LABEL: convert_layout_mmav2_block
|
||||||
func @convert_layout_mma_blocked(%arg0: tensor<32x16xf32, #mma>) {
|
func @convert_layout_mmav2_blocked(%arg0: tensor<32x16xf32, #mma>) {
|
||||||
// CHECK: llvm.store
|
// CHECK: llvm.store
|
||||||
// CHECK-SAME: !llvm.ptr<vector<2xf32>, 3>
|
// CHECK-SAME: !llvm.ptr<vector<2xf32>, 3>
|
||||||
// CHECK: llvm.store
|
// CHECK: llvm.store
|
||||||
@@ -728,6 +728,26 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
|||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [1, 0]}>
|
||||||
|
#mma = #triton_gpu.mma<{version = 1, warpsPerCTA = [2, 1]}>
|
||||||
|
module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||||
|
// CHECK: llvm.mlir.global external @global_smem() {addr_space = 3 : i32} : !llvm.array<0 x i8>
|
||||||
|
// CHECK-LABEL: convert_layout_mmav1_block
|
||||||
|
func @convert_layout_mmav1_blocked(%arg0: tensor<32x16xf32, #mma>) {
|
||||||
|
// CHECK: llvm.store
|
||||||
|
// CHECK-SAME: !llvm.ptr<vector<4xf32>, 3>
|
||||||
|
// CHECK: llvm.store
|
||||||
|
// CHECK-SAME: !llvm.ptr<vector<4xf32>, 3>
|
||||||
|
// CHECK: nvvm.barrier0
|
||||||
|
// CHECK: llvm.load
|
||||||
|
// CHECK-SAME: !llvm.ptr<vector<4xf32>, 3>
|
||||||
|
%0 = triton_gpu.convert_layout %arg0 : (tensor<32x16xf32, #mma>) -> tensor<32x16xf32, #blocked0>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [8, 1], order = [1, 0]}>
|
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [8, 1], order = [1, 0]}>
|
||||||
#shared0 = #triton_gpu.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0]}>
|
#shared0 = #triton_gpu.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0]}>
|
||||||
module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||||
|
Reference in New Issue
Block a user