[TritonIR] simplify Load/StoreOps when mask is true/false (#79)
* [TritonIR] fix Load/Store/CopyAsyncOp's parsers * [TritonIR] simplify Load/StoreOps when mask is true/false * [TEST] adds tests to check load/store simplification
This commit is contained in:
@@ -92,6 +92,91 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
// load(ptr, splat(1), ...) -> load(ptr, ...)
|
||||
// load(ptr, splat(0), other, ...) -> other
|
||||
struct CanonicalizeMaskedLoadPattern
|
||||
: public mlir::OpRewritePattern<triton::LoadOp> {
|
||||
CanonicalizeMaskedLoadPattern(mlir::MLIRContext *context)
|
||||
: OpRewritePattern<triton::LoadOp>(context, 1) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(triton::LoadOp loadOp,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
auto mask = loadOp.mask();
|
||||
if (!mask)
|
||||
return mlir::failure();
|
||||
|
||||
auto constantMask = llvm::dyn_cast<arith::ConstantOp>(mask.getDefiningOp());
|
||||
if (!constantMask)
|
||||
return mlir::failure();
|
||||
|
||||
auto splatMask = constantMask.getValue().dyn_cast<SplatElementsAttr>();
|
||||
if (!splatMask)
|
||||
return mlir::failure();
|
||||
|
||||
if (splatMask.getSplatValue<IntegerAttr>().getValue() == true) {
|
||||
// mask = splat(1)
|
||||
rewriter.replaceOpWithNewOp<triton::LoadOp>(
|
||||
loadOp, loadOp.getType(), loadOp.ptr(), Value(), Value(),
|
||||
loadOp.cache(), loadOp.evict(), loadOp.isVolatile());
|
||||
} else {
|
||||
// mask = splat(0)
|
||||
|
||||
// If there's no "other", the value is "undef". Perhaps we want to
|
||||
// optimize it in the future.x
|
||||
auto otherVal = loadOp.other();
|
||||
if (!otherVal)
|
||||
return mlir::failure();
|
||||
rewriter.replaceOp(loadOp, otherVal);
|
||||
}
|
||||
return mlir::success();
|
||||
}
|
||||
};
|
||||
|
||||
void triton::LoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
||||
MLIRContext *context) {
|
||||
results.add<CanonicalizeMaskedLoadPattern>(context);
|
||||
}
|
||||
|
||||
// store(ptr, value, splat(1), ...) -> store(ptr, value, ...)
|
||||
// store(ptr, value, splat(0), ...) -> [none]
|
||||
struct CanonicalizeMaskedStorePattern
|
||||
: public mlir::OpRewritePattern<triton::StoreOp> {
|
||||
CanonicalizeMaskedStorePattern(mlir::MLIRContext *context)
|
||||
: OpRewritePattern<triton::StoreOp>(context, 1) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(triton::StoreOp storeOp,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
auto mask = storeOp.mask();
|
||||
if (!mask)
|
||||
return mlir::failure();
|
||||
|
||||
auto constantMask = llvm::dyn_cast<arith::ConstantOp>(mask.getDefiningOp());
|
||||
if (!constantMask)
|
||||
return mlir::failure();
|
||||
|
||||
auto splatMask = constantMask.getValue().dyn_cast<SplatElementsAttr>();
|
||||
if (!splatMask)
|
||||
return mlir::failure();
|
||||
|
||||
if (splatMask.getSplatValue<IntegerAttr>().getValue() == true) {
|
||||
// mask = splat(1)
|
||||
rewriter.replaceOpWithNewOp<triton::StoreOp>(storeOp, storeOp.ptr(),
|
||||
storeOp.value());
|
||||
} else {
|
||||
// mask = splat(0)
|
||||
rewriter.eraseOp(storeOp);
|
||||
}
|
||||
return mlir::success();
|
||||
}
|
||||
};
|
||||
|
||||
void triton::StoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
||||
MLIRContext *context) {
|
||||
results.add<CanonicalizeMaskedStorePattern>(context);
|
||||
}
|
||||
|
||||
#define GEN_PASS_CLASSES
|
||||
#include "triton/Dialect/Triton/Transforms/Passes.h.inc"
|
||||
|
||||
|
Reference in New Issue
Block a user