[Triton-MLIR][BACKEND] Fix masked load store op vector size (#785)
Correct the Load/Store Op's vector size with the mask's alignment correctly considered. Some cases: ```mlir // num_warp = 2 // block_size = 128 func @vecadd_mask_align_16(%a_ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %b_ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %out_ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n_elements: i32 {tt.divisibility = 16 : i32}) { // mask = make_range(128) < n_element } ``` This should get the vec=2 `ld`/`st` instructions. While the following example ```mlir // num_warp = 2 // block_size = 128 func @vecadd_mask_align_16(%a_ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %b_ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %out_ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n_elements: i32) { // mask = make_range(128) < n_element } ``` it should get the vec=1 `ld`/`st` instructions.
This commit is contained in:
@@ -209,6 +209,33 @@ ChangeResult AxisInfoAnalysis::visitOperation(
|
||||
}
|
||||
curr = AxisInfo(contiguity, divisibility, constancy);
|
||||
}
|
||||
|
||||
// CmpI
|
||||
if ((llvm::dyn_cast<arith::CmpIOp>(op) ||
|
||||
llvm::dyn_cast<triton::gpu::CmpIOp>(op)) &&
|
||||
op->getResult(0).getType().dyn_cast<TensorType>()) {
|
||||
auto resTy = op->getResult(0).getType().cast<TensorType>();
|
||||
short rank = resTy.getRank();
|
||||
auto lhsInfo = operands[0]->getValue();
|
||||
auto rhsInfo = operands[1]->getValue();
|
||||
auto shape = resTy.getShape();
|
||||
|
||||
AxisInfo::DimVectorT contiguity, divisibility, constancy;
|
||||
for (short d = 0; d < rank; ++d) {
|
||||
if (rhsInfo.getConstancy(d) % lhsInfo.getContiguity(d) == 0 ||
|
||||
rhsInfo.getConstancy(d) % lhsInfo.getConstancy(d))
|
||||
constancy.push_back(
|
||||
gcd(lhsInfo.getDivisibility(d), rhsInfo.getDivisibility(d)));
|
||||
else
|
||||
constancy.push_back(1);
|
||||
|
||||
divisibility.push_back(shape[d]);
|
||||
contiguity.push_back(1);
|
||||
}
|
||||
|
||||
curr = AxisInfo(contiguity, divisibility, constancy);
|
||||
}
|
||||
|
||||
// UnrealizedConversionCast
|
||||
// This is needed by TritonGPUToLLVM, to get AxisInfo when the graph is
|
||||
// in the process of a PartialConversion, where UnrealizedConversionCast
|
||||
@@ -219,7 +246,8 @@ ChangeResult AxisInfoAnalysis::visitOperation(
|
||||
if (curr.getRank() == 0) {
|
||||
return markAllPessimisticFixpoint(op->getResults());
|
||||
}
|
||||
// join all latice elements
|
||||
|
||||
// join all lattice elements
|
||||
ChangeResult result = ChangeResult::NoChange;
|
||||
for (Value value : op->getResults()) {
|
||||
result |= getLatticeElement(value).join(curr);
|
||||
|
Reference in New Issue
Block a user