[TritonIR] simplify Load/StoreOps when mask is true/false (#79)

* [TritonIR] fix Load/Store/CopyAsyncOp's parsers

* [TritonIR] simplify Load/StoreOps when mask is true/false

* [TEST] adds tests to check load/store simplification
This commit is contained in:
Shintaro Iwasaki
2022-08-24 12:55:49 -07:00
committed by GitHub
parent 1b513c9866
commit 84aa7d025a
6 changed files with 269 additions and 36 deletions

View File

@@ -36,6 +36,68 @@ static Type getPointerTypeFromTensor(Type type) {
return Type();
}
// Parser & printer for assembly forms
ParseResult parseLoadOp(OpAsmParser &parser, OperationState &result) {
SmallVector<OpAsmParser::OperandType, 4> allOperands;
Type resultTypes[1];
SMLoc allOperandLoc = parser.getCurrentLocation();
if (parser.parseOperandList(allOperands) ||
parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
parser.parseCustomTypeWithFallback(resultTypes[0]))
return failure();
result.addTypes(resultTypes);
SmallVector<Type> operandTypes;
operandTypes.push_back(getPointerTypeFromTensor(resultTypes[0])); // ptr
if (allOperands.size() >= 2)
operandTypes.push_back(getI1SameShape(resultTypes[0])); // mask
if (allOperands.size() >= 3)
operandTypes.push_back(resultTypes[0]); // other
if (parser.resolveOperands(allOperands, operandTypes, allOperandLoc,
result.operands))
return failure();
return success();
}
void printLoadOp(OpAsmPrinter &printer, LoadOp loadOp) {
printer << " ";
printer << loadOp.getOperation()->getOperands();
printer.printOptionalAttrDict(loadOp->getAttrs(), /*elidedAttrs=*/{});
printer << " : ";
printer.printStrippedAttrOrType(loadOp.result().getType());
}
ParseResult parseStoreOp(OpAsmParser &parser, OperationState &result) {
SmallVector<OpAsmParser::OperandType, 4> allOperands;
Type valueType;
SMLoc allOperandLoc = parser.getCurrentLocation();
if (parser.parseOperandList(allOperands) ||
parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
parser.parseCustomTypeWithFallback(valueType))
return failure();
SmallVector<Type> operandTypes;
operandTypes.push_back(getPointerTypeFromTensor(valueType)); // ptr
operandTypes.push_back(valueType); // value
if (allOperands.size() >= 3)
operandTypes.push_back(getI1SameShape(valueType)); // mask
if (parser.resolveOperands(allOperands, operandTypes, allOperandLoc,
result.operands))
return failure();
return success();
}
void printStoreOp(OpAsmPrinter &printer, StoreOp storeOp) {
printer << " ";
printer << storeOp.getOperation()->getOperands();
printer.printOptionalAttrDict(storeOp->getAttrs(), /*elidedAttrs=*/{});
printer << " : ";
printer.printStrippedAttrOrType(storeOp.value().getType());
}
} // namespace triton
} // namespace mlir

View File

@@ -92,6 +92,91 @@ public:
}
};
// load(ptr, splat(1), ...) -> load(ptr, ...)
// load(ptr, splat(0), other, ...) -> other
struct CanonicalizeMaskedLoadPattern
: public mlir::OpRewritePattern<triton::LoadOp> {
CanonicalizeMaskedLoadPattern(mlir::MLIRContext *context)
: OpRewritePattern<triton::LoadOp>(context, 1) {}
mlir::LogicalResult
matchAndRewrite(triton::LoadOp loadOp,
mlir::PatternRewriter &rewriter) const override {
auto mask = loadOp.mask();
if (!mask)
return mlir::failure();
auto constantMask = llvm::dyn_cast<arith::ConstantOp>(mask.getDefiningOp());
if (!constantMask)
return mlir::failure();
auto splatMask = constantMask.getValue().dyn_cast<SplatElementsAttr>();
if (!splatMask)
return mlir::failure();
if (splatMask.getSplatValue<IntegerAttr>().getValue() == true) {
// mask = splat(1)
rewriter.replaceOpWithNewOp<triton::LoadOp>(
loadOp, loadOp.getType(), loadOp.ptr(), Value(), Value(),
loadOp.cache(), loadOp.evict(), loadOp.isVolatile());
} else {
// mask = splat(0)
// If there's no "other", the value is "undef". Perhaps we want to
// optimize it in the future.x
auto otherVal = loadOp.other();
if (!otherVal)
return mlir::failure();
rewriter.replaceOp(loadOp, otherVal);
}
return mlir::success();
}
};
void triton::LoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<CanonicalizeMaskedLoadPattern>(context);
}
// store(ptr, value, splat(1), ...) -> store(ptr, value, ...)
// store(ptr, value, splat(0), ...) -> [none]
struct CanonicalizeMaskedStorePattern
: public mlir::OpRewritePattern<triton::StoreOp> {
CanonicalizeMaskedStorePattern(mlir::MLIRContext *context)
: OpRewritePattern<triton::StoreOp>(context, 1) {}
mlir::LogicalResult
matchAndRewrite(triton::StoreOp storeOp,
mlir::PatternRewriter &rewriter) const override {
auto mask = storeOp.mask();
if (!mask)
return mlir::failure();
auto constantMask = llvm::dyn_cast<arith::ConstantOp>(mask.getDefiningOp());
if (!constantMask)
return mlir::failure();
auto splatMask = constantMask.getValue().dyn_cast<SplatElementsAttr>();
if (!splatMask)
return mlir::failure();
if (splatMask.getSplatValue<IntegerAttr>().getValue() == true) {
// mask = splat(1)
rewriter.replaceOpWithNewOp<triton::StoreOp>(storeOp, storeOp.ptr(),
storeOp.value());
} else {
// mask = splat(0)
rewriter.eraseOp(storeOp);
}
return mlir::success();
}
};
void triton::StoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<CanonicalizeMaskedStorePattern>(context);
}
#define GEN_PASS_CLASSES
#include "triton/Dialect/Triton/Transforms/Passes.h.inc"