@@ -50,21 +50,34 @@ ParseResult parseLoadOp(OpAsmParser &parser, OperationState &result) {
|
||||
|
||||
SmallVector<Type> operandTypes;
|
||||
operandTypes.push_back(getPointerTypeFromTensor(resultTypes[0])); // ptr
|
||||
if (allOperands.size() >= 2)
|
||||
int hasMask = 0, hasOther = 0;
|
||||
if (allOperands.size() >= 2) {
|
||||
operandTypes.push_back(getI1SameShape(resultTypes[0])); // mask
|
||||
if (allOperands.size() >= 3)
|
||||
hasMask = 1;
|
||||
}
|
||||
if (allOperands.size() >= 3) {
|
||||
operandTypes.push_back(resultTypes[0]); // other
|
||||
hasOther = 1;
|
||||
}
|
||||
|
||||
if (parser.resolveOperands(allOperands, operandTypes, allOperandLoc,
|
||||
result.operands))
|
||||
return failure();
|
||||
// Deduce operand_segment_sizes from the number of the operands.
|
||||
auto operand_segment_sizesAttrName =
|
||||
LoadOp::operand_segment_sizesAttrName(result.name);
|
||||
result.addAttribute(
|
||||
operand_segment_sizesAttrName,
|
||||
parser.getBuilder().getI32VectorAttr({1, hasMask, hasOther}));
|
||||
return success();
|
||||
}
|
||||
|
||||
void printLoadOp(OpAsmPrinter &printer, LoadOp loadOp) {
|
||||
printer << " ";
|
||||
printer << loadOp.getOperation()->getOperands();
|
||||
printer.printOptionalAttrDict(loadOp->getAttrs(), /*elidedAttrs=*/{});
|
||||
// "operand_segment_sizes" can be deduced, so we don't print it.
|
||||
printer.printOptionalAttrDict(loadOp->getAttrs(),
|
||||
{loadOp.operand_segment_sizesAttrName()});
|
||||
printer << " : ";
|
||||
printer.printStrippedAttrOrType(loadOp.result().getType());
|
||||
}
|
||||
@@ -148,6 +161,9 @@ void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
|
||||
state.addOperands(other);
|
||||
}
|
||||
}
|
||||
state.addAttribute(
|
||||
operand_segment_sizesAttrName(state.name),
|
||||
builder.getI32VectorAttr({1, (mask ? 1 : 0), (other ? 1 : 0)}));
|
||||
state.addAttribute(
|
||||
cacheAttrName(state.name),
|
||||
::mlir::triton::CacheModifierAttr::get(builder.getContext(), cache));
|
||||
|
Reference in New Issue
Block a user