[Triton-MLIR] Keren/code gen for extract slice and alloc tensor (#692)
Co-authored-by: gzhu <goostavz@outlook.com>
This commit is contained in:
@@ -22,11 +22,9 @@ func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B
|
||||
|
||||
scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr<f16>, #AL>, tensor<32x128x!tt.ptr<f16>, #BL>, tensor<128x128xf32, #C>) {
|
||||
%a_ = tt.load %a_ptr, %a_mask, %a_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL>
|
||||
// CHECK: scratch offset = 8192, size = 0
|
||||
// CHECK-NEXT: offset = 0, size = 8192
|
||||
// CHECK: offset = 0, size = 8192
|
||||
%a = triton_gpu.convert_layout %a_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A>
|
||||
%b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL>
|
||||
// CHECK-NEXT: scratch offset = 16384, size = 0
|
||||
// CHECK-NEXT: offset = 8192, size = 8192
|
||||
%b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B>
|
||||
|
||||
@@ -52,20 +50,16 @@ func @reusable(%A : !tt.ptr<f16>) {
|
||||
%a_ptr = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<128x32x!tt.ptr<f16>, #AL>
|
||||
%b_ptr = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<32x128x!tt.ptr<f16>, #AL>
|
||||
%a1_ = tt.load %a_ptr, %cst1, %cst2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL>
|
||||
// CHECK: scratch offset = 8192, size = 0
|
||||
// CHECK-NEXT: offset = 0, size = 8192
|
||||
%a1 = triton_gpu.convert_layout %a1_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A>
|
||||
%a2_ = tt.load %b_ptr, %cst3, %cst4 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #AL>
|
||||
// CHECK-NEXT: scratch offset = 16384, size = 0
|
||||
// CHECK-NEXT: offset = 8192, size = 8192
|
||||
%a2 = triton_gpu.convert_layout %a2_ : (tensor<32x128xf16, #AL>) -> tensor<32x128xf16, #A>
|
||||
%a3_ = tt.load %a_ptr, %cst1, %cst2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL>
|
||||
// CHECK-NEXT: scratch offset = 24576, size = 0
|
||||
// CHECK-NEXT: offset = 16384, size = 8192
|
||||
%a3 = triton_gpu.convert_layout %a3_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A>
|
||||
%c = tt.dot %a1, %a2, %c_init {allowTF32 = true} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C>
|
||||
%a4_ = tt.load %b_ptr, %cst3, %cst4 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #AL>
|
||||
// CHECK-NEXT: scratch offset = 8192, size = 0
|
||||
// CHECK-NEXT: offset = 0, size = 8192
|
||||
%a4 = triton_gpu.convert_layout %a4_ : (tensor<32x128xf16, #AL>) -> tensor<32x128xf16, #A>
|
||||
%c1 = tt.dot %a3, %a4, %c {allowTF32 = true} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C>
|
||||
|
@@ -293,6 +293,44 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
|
||||
// -----
|
||||
|
||||
#shared0 = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK: llvm.mlir.global internal @global_smem
|
||||
// CHECK-LABEL: basic_alloc_tensor
|
||||
func @basic_alloc_tensor() {
|
||||
// CHECK: llvm.mlir.addressof @global_smem
|
||||
// CHECK-NEXT: llvm.mlir.constant
|
||||
// CHECK-NEXT: llvm.getelementptr
|
||||
// CHECK-NEXT: llvm.bitcast
|
||||
%0 = triton_gpu.alloc_tensor : tensor<16x16xf16, #shared0>
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
#shared0 = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK: llvm.mlir.global internal @global_smem
|
||||
// CHECK-LABEL: basic_extract_slice
|
||||
func @basic_extract_slice() {
|
||||
// CHECK: %[[BASE0:.*]] = llvm.mlir.addressof @global_smem
|
||||
// CHECK-NEXT: %[[OFFSET0:.*]] = llvm.mlir.constant
|
||||
// CHECK-NEXT: %[[OFFSET1:.*]] = llvm.mlir.constant
|
||||
// CHECK-NEXT: llvm.getelementptr %[[BASE0]][%[[OFFSET1]]]
|
||||
// CHECK-NEXT: %[[BASE1:.*]] = llvm.bitcast
|
||||
// CHECK-NEXT: %[[OFFSET2:.*]] = llvm.mlir.constant
|
||||
// CHECK-NEXT: %[[OFFSET3:.*]] = llvm.mul %[[OFFSET0]], %[[OFFSET2]]
|
||||
// CHECK-NEXT: llvm.getelementptr %[[BASE1]][%[[OFFSET3]]]
|
||||
%index = arith.constant 1 : i32
|
||||
%0 = triton_gpu.alloc_tensor : tensor<128x16x32xf32, #shared0>
|
||||
%1 = triton_gpu.extract_slice %0, %index {axis = 0: i32} : tensor<128x16x32xf32, #shared0> -> tensor<16x32xf32, #shared0>
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK: basic_splat
|
||||
|
Reference in New Issue
Block a user