[Triton-MLIR][Backend] support MMA v1 in ConvertLayout (#922)

The e2e verification of mma v1 is not done yet. 
Get this merged in advance just to prevent more conflicts.
This commit is contained in:
goostavz
2022-11-28 16:10:30 +08:00
committed by GitHub
parent 9d31998a9d
commit 0c1d4d764e
3 changed files with 116 additions and 32 deletions

View File

@@ -712,8 +712,8 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
#mma = #triton_gpu.mma<{version = 2, warpsPerCTA = [2, 2]}>
module attributes {"triton_gpu.num-warps" = 1 : i32} {
// CHECK: llvm.mlir.global external @global_smem() {addr_space = 3 : i32} : !llvm.array<0 x i8>
// CHECK-LABEL: convert_layout_mma_block
func @convert_layout_mma_blocked(%arg0: tensor<32x16xf32, #mma>) {
// CHECK-LABEL: convert_layout_mmav2_block
func @convert_layout_mmav2_blocked(%arg0: tensor<32x16xf32, #mma>) {
// CHECK: llvm.store
// CHECK-SAME: !llvm.ptr<vector<2xf32>, 3>
// CHECK: llvm.store
@@ -728,6 +728,26 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
// -----
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [1, 0]}>
#mma = #triton_gpu.mma<{version = 1, warpsPerCTA = [2, 1]}>
module attributes {"triton_gpu.num-warps" = 1 : i32} {
// CHECK: llvm.mlir.global external @global_smem() {addr_space = 3 : i32} : !llvm.array<0 x i8>
// CHECK-LABEL: convert_layout_mmav1_block
func @convert_layout_mmav1_blocked(%arg0: tensor<32x16xf32, #mma>) {
// CHECK: llvm.store
// CHECK-SAME: !llvm.ptr<vector<4xf32>, 3>
// CHECK: llvm.store
// CHECK-SAME: !llvm.ptr<vector<4xf32>, 3>
// CHECK: nvvm.barrier0
// CHECK: llvm.load
// CHECK-SAME: !llvm.ptr<vector<4xf32>, 3>
%0 = triton_gpu.convert_layout %arg0 : (tensor<32x16xf32, #mma>) -> tensor<32x16xf32, #blocked0>
return
}
}
// -----
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [8, 1], order = [1, 0]}>
#shared0 = #triton_gpu.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0]}>
module attributes {"triton_gpu.num-warps" = 1 : i32} {