[TritonIR] simplify Load/StoreOps when mask is true/false (#79)
* [TritonIR] fix Load/Store/CopyAsyncOp's parsers * [TritonIR] simplify Load/StoreOps when mask is true/false * [TEST] adds tests to check load/store simplification
This commit is contained in:
@@ -36,6 +36,68 @@ static Type getPointerTypeFromTensor(Type type) {
|
||||
return Type();
|
||||
}
|
||||
|
||||
// Parser & printer for assembly forms
|
||||
ParseResult parseLoadOp(OpAsmParser &parser, OperationState &result) {
|
||||
SmallVector<OpAsmParser::OperandType, 4> allOperands;
|
||||
Type resultTypes[1];
|
||||
SMLoc allOperandLoc = parser.getCurrentLocation();
|
||||
if (parser.parseOperandList(allOperands) ||
|
||||
parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
|
||||
parser.parseCustomTypeWithFallback(resultTypes[0]))
|
||||
return failure();
|
||||
|
||||
result.addTypes(resultTypes);
|
||||
|
||||
SmallVector<Type> operandTypes;
|
||||
operandTypes.push_back(getPointerTypeFromTensor(resultTypes[0])); // ptr
|
||||
if (allOperands.size() >= 2)
|
||||
operandTypes.push_back(getI1SameShape(resultTypes[0])); // mask
|
||||
if (allOperands.size() >= 3)
|
||||
operandTypes.push_back(resultTypes[0]); // other
|
||||
|
||||
if (parser.resolveOperands(allOperands, operandTypes, allOperandLoc,
|
||||
result.operands))
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
|
||||
void printLoadOp(OpAsmPrinter &printer, LoadOp loadOp) {
|
||||
printer << " ";
|
||||
printer << loadOp.getOperation()->getOperands();
|
||||
printer.printOptionalAttrDict(loadOp->getAttrs(), /*elidedAttrs=*/{});
|
||||
printer << " : ";
|
||||
printer.printStrippedAttrOrType(loadOp.result().getType());
|
||||
}
|
||||
|
||||
ParseResult parseStoreOp(OpAsmParser &parser, OperationState &result) {
|
||||
SmallVector<OpAsmParser::OperandType, 4> allOperands;
|
||||
Type valueType;
|
||||
SMLoc allOperandLoc = parser.getCurrentLocation();
|
||||
if (parser.parseOperandList(allOperands) ||
|
||||
parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
|
||||
parser.parseCustomTypeWithFallback(valueType))
|
||||
return failure();
|
||||
|
||||
SmallVector<Type> operandTypes;
|
||||
operandTypes.push_back(getPointerTypeFromTensor(valueType)); // ptr
|
||||
operandTypes.push_back(valueType); // value
|
||||
if (allOperands.size() >= 3)
|
||||
operandTypes.push_back(getI1SameShape(valueType)); // mask
|
||||
|
||||
if (parser.resolveOperands(allOperands, operandTypes, allOperandLoc,
|
||||
result.operands))
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
|
||||
void printStoreOp(OpAsmPrinter &printer, StoreOp storeOp) {
|
||||
printer << " ";
|
||||
printer << storeOp.getOperation()->getOperands();
|
||||
printer.printOptionalAttrDict(storeOp->getAttrs(), /*elidedAttrs=*/{});
|
||||
printer << " : ";
|
||||
printer.printStrippedAttrOrType(storeOp.value().getType());
|
||||
}
|
||||
|
||||
} // namespace triton
|
||||
} // namespace mlir
|
||||
|
||||
|
@@ -92,6 +92,91 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
// load(ptr, splat(1), ...) -> load(ptr, ...)
|
||||
// load(ptr, splat(0), other, ...) -> other
|
||||
struct CanonicalizeMaskedLoadPattern
|
||||
: public mlir::OpRewritePattern<triton::LoadOp> {
|
||||
CanonicalizeMaskedLoadPattern(mlir::MLIRContext *context)
|
||||
: OpRewritePattern<triton::LoadOp>(context, 1) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(triton::LoadOp loadOp,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
auto mask = loadOp.mask();
|
||||
if (!mask)
|
||||
return mlir::failure();
|
||||
|
||||
auto constantMask = llvm::dyn_cast<arith::ConstantOp>(mask.getDefiningOp());
|
||||
if (!constantMask)
|
||||
return mlir::failure();
|
||||
|
||||
auto splatMask = constantMask.getValue().dyn_cast<SplatElementsAttr>();
|
||||
if (!splatMask)
|
||||
return mlir::failure();
|
||||
|
||||
if (splatMask.getSplatValue<IntegerAttr>().getValue() == true) {
|
||||
// mask = splat(1)
|
||||
rewriter.replaceOpWithNewOp<triton::LoadOp>(
|
||||
loadOp, loadOp.getType(), loadOp.ptr(), Value(), Value(),
|
||||
loadOp.cache(), loadOp.evict(), loadOp.isVolatile());
|
||||
} else {
|
||||
// mask = splat(0)
|
||||
|
||||
// If there's no "other", the value is "undef". Perhaps we want to
|
||||
// optimize it in the future.x
|
||||
auto otherVal = loadOp.other();
|
||||
if (!otherVal)
|
||||
return mlir::failure();
|
||||
rewriter.replaceOp(loadOp, otherVal);
|
||||
}
|
||||
return mlir::success();
|
||||
}
|
||||
};
|
||||
|
||||
void triton::LoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
||||
MLIRContext *context) {
|
||||
results.add<CanonicalizeMaskedLoadPattern>(context);
|
||||
}
|
||||
|
||||
// store(ptr, value, splat(1), ...) -> store(ptr, value, ...)
|
||||
// store(ptr, value, splat(0), ...) -> [none]
|
||||
struct CanonicalizeMaskedStorePattern
|
||||
: public mlir::OpRewritePattern<triton::StoreOp> {
|
||||
CanonicalizeMaskedStorePattern(mlir::MLIRContext *context)
|
||||
: OpRewritePattern<triton::StoreOp>(context, 1) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(triton::StoreOp storeOp,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
auto mask = storeOp.mask();
|
||||
if (!mask)
|
||||
return mlir::failure();
|
||||
|
||||
auto constantMask = llvm::dyn_cast<arith::ConstantOp>(mask.getDefiningOp());
|
||||
if (!constantMask)
|
||||
return mlir::failure();
|
||||
|
||||
auto splatMask = constantMask.getValue().dyn_cast<SplatElementsAttr>();
|
||||
if (!splatMask)
|
||||
return mlir::failure();
|
||||
|
||||
if (splatMask.getSplatValue<IntegerAttr>().getValue() == true) {
|
||||
// mask = splat(1)
|
||||
rewriter.replaceOpWithNewOp<triton::StoreOp>(storeOp, storeOp.ptr(),
|
||||
storeOp.value());
|
||||
} else {
|
||||
// mask = splat(0)
|
||||
rewriter.eraseOp(storeOp);
|
||||
}
|
||||
return mlir::success();
|
||||
}
|
||||
};
|
||||
|
||||
void triton::StoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
||||
MLIRContext *context) {
|
||||
results.add<CanonicalizeMaskedStorePattern>(context);
|
||||
}
|
||||
|
||||
#define GEN_PASS_CLASSES
|
||||
#include "triton/Dialect/Triton/Transforms/Passes.h.inc"
|
||||
|
||||
|
Reference in New Issue
Block a user