[Triton-MLIR][BACKEND] Fix masked load store op vector size (#785)

Correct the Load/Store Op's vector size with the mask's alignment
correctly considered.

Some cases:

```mlir
// num_warp = 2
// block_size = 128
func @vecadd_mask_align_16(%a_ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %b_ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, 
  %out_ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n_elements: i32 {tt.divisibility = 16 : i32}) {
    // mask = make_range(128) < n_element
}
```
This should get the vec=2 `ld`/`st` instructions.

While the following example

```mlir
// num_warp = 2
// block_size = 128
func @vecadd_mask_align_16(%a_ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %b_ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, 
  %out_ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n_elements: i32) {
    // mask = make_range(128) < n_element
}
```
it should get the vec=1 `ld`/`st` instructions.
This commit is contained in:
Yan Chunwei
2022-10-18 11:43:50 +08:00
committed by GitHub
parent 38a80664b5
commit 4464646efb
6 changed files with 383 additions and 227 deletions

View File

@@ -188,7 +188,7 @@ def test_vecadd_no_scf(num_warps, block_size, shape):
[2, 256, (3, 256 + 7)],
[4, 256, (3, 256 + 7)],
])
def test_vecadd__no_scf_masked(num_warps, block_size, shape):
def test_vecadd_no_scf_masked(num_warps, block_size, shape):
vecadd_no_scf_tester(num_warps, block_size, shape)