2022-04-27 19:28:21 +08:00
|
|
|
#include "mlir/IR/BuiltinAttributes.h"
|
|
|
|
#include "mlir/IR/Matchers.h"
|
|
|
|
#include "mlir/IR/PatternMatch.h"
|
|
|
|
#include "mlir/Pass/Pass.h"
|
|
|
|
#include "mlir/Support/LogicalResult.h"
|
|
|
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
|
|
|
|
|
|
|
#include "triton/Dialect/Triton/IR/Dialect.h"
|
|
|
|
#include "triton/Dialect/Triton/Transforms/Passes.h"
|
|
|
|
|
|
|
|
#include <memory>
|
|
|
|
|
2022-06-17 16:19:47 +08:00
|
|
|
using namespace mlir;
|
2022-04-27 19:28:21 +08:00
|
|
|
|
|
|
|
namespace {
|
|
|
|
|
2022-07-28 03:50:08 +08:00
|
|
|
bool isZero(mlir::Value val) {
|
|
|
|
if (mlir::matchPattern(val, mlir::m_Zero()) ||
|
|
|
|
mlir::matchPattern(val, mlir::m_AnyZeroFloat()))
|
|
|
|
return true;
|
|
|
|
// broadcast(constant_0)
|
|
|
|
if (auto bc = val.getDefiningOp<mlir::triton::BroadcastOp>()) {
|
|
|
|
if (mlir::matchPattern(bc.src(), mlir::m_Zero()) ||
|
|
|
|
mlir::matchPattern(bc.src(), mlir::m_AnyZeroFloat()))
|
2022-04-27 19:28:21 +08:00
|
|
|
return true;
|
|
|
|
}
|
2022-07-28 03:50:08 +08:00
|
|
|
return false;
|
|
|
|
}
|
2022-04-27 19:28:21 +08:00
|
|
|
|
2022-07-28 03:50:08 +08:00
|
|
|
bool isBroadcastConstantCombinable(Attribute value) {
|
|
|
|
if (auto denseValue = value.dyn_cast<DenseElementsAttr>()) {
|
|
|
|
return denseValue.isSplat();
|
2022-04-27 19:28:21 +08:00
|
|
|
}
|
2022-07-28 03:50:08 +08:00
|
|
|
return value.isa<FloatAttr, IntegerAttr>();
|
|
|
|
}
|
2022-04-27 19:28:21 +08:00
|
|
|
|
2022-07-28 03:50:08 +08:00
|
|
|
DenseElementsAttr getConstantValue(Builder &builder, Attribute value,
|
|
|
|
Value bcast_res) {
|
2022-04-27 19:28:21 +08:00
|
|
|
|
2022-07-28 03:50:08 +08:00
|
|
|
Type resType = bcast_res.getType();
|
|
|
|
DenseElementsAttr res;
|
|
|
|
if (auto denseValue = value.dyn_cast<DenseElementsAttr>()) {
|
|
|
|
res =
|
|
|
|
DenseElementsAttr::get(resType, denseValue.getSplatValue<Attribute>());
|
|
|
|
} else {
|
|
|
|
res = DenseElementsAttr::get(resType, value);
|
2022-04-27 19:28:21 +08:00
|
|
|
}
|
2022-07-28 03:50:08 +08:00
|
|
|
return res;
|
|
|
|
}
|
2022-06-17 16:19:47 +08:00
|
|
|
|
2022-07-28 03:50:08 +08:00
|
|
|
#include "TritonCombine.inc"
|
2022-07-26 17:25:03 -07:00
|
|
|
|
2022-04-27 19:28:21 +08:00
|
|
|
} // anonymous namespace
|
|
|
|
|
|
|
|
#define GEN_PASS_CLASSES
|
|
|
|
#include "triton/Dialect/Triton/Transforms/Passes.h.inc"
|
|
|
|
|
|
|
|
class CombineOpsPass : public TritonCombineOpsBase<CombineOpsPass> {
|
|
|
|
public:
|
|
|
|
void runOnOperation() override {
|
|
|
|
mlir::MLIRContext *context = &getContext();
|
|
|
|
mlir::RewritePatternSet patterns(context);
|
|
|
|
mlir::ModuleOp m = getOperation();
|
|
|
|
|
2022-07-28 03:50:08 +08:00
|
|
|
// Dot Add %{
|
|
|
|
patterns.add<CombineDotAddIPattern>(context);
|
|
|
|
patterns.add<CombineDotAddFPattern>(context);
|
|
|
|
patterns.add<CombineDotAddIRevPattern>(context);
|
|
|
|
patterns.add<CombineDotAddFRevPattern>(context);
|
|
|
|
// %}
|
|
|
|
patterns.add<CombineSelectMaskedLoadPattern>(context);
|
|
|
|
patterns.add<CombineGEPPattern>(context);
|
|
|
|
patterns.add<CombineBroadcastConstantPattern>(context);
|
2022-04-27 19:28:21 +08:00
|
|
|
|
|
|
|
if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed())
|
|
|
|
signalPassFailure();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
|
|
|
std::unique_ptr<mlir::Pass> mlir::triton::createCombineOpsPass() {
|
|
|
|
return std::make_unique<CombineOpsPass>();
|
|
|
|
}
|