[CI] run clang-format (#24)
This commit is contained in:
@@ -17,21 +17,23 @@ namespace {
|
||||
class CombineDotOp : public mlir::RewritePattern {
|
||||
public:
|
||||
CombineDotOp(mlir::MLIRContext *context)
|
||||
: mlir::RewritePattern(mlir::RewritePattern::MatchAnyOpTypeTag(), 1, context) {}
|
||||
mlir::LogicalResult matchAndRewrite(mlir::Operation *op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
: mlir::RewritePattern(mlir::RewritePattern::MatchAnyOpTypeTag(), 1,
|
||||
context) {}
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::Operation *op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
if (llvm::isa<mlir::arith::AddIOp, mlir::arith::AddFOp>(op)) {
|
||||
if (isCandidate(op->getOperand(0)).succeeded()) {
|
||||
auto dotOp = op->getOperand(0).getDefiningOp<mlir::triton::DotOp>();
|
||||
rewriter.replaceOpWithNewOp<mlir::triton::DotOp>(
|
||||
op, dotOp->getResultTypes().front(), dotOp.a(),
|
||||
dotOp.b(), op->getOperand(1), dotOp.allowTF32());
|
||||
op, dotOp->getResultTypes().front(), dotOp.a(), dotOp.b(),
|
||||
op->getOperand(1), dotOp.allowTF32());
|
||||
return mlir::success();
|
||||
} else if (isCandidate(op->getOperand(1)).succeeded()) {
|
||||
auto dotOp = op->getOperand(1).getDefiningOp<mlir::triton::DotOp>();
|
||||
rewriter.replaceOpWithNewOp<mlir::triton::DotOp>(
|
||||
op, dotOp->getResultTypes().front(), dotOp.a(),
|
||||
dotOp.b(), op->getOperand(0), dotOp.allowTF32());
|
||||
op, dotOp->getResultTypes().front(), dotOp.a(), dotOp.b(),
|
||||
op->getOperand(0), dotOp.allowTF32());
|
||||
return mlir::success();
|
||||
}
|
||||
}
|
||||
@@ -54,7 +56,7 @@ private:
|
||||
return true;
|
||||
// broadcast(constant_0)
|
||||
if (auto bc = val.getDefiningOp<mlir::triton::BroadcastOp>()) {
|
||||
if (mlir::matchPattern(bc.src(), mlir::m_Zero()) ||
|
||||
if (mlir::matchPattern(bc.src(), mlir::m_Zero()) ||
|
||||
mlir::matchPattern(bc.src(), mlir::m_AnyZeroFloat()))
|
||||
return true;
|
||||
}
|
||||
@@ -68,18 +70,19 @@ private:
|
||||
class CombineGEPOp : public mlir::RewritePattern {
|
||||
public:
|
||||
CombineGEPOp(mlir::MLIRContext *context)
|
||||
: mlir::RewritePattern(mlir::RewritePattern::MatchAnyOpTypeTag(), 1, context) {}
|
||||
: mlir::RewritePattern(mlir::RewritePattern::MatchAnyOpTypeTag(), 1,
|
||||
context) {}
|
||||
|
||||
mlir::LogicalResult matchAndRewrite(mlir::Operation *op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::Operation *op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
if (llvm::isa<mlir::triton::GEPOp>(op)) {
|
||||
if (auto gep2 = op->getOperand(0).getDefiningOp<mlir::triton::GEPOp>()) {
|
||||
auto loc = op->getLoc();
|
||||
mlir::Value newIdx = rewriter.create<mlir::arith::AddIOp>(
|
||||
loc, op->getOperand(1), gep2->getOperand(1));
|
||||
loc, op->getOperand(1), gep2->getOperand(1));
|
||||
rewriter.replaceOpWithNewOp<mlir::triton::GEPOp>(
|
||||
op, op->getResultTypes().front(), gep2->getOperand(0), newIdx
|
||||
);
|
||||
op, op->getResultTypes().front(), gep2->getOperand(0), newIdx);
|
||||
return mlir::success();
|
||||
}
|
||||
}
|
||||
@@ -92,20 +95,21 @@ public:
|
||||
class CombineSelectMaskedLoadOp : public mlir::RewritePattern {
|
||||
public:
|
||||
CombineSelectMaskedLoadOp(mlir::MLIRContext *context)
|
||||
: mlir::RewritePattern(mlir::RewritePattern::MatchAnyOpTypeTag(), 1, context) {}
|
||||
: mlir::RewritePattern(mlir::RewritePattern::MatchAnyOpTypeTag(), 1,
|
||||
context) {}
|
||||
|
||||
mlir::LogicalResult matchAndRewrite(mlir::Operation *op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::Operation *op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
if (llvm::isa<mlir::SelectOp>(op)) {
|
||||
if (auto load = op->getOperand(1).getDefiningOp<mlir::triton::LoadOp>()) {
|
||||
mlir::Value cond = op->getOperand(0);
|
||||
if (auto bc = load.mask().getDefiningOp<mlir::triton::BroadcastOp>()) {
|
||||
if (bc.src().getDefiningOp() == cond.getDefiningOp()) {
|
||||
rewriter.replaceOpWithNewOp<mlir::triton::LoadOp>(
|
||||
op, op->getResultTypes().front(),
|
||||
load.ptr(), load.mask(), op->getOperand(2),
|
||||
load.cache(), load.evict(), load.isVolatile()
|
||||
);
|
||||
op, op->getResultTypes().front(), load.ptr(), load.mask(),
|
||||
op->getOperand(2), load.cache(), load.evict(),
|
||||
load.isVolatile());
|
||||
return mlir::success();
|
||||
}
|
||||
}
|
||||
@@ -120,11 +124,11 @@ public:
|
||||
class CombineBroadcastConstantOp : public mlir::RewritePattern {
|
||||
public:
|
||||
CombineBroadcastConstantOp(mlir::MLIRContext *context)
|
||||
: mlir::RewritePattern(mlir::RewritePattern::MatchAnyOpTypeTag(), 1,
|
||||
context) {}
|
||||
|
||||
: mlir::RewritePattern(mlir::RewritePattern::MatchAnyOpTypeTag(), 1,
|
||||
context) {}
|
||||
|
||||
LogicalResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
PatternRewriter &rewriter) const override {
|
||||
if (auto broadcast = llvm::dyn_cast<triton::BroadcastOp>(op)) {
|
||||
if (auto cst = broadcast.src().getDefiningOp<arith::ConstantOp>()) {
|
||||
Attribute value = cst.getValue();
|
||||
@@ -132,15 +136,14 @@ public:
|
||||
if (auto denseValue = value.dyn_cast<DenseElementsAttr>()) {
|
||||
if (!denseValue.isSplat())
|
||||
return failure();
|
||||
value = DenseElementsAttr::get(resType, denseValue.getSplatValue<Attribute>());
|
||||
value = DenseElementsAttr::get(resType,
|
||||
denseValue.getSplatValue<Attribute>());
|
||||
} else {
|
||||
if (!value.isa<FloatAttr, IntegerAttr>())
|
||||
return failure();
|
||||
value = DenseElementsAttr::get(resType, value);
|
||||
}
|
||||
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
|
||||
op, value, resType
|
||||
);
|
||||
rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, value, resType);
|
||||
return success();
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user