[Triton-MLIR][Backend] support MMA v1 in ConvertLayout (#922)

The e2e verification of mma v1 is not done yet. 
Get this merged in advance just to prevent more conflicts.
This commit is contained in:
goostavz
2022-11-28 16:10:30 +08:00
committed by GitHub
parent 9d31998a9d
commit 0c1d4d764e
3 changed files with 116 additions and 32 deletions

View File

@@ -2926,7 +2926,7 @@ private:
return multiDimOffset; return multiDimOffset;
} }
if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) { if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
SmallVector<Value> mmaColIdx(2); SmallVector<Value> mmaColIdx(4);
SmallVector<Value> mmaRowIdx(2); SmallVector<Value> mmaRowIdx(2);
Value threadId = getThreadId(rewriter, loc); Value threadId = getThreadId(rewriter, loc);
Value warpSize = idx_val(32); Value warpSize = idx_val(32);
@@ -2936,31 +2936,79 @@ private:
SmallVector<Value> multiDimWarpId(2); SmallVector<Value> multiDimWarpId(2);
multiDimWarpId[0] = urem(warpId, idx_val(mmaLayout.getWarpsPerCTA()[0])); multiDimWarpId[0] = urem(warpId, idx_val(mmaLayout.getWarpsPerCTA()[0]));
multiDimWarpId[1] = udiv(warpId, idx_val(mmaLayout.getWarpsPerCTA()[0])); multiDimWarpId[1] = udiv(warpId, idx_val(mmaLayout.getWarpsPerCTA()[0]));
multiDimWarpId[0] = urem(multiDimWarpId[0], idx_val(shape[0] / 16)); Value _1 = idx_val(1);
multiDimWarpId[1] = urem(multiDimWarpId[1], idx_val(shape[1] / 8)); Value _2 = idx_val(2);
Value four = idx_val(4); Value _4 = idx_val(4);
Value mmaGrpId = udiv(laneId, four); Value _8 = idx_val(8);
Value mmaGrpIdP8 = add(mmaGrpId, idx_val(8)); Value _16 = idx_val(16);
Value mmaThreadIdInGrp = urem(laneId, four); if (mmaLayout.getVersion() == 2) {
Value mmaThreadIdInGrpM2 = mul(mmaThreadIdInGrp, idx_val(2)); multiDimWarpId[0] = urem(multiDimWarpId[0], idx_val(shape[0] / 16));
Value mmaThreadIdInGrpM2P1 = add(mmaThreadIdInGrpM2, idx_val(1)); multiDimWarpId[1] = urem(multiDimWarpId[1], idx_val(shape[1] / 8));
Value colWarpOffset = mul(multiDimWarpId[0], idx_val(16)); Value mmaGrpId = udiv(laneId, _4);
mmaColIdx[0] = add(mmaGrpId, colWarpOffset); Value mmaGrpIdP8 = add(mmaGrpId, _8);
mmaColIdx[1] = add(mmaGrpIdP8, colWarpOffset); Value mmaThreadIdInGrp = urem(laneId, _4);
Value rowWarpOffset = mul(multiDimWarpId[1], idx_val(8)); Value mmaThreadIdInGrpM2 = mul(mmaThreadIdInGrp, _2);
mmaRowIdx[0] = add(mmaThreadIdInGrpM2, rowWarpOffset); Value mmaThreadIdInGrpM2P1 = add(mmaThreadIdInGrpM2, _1);
mmaRowIdx[1] = add(mmaThreadIdInGrpM2P1, rowWarpOffset); Value colWarpOffset = mul(multiDimWarpId[0], _16);
mmaColIdx[0] = add(mmaGrpId, colWarpOffset);
mmaColIdx[1] = add(mmaGrpIdP8, colWarpOffset);
Value rowWarpOffset = mul(multiDimWarpId[1], _8);
mmaRowIdx[0] = add(mmaThreadIdInGrpM2, rowWarpOffset);
mmaRowIdx[1] = add(mmaThreadIdInGrpM2P1, rowWarpOffset);
} else if (mmaLayout.getVersion() == 1) {
multiDimWarpId[0] = urem(multiDimWarpId[0], idx_val(shape[0] / 16));
multiDimWarpId[1] = urem(multiDimWarpId[1], idx_val(shape[1] / 16));
Value partId = udiv(laneId, _4);
Value partIdDiv4 = udiv(partId, _4);
Value partIdRem4 = urem(partId, _4);
Value partRowOffset = mul(udiv(partIdRem4, _2), _8);
partRowOffset = add(mul(partIdDiv4, _4), partRowOffset);
Value partColOffset = mul(urem(partIdRem4, _2), _8);
Value colOffset = add(mul(multiDimWarpId[0], _16), partColOffset);
Value rowOffset = add(mul(multiDimWarpId[1], _16), partRowOffset);
mmaRowIdx[0] = add(urem(laneId, _2), rowOffset);
mmaRowIdx[1] = add(mmaRowIdx[0], _2);
mmaColIdx[0] = add(udiv(urem(laneId, _4), _2), colOffset);
mmaColIdx[1] = add(mmaColIdx[0], _1);
mmaColIdx[2] = add(mmaColIdx[0], _4);
mmaColIdx[3] = add(mmaColIdx[0], idx_val(5));
} else {
llvm_unreachable("Unexpected MMALayout version");
}
assert(rank == 2); assert(rank == 2);
assert(mmaLayout.getVersion() == 2 &&
"mmaLayout ver1 not implemented yet");
SmallVector<Value> multiDimOffset(rank); SmallVector<Value> multiDimOffset(rank);
multiDimOffset[0] = elemId < 2 ? mmaColIdx[0] : mmaColIdx[1]; if (mmaLayout.getVersion() == 2) {
multiDimOffset[1] = elemId % 2 == 0 ? mmaRowIdx[0] : mmaRowIdx[1]; multiDimOffset[0] = elemId < 2 ? mmaColIdx[0] : mmaColIdx[1];
multiDimOffset[0] = add(multiDimOffset[0], multiDimOffset[1] = elemId % 2 == 0 ? mmaRowIdx[0] : mmaRowIdx[1];
idx_val(multiDimCTAInRepId[0] * shapePerCTA[0])); multiDimOffset[0] = add(
multiDimOffset[1] = add(multiDimOffset[1], multiDimOffset[0], idx_val(multiDimCTAInRepId[0] * shapePerCTA[0]));
idx_val(multiDimCTAInRepId[1] * shapePerCTA[1])); multiDimOffset[1] = add(
multiDimOffset[1], idx_val(multiDimCTAInRepId[1] * shapePerCTA[1]));
} else if (mmaLayout.getVersion() == 1) {
// the order of elements in a thread:
// c0, c1, c4, c5
// c2, c3, c6, c7
if (elemId < 2) {
multiDimOffset[0] = mmaColIdx[elemId % 2];
multiDimOffset[1] = mmaRowIdx[0];
} else if (elemId >= 2 && elemId < 4) {
multiDimOffset[0] = mmaColIdx[elemId % 2];
multiDimOffset[1] = mmaRowIdx[1];
} else if (elemId >= 4 && elemId < 6) {
multiDimOffset[0] = mmaColIdx[elemId % 2 + 2];
multiDimOffset[1] = mmaRowIdx[0];
} else if (elemId >= 6) {
multiDimOffset[0] = mmaColIdx[elemId % 2 + 2];
multiDimOffset[1] = mmaRowIdx[1];
}
multiDimOffset[0] = add(
multiDimOffset[0], idx_val(multiDimCTAInRepId[0] * shapePerCTA[0]));
multiDimOffset[1] = add(
multiDimOffset[1], idx_val(multiDimCTAInRepId[1] * shapePerCTA[1]));
} else {
llvm_unreachable("Unexpected MMALayout version");
}
return multiDimOffset; return multiDimOffset;
} }
llvm_unreachable("unexpected layout in getMultiDimOffset"); llvm_unreachable("unexpected layout in getMultiDimOffset");

View File

@@ -78,9 +78,9 @@ SmallVector<unsigned> getThreadsPerWarp(Attribute layout) {
} }
if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) { if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
if (mmaLayout.getVersion() == 1) if (mmaLayout.getVersion() == 1)
return SmallVector<unsigned>{4, 8}; return {4, 8};
if (mmaLayout.getVersion() == 2) if (mmaLayout.getVersion() == 2)
return SmallVector<unsigned>{8, 4}; return {8, 4};
} }
assert(0 && "getThreadsPerWarp not implemented"); assert(0 && "getThreadsPerWarp not implemented");
return {}; return {};
@@ -106,9 +106,13 @@ SmallVector<unsigned> getSizePerThread(Attribute layout) {
} else if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) { } else if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
return getSizePerThread(sliceLayout.getParent()); return getSizePerThread(sliceLayout.getParent());
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) { } else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
assert(mmaLayout.getVersion() == 2 && if (mmaLayout.getVersion() == 2) {
"mmaLayout version = 1 is not implemented yet"); return {2, 2};
return SmallVector<unsigned>{2, 2}; } else if (mmaLayout.getVersion() == 1) {
return {2, 4};
} else {
llvm_unreachable("Unexpected mma version");
}
} else if (auto dotLayout = layout.dyn_cast<DotOperandEncodingAttr>()) { } else if (auto dotLayout = layout.dyn_cast<DotOperandEncodingAttr>()) {
auto parentLayout = dotLayout.getParent(); auto parentLayout = dotLayout.getParent();
assert(parentLayout && "DotOperandEncodingAttr must have a parent"); assert(parentLayout && "DotOperandEncodingAttr must have a parent");
@@ -194,6 +198,16 @@ SmallVector<unsigned> getShapePerCTA(const Attribute &layout) {
assert(0 && "DotOperandEncodingAttr non-MmaEncodingAttr parent not " assert(0 && "DotOperandEncodingAttr non-MmaEncodingAttr parent not "
"supported yet"); "supported yet");
} }
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
if (mmaLayout.getVersion() == 2) {
return {16 * mmaLayout.getWarpsPerCTA()[0],
8 * mmaLayout.getWarpsPerCTA()[1]};
} else if (mmaLayout.getVersion() == 1) {
return {16 * mmaLayout.getWarpsPerCTA()[0],
16 * mmaLayout.getWarpsPerCTA()[1]};
} else {
llvm_unreachable("Unexpected mma version");
}
} else { } else {
assert(0 && "Unimplemented usage of getShapePerCTA"); assert(0 && "Unimplemented usage of getShapePerCTA");
} }
@@ -205,9 +219,9 @@ SmallVector<unsigned> getOrder(const Attribute &layout) {
return SmallVector<unsigned>(blockedLayout.getOrder().begin(), return SmallVector<unsigned>(blockedLayout.getOrder().begin(),
blockedLayout.getOrder().end()); blockedLayout.getOrder().end());
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) { } else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
return SmallVector<unsigned>{1, 0}; return {1, 0};
} else if (auto dotLayout = layout.dyn_cast<DotOperandEncodingAttr>()) { } else if (auto dotLayout = layout.dyn_cast<DotOperandEncodingAttr>()) {
return SmallVector<unsigned>{1, 0}; return {1, 0};
} else if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) { } else if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
SmallVector<unsigned> parentOrder = getOrder(sliceLayout.getParent()); SmallVector<unsigned> parentOrder = getOrder(sliceLayout.getParent());
unsigned dim = sliceLayout.getDim(); unsigned dim = sliceLayout.getDim();
@@ -358,6 +372,8 @@ unsigned MmaEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
unsigned elemsCol = ceil<unsigned>(shape[0], 16 * getWarpsPerCTA()[0]) * 2; unsigned elemsCol = ceil<unsigned>(shape[0], 16 * getWarpsPerCTA()[0]) * 2;
unsigned elemsRow = ceil<unsigned>(shape[1], 8 * getWarpsPerCTA()[1]) * 2; unsigned elemsRow = ceil<unsigned>(shape[1], 8 * getWarpsPerCTA()[1]) * 2;
res = elemsCol * elemsRow; res = elemsCol * elemsRow;
} else {
llvm_unreachable("Unexpected mma version");
} }
return res; return res;

View File

@@ -712,8 +712,8 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
#mma = #triton_gpu.mma<{version = 2, warpsPerCTA = [2, 2]}> #mma = #triton_gpu.mma<{version = 2, warpsPerCTA = [2, 2]}>
module attributes {"triton_gpu.num-warps" = 1 : i32} { module attributes {"triton_gpu.num-warps" = 1 : i32} {
// CHECK: llvm.mlir.global external @global_smem() {addr_space = 3 : i32} : !llvm.array<0 x i8> // CHECK: llvm.mlir.global external @global_smem() {addr_space = 3 : i32} : !llvm.array<0 x i8>
// CHECK-LABEL: convert_layout_mma_block // CHECK-LABEL: convert_layout_mmav2_block
func @convert_layout_mma_blocked(%arg0: tensor<32x16xf32, #mma>) { func @convert_layout_mmav2_blocked(%arg0: tensor<32x16xf32, #mma>) {
// CHECK: llvm.store // CHECK: llvm.store
// CHECK-SAME: !llvm.ptr<vector<2xf32>, 3> // CHECK-SAME: !llvm.ptr<vector<2xf32>, 3>
// CHECK: llvm.store // CHECK: llvm.store
@@ -728,6 +728,26 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
// ----- // -----
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [1, 0]}>
#mma = #triton_gpu.mma<{version = 1, warpsPerCTA = [2, 1]}>
module attributes {"triton_gpu.num-warps" = 1 : i32} {
// CHECK: llvm.mlir.global external @global_smem() {addr_space = 3 : i32} : !llvm.array<0 x i8>
// CHECK-LABEL: convert_layout_mmav1_block
func @convert_layout_mmav1_blocked(%arg0: tensor<32x16xf32, #mma>) {
// CHECK: llvm.store
// CHECK-SAME: !llvm.ptr<vector<4xf32>, 3>
// CHECK: llvm.store
// CHECK-SAME: !llvm.ptr<vector<4xf32>, 3>
// CHECK: nvvm.barrier0
// CHECK: llvm.load
// CHECK-SAME: !llvm.ptr<vector<4xf32>, 3>
%0 = triton_gpu.convert_layout %arg0 : (tensor<32x16xf32, #mma>) -> tensor<32x16xf32, #blocked0>
return
}
}
// -----
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [8, 1], order = [1, 0]}> #blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [8, 1], order = [1, 0]}>
#shared0 = #triton_gpu.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0]}> #shared0 = #triton_gpu.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0]}>
module attributes {"triton_gpu.num-warps" = 1 : i32} { module attributes {"triton_gpu.num-warps" = 1 : i32} {