From 0c1d4d764edb4986dc675dc3c4e1890e06a7fc73 Mon Sep 17 00:00:00 2001 From: goostavz <109190422+goostavz@users.noreply.github.com> Date: Mon, 28 Nov 2022 16:10:30 +0800 Subject: [PATCH] [Triton-MLIR][Backend] support MMA v1 in ConvertLayout (#922) The e2e verification of mma v1 is not done yet. Get this merged in advance just to prevent more conflicts. --- .../TritonGPUToLLVM/TritonGPUToLLVM.cpp | 94 ++++++++++++++----- lib/Dialect/TritonGPU/IR/Dialect.cpp | 30 ++++-- test/Conversion/tritongpu_to_llvm.mlir | 24 ++++- 3 files changed, 116 insertions(+), 32 deletions(-) diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index 322c43446..921895dc9 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -2926,7 +2926,7 @@ private: return multiDimOffset; } if (auto mmaLayout = layout.dyn_cast()) { - SmallVector mmaColIdx(2); + SmallVector mmaColIdx(4); SmallVector mmaRowIdx(2); Value threadId = getThreadId(rewriter, loc); Value warpSize = idx_val(32); @@ -2936,31 +2936,79 @@ private: SmallVector multiDimWarpId(2); multiDimWarpId[0] = urem(warpId, idx_val(mmaLayout.getWarpsPerCTA()[0])); multiDimWarpId[1] = udiv(warpId, idx_val(mmaLayout.getWarpsPerCTA()[0])); - multiDimWarpId[0] = urem(multiDimWarpId[0], idx_val(shape[0] / 16)); - multiDimWarpId[1] = urem(multiDimWarpId[1], idx_val(shape[1] / 8)); - Value four = idx_val(4); - Value mmaGrpId = udiv(laneId, four); - Value mmaGrpIdP8 = add(mmaGrpId, idx_val(8)); - Value mmaThreadIdInGrp = urem(laneId, four); - Value mmaThreadIdInGrpM2 = mul(mmaThreadIdInGrp, idx_val(2)); - Value mmaThreadIdInGrpM2P1 = add(mmaThreadIdInGrpM2, idx_val(1)); - Value colWarpOffset = mul(multiDimWarpId[0], idx_val(16)); - mmaColIdx[0] = add(mmaGrpId, colWarpOffset); - mmaColIdx[1] = add(mmaGrpIdP8, colWarpOffset); - Value rowWarpOffset = mul(multiDimWarpId[1], idx_val(8)); - mmaRowIdx[0] = add(mmaThreadIdInGrpM2, rowWarpOffset); - mmaRowIdx[1] = add(mmaThreadIdInGrpM2P1, rowWarpOffset); + Value _1 = idx_val(1); + Value _2 = idx_val(2); + Value _4 = idx_val(4); + Value _8 = idx_val(8); + Value _16 = idx_val(16); + if (mmaLayout.getVersion() == 2) { + multiDimWarpId[0] = urem(multiDimWarpId[0], idx_val(shape[0] / 16)); + multiDimWarpId[1] = urem(multiDimWarpId[1], idx_val(shape[1] / 8)); + Value mmaGrpId = udiv(laneId, _4); + Value mmaGrpIdP8 = add(mmaGrpId, _8); + Value mmaThreadIdInGrp = urem(laneId, _4); + Value mmaThreadIdInGrpM2 = mul(mmaThreadIdInGrp, _2); + Value mmaThreadIdInGrpM2P1 = add(mmaThreadIdInGrpM2, _1); + Value colWarpOffset = mul(multiDimWarpId[0], _16); + mmaColIdx[0] = add(mmaGrpId, colWarpOffset); + mmaColIdx[1] = add(mmaGrpIdP8, colWarpOffset); + Value rowWarpOffset = mul(multiDimWarpId[1], _8); + mmaRowIdx[0] = add(mmaThreadIdInGrpM2, rowWarpOffset); + mmaRowIdx[1] = add(mmaThreadIdInGrpM2P1, rowWarpOffset); + } else if (mmaLayout.getVersion() == 1) { + multiDimWarpId[0] = urem(multiDimWarpId[0], idx_val(shape[0] / 16)); + multiDimWarpId[1] = urem(multiDimWarpId[1], idx_val(shape[1] / 16)); + Value partId = udiv(laneId, _4); + Value partIdDiv4 = udiv(partId, _4); + Value partIdRem4 = urem(partId, _4); + Value partRowOffset = mul(udiv(partIdRem4, _2), _8); + partRowOffset = add(mul(partIdDiv4, _4), partRowOffset); + Value partColOffset = mul(urem(partIdRem4, _2), _8); + Value colOffset = add(mul(multiDimWarpId[0], _16), partColOffset); + Value rowOffset = add(mul(multiDimWarpId[1], _16), partRowOffset); + mmaRowIdx[0] = add(urem(laneId, _2), rowOffset); + mmaRowIdx[1] = add(mmaRowIdx[0], _2); + mmaColIdx[0] = add(udiv(urem(laneId, _4), _2), colOffset); + mmaColIdx[1] = add(mmaColIdx[0], _1); + mmaColIdx[2] = add(mmaColIdx[0], _4); + mmaColIdx[3] = add(mmaColIdx[0], idx_val(5)); + } else { + llvm_unreachable("Unexpected MMALayout version"); + } assert(rank == 2); - assert(mmaLayout.getVersion() == 2 && - "mmaLayout ver1 not implemented yet"); SmallVector multiDimOffset(rank); - multiDimOffset[0] = elemId < 2 ? mmaColIdx[0] : mmaColIdx[1]; - multiDimOffset[1] = elemId % 2 == 0 ? mmaRowIdx[0] : mmaRowIdx[1]; - multiDimOffset[0] = add(multiDimOffset[0], - idx_val(multiDimCTAInRepId[0] * shapePerCTA[0])); - multiDimOffset[1] = add(multiDimOffset[1], - idx_val(multiDimCTAInRepId[1] * shapePerCTA[1])); + if (mmaLayout.getVersion() == 2) { + multiDimOffset[0] = elemId < 2 ? mmaColIdx[0] : mmaColIdx[1]; + multiDimOffset[1] = elemId % 2 == 0 ? mmaRowIdx[0] : mmaRowIdx[1]; + multiDimOffset[0] = add( + multiDimOffset[0], idx_val(multiDimCTAInRepId[0] * shapePerCTA[0])); + multiDimOffset[1] = add( + multiDimOffset[1], idx_val(multiDimCTAInRepId[1] * shapePerCTA[1])); + } else if (mmaLayout.getVersion() == 1) { + // the order of elements in a thread: + // c0, c1, c4, c5 + // c2, c3, c6, c7 + if (elemId < 2) { + multiDimOffset[0] = mmaColIdx[elemId % 2]; + multiDimOffset[1] = mmaRowIdx[0]; + } else if (elemId >= 2 && elemId < 4) { + multiDimOffset[0] = mmaColIdx[elemId % 2]; + multiDimOffset[1] = mmaRowIdx[1]; + } else if (elemId >= 4 && elemId < 6) { + multiDimOffset[0] = mmaColIdx[elemId % 2 + 2]; + multiDimOffset[1] = mmaRowIdx[0]; + } else if (elemId >= 6) { + multiDimOffset[0] = mmaColIdx[elemId % 2 + 2]; + multiDimOffset[1] = mmaRowIdx[1]; + } + multiDimOffset[0] = add( + multiDimOffset[0], idx_val(multiDimCTAInRepId[0] * shapePerCTA[0])); + multiDimOffset[1] = add( + multiDimOffset[1], idx_val(multiDimCTAInRepId[1] * shapePerCTA[1])); + } else { + llvm_unreachable("Unexpected MMALayout version"); + } return multiDimOffset; } llvm_unreachable("unexpected layout in getMultiDimOffset"); diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index d603e823c..68c7f48a2 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -78,9 +78,9 @@ SmallVector getThreadsPerWarp(Attribute layout) { } if (auto mmaLayout = layout.dyn_cast()) { if (mmaLayout.getVersion() == 1) - return SmallVector{4, 8}; + return {4, 8}; if (mmaLayout.getVersion() == 2) - return SmallVector{8, 4}; + return {8, 4}; } assert(0 && "getThreadsPerWarp not implemented"); return {}; @@ -106,9 +106,13 @@ SmallVector getSizePerThread(Attribute layout) { } else if (auto sliceLayout = layout.dyn_cast()) { return getSizePerThread(sliceLayout.getParent()); } else if (auto mmaLayout = layout.dyn_cast()) { - assert(mmaLayout.getVersion() == 2 && - "mmaLayout version = 1 is not implemented yet"); - return SmallVector{2, 2}; + if (mmaLayout.getVersion() == 2) { + return {2, 2}; + } else if (mmaLayout.getVersion() == 1) { + return {2, 4}; + } else { + llvm_unreachable("Unexpected mma version"); + } } else if (auto dotLayout = layout.dyn_cast()) { auto parentLayout = dotLayout.getParent(); assert(parentLayout && "DotOperandEncodingAttr must have a parent"); @@ -194,6 +198,16 @@ SmallVector getShapePerCTA(const Attribute &layout) { assert(0 && "DotOperandEncodingAttr non-MmaEncodingAttr parent not " "supported yet"); } + } else if (auto mmaLayout = layout.dyn_cast()) { + if (mmaLayout.getVersion() == 2) { + return {16 * mmaLayout.getWarpsPerCTA()[0], + 8 * mmaLayout.getWarpsPerCTA()[1]}; + } else if (mmaLayout.getVersion() == 1) { + return {16 * mmaLayout.getWarpsPerCTA()[0], + 16 * mmaLayout.getWarpsPerCTA()[1]}; + } else { + llvm_unreachable("Unexpected mma version"); + } } else { assert(0 && "Unimplemented usage of getShapePerCTA"); } @@ -205,9 +219,9 @@ SmallVector getOrder(const Attribute &layout) { return SmallVector(blockedLayout.getOrder().begin(), blockedLayout.getOrder().end()); } else if (auto mmaLayout = layout.dyn_cast()) { - return SmallVector{1, 0}; + return {1, 0}; } else if (auto dotLayout = layout.dyn_cast()) { - return SmallVector{1, 0}; + return {1, 0}; } else if (auto sliceLayout = layout.dyn_cast()) { SmallVector parentOrder = getOrder(sliceLayout.getParent()); unsigned dim = sliceLayout.getDim(); @@ -358,6 +372,8 @@ unsigned MmaEncodingAttr::getElemsPerThread(ArrayRef shape) const { unsigned elemsCol = ceil(shape[0], 16 * getWarpsPerCTA()[0]) * 2; unsigned elemsRow = ceil(shape[1], 8 * getWarpsPerCTA()[1]) * 2; res = elemsCol * elemsRow; + } else { + llvm_unreachable("Unexpected mma version"); } return res; diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index a7fb4551a..ac632c13b 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -712,8 +712,8 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} { #mma = #triton_gpu.mma<{version = 2, warpsPerCTA = [2, 2]}> module attributes {"triton_gpu.num-warps" = 1 : i32} { // CHECK: llvm.mlir.global external @global_smem() {addr_space = 3 : i32} : !llvm.array<0 x i8> - // CHECK-LABEL: convert_layout_mma_block - func @convert_layout_mma_blocked(%arg0: tensor<32x16xf32, #mma>) { + // CHECK-LABEL: convert_layout_mmav2_block + func @convert_layout_mmav2_blocked(%arg0: tensor<32x16xf32, #mma>) { // CHECK: llvm.store // CHECK-SAME: !llvm.ptr, 3> // CHECK: llvm.store @@ -728,6 +728,26 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} { // ----- +#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [1, 0]}> +#mma = #triton_gpu.mma<{version = 1, warpsPerCTA = [2, 1]}> +module attributes {"triton_gpu.num-warps" = 1 : i32} { + // CHECK: llvm.mlir.global external @global_smem() {addr_space = 3 : i32} : !llvm.array<0 x i8> + // CHECK-LABEL: convert_layout_mmav1_block + func @convert_layout_mmav1_blocked(%arg0: tensor<32x16xf32, #mma>) { + // CHECK: llvm.store + // CHECK-SAME: !llvm.ptr, 3> + // CHECK: llvm.store + // CHECK-SAME: !llvm.ptr, 3> + // CHECK: nvvm.barrier0 + // CHECK: llvm.load + // CHECK-SAME: !llvm.ptr, 3> + %0 = triton_gpu.convert_layout %arg0 : (tensor<32x16xf32, #mma>) -> tensor<32x16xf32, #blocked0> + return + } +} + +// ----- + #blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [8, 1], order = [1, 0]}> #shared0 = #triton_gpu.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0]}> module attributes {"triton_gpu.num-warps" = 1 : i32} {