Add Triton CombineOps
This commit is contained in:
@@ -131,6 +131,7 @@ def TT_EvictionPolicyAttr : I32EnumAttr<
|
||||
|
||||
def TT_LoadOp : TT_Op<"load",
|
||||
[SameOperandsAndResultShape,
|
||||
MemoryEffects<[MemRead]>,
|
||||
TypesMatchWith<"infer ptr type from result type",
|
||||
"result", "ptr",
|
||||
"getPointerTypeFromTensor($_self)">,
|
||||
@@ -161,6 +162,7 @@ def TT_LoadOp : TT_Op<"load",
|
||||
|
||||
def TT_StoreOp : TT_Op<"store",
|
||||
[SameOperandsShape,
|
||||
MemoryEffects<[MemWrite]>,
|
||||
TypesMatchWith<"infer ptr type from value type",
|
||||
"value", "ptr",
|
||||
"getPointerTypeFromTensor($_self)">,
|
||||
@@ -214,6 +216,8 @@ def TT_BroadcastOp : TT_Op<"broadcast", [SameOperandsAndResultElementType]> {
|
||||
let results = (outs TT_Type:$result);
|
||||
|
||||
let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)";
|
||||
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def TT_CatOp : TT_Op<"cat", [SameOperandsAndResultElementType]> {
|
||||
@@ -259,6 +263,8 @@ def TT_DotOp : TT_Op<"dot", [NoSideEffect,
|
||||
let results = (outs TT_FpIntTensor:$d);
|
||||
|
||||
let assemblyFormat = "$a`,` $b`,` $c attr-dict `:` type($a) `*` type($b) `->` type($d)";
|
||||
|
||||
// let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
// reduction
|
||||
|
12
include/triton/transforms/Passes.h
Normal file
12
include/triton/transforms/Passes.h
Normal file
@@ -0,0 +1,12 @@
|
||||
#ifndef TRITON_TRANSFORMS_PASSES_H_
|
||||
#define TRITON_TRANSFORMS_PASSES_H_
|
||||
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
std::unique_ptr<Pass> createCombineOpsPass();
|
||||
|
||||
// // Registration
|
||||
// #define GEN_PASS_REGISTRATION
|
||||
// #include
|
||||
|
||||
#endif // TRITON_TRANSFORMS_PASSES_H_
|
Reference in New Issue
Block a user