[Triton-MLIR][BACKEND] Fix masked load store op vector size (#785)
Correct the Load/Store Op's vector size with the mask's alignment correctly considered. Some cases: ```mlir // num_warp = 2 // block_size = 128 func @vecadd_mask_align_16(%a_ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %b_ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %out_ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n_elements: i32 {tt.divisibility = 16 : i32}) { // mask = make_range(128) < n_element } ``` This should get the vec=2 `ld`/`st` instructions. While the following example ```mlir // num_warp = 2 // block_size = 128 func @vecadd_mask_align_16(%a_ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %b_ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %out_ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n_elements: i32) { // mask = make_range(128) < n_element } ``` it should get the vec=1 `ld`/`st` instructions.
This commit is contained in:
@@ -50,3 +50,92 @@ func @permute_2d(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {t
|
||||
tt.store %19, %20, %cst : tensor<128x128xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
module {
|
||||
|
||||
// This is a tiny test for verifying StoreOp-related alignment, It simply store a constant to a buffer.
|
||||
func @store_constant_align(%addr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n: i32 {tt.divisibility = 16 : i32}) {
|
||||
// CHECK: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1]
|
||||
%pid = tt.get_program_id {axis = 0 : i32} : i32
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [128] ; Constancy: [1]
|
||||
%c128_i32 = arith.constant 128 : i32
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [128] ; Constancy: [1]
|
||||
%1 = arith.muli %pid, %c128_i32 : i32
|
||||
// CHECK-NEXT: Contiguity: [128] ; Divisibility: [65536] ; Constancy: [1]
|
||||
%2 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [128] ; Constancy: [128]
|
||||
%3 = tt.splat %1 : (i32) -> tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [128] ; Divisibility: [128] ; Constancy: [1]
|
||||
%4 = arith.addi %3, %2 : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [16] ; Constancy: [128]
|
||||
%5 = tt.splat %addr : (!tt.ptr<f32>) -> tensor<128x!tt.ptr<f32>>
|
||||
// CHECK-NEXT: Contiguity: [128] ; Divisibility: [16] ; Constancy: [1]
|
||||
%6 = tt.addptr %5, %4 : tensor<128x!tt.ptr<f32>>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [16] ; Constancy: [128]
|
||||
%9 = tt.splat %n : (i32) -> tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [128] ; Constancy: [16]
|
||||
%mask = arith.cmpi slt, %4, %9 : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1]
|
||||
%cst = arith.constant dense<0.0> : tensor<128xf32>
|
||||
tt.store %5, %cst, %mask : tensor<128xf32>
|
||||
return
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// This IR is dumped from vecadd test.
|
||||
// Note, the hint {tt.divisibility = 16 : i32} for %n_elements affects the alignment of mask.
|
||||
func @vecadd_mask_align_16(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n_elements: i32 {tt.divisibility = 16 : i32}) {
|
||||
%c64_i32 = arith.constant 64 : i32
|
||||
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
||||
%1 = arith.muli %0, %c64_i32 : i32
|
||||
%2 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
|
||||
%3 = tt.splat %1 : (i32) -> tensor<64xi32>
|
||||
%4 = arith.addi %3, %2 : tensor<64xi32>
|
||||
%5 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<64x!tt.ptr<f32>>
|
||||
%6 = tt.addptr %5, %4 : tensor<64x!tt.ptr<f32>>
|
||||
%7 = tt.splat %arg1 : (!tt.ptr<f32>) -> tensor<64x!tt.ptr<f32>>
|
||||
%8 = tt.addptr %7, %4 : tensor<64x!tt.ptr<f32>>
|
||||
%9 = tt.splat %n_elements : (i32) -> tensor<64xi32>
|
||||
// CHECK: Contiguity: [1] ; Divisibility: [64] ; Constancy: [16] ( %{{.*}} = arith.cmpi slt, %{{.*}}, %{{.*}} : tensor<64xi32> )
|
||||
%mask = arith.cmpi slt, %4, %9 : tensor<64xi32>
|
||||
%11 = tt.load %6, %mask {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64xf32>
|
||||
%12 = tt.load %8, %mask {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64xf32>
|
||||
%13 = arith.addf %11, %12 : tensor<64xf32>
|
||||
%14 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<64x!tt.ptr<f32>>
|
||||
// CHECK: Contiguity: [64] ; Divisibility: [16] ; Constancy: [1] ( %{{.*}} = tt.addptr %{{.*}}, %{{.*}} : tensor<64x!tt.ptr<f32>> )
|
||||
%15 = tt.addptr %14, %4 : tensor<64x!tt.ptr<f32>>
|
||||
tt.store %15, %13, %mask : tensor<64xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// This IR is dumped from vecadd test.
|
||||
// Note, there is no divisibility hint for %n_elements, Triton should assume its divisibility to be 1 by default.
|
||||
func @vecadd_mask_align_1(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n_elements: i32) {
|
||||
%c64_i32 = arith.constant 64 : i32
|
||||
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
||||
%1 = arith.muli %0, %c64_i32 : i32
|
||||
%2 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
|
||||
%3 = tt.splat %1 : (i32) -> tensor<64xi32>
|
||||
%4 = arith.addi %3, %2 : tensor<64xi32>
|
||||
%5 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<64x!tt.ptr<f32>>
|
||||
%6 = tt.addptr %5, %4 : tensor<64x!tt.ptr<f32>>
|
||||
%7 = tt.splat %arg1 : (!tt.ptr<f32>) -> tensor<64x!tt.ptr<f32>>
|
||||
%8 = tt.addptr %7, %4 : tensor<64x!tt.ptr<f32>>
|
||||
%9 = tt.splat %n_elements : (i32) -> tensor<64xi32>
|
||||
// CHECK: Contiguity: [1] ; Divisibility: [64] ; Constancy: [1] ( %{{.*}} = arith.cmpi slt, %{{.*}}, %{{.*}} : tensor<64xi32> )
|
||||
%10 = arith.cmpi slt, %4, %9 : tensor<64xi32>
|
||||
%11 = tt.load %6, %10 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64xf32>
|
||||
%12 = tt.load %8, %10 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64xf32>
|
||||
%13 = arith.addf %11, %12 : tensor<64xf32>
|
||||
%14 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<64x!tt.ptr<f32>>
|
||||
%15 = tt.addptr %14, %4 : tensor<64x!tt.ptr<f32>>
|
||||
tt.store %15, %13, %10 : tensor<64xf32>
|
||||
return
|
||||
}
|
||||
|
Reference in New Issue
Block a user