Add Triton CombineOps
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
# add_subdirectory(codegen)
|
||||
add_subdirectory(driver)
|
||||
add_subdirectory(ir)
|
||||
# add_subdirectory(transforms)
|
||||
|
@@ -55,16 +55,12 @@ std::unique_ptr<llvm::Module> add_passes_to_emit_bin(ir::module &ir, llvm::LLVMC
|
||||
dce.run(ir);
|
||||
disassociate.run(ir);
|
||||
dce.run(ir);
|
||||
align.run(ir);
|
||||
axes.run(ir);
|
||||
layouts.run(ir);
|
||||
align.run(ir); axes.run(ir); layouts.run(ir);
|
||||
peephole.run(ir);
|
||||
dce.run(ir);
|
||||
if (target->is_gpu())
|
||||
cts.run(ir);
|
||||
align.run(ir);
|
||||
axes.run(ir);
|
||||
layouts.run(ir);
|
||||
align.run(ir); axes.run(ir); layouts.run(ir);
|
||||
coalesce.run(ir);
|
||||
dce.run(ir);
|
||||
align.run(ir);
|
||||
@@ -72,14 +68,10 @@ std::unique_ptr<llvm::Module> add_passes_to_emit_bin(ir::module &ir, llvm::LLVMC
|
||||
if (target->is_gpu())
|
||||
cts.run(ir);
|
||||
dce.run(ir);
|
||||
align.run(ir);
|
||||
axes.run(ir);
|
||||
layouts.run(ir);
|
||||
align.run(ir); axes.run(ir); layouts.run(ir);
|
||||
peephole.run(ir);
|
||||
dce.run(ir);
|
||||
align.run(ir);
|
||||
axes.run(ir);
|
||||
layouts.run(ir);
|
||||
align.run(ir); axes.run(ir); layouts.run(ir);
|
||||
swizzle.run(ir);
|
||||
liveness.run(ir);
|
||||
allocation.run(ir);
|
||||
|
@@ -95,5 +95,18 @@ void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state, ::
|
||||
state.addTypes({resultType});
|
||||
}
|
||||
|
||||
//-- DotOp --
|
||||
|
||||
//-- BroadcastOp --
|
||||
OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
|
||||
auto constOperand = src().getDefiningOp<arith::ConstantOp>();
|
||||
if (!constOperand)
|
||||
return {};
|
||||
|
||||
auto shapedType = getType().cast<ShapedType>();
|
||||
|
||||
return SplatElementsAttr::get(shapedType, {constOperand.getValue()});
|
||||
}
|
||||
|
||||
} // namespace triton
|
||||
} // namespace mlir
|
||||
|
43
lib/transforms/CombineOps.cpp
Normal file
43
lib/transforms/CombineOps.cpp
Normal file
@@ -0,0 +1,43 @@
|
||||
#include "triton/transforms/Passes.h"
|
||||
#include <memory>
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace {
|
||||
// <patterns>
|
||||
struct CombineDotOp : public RewritePattern {
|
||||
explicit CombineDotOp(MLIRContext *context)
|
||||
: RewritePattern(/*rootName*/FuncOp::getOperationName(), /*Benefit*/1, context);
|
||||
|
||||
LogicalResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
//
|
||||
|
||||
}
|
||||
};
|
||||
// </patterns>
|
||||
|
||||
/// Passes.td (?)
|
||||
struct CombineOpsPass { // : public mlir::OperationPass<FuncOp>
|
||||
void runOnOperation() override {
|
||||
MLIRContext *context = &getContext();
|
||||
ConversionTarget target(*context);
|
||||
RewritePatternSet patterns(context);
|
||||
|
||||
patterns.add<CombineDotOp>();
|
||||
patterns.add<CombineSelectMaskedLoadOp>();
|
||||
patterns.add<CombineGEPOp>();
|
||||
patterns.add<CombineReduceOp>();
|
||||
|
||||
// TODO: populate xxx Patter(?)
|
||||
|
||||
// TODO: should be use applyPartialConversion(...) ?
|
||||
if (failed(applyFullConversion(getOperation(), target, std::move(patterns))))
|
||||
signalPassFailure();
|
||||
};
|
||||
};
|
||||
} // anonymous namespace
|
||||
|
||||
std::unique_ptr<Pass> mlir::triton::createCombineOpsPass() {
|
||||
return std::make_unique<CombineOpsPass>();
|
||||
}
|
Reference in New Issue
Block a user