Add Triton CombineOps

This commit is contained in:
Yan Da
2022-04-27 13:45:56 +08:00
parent 9e304cf79d
commit 74585fb970
9 changed files with 339 additions and 134 deletions

View File

@@ -131,6 +131,7 @@ def TT_EvictionPolicyAttr : I32EnumAttr<
def TT_LoadOp : TT_Op<"load",
[SameOperandsAndResultShape,
MemoryEffects<[MemRead]>,
TypesMatchWith<"infer ptr type from result type",
"result", "ptr",
"getPointerTypeFromTensor($_self)">,
@@ -161,6 +162,7 @@ def TT_LoadOp : TT_Op<"load",
def TT_StoreOp : TT_Op<"store",
[SameOperandsShape,
MemoryEffects<[MemWrite]>,
TypesMatchWith<"infer ptr type from value type",
"value", "ptr",
"getPointerTypeFromTensor($_self)">,
@@ -214,6 +216,8 @@ def TT_BroadcastOp : TT_Op<"broadcast", [SameOperandsAndResultElementType]> {
let results = (outs TT_Type:$result);
let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)";
let hasFolder = 1;
}
def TT_CatOp : TT_Op<"cat", [SameOperandsAndResultElementType]> {
@@ -259,6 +263,8 @@ def TT_DotOp : TT_Op<"dot", [NoSideEffect,
let results = (outs TT_FpIntTensor:$d);
let assemblyFormat = "$a`,` $b`,` $c attr-dict `:` type($a) `*` type($b) `->` type($d)";
// let hasCanonicalizer = 1;
}
// reduction

View File

@@ -0,0 +1,12 @@
#ifndef TRITON_TRANSFORMS_PASSES_H_
#define TRITON_TRANSFORMS_PASSES_H_
#include "mlir/Pass/Pass.h"
std::unique_ptr<Pass> createCombineOpsPass();
// // Registration
// #define GEN_PASS_REGISTRATION
// #include
#endif // TRITON_TRANSFORMS_PASSES_H_