[Triton-MLIR][BACKEND] Fix masked load store op vector size (#785)

Correct the Load/Store Op's vector size with the mask's alignment
correctly considered.

Some cases:

```mlir
// num_warp = 2
// block_size = 128
func @vecadd_mask_align_16(%a_ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %b_ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, 
  %out_ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n_elements: i32 {tt.divisibility = 16 : i32}) {
    // mask = make_range(128) < n_element
}
```
This should get the vec=2 `ld`/`st` instructions.

While the following example

```mlir
// num_warp = 2
// block_size = 128
func @vecadd_mask_align_16(%a_ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %b_ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, 
  %out_ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n_elements: i32) {
    // mask = make_range(128) < n_element
}
```
it should get the vec=1 `ld`/`st` instructions.
This commit is contained in:
Yan Chunwei
2022-10-18 11:43:50 +08:00
committed by GitHub
parent 38a80664b5
commit 4464646efb
6 changed files with 383 additions and 227 deletions

View File

@@ -1,4 +1,5 @@
#include "triton/Analysis/AxisInfo.h"
#include "triton/Analysis/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
#include <numeric>
@@ -23,6 +24,11 @@ struct CoalescePass : public TritonGPUCoalesceBase<CoalescePass> {
std::sort(order.begin(), order.end(), [&](unsigned x, unsigned y) {
return contiguity[x] > contiguity[y];
});
int numElems = product(origType.getShape());
int numThreads = numWarps * 32;
int numElemsPerThread = std::max(numElems / numThreads, 1);
// Thread tile size depends on memory alignment
SmallVector<unsigned, 4> sizePerThread(rank, 1);
PointerType ptrType = origType.getElementType().cast<PointerType>();
@@ -31,7 +37,8 @@ struct CoalescePass : public TritonGPUCoalesceBase<CoalescePass> {
unsigned maxContig = info.getContiguity(order[0]);
unsigned alignment = std::min(maxMultiple, maxContig);
unsigned perThread = std::min(alignment, 128 / numBits);
sizePerThread[order[0]] = perThread;
sizePerThread[order[0]] = std::min<int>(perThread, numElemsPerThread);
SmallVector<unsigned> dims(rank);
std::iota(dims.begin(), dims.end(), 0);
// create encoding