[TritonIR] simplify Load/StoreOps when mask is true/false (#79)

* [TritonIR] fix Load/Store/CopyAsyncOp's parsers

* [TritonIR] simplify Load/StoreOps when mask is true/false

* [TEST] adds tests to check load/store simplification
This commit is contained in:
Shintaro Iwasaki
2022-08-24 12:55:49 -07:00
committed by GitHub
parent 1b513c9866
commit 84aa7d025a
6 changed files with 269 additions and 36 deletions

View File

@@ -36,6 +36,68 @@ static Type getPointerTypeFromTensor(Type type) {
return Type();
}
// Parser & printer for assembly forms
ParseResult parseLoadOp(OpAsmParser &parser, OperationState &result) {
SmallVector<OpAsmParser::OperandType, 4> allOperands;
Type resultTypes[1];
SMLoc allOperandLoc = parser.getCurrentLocation();
if (parser.parseOperandList(allOperands) ||
parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
parser.parseCustomTypeWithFallback(resultTypes[0]))
return failure();
result.addTypes(resultTypes);
SmallVector<Type> operandTypes;
operandTypes.push_back(getPointerTypeFromTensor(resultTypes[0])); // ptr
if (allOperands.size() >= 2)
operandTypes.push_back(getI1SameShape(resultTypes[0])); // mask
if (allOperands.size() >= 3)
operandTypes.push_back(resultTypes[0]); // other
if (parser.resolveOperands(allOperands, operandTypes, allOperandLoc,
result.operands))
return failure();
return success();
}
void printLoadOp(OpAsmPrinter &printer, LoadOp loadOp) {
printer << " ";
printer << loadOp.getOperation()->getOperands();
printer.printOptionalAttrDict(loadOp->getAttrs(), /*elidedAttrs=*/{});
printer << " : ";
printer.printStrippedAttrOrType(loadOp.result().getType());
}
ParseResult parseStoreOp(OpAsmParser &parser, OperationState &result) {
SmallVector<OpAsmParser::OperandType, 4> allOperands;
Type valueType;
SMLoc allOperandLoc = parser.getCurrentLocation();
if (parser.parseOperandList(allOperands) ||
parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
parser.parseCustomTypeWithFallback(valueType))
return failure();
SmallVector<Type> operandTypes;
operandTypes.push_back(getPointerTypeFromTensor(valueType)); // ptr
operandTypes.push_back(valueType); // value
if (allOperands.size() >= 3)
operandTypes.push_back(getI1SameShape(valueType)); // mask
if (parser.resolveOperands(allOperands, operandTypes, allOperandLoc,
result.operands))
return failure();
return success();
}
void printStoreOp(OpAsmPrinter &printer, StoreOp storeOp) {
printer << " ";
printer << storeOp.getOperation()->getOperands();
printer.printOptionalAttrDict(storeOp->getAttrs(), /*elidedAttrs=*/{});
printer << " : ";
printer.printStrippedAttrOrType(storeOp.value().getType());
}
} // namespace triton
} // namespace mlir