[Triton-MLIR] Replace triton.extract_slice with tensor.extract_slice and support more general tensor slicing (#837)
## Features - Allow taking a block of tensor slice, as long as each dimension is contiguous (unit stride). - Fix some problems in `insert_slice_async`'s semantic. - More general verification for ops that return shared layout encoding. ## Known Limitations - `insert_slice_async` still uses the old semantic. May submit another PR later to support similar semantic like `tensor.extract_slice`. - No encoding verification for `tensor.extract_slice`. - 3d tensor ops are broken. - Strided accesses are not allowed. - May cause a little performance slowdown since we are passing strides as values but not constants (e.g., int). It would be difficult to pass strides as attributes when we have control flows. A block argument is possible to accept tensors with different strides.
This commit is contained in:
@@ -346,18 +346,24 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK: llvm.mlir.global external @global_smem
|
||||
// CHECK-LABEL: basic_extract_slice
|
||||
func @basic_extract_slice() {
|
||||
// CHECK: %[[BASE0:.*]] = llvm.mlir.addressof @global_smem
|
||||
// CHECK-NEXT: %[[BASE1:.*]] = llvm.bitcast %[[BASE0]]
|
||||
// CHECK-NEXT: %[[OFFSET0:.*]] = llvm.mlir.constant
|
||||
// CHECK-NEXT: %[[OFFSET1:.*]] = llvm.mlir.constant
|
||||
// CHECK-NEXT: llvm.getelementptr %[[BASE1]][%[[OFFSET1]]]
|
||||
// CHECK-NEXT: %[[BASE2:.*]] = llvm.bitcast
|
||||
// CHECK-NEXT: %[[OFFSET2:.*]] = llvm.mlir.constant
|
||||
// CHECK-NEXT: %[[OFFSET3:.*]] = llvm.mul %[[OFFSET0]], %[[OFFSET2]]
|
||||
// CHECK-NEXT: llvm.getelementptr %[[BASE2]][%[[OFFSET3]]]
|
||||
%index = arith.constant 1 : i32
|
||||
// CHECK: llvm.mlir.addressof @global_smem
|
||||
// CHECK: llvm.extractvalue
|
||||
// CHECK-NEXT: llvm.extractvalue
|
||||
// CHECK-NEXT: llvm.extractvalue
|
||||
// CHECK-NEXT: llvm.extractvalue
|
||||
// CHECK-NEXT: llvm.mlir.constant(0 : i32) : i32
|
||||
// CHECK-NEXT: llvm.mlir.constant(0 : i32) : i32
|
||||
// CHECK-NEXT: llvm.mlir.constant(0 : i32) : i32
|
||||
// CHECK-NEXT: llvm.mul
|
||||
// CHECK-NEXT: llvm.add
|
||||
// CHECK-NEXT: llvm.mul
|
||||
// CHECK-NEXT: llvm.add
|
||||
// CHECK-NEXT: llvm.mul
|
||||
// CHECK-NEXT: llvm.add
|
||||
// CHECK-NEXT: llvm.getelementptr
|
||||
%index = arith.constant 1 : index
|
||||
%0 = triton_gpu.alloc_tensor : tensor<128x16x32xf32, #shared0>
|
||||
%1 = triton_gpu.extract_slice %0, %index {axis = 0: i32} : tensor<128x16x32xf32, #shared0> -> tensor<16x32xf32, #shared0>
|
||||
%1 = tensor.extract_slice %0[%index, 0, 0][1, 16, 32][1, 1, 1] : tensor<128x16x32xf32, #shared0> to tensor<16x32xf32, #shared0>
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -488,22 +494,38 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
%tensor = triton_gpu.alloc_tensor : tensor<2x32x32xf32, #A>
|
||||
%index = arith.constant 1 : i32
|
||||
|
||||
// CHECK: llvm.mlir.constant(0 : i32) : i32
|
||||
// CHECK: llvm.add
|
||||
// CHECK: llvm.inline_asm
|
||||
// CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4
|
||||
// CHECK: llvm.mlir.constant(0 : i32) : i32
|
||||
// CHECK: llvm.add
|
||||
// CHECK: llvm.inline_asm
|
||||
// CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4
|
||||
// CHECK: llvm.mlir.constant(0 : i32) : i32
|
||||
// CHECK: llvm.add
|
||||
// CHECK: llvm.inline_asm
|
||||
// CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4
|
||||
// CHECK: llvm.mlir.constant(0 : i32) : i32
|
||||
// CHECK: llvm.add
|
||||
// CHECK: llvm.inline_asm
|
||||
// CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4
|
||||
// CHECK: llvm.mlir.constant(16 : i32) : i32
|
||||
// CHECK: llvm.add
|
||||
// CHECK: llvm.inline_asm
|
||||
// CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 2048 ], [ ${{.*}} + 0 ], 0x4, 0x4
|
||||
// CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4
|
||||
// CHECK: llvm.mlir.constant(16 : i32) : i32
|
||||
// CHECK: llvm.add
|
||||
// CHECK: llvm.inline_asm
|
||||
// CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 2048 ], [ ${{.*}} + 0 ], 0x4, 0x4
|
||||
// CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4
|
||||
// CHECK: llvm.mlir.constant(16 : i32) : i32
|
||||
// CHECK: llvm.add
|
||||
// CHECK: llvm.inline_asm
|
||||
// CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 2048 ], [ ${{.*}} + 0 ], 0x4, 0x4
|
||||
// CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4
|
||||
// CHECK: llvm.mlir.constant(16 : i32) : i32
|
||||
// CHECK: llvm.add
|
||||
// CHECK: llvm.inline_asm
|
||||
// CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 2048 ], [ ${{.*}} + 0 ], 0x4, 0x4
|
||||
// CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4
|
||||
// CHECK: llvm.inline_asm
|
||||
// CHECK-SAME: cp.async.commit_group
|
||||
%a = triton_gpu.insert_slice_async %a_ptr, %tensor, %index {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x32x!tt.ptr<f32>, #AL> -> tensor<2x32x32xf32, #A>
|
||||
|
Reference in New Issue
Block a user