[Triton-MLIR][Backend] Support ConvertLayout blocked->shared and a few fixes related with mma(#716)
This commit is contained in:
@@ -299,6 +299,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-LABEL: basic_alloc_tensor
|
||||
func @basic_alloc_tensor() {
|
||||
// CHECK: llvm.mlir.addressof @global_smem
|
||||
// CHECK-NEXT: llvm.bitcast
|
||||
// CHECK-NEXT: llvm.mlir.constant
|
||||
// CHECK-NEXT: llvm.getelementptr
|
||||
// CHECK-NEXT: llvm.bitcast
|
||||
@@ -315,13 +316,14 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-LABEL: basic_extract_slice
|
||||
func @basic_extract_slice() {
|
||||
// CHECK: %[[BASE0:.*]] = llvm.mlir.addressof @global_smem
|
||||
// CHECK-NEXT: %[[BASE1:.*]] = llvm.bitcast %[[BASE0]]
|
||||
// CHECK-NEXT: %[[OFFSET0:.*]] = llvm.mlir.constant
|
||||
// CHECK-NEXT: %[[OFFSET1:.*]] = llvm.mlir.constant
|
||||
// CHECK-NEXT: llvm.getelementptr %[[BASE0]][%[[OFFSET1]]]
|
||||
// CHECK-NEXT: %[[BASE1:.*]] = llvm.bitcast
|
||||
// CHECK-NEXT: llvm.getelementptr %[[BASE1]][%[[OFFSET1]]]
|
||||
// CHECK-NEXT: %[[BASE2:.*]] = llvm.bitcast
|
||||
// CHECK-NEXT: %[[OFFSET2:.*]] = llvm.mlir.constant
|
||||
// CHECK-NEXT: %[[OFFSET3:.*]] = llvm.mul %[[OFFSET0]], %[[OFFSET2]]
|
||||
// CHECK-NEXT: llvm.getelementptr %[[BASE1]][%[[OFFSET3]]]
|
||||
// CHECK-NEXT: llvm.getelementptr %[[BASE2]][%[[OFFSET3]]]
|
||||
%index = arith.constant 1 : i32
|
||||
%0 = triton_gpu.alloc_tensor : tensor<128x16x32xf32, #shared0>
|
||||
%1 = triton_gpu.extract_slice %0, %index {axis = 0: i32} : tensor<128x16x32xf32, #shared0> -> tensor<16x32xf32, #shared0>
|
||||
@@ -515,3 +517,20 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [8, 1], order = [1, 0]}>
|
||||
#shared0 = #triton_gpu.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
// CHECK: llvm.mlir.global internal @global_smem() {addr_space = 3 : i32} : !llvm.array<16384 x i8>
|
||||
// CHECK-LABEL: convert_layout_blocked_shared
|
||||
func @convert_layout_blocked_shared(%arg0: tensor<128x32xf32, #blocked0>) {
|
||||
// CHECK: llvm.store
|
||||
// CHECK-SAME: !llvm.ptr<vector<8xf32>, 3>
|
||||
// CHECK: llvm.store
|
||||
// CHECK-SAME: !llvm.ptr<vector<8xf32>, 3>
|
||||
%0 = triton_gpu.convert_layout %arg0 : (tensor<128x32xf32, #blocked0>) -> tensor<128x32xf32, #shared0>
|
||||
return
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user