[Triton-MLIR] Replace triton.extract_slice with tensor.extract_slice and support more general tensor slicing (#837)

## Features

- Allow taking a block of tensor slice, as long as each dimension is
contiguous (unit stride).
- Fix some problems in `insert_slice_async`'s semantic.
- More general verification for ops that return shared layout encoding.

## Known Limitations

- `insert_slice_async` still uses the old semantic. May submit another
PR later to support similar semantic like `tensor.extract_slice`.
- No encoding verification for `tensor.extract_slice`.
- 3d tensor ops are broken.
- Strided accesses are not allowed.
- May cause a little performance slowdown since we are passing strides
as values but not constants (e.g., int).
It would be difficult to pass strides as attributes when we have control
flows. A block argument is possible to accept tensors with different
strides.
This commit is contained in:
Keren Zhou
2022-11-06 22:59:03 -08:00
committed by GitHub
parent a4ff0c362c
commit fdd59900f7
26 changed files with 507 additions and 339 deletions

View File

@@ -346,18 +346,24 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK: llvm.mlir.global external @global_smem
// CHECK-LABEL: basic_extract_slice
func @basic_extract_slice() {
// CHECK: %[[BASE0:.*]] = llvm.mlir.addressof @global_smem
// CHECK-NEXT: %[[BASE1:.*]] = llvm.bitcast %[[BASE0]]
// CHECK-NEXT: %[[OFFSET0:.*]] = llvm.mlir.constant
// CHECK-NEXT: %[[OFFSET1:.*]] = llvm.mlir.constant
// CHECK-NEXT: llvm.getelementptr %[[BASE1]][%[[OFFSET1]]]
// CHECK-NEXT: %[[BASE2:.*]] = llvm.bitcast
// CHECK-NEXT: %[[OFFSET2:.*]] = llvm.mlir.constant
// CHECK-NEXT: %[[OFFSET3:.*]] = llvm.mul %[[OFFSET0]], %[[OFFSET2]]
// CHECK-NEXT: llvm.getelementptr %[[BASE2]][%[[OFFSET3]]]
%index = arith.constant 1 : i32
// CHECK: llvm.mlir.addressof @global_smem
// CHECK: llvm.extractvalue
// CHECK-NEXT: llvm.extractvalue
// CHECK-NEXT: llvm.extractvalue
// CHECK-NEXT: llvm.extractvalue
// CHECK-NEXT: llvm.mlir.constant(0 : i32) : i32
// CHECK-NEXT: llvm.mlir.constant(0 : i32) : i32
// CHECK-NEXT: llvm.mlir.constant(0 : i32) : i32
// CHECK-NEXT: llvm.mul
// CHECK-NEXT: llvm.add
// CHECK-NEXT: llvm.mul
// CHECK-NEXT: llvm.add
// CHECK-NEXT: llvm.mul
// CHECK-NEXT: llvm.add
// CHECK-NEXT: llvm.getelementptr
%index = arith.constant 1 : index
%0 = triton_gpu.alloc_tensor : tensor<128x16x32xf32, #shared0>
%1 = triton_gpu.extract_slice %0, %index {axis = 0: i32} : tensor<128x16x32xf32, #shared0> -> tensor<16x32xf32, #shared0>
%1 = tensor.extract_slice %0[%index, 0, 0][1, 16, 32][1, 1, 1] : tensor<128x16x32xf32, #shared0> to tensor<16x32xf32, #shared0>
return
}
}
@@ -488,22 +494,38 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
%tensor = triton_gpu.alloc_tensor : tensor<2x32x32xf32, #A>
%index = arith.constant 1 : i32
// CHECK: llvm.mlir.constant(0 : i32) : i32
// CHECK: llvm.add
// CHECK: llvm.inline_asm
// CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4
// CHECK: llvm.mlir.constant(0 : i32) : i32
// CHECK: llvm.add
// CHECK: llvm.inline_asm
// CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4
// CHECK: llvm.mlir.constant(0 : i32) : i32
// CHECK: llvm.add
// CHECK: llvm.inline_asm
// CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4
// CHECK: llvm.mlir.constant(0 : i32) : i32
// CHECK: llvm.add
// CHECK: llvm.inline_asm
// CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4
// CHECK: llvm.mlir.constant(16 : i32) : i32
// CHECK: llvm.add
// CHECK: llvm.inline_asm
// CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 2048 ], [ ${{.*}} + 0 ], 0x4, 0x4
// CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4
// CHECK: llvm.mlir.constant(16 : i32) : i32
// CHECK: llvm.add
// CHECK: llvm.inline_asm
// CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 2048 ], [ ${{.*}} + 0 ], 0x4, 0x4
// CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4
// CHECK: llvm.mlir.constant(16 : i32) : i32
// CHECK: llvm.add
// CHECK: llvm.inline_asm
// CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 2048 ], [ ${{.*}} + 0 ], 0x4, 0x4
// CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4
// CHECK: llvm.mlir.constant(16 : i32) : i32
// CHECK: llvm.add
// CHECK: llvm.inline_asm
// CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 2048 ], [ ${{.*}} + 0 ], 0x4, 0x4
// CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4
// CHECK: llvm.inline_asm
// CHECK-SAME: cp.async.commit_group
%a = triton_gpu.insert_slice_async %a_ptr, %tensor, %index {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x32x!tt.ptr<f32>, #AL> -> tensor<2x32x32xf32, #A>