Add TritonCombineOps

This commit is contained in:
Yan Da
2022-04-27 19:28:21 +08:00
parent c70f6b666e
commit 8dfe78f6cf
20 changed files with 239 additions and 41 deletions

View File

@@ -1,10 +1,2 @@
set(LLVM_TARGET_DEFINITIONS TritonOps.td)
mlir_tablegen(Ops.h.inc -gen-op-decls)
mlir_tablegen(Ops.cpp.inc -gen-op-defs)
mlir_tablegen(Dialect.h.inc -gen-dialect-decls)
mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs)
mlir_tablegen(OpsEnums.h.inc -gen-enum-decls)
mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs)
mlir_tablegen(Types.h.inc -gen-typedef-decls)
mlir_tablegen(Types.cpp.inc -gen-typedef-defs)
add_public_tablegen_target(TritonTableGen)
add_subdirectory(IR)
add_subdirectory(Transforms)

View File

@@ -0,0 +1,10 @@
set(LLVM_TARGET_DEFINITIONS TritonOps.td)
mlir_tablegen(Ops.h.inc -gen-op-decls)
mlir_tablegen(Ops.cpp.inc -gen-op-defs)
mlir_tablegen(Dialect.h.inc -gen-dialect-decls)
mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs)
mlir_tablegen(OpsEnums.h.inc -gen-enum-decls)
mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs)
mlir_tablegen(Types.h.inc -gen-typedef-decls)
mlir_tablegen(Types.cpp.inc -gen-typedef-defs)
add_public_tablegen_target(TritonTableGen)

View File

@@ -8,11 +8,11 @@
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "triton/Dialect/Triton/Traits.h"
#include "triton/Dialect/Triton/IR/Traits.h"
#include "triton/Dialect/Triton/Dialect.h.inc"
#include "triton/Dialect/Triton/OpsEnums.h.inc"
#define GET_OP_CLASSES
#include "triton/Dialect/Triton/Ops.h.inc"
#include "triton/Dialect/Triton/IR/Ops.h.inc"
#endif // TRITON_IR_DIALECT_H_

View File

@@ -5,6 +5,6 @@
#include "mlir/IR/Types.h"
#define GET_TYPEDEF_CLASSES
#include "triton/Dialect/Triton/Types.h.inc"
#include "triton/Dialect/Triton/IR/Types.h.inc"
#endif // TRITON_IR_TYPES_H_

View File

@@ -0,0 +1,3 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls -name Triton)
add_public_tablegen_target(TritonTransformsIncGen)

View File

@@ -0,0 +1,14 @@
#ifndef TRITON_DIALECT_TRITON_TRANSFORMS_PASSES_H_
#define TRITON_DIALECT_TRITON_TRANSFORMS_PASSES_H_
#include "mlir/Pass/Pass.h"
namespace mlir {
namespace triton {
std::unique_ptr<Pass> createCombineOpsPass();
}
}
#endif

View File

@@ -0,0 +1,23 @@
#ifndef TRITON_PASSES
#define TRITON_PASSES
include "mlir/Pass/PassBase.td"
def TritonCombineOps : Pass</*cli-arg*/"triton-combine", /*Op*/"mlir::ModuleOp"> {
let summary = "combine ops";
let description = [{
dot(a, b, 0) + c => dot(a, b, c)
gep(gep(ptr, idx0), idx1) => gep(ptr, AddI(idx0, idx1))
select(cond, load(ptrs, broadcast(cond), ???), other) =>
load(ptrs, broadcast(cond), other)
}];
let constructor = "mlir::triton::createCombineOpsPass";
let dependentDialects = ["mlir::arith::ArithmeticDialect",
/*SelectOp*/"mlir::StandardOpsDialect"];
}
#endif