TritonGPU combiner
This commit is contained in:
@@ -1,9 +1,9 @@
|
||||
#ifndef Triton_OPS
|
||||
#define Triton_OPS
|
||||
#ifndef TRITON_OPS
|
||||
#define TRITON_OPS
|
||||
|
||||
include "triton/Dialect/Triton/IR/TritonDialect.td"
|
||||
include "triton/Dialect/Triton/IR/TritonTypes.td"
|
||||
include "mlir/IR/EnumAttr.td"
|
||||
include "triton/Dialect/Triton/IR/TritonAttrDefs.td"
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td" // NoSideEffect
|
||||
include "mlir/Interfaces/ControlFlowInterfaces.td" // BranchOpInterface
|
||||
@@ -64,25 +64,6 @@ def TT_FpToFp : TT_Op<"fp_to_fp", [SameOperandsAndResultShape, NoSideEffect,
|
||||
//
|
||||
// Load/Store Ops
|
||||
//
|
||||
def TT_CacheModifierAttr : I32EnumAttr<
|
||||
"CacheModifier", "",
|
||||
[
|
||||
I32EnumAttrCase<"NONE", 1, "none">,
|
||||
I32EnumAttrCase<"CA", 2, "ca">,
|
||||
I32EnumAttrCase<"CG", 3, "cg">,
|
||||
]> {
|
||||
let cppNamespace = "::mlir::triton";
|
||||
}
|
||||
def TT_EvictionPolicyAttr : I32EnumAttr<
|
||||
"EvictionPolicy", "",
|
||||
[
|
||||
I32EnumAttrCase<"NORMAL", 1, "normal">,
|
||||
I32EnumAttrCase<"EVICT_FIRST", 2, "evict_first">,
|
||||
I32EnumAttrCase<"EVICT_LAST", 3, "evict_last">
|
||||
]> {
|
||||
let cppNamespace = "::mlir::triton";
|
||||
}
|
||||
|
||||
def TT_LoadOp : TT_Op<"load",
|
||||
[SameOperandsAndResultShape,
|
||||
MemoryEffects<[MemRead]>,
|
||||
@@ -221,45 +202,12 @@ def TT_DotOp : TT_Op<"dot", [NoSideEffect,
|
||||
// let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
// reduction
|
||||
def TT_RedOpAttr : I32EnumAttr<
|
||||
/*name*/"RedOp", /*summary*/"",
|
||||
/*case*/
|
||||
[
|
||||
I32EnumAttrCase</*sym*/"ADD", 1, /*str*/"add">,
|
||||
I32EnumAttrCase<"MAX", 2, "max">,
|
||||
I32EnumAttrCase<"MIN", 3, "min">,
|
||||
I32EnumAttrCase<"FADD", 4, "fadd">,
|
||||
I32EnumAttrCase<"FMAX", 5, "fmax">,
|
||||
I32EnumAttrCase<"FMIN", 6, "fmin">,
|
||||
I32EnumAttrCase<"XOR", 7, "xor">
|
||||
]> {
|
||||
let cppNamespace = "::mlir::triton";
|
||||
}
|
||||
|
||||
def TT_ReduceOp : TT_Op<"reduce"> {
|
||||
let summary = "reduce";
|
||||
|
||||
let arguments = (ins TT_RedOpAttr:$reduce_op, TT_Type:$operand, I32Attr:$axis);
|
||||
}
|
||||
|
||||
// atomic
|
||||
def TT_AtomicRMWAttr : I32EnumAttr<
|
||||
"RMWOp", "",
|
||||
[
|
||||
I32EnumAttrCase<"AND", 1, "and">,
|
||||
I32EnumAttrCase<"OR", 2, "or">,
|
||||
I32EnumAttrCase<"XOR", 3, "xor">,
|
||||
I32EnumAttrCase<"ADD", 4, "add">,
|
||||
I32EnumAttrCase<"FADD", 5, "fadd">,
|
||||
I32EnumAttrCase<"MAX", 6, "max">,
|
||||
I32EnumAttrCase<"MIN", 7, "min">,
|
||||
I32EnumAttrCase<"UMAX", 8, "umax">,
|
||||
I32EnumAttrCase<"UMIN", 9, "umin">
|
||||
]> {
|
||||
let cppNamespace = "::mlir::triton";
|
||||
}
|
||||
|
||||
def TT_AtomicRMWOp : TT_Op<"atomic_rmw"> {
|
||||
let summary = "atomic rmw";
|
||||
|
||||
|
@@ -1,7 +1,7 @@
|
||||
#ifndef TRITONGPU_ATTRDEFS
|
||||
#define TRITONGPU_ATTRDEFS
|
||||
|
||||
include "TritonGPUDialect.td"
|
||||
include "triton/Dialect/TritonGPU/IR/TritonGPUDialect.td"
|
||||
// include "mlir/IR/TensorEncoding.td"
|
||||
|
||||
class TritonGPU_Attr<string name, list<Trait> traits = []>
|
||||
|
@@ -5,6 +5,7 @@ include "triton/Dialect/TritonGPU/IR/TritonGPUDialect.td"
|
||||
include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td"
|
||||
include "mlir/Dialect/Arithmetic/IR/ArithmeticBase.td"
|
||||
include "triton/Dialect/Triton/IR/TritonTypes.td"
|
||||
include "triton/Dialect/Triton/IR/TritonAttrDefs.td"
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td" // NoSideEffect
|
||||
include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType
|
||||
@@ -33,7 +34,18 @@ def TTG_AsyncWaitOp : TTG_Op<"async_wait"> {
|
||||
let arguments = (ins I32Attr:$num);
|
||||
}
|
||||
|
||||
// def TTG_CopyAsyncOp : TTG_Op<"copy_async"> {}
|
||||
def TTG_CopyAsyncOp : TTG_Op<"copy_async",
|
||||
[MemoryEffects<[MemRead, MemWrite]>]> {
|
||||
let summary = "copy async";
|
||||
|
||||
let arguments = (ins TT_PtrTensor:$ptr, BoolLike:$mask, TT_Type:$other,
|
||||
TT_CacheModifierAttr:$cache, TT_EvictionPolicyAttr:$evict,
|
||||
BoolAttr:$isVolatile);
|
||||
|
||||
let results = (outs TT_Type:$result);
|
||||
|
||||
// let assemblyFormat = "$ptr`,` $mask`,` $other attr-dict `:` type($result)";
|
||||
}
|
||||
|
||||
// Port Arith_CmpIOp & Arith_CmpFOp to TritonGPU.
|
||||
def TTG_CmpIOp : TTG_Op<"cmpi"> {
|
||||
|
@@ -6,6 +6,12 @@
|
||||
namespace mlir {
|
||||
std::unique_ptr<Pass> createTritonGPUPipelinePass();
|
||||
|
||||
namespace triton {
|
||||
namespace gpu {
|
||||
std::unique_ptr<Pass> createCombineOpsPass();
|
||||
}
|
||||
}
|
||||
|
||||
// /// Generate the code for registering passes.
|
||||
// #define GEN_PASS_REGISTRATION
|
||||
// #include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
|
||||
|
@@ -26,4 +26,23 @@ def TritonGPUPipeline : Pass<"tritongpu-pipeline", "mlir::ModuleOp"> {
|
||||
"mlir::arith::ArithmeticDialect"];
|
||||
}
|
||||
|
||||
def TritonGPUCombineOps : Pass<"tritongpu-combine", "mlir::ModuleOp"> {
|
||||
let summary = "combine triton gpu ops";
|
||||
|
||||
let description = [{
|
||||
convert_layout(load(%ptr, %mask, %other), #SMEM_LAYOUT) =>
|
||||
copy_async(%ptr, %mask, %other), barrier
|
||||
|
||||
convert_layout(convert_layout(%src, #LAYOUT_0), #LAYOUT_1) =>
|
||||
convert_layout(%src, #LAYOUT_1)
|
||||
|
||||
convert_layout(%src, #LAYOUT) => %src if %src.layout() == #LAYOUT
|
||||
}];
|
||||
|
||||
let constructor = "mlir::triton::gpu::createCombineOpsPass";
|
||||
|
||||
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
|
||||
"mlir::triton::TritonDialect"];
|
||||
}
|
||||
|
||||
#endif
|
||||
|
Reference in New Issue
Block a user