[Triton-MLIR] Keren/code gen for extract slice and alloc tensor (#692)
Co-authored-by: gzhu <goostavz@outlook.com>
This commit is contained in:
@@ -293,6 +293,44 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
|
||||
// -----
|
||||
|
||||
#shared0 = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK: llvm.mlir.global internal @global_smem
|
||||
// CHECK-LABEL: basic_alloc_tensor
|
||||
func @basic_alloc_tensor() {
|
||||
// CHECK: llvm.mlir.addressof @global_smem
|
||||
// CHECK-NEXT: llvm.mlir.constant
|
||||
// CHECK-NEXT: llvm.getelementptr
|
||||
// CHECK-NEXT: llvm.bitcast
|
||||
%0 = triton_gpu.alloc_tensor : tensor<16x16xf16, #shared0>
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
#shared0 = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK: llvm.mlir.global internal @global_smem
|
||||
// CHECK-LABEL: basic_extract_slice
|
||||
func @basic_extract_slice() {
|
||||
// CHECK: %[[BASE0:.*]] = llvm.mlir.addressof @global_smem
|
||||
// CHECK-NEXT: %[[OFFSET0:.*]] = llvm.mlir.constant
|
||||
// CHECK-NEXT: %[[OFFSET1:.*]] = llvm.mlir.constant
|
||||
// CHECK-NEXT: llvm.getelementptr %[[BASE0]][%[[OFFSET1]]]
|
||||
// CHECK-NEXT: %[[BASE1:.*]] = llvm.bitcast
|
||||
// CHECK-NEXT: %[[OFFSET2:.*]] = llvm.mlir.constant
|
||||
// CHECK-NEXT: %[[OFFSET3:.*]] = llvm.mul %[[OFFSET0]], %[[OFFSET2]]
|
||||
// CHECK-NEXT: llvm.getelementptr %[[BASE1]][%[[OFFSET3]]]
|
||||
%index = arith.constant 1 : i32
|
||||
%0 = triton_gpu.alloc_tensor : tensor<128x16x32xf32, #shared0>
|
||||
%1 = triton_gpu.extract_slice %0, %index {axis = 0: i32} : tensor<128x16x32xf32, #shared0> -> tensor<16x32xf32, #shared0>
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK: basic_splat
|
||||
|
Reference in New Issue
Block a user