Add TritonCombineOps
This commit is contained in:
@@ -1,10 +1,2 @@
|
||||
set(LLVM_TARGET_DEFINITIONS TritonOps.td)
|
||||
mlir_tablegen(Ops.h.inc -gen-op-decls)
|
||||
mlir_tablegen(Ops.cpp.inc -gen-op-defs)
|
||||
mlir_tablegen(Dialect.h.inc -gen-dialect-decls)
|
||||
mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs)
|
||||
mlir_tablegen(OpsEnums.h.inc -gen-enum-decls)
|
||||
mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs)
|
||||
mlir_tablegen(Types.h.inc -gen-typedef-decls)
|
||||
mlir_tablegen(Types.cpp.inc -gen-typedef-defs)
|
||||
add_public_tablegen_target(TritonTableGen)
|
||||
add_subdirectory(IR)
|
||||
add_subdirectory(Transforms)
|
||||
|
10
include/triton/Dialect/Triton/IR/CMakeLists.txt
Normal file
10
include/triton/Dialect/Triton/IR/CMakeLists.txt
Normal file
@@ -0,0 +1,10 @@
|
||||
set(LLVM_TARGET_DEFINITIONS TritonOps.td)
|
||||
mlir_tablegen(Ops.h.inc -gen-op-decls)
|
||||
mlir_tablegen(Ops.cpp.inc -gen-op-defs)
|
||||
mlir_tablegen(Dialect.h.inc -gen-dialect-decls)
|
||||
mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs)
|
||||
mlir_tablegen(OpsEnums.h.inc -gen-enum-decls)
|
||||
mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs)
|
||||
mlir_tablegen(Types.h.inc -gen-typedef-decls)
|
||||
mlir_tablegen(Types.cpp.inc -gen-typedef-defs)
|
||||
add_public_tablegen_target(TritonTableGen)
|
@@ -8,11 +8,11 @@
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
|
||||
#include "triton/Dialect/Triton/Traits.h"
|
||||
#include "triton/Dialect/Triton/IR/Traits.h"
|
||||
#include "triton/Dialect/Triton/Dialect.h.inc"
|
||||
#include "triton/Dialect/Triton/OpsEnums.h.inc"
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "triton/Dialect/Triton/Ops.h.inc"
|
||||
#include "triton/Dialect/Triton/IR/Ops.h.inc"
|
||||
|
||||
#endif // TRITON_IR_DIALECT_H_
|
@@ -5,6 +5,6 @@
|
||||
#include "mlir/IR/Types.h"
|
||||
|
||||
#define GET_TYPEDEF_CLASSES
|
||||
#include "triton/Dialect/Triton/Types.h.inc"
|
||||
#include "triton/Dialect/Triton/IR/Types.h.inc"
|
||||
|
||||
#endif // TRITON_IR_TYPES_H_
|
3
include/triton/Dialect/Triton/Transforms/CMakeLists.txt
Normal file
3
include/triton/Dialect/Triton/Transforms/CMakeLists.txt
Normal file
@@ -0,0 +1,3 @@
|
||||
set(LLVM_TARGET_DEFINITIONS Passes.td)
|
||||
mlir_tablegen(Passes.h.inc -gen-pass-decls -name Triton)
|
||||
add_public_tablegen_target(TritonTransformsIncGen)
|
14
include/triton/Dialect/Triton/Transforms/Passes.h
Normal file
14
include/triton/Dialect/Triton/Transforms/Passes.h
Normal file
@@ -0,0 +1,14 @@
|
||||
#ifndef TRITON_DIALECT_TRITON_TRANSFORMS_PASSES_H_
|
||||
#define TRITON_DIALECT_TRITON_TRANSFORMS_PASSES_H_
|
||||
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace triton {
|
||||
|
||||
std::unique_ptr<Pass> createCombineOpsPass();
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
23
include/triton/Dialect/Triton/Transforms/Passes.td
Normal file
23
include/triton/Dialect/Triton/Transforms/Passes.td
Normal file
@@ -0,0 +1,23 @@
|
||||
#ifndef TRITON_PASSES
|
||||
#define TRITON_PASSES
|
||||
|
||||
include "mlir/Pass/PassBase.td"
|
||||
|
||||
def TritonCombineOps : Pass</*cli-arg*/"triton-combine", /*Op*/"mlir::ModuleOp"> {
|
||||
let summary = "combine ops";
|
||||
let description = [{
|
||||
dot(a, b, 0) + c => dot(a, b, c)
|
||||
|
||||
gep(gep(ptr, idx0), idx1) => gep(ptr, AddI(idx0, idx1))
|
||||
|
||||
select(cond, load(ptrs, broadcast(cond), ???), other) =>
|
||||
load(ptrs, broadcast(cond), other)
|
||||
}];
|
||||
|
||||
let constructor = "mlir::triton::createCombineOpsPass";
|
||||
|
||||
let dependentDialects = ["mlir::arith::ArithmeticDialect",
|
||||
/*SelectOp*/"mlir::StandardOpsDialect"];
|
||||
}
|
||||
|
||||
#endif
|
Reference in New Issue
Block a user