TritonGPU combiner

This commit is contained in:
Yan Da
2022-05-16 19:16:01 +08:00
parent 0e68e6eb59
commit e3916c3a46
9 changed files with 109 additions and 118 deletions

View File

@@ -1,9 +1,9 @@
#ifndef Triton_OPS
#define Triton_OPS
#ifndef TRITON_OPS
#define TRITON_OPS
include "triton/Dialect/Triton/IR/TritonDialect.td"
include "triton/Dialect/Triton/IR/TritonTypes.td"
include "mlir/IR/EnumAttr.td"
include "triton/Dialect/Triton/IR/TritonAttrDefs.td"
include "mlir/IR/OpBase.td"
include "mlir/Interfaces/SideEffectInterfaces.td" // NoSideEffect
include "mlir/Interfaces/ControlFlowInterfaces.td" // BranchOpInterface
@@ -64,25 +64,6 @@ def TT_FpToFp : TT_Op<"fp_to_fp", [SameOperandsAndResultShape, NoSideEffect,
//
// Load/Store Ops
//
def TT_CacheModifierAttr : I32EnumAttr<
"CacheModifier", "",
[
I32EnumAttrCase<"NONE", 1, "none">,
I32EnumAttrCase<"CA", 2, "ca">,
I32EnumAttrCase<"CG", 3, "cg">,
]> {
let cppNamespace = "::mlir::triton";
}
def TT_EvictionPolicyAttr : I32EnumAttr<
"EvictionPolicy", "",
[
I32EnumAttrCase<"NORMAL", 1, "normal">,
I32EnumAttrCase<"EVICT_FIRST", 2, "evict_first">,
I32EnumAttrCase<"EVICT_LAST", 3, "evict_last">
]> {
let cppNamespace = "::mlir::triton";
}
def TT_LoadOp : TT_Op<"load",
[SameOperandsAndResultShape,
MemoryEffects<[MemRead]>,
@@ -221,45 +202,12 @@ def TT_DotOp : TT_Op<"dot", [NoSideEffect,
// let hasCanonicalizer = 1;
}
// reduction
def TT_RedOpAttr : I32EnumAttr<
/*name*/"RedOp", /*summary*/"",
/*case*/
[
I32EnumAttrCase</*sym*/"ADD", 1, /*str*/"add">,
I32EnumAttrCase<"MAX", 2, "max">,
I32EnumAttrCase<"MIN", 3, "min">,
I32EnumAttrCase<"FADD", 4, "fadd">,
I32EnumAttrCase<"FMAX", 5, "fmax">,
I32EnumAttrCase<"FMIN", 6, "fmin">,
I32EnumAttrCase<"XOR", 7, "xor">
]> {
let cppNamespace = "::mlir::triton";
}
def TT_ReduceOp : TT_Op<"reduce"> {
let summary = "reduce";
let arguments = (ins TT_RedOpAttr:$reduce_op, TT_Type:$operand, I32Attr:$axis);
}
// atomic
def TT_AtomicRMWAttr : I32EnumAttr<
"RMWOp", "",
[
I32EnumAttrCase<"AND", 1, "and">,
I32EnumAttrCase<"OR", 2, "or">,
I32EnumAttrCase<"XOR", 3, "xor">,
I32EnumAttrCase<"ADD", 4, "add">,
I32EnumAttrCase<"FADD", 5, "fadd">,
I32EnumAttrCase<"MAX", 6, "max">,
I32EnumAttrCase<"MIN", 7, "min">,
I32EnumAttrCase<"UMAX", 8, "umax">,
I32EnumAttrCase<"UMIN", 9, "umin">
]> {
let cppNamespace = "::mlir::triton";
}
def TT_AtomicRMWOp : TT_Op<"atomic_rmw"> {
let summary = "atomic rmw";

View File

@@ -1,7 +1,7 @@
#ifndef TRITONGPU_ATTRDEFS
#define TRITONGPU_ATTRDEFS
include "TritonGPUDialect.td"
include "triton/Dialect/TritonGPU/IR/TritonGPUDialect.td"
// include "mlir/IR/TensorEncoding.td"
class TritonGPU_Attr<string name, list<Trait> traits = []>

View File

@@ -5,6 +5,7 @@ include "triton/Dialect/TritonGPU/IR/TritonGPUDialect.td"
include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td"
include "mlir/Dialect/Arithmetic/IR/ArithmeticBase.td"
include "triton/Dialect/Triton/IR/TritonTypes.td"
include "triton/Dialect/Triton/IR/TritonAttrDefs.td"
include "mlir/IR/OpBase.td"
include "mlir/Interfaces/SideEffectInterfaces.td" // NoSideEffect
include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType
@@ -33,7 +34,18 @@ def TTG_AsyncWaitOp : TTG_Op<"async_wait"> {
let arguments = (ins I32Attr:$num);
}
// def TTG_CopyAsyncOp : TTG_Op<"copy_async"> {}
def TTG_CopyAsyncOp : TTG_Op<"copy_async",
[MemoryEffects<[MemRead, MemWrite]>]> {
let summary = "copy async";
let arguments = (ins TT_PtrTensor:$ptr, BoolLike:$mask, TT_Type:$other,
TT_CacheModifierAttr:$cache, TT_EvictionPolicyAttr:$evict,
BoolAttr:$isVolatile);
let results = (outs TT_Type:$result);
// let assemblyFormat = "$ptr`,` $mask`,` $other attr-dict `:` type($result)";
}
// Port Arith_CmpIOp & Arith_CmpFOp to TritonGPU.
def TTG_CmpIOp : TTG_Op<"cmpi"> {

View File

@@ -6,6 +6,12 @@
namespace mlir {
std::unique_ptr<Pass> createTritonGPUPipelinePass();
namespace triton {
namespace gpu {
std::unique_ptr<Pass> createCombineOpsPass();
}
}
// /// Generate the code for registering passes.
// #define GEN_PASS_REGISTRATION
// #include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"

View File

@@ -26,4 +26,23 @@ def TritonGPUPipeline : Pass<"tritongpu-pipeline", "mlir::ModuleOp"> {
"mlir::arith::ArithmeticDialect"];
}
def TritonGPUCombineOps : Pass<"tritongpu-combine", "mlir::ModuleOp"> {
let summary = "combine triton gpu ops";
let description = [{
convert_layout(load(%ptr, %mask, %other), #SMEM_LAYOUT) =>
copy_async(%ptr, %mask, %other), barrier
convert_layout(convert_layout(%src, #LAYOUT_0), #LAYOUT_1) =>
convert_layout(%src, #LAYOUT_1)
convert_layout(%src, #LAYOUT) => %src if %src.layout() == #LAYOUT
}];
let constructor = "mlir::triton::gpu::createCombineOpsPass";
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
"mlir::triton::TritonDialect"];
}
#endif