Add Triton CombineOps

This commit is contained in:
Yan Da
2022-04-27 13:45:56 +08:00
parent 9e304cf79d
commit 74585fb970
9 changed files with 339 additions and 134 deletions

View File

@@ -1,3 +1,4 @@
# add_subdirectory(codegen)
add_subdirectory(driver)
add_subdirectory(ir)
# add_subdirectory(transforms)

View File

@@ -55,16 +55,12 @@ std::unique_ptr<llvm::Module> add_passes_to_emit_bin(ir::module &ir, llvm::LLVMC
dce.run(ir);
disassociate.run(ir);
dce.run(ir);
align.run(ir);
axes.run(ir);
layouts.run(ir);
align.run(ir); axes.run(ir); layouts.run(ir);
peephole.run(ir);
dce.run(ir);
if (target->is_gpu())
cts.run(ir);
align.run(ir);
axes.run(ir);
layouts.run(ir);
align.run(ir); axes.run(ir); layouts.run(ir);
coalesce.run(ir);
dce.run(ir);
align.run(ir);
@@ -72,14 +68,10 @@ std::unique_ptr<llvm::Module> add_passes_to_emit_bin(ir::module &ir, llvm::LLVMC
if (target->is_gpu())
cts.run(ir);
dce.run(ir);
align.run(ir);
axes.run(ir);
layouts.run(ir);
align.run(ir); axes.run(ir); layouts.run(ir);
peephole.run(ir);
dce.run(ir);
align.run(ir);
axes.run(ir);
layouts.run(ir);
align.run(ir); axes.run(ir); layouts.run(ir);
swizzle.run(ir);
liveness.run(ir);
allocation.run(ir);

View File

@@ -95,5 +95,18 @@ void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state, ::
state.addTypes({resultType});
}
//-- DotOp --
//-- BroadcastOp --
OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
auto constOperand = src().getDefiningOp<arith::ConstantOp>();
if (!constOperand)
return {};
auto shapedType = getType().cast<ShapedType>();
return SplatElementsAttr::get(shapedType, {constOperand.getValue()});
}
} // namespace triton
} // namespace mlir

View File

@@ -0,0 +1,43 @@
#include "triton/transforms/Passes.h"
#include <memory>
using namespace mlir;
namespace {
// <patterns>
struct CombineDotOp : public RewritePattern {
explicit CombineDotOp(MLIRContext *context)
: RewritePattern(/*rootName*/FuncOp::getOperationName(), /*Benefit*/1, context);
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
//
}
};
// </patterns>
/// Passes.td (?)
struct CombineOpsPass { // : public mlir::OperationPass<FuncOp>
void runOnOperation() override {
MLIRContext *context = &getContext();
ConversionTarget target(*context);
RewritePatternSet patterns(context);
patterns.add<CombineDotOp>();
patterns.add<CombineSelectMaskedLoadOp>();
patterns.add<CombineGEPOp>();
patterns.add<CombineReduceOp>();
// TODO: populate xxx Patter(?)
// TODO: should be use applyPartialConversion(...) ?
if (failed(applyFullConversion(getOperation(), target, std::move(patterns))))
signalPassFailure();
};
};
} // anonymous namespace
std::unique_ptr<Pass> mlir::triton::createCombineOpsPass() {
return std::make_unique<CombineOpsPass>();
}