diff --git a/include/triton/Dialect/Triton/IR/TritonAttrDefs.td b/include/triton/Dialect/Triton/IR/TritonAttrDefs.td new file mode 100644 index 000000000..ac50d2d1a --- /dev/null +++ b/include/triton/Dialect/Triton/IR/TritonAttrDefs.td @@ -0,0 +1,59 @@ +#ifndef TRITON_ATTR_DEFS +#define TRITON_ATTR_DEFS + +include "mlir/IR/EnumAttr.td" + +// Attrs for LoadOp +def TT_CacheModifierAttr : I32EnumAttr< + "CacheModifier", "", + [ + I32EnumAttrCase<"NONE", 1, "none">, + I32EnumAttrCase<"CA", 2, "ca">, + I32EnumAttrCase<"CG", 3, "cg">, + ]> { + let cppNamespace = "::mlir::triton"; +} +def TT_EvictionPolicyAttr : I32EnumAttr< + "EvictionPolicy", "", + [ + I32EnumAttrCase<"NORMAL", 1, "normal">, + I32EnumAttrCase<"EVICT_FIRST", 2, "evict_first">, + I32EnumAttrCase<"EVICT_LAST", 3, "evict_last"> + ]> { + let cppNamespace = "::mlir::triton"; +} + +// reduction +def TT_RedOpAttr : I32EnumAttr< + /*name*/"RedOp", /*summary*/"", + /*case*/ + [ + I32EnumAttrCase, + I32EnumAttrCase<"MAX", 2, "max">, + I32EnumAttrCase<"MIN", 3, "min">, + I32EnumAttrCase<"FADD", 4, "fadd">, + I32EnumAttrCase<"FMAX", 5, "fmax">, + I32EnumAttrCase<"FMIN", 6, "fmin">, + I32EnumAttrCase<"XOR", 7, "xor"> + ]> { + let cppNamespace = "::mlir::triton"; +} + +// atomic +def TT_AtomicRMWAttr : I32EnumAttr< + "RMWOp", "", + [ + I32EnumAttrCase<"AND", 1, "and">, + I32EnumAttrCase<"OR", 2, "or">, + I32EnumAttrCase<"XOR", 3, "xor">, + I32EnumAttrCase<"ADD", 4, "add">, + I32EnumAttrCase<"FADD", 5, "fadd">, + I32EnumAttrCase<"MAX", 6, "max">, + I32EnumAttrCase<"MIN", 7, "min">, + I32EnumAttrCase<"UMAX", 8, "umax">, + I32EnumAttrCase<"UMIN", 9, "umin"> + ]> { + let cppNamespace = "::mlir::triton"; +} + +#endif diff --git a/lib/Dialect/TritonGPU/Transforms/Combine.cpp b/lib/Dialect/TritonGPU/Transforms/Combine.cpp new file mode 100644 index 000000000..ada418f5c --- /dev/null +++ b/lib/Dialect/TritonGPU/Transforms/Combine.cpp @@ -0,0 +1,50 @@ +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" + +#include + +using namespace mlir; + +static bool isSharedLayout(Value v) { + if (auto tensorType = v.getType().dyn_cast()) { + Attribute encoding = tensorType.getEncoding(); + return encoding.isa(); + } + return false; +} + +namespace { +#include "TritonGPUCombine.inc" +} + +#define GEN_PASS_CLASSES +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +class TritonGPUCombineOpsPass + : public TritonGPUCombineOpsBase { +public: + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp m = getOperation(); + + mlir::RewritePatternSet patterns(context); + + patterns.add(context); + patterns.add(context); + patterns.add(context); + + if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) + signalPassFailure(); + } +}; + +std::unique_ptr triton::gpu::createCombineOpsPass() { + return std::make_unique(); +} diff --git a/lib/Dialect/TritonGPU/Transforms/Combine.td b/lib/Dialect/TritonGPU/Transforms/Combine.td new file mode 100644 index 000000000..4ecd9f73a --- /dev/null +++ b/lib/Dialect/TritonGPU/Transforms/Combine.td @@ -0,0 +1,25 @@ +#ifndef TRITONGPU_PATTERNS +#define TRITONGPU_PATTERNS + +include "triton/Dialect/TritonGPU/IR/TritonGPUOps.td" +include "triton/Dialect/Triton/IR/TritonOps.td" + +// convert_layout(load(...), #L) => copy_async(...); barrier +// if #L is smem_layout +def CopyAsyncOptPattern : Pat< + (TTG_ConvertLayoutOp:$res (TT_LoadOp $ptr, $mask, $other, $cache, $evict, $isVolatile)), + (TTG_CopyAsyncOp $ptr, $mask, $other, $cache, $evict, $isVolatile), + [(Constraint> $res)]>; + +// ConvertLayout(ConvertLayout(x, #L0), #L1) => ConvertLayout(x, #L1) +def ConvertLayoutOptPattern : Pat< + (TTG_ConvertLayoutOp (TTG_ConvertLayoutOp $x)), + (TTG_ConvertLayoutOp $x)>; + +// TODO: can we replace this with ConvertLayoutOp's folder? +// ConvertLayout(x, #L) => x if x.layout() == #L +def RedundantConvertLayoutOptPattern : Pat< + (TTG_ConvertLayoutOp:$res $x), (replaceWithValue $x), + [(Constraint> $res, $x)]>; + +#endif