[Triton-MLIR][BACKEND] Fix masked load store op vector size (#785)
Correct the Load/Store Op's vector size with the mask's alignment correctly considered. Some cases: ```mlir // num_warp = 2 // block_size = 128 func @vecadd_mask_align_16(%a_ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %b_ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %out_ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n_elements: i32 {tt.divisibility = 16 : i32}) { // mask = make_range(128) < n_element } ``` This should get the vec=2 `ld`/`st` instructions. While the following example ```mlir // num_warp = 2 // block_size = 128 func @vecadd_mask_align_16(%a_ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %b_ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %out_ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n_elements: i32) { // mask = make_range(128) < n_element } ``` it should get the vec=1 `ld`/`st` instructions.
This commit is contained in:
@@ -188,7 +188,7 @@ def test_vecadd_no_scf(num_warps, block_size, shape):
|
||||
[2, 256, (3, 256 + 7)],
|
||||
[4, 256, (3, 256 + 7)],
|
||||
])
|
||||
def test_vecadd__no_scf_masked(num_warps, block_size, shape):
|
||||
def test_vecadd_no_scf_masked(num_warps, block_size, shape):
|
||||
vecadd_no_scf_tester(num_warps, block_size, shape)
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user