[Triton-IR] Fix LoadOp definition (#771) (#777)

This commit is contained in:
Shintaro Iwasaki
2022-10-13 18:53:00 -07:00
committed by GitHub
parent 963d031247
commit 5898352f97
4 changed files with 42 additions and 16 deletions

View File

@@ -50,21 +50,34 @@ ParseResult parseLoadOp(OpAsmParser &parser, OperationState &result) {
SmallVector<Type> operandTypes;
operandTypes.push_back(getPointerTypeFromTensor(resultTypes[0])); // ptr
if (allOperands.size() >= 2)
int hasMask = 0, hasOther = 0;
if (allOperands.size() >= 2) {
operandTypes.push_back(getI1SameShape(resultTypes[0])); // mask
if (allOperands.size() >= 3)
hasMask = 1;
}
if (allOperands.size() >= 3) {
operandTypes.push_back(resultTypes[0]); // other
hasOther = 1;
}
if (parser.resolveOperands(allOperands, operandTypes, allOperandLoc,
result.operands))
return failure();
// Deduce operand_segment_sizes from the number of the operands.
auto operand_segment_sizesAttrName =
LoadOp::operand_segment_sizesAttrName(result.name);
result.addAttribute(
operand_segment_sizesAttrName,
parser.getBuilder().getI32VectorAttr({1, hasMask, hasOther}));
return success();
}
void printLoadOp(OpAsmPrinter &printer, LoadOp loadOp) {
printer << " ";
printer << loadOp.getOperation()->getOperands();
printer.printOptionalAttrDict(loadOp->getAttrs(), /*elidedAttrs=*/{});
// "operand_segment_sizes" can be deduced, so we don't print it.
printer.printOptionalAttrDict(loadOp->getAttrs(),
{loadOp.operand_segment_sizesAttrName()});
printer << " : ";
printer.printStrippedAttrOrType(loadOp.result().getType());
}
@@ -148,6 +161,9 @@ void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
state.addOperands(other);
}
}
state.addAttribute(
operand_segment_sizesAttrName(state.name),
builder.getI32VectorAttr({1, (mask ? 1 : 0), (other ? 1 : 0)}));
state.addAttribute(
cacheAttrName(state.name),
::mlir::triton::CacheModifierAttr::get(builder.getContext(), cache));