[Triton-MLIR][BACKEND] Fix masked load store op vector size (#785)

Correct the Load/Store Op's vector size with the mask's alignment
correctly considered.

Some cases:

```mlir
// num_warp = 2
// block_size = 128
func @vecadd_mask_align_16(%a_ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %b_ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, 
  %out_ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n_elements: i32 {tt.divisibility = 16 : i32}) {
    // mask = make_range(128) < n_element
}
```
This should get the vec=2 `ld`/`st` instructions.

While the following example

```mlir
// num_warp = 2
// block_size = 128
func @vecadd_mask_align_16(%a_ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %b_ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, 
  %out_ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n_elements: i32) {
    // mask = make_range(128) < n_element
}
```
it should get the vec=1 `ld`/`st` instructions.
This commit is contained in:
Yan Chunwei
2022-10-18 11:43:50 +08:00
committed by GitHub
parent 38a80664b5
commit 4464646efb
6 changed files with 383 additions and 227 deletions

View File

@@ -209,6 +209,33 @@ ChangeResult AxisInfoAnalysis::visitOperation(
}
curr = AxisInfo(contiguity, divisibility, constancy);
}
// CmpI
if ((llvm::dyn_cast<arith::CmpIOp>(op) ||
llvm::dyn_cast<triton::gpu::CmpIOp>(op)) &&
op->getResult(0).getType().dyn_cast<TensorType>()) {
auto resTy = op->getResult(0).getType().cast<TensorType>();
short rank = resTy.getRank();
auto lhsInfo = operands[0]->getValue();
auto rhsInfo = operands[1]->getValue();
auto shape = resTy.getShape();
AxisInfo::DimVectorT contiguity, divisibility, constancy;
for (short d = 0; d < rank; ++d) {
if (rhsInfo.getConstancy(d) % lhsInfo.getContiguity(d) == 0 ||
rhsInfo.getConstancy(d) % lhsInfo.getConstancy(d))
constancy.push_back(
gcd(lhsInfo.getDivisibility(d), rhsInfo.getDivisibility(d)));
else
constancy.push_back(1);
divisibility.push_back(shape[d]);
contiguity.push_back(1);
}
curr = AxisInfo(contiguity, divisibility, constancy);
}
// UnrealizedConversionCast
// This is needed by TritonGPUToLLVM, to get AxisInfo when the graph is
// in the process of a PartialConversion, where UnrealizedConversionCast
@@ -219,7 +246,8 @@ ChangeResult AxisInfoAnalysis::visitOperation(
if (curr.getRank() == 0) {
return markAllPessimisticFixpoint(op->getResults());
}
// join all latice elements
// join all lattice elements
ChangeResult result = ChangeResult::NoChange;
for (Value value : op->getResults()) {
result |= getLatticeElement(value).join(curr);