From 623c99609f2ec3a64abbe5d01756c4b4f5682276 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Tue, 11 Oct 2022 18:16:41 -0700 Subject: [PATCH] [Triton-IR] Added type inference and verifier for Triton-IR operations (#767) --- .../triton/Dialect/Triton/IR/CMakeLists.txt | 13 +- include/triton/Dialect/Triton/IR/Dialect.h | 20 ++ include/triton/Dialect/Triton/IR/Interfaces.h | 9 + include/triton/Dialect/Triton/IR/Traits.h | 71 ++--- .../triton/Dialect/Triton/IR/TritonDialect.td | 3 + .../Dialect/Triton/IR/TritonInterfaces.td | 6 + include/triton/Dialect/Triton/IR/TritonOps.td | 300 ++++++++++-------- include/triton/Dialect/TritonGPU/IR/Dialect.h | 1 + .../Dialect/TritonGPU/IR/TritonGPUAttrDefs.td | 14 +- .../Dialect/TritonGPU/Transforms/Passes.td | 11 - .../TritonToTritonGPU/TritonToTritonGPU.cpp | 14 +- lib/Dialect/Triton/IR/CMakeLists.txt | 2 + lib/Dialect/Triton/IR/Dialect.cpp | 1 + lib/Dialect/Triton/IR/Interfaces.cpp | 0 lib/Dialect/Triton/IR/Ops.cpp | 72 ++++- lib/Dialect/Triton/IR/Traits.cpp | 69 ++++ lib/Dialect/TritonGPU/IR/Dialect.cpp | 58 ++-- .../TritonGPU/Transforms/CMakeLists.txt | 1 - lib/Dialect/TritonGPU/Transforms/Verifier.cpp | 106 ------- python/src/triton.cc | 4 - python/triton/compiler.py | 1 - test/Analysis/test-allocation.mlir | 3 +- test/Analysis/test-membar.mlir | 4 +- test/Conversion/tritongpu_to_llvm.mlir | 30 +- test/TritonGPU/coalesce.mlir | 11 +- test/TritonGPU/combine.mlir | 16 +- test/TritonGPU/loop-pipeline.mlir | 2 +- 27 files changed, 494 insertions(+), 348 deletions(-) create mode 100644 include/triton/Dialect/Triton/IR/Interfaces.h create mode 100644 include/triton/Dialect/Triton/IR/TritonInterfaces.td create mode 100644 lib/Dialect/Triton/IR/Interfaces.cpp create mode 100644 lib/Dialect/Triton/IR/Traits.cpp delete mode 100644 lib/Dialect/TritonGPU/Transforms/Verifier.cpp diff --git a/include/triton/Dialect/Triton/IR/CMakeLists.txt b/include/triton/Dialect/Triton/IR/CMakeLists.txt index 46573add6..81af1dff1 100644 --- a/include/triton/Dialect/Triton/IR/CMakeLists.txt +++ b/include/triton/Dialect/Triton/IR/CMakeLists.txt @@ -1,10 +1,19 @@ set(LLVM_TARGET_DEFINITIONS TritonOps.td) mlir_tablegen(Ops.h.inc -gen-op-decls) mlir_tablegen(Ops.cpp.inc -gen-op-defs) -mlir_tablegen(Dialect.h.inc -gen-dialect-decls) -mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs) mlir_tablegen(OpsEnums.h.inc -gen-enum-decls) mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs) + +set(LLVM_TARGET_DEFINITIONS TritonDialect.td) +mlir_tablegen(Dialect.h.inc -gen-dialect-decls) +mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs) + +set(LLVM_TARGET_DEFINITIONS TritonOps.td) mlir_tablegen(Types.h.inc -gen-typedef-decls) mlir_tablegen(Types.cpp.inc -gen-typedef-defs) + +set(LLVM_TARGET_DEFINITIONS TritonInterfaces.td) +mlir_tablegen(AttrInterfaces.h.inc -gen-attr-interface-decls) +mlir_tablegen(AttrInterfaces.cpp.inc -gen-attr-interface-defs) + add_public_tablegen_target(TritonTableGen) diff --git a/include/triton/Dialect/Triton/IR/Dialect.h b/include/triton/Dialect/Triton/IR/Dialect.h index 6ce3bd76a..c211d287a 100644 --- a/include/triton/Dialect/Triton/IR/Dialect.h +++ b/include/triton/Dialect/Triton/IR/Dialect.h @@ -17,4 +17,24 @@ #define GET_OP_CLASSES #include "triton/Dialect/Triton/IR/Ops.h.inc" +namespace mlir { +namespace triton { + +class DialectInferLayoutInterface + : public DialectInterface::Base { +public: + DialectInferLayoutInterface(Dialect *dialect) : Base(dialect) {} + + virtual LogicalResult + inferReduceOpEncoding(Attribute operandEncoding, int axis, + Attribute &resultEncoding) const = 0; + + virtual LogicalResult + inferExpandDimsOpEncoding(Attribute operandEncoding, int axis, + Attribute &resultEncoding) const = 0; +}; + +} // namespace triton +} // namespace mlir + #endif // TRITON_IR_DIALECT_H_ diff --git a/include/triton/Dialect/Triton/IR/Interfaces.h b/include/triton/Dialect/Triton/IR/Interfaces.h new file mode 100644 index 000000000..f8f3a6f74 --- /dev/null +++ b/include/triton/Dialect/Triton/IR/Interfaces.h @@ -0,0 +1,9 @@ +#ifndef TRITON_IR_INTERFACES_H_ +#define TRITON_IR_INTERFACES_H_ + +#include "mlir/IR/OpDefinition.h" + +#define GET_TYPEDEF_CLASSES +#include "triton/Dialect/Triton/IR/AttrInterfaces.h.inc" + +#endif // TRITON_IR_TYPES_H_ diff --git a/include/triton/Dialect/Triton/IR/Traits.h b/include/triton/Dialect/Triton/IR/Traits.h index fd20236f1..e83a8e3b9 100644 --- a/include/triton/Dialect/Triton/IR/Traits.h +++ b/include/triton/Dialect/Triton/IR/Traits.h @@ -10,50 +10,47 @@ namespace mlir { namespace OpTrait { -// TODO: should have `namespace triton {}` here + +// These functions are out-of-line implementations of the methods in the +// corresponding trait classes. This avoids them being template +// instantiated/duplicated. +namespace impl { +LogicalResult verifySameOperandsAndResultEncoding(Operation *op); +LogicalResult verifySameOperandsEncoding(Operation *op); +// The rationale for this trait is to prevent users from creating programs +// that would have catastrophic register pressure and cause the compiler to +// hang. +// Since H100 has 256KB registers, we should allow users to create tensors +// of size up to 256K elements. It will spill for datatypes wider than 1B, +// but we probably should limit number of elements (rather than bytes) to +// keep specs simple +int constexpr maxTensorNumElements = 1048576; +LogicalResult verifyTensorSize(Operation *op); +} // namespace impl template class TensorSizeTrait : public TraitBase { public: - // TODO: move impl to .cc files static LogicalResult verifyTrait(Operation *op) { - // The rationale for this number is to prevent users from creating programs - // that would have catastrophic register pressure and cause the compiler to - // hang. - // Since H100 has 256KB registers, we should allow users to create tensors - // of size up to 256K elements. It will spill for datatypes wider than 1B, - // but we probably should limit number of elements (rather than bytes) to - // keep specs simple - int constexpr maxElement = 1048576; - for (auto opType : op->getOperandTypes()) { - if (auto tensorType = opType.dyn_cast()) { - int64_t numElements = 1; - for (int64_t s : tensorType.getShape()) - numElements *= s; - if (numElements > maxElement) - return op->emitError("Maximum allowed number of elements is ") - << maxElement << ", but " << *op << " has more than that"; - if ((numElements & (numElements - 1)) != 0) - return op->emitError("Number of elements must be power-of-two, but ") - << *op << " doesn't follow the rule"; - } - } + return impl::verifyTensorSize(op); + } +}; - for (auto opType : op->getResultTypes()) { - if (auto tensorType = opType.dyn_cast()) { - int64_t numElements = 1; - for (int64_t s : tensorType.getShape()) - numElements *= s; - if (numElements > maxElement) - return op->emitError("Maximum allowed number of elements is ") - << maxElement << ", but " << *op << " has more than that"; - if ((numElements & (numElements - 1)) != 0) - return op->emitError("Number of elements must be power-of-two, but ") - << *op << " doesn't follow the rule"; - } - } +template +class SameOperandsAndResultEncoding + : public TraitBase { +public: + static LogicalResult verifyTrait(Operation *op) { + return impl::verifySameOperandsAndResultEncoding(op); + } +}; - return success(); +template +class SameOperandsEncoding + : public TraitBase { +public: + static LogicalResult verifyTrait(Operation *op) { + return impl::verifySameOperandsEncoding(op); } }; diff --git a/include/triton/Dialect/Triton/IR/TritonDialect.td b/include/triton/Dialect/Triton/IR/TritonDialect.td index eb5d5c8f2..ea82bedd8 100644 --- a/include/triton/Dialect/Triton/IR/TritonDialect.td +++ b/include/triton/Dialect/Triton/IR/TritonDialect.td @@ -41,4 +41,7 @@ def Triton_Dialect : Dialect { let hasConstantMaterializer = 1; } +include "triton/Dialect/Triton/IR/TritonTypes.td" + + #endif // TRITON_DIALECT diff --git a/include/triton/Dialect/Triton/IR/TritonInterfaces.td b/include/triton/Dialect/Triton/IR/TritonInterfaces.td new file mode 100644 index 000000000..5c05a9d1f --- /dev/null +++ b/include/triton/Dialect/Triton/IR/TritonInterfaces.td @@ -0,0 +1,6 @@ +#ifndef TRITON_INTERFACES +#define TRITON_INTERFACES + +include "mlir/IR/OpBase.td" + +#endif // TRITON_INTERFACES \ No newline at end of file diff --git a/include/triton/Dialect/Triton/IR/TritonOps.td b/include/triton/Dialect/Triton/IR/TritonOps.td index ce74c8f5e..c65d92be6 100644 --- a/include/triton/Dialect/Triton/IR/TritonOps.td +++ b/include/triton/Dialect/Triton/IR/TritonOps.td @@ -4,18 +4,23 @@ include "triton/Dialect/Triton/IR/TritonDialect.td" include "triton/Dialect/Triton/IR/TritonTypes.td" include "triton/Dialect/Triton/IR/TritonAttrDefs.td" +include "triton/Dialect/Triton/IR/TritonInterfaces.td" include "mlir/IR/OpBase.td" include "mlir/Interfaces/SideEffectInterfaces.td" // NoSideEffect include "mlir/Interfaces/ControlFlowInterfaces.td" // BranchOpInterface include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType - +include "mlir/Interfaces/SideEffectInterfaces.td" // NoSideEffect def TensorSizeTrait : NativeOpTrait<"TensorSizeTrait">; +def SameOperandsAndResultEncoding : NativeOpTrait<"SameOperandsAndResultEncoding">; +def SameOperandsEncoding : NativeOpTrait<"SameOperandsEncoding">; + // // Op Base // class TT_Op traits = []> : - Op; + Op { +} // // CastOps @@ -25,7 +30,9 @@ class TT_Op traits = []> : // fptoui, fptosi, uitofp, sitofp, // extf, tructf, // extui, extsi, tructi -def TT_IntToPtrOp : TT_Op<"int_to_ptr", [SameOperandsAndResultShape, NoSideEffect, +def TT_IntToPtrOp : TT_Op<"int_to_ptr", [SameOperandsAndResultShape, + SameOperandsAndResultEncoding, + NoSideEffect, /*DeclareOpInterfaceMethods*/]> { let summary = "Cast int64 to pointer"; @@ -36,7 +43,9 @@ def TT_IntToPtrOp : TT_Op<"int_to_ptr", [SameOperandsAndResultShape, NoSideEffec let assemblyFormat = "$from attr-dict `:` type($from) `->` type($result)"; } -def TT_PtrToIntOp : TT_Op<"ptr_to_int", [SameOperandsAndResultShape, NoSideEffect, +def TT_PtrToIntOp : TT_Op<"ptr_to_int", [SameOperandsAndResultShape, + SameOperandsAndResultEncoding, + NoSideEffect, /*DeclareOpInterfaceMethods*/]> { let summary = "Cast pointer to int64"; @@ -47,7 +56,9 @@ def TT_PtrToIntOp : TT_Op<"ptr_to_int", [SameOperandsAndResultShape, NoSideEffec let assemblyFormat = "$from attr-dict `:` type($from) `->` type($result)"; } -def TT_FpToFp : TT_Op<"fp_to_fp", [SameOperandsAndResultShape, NoSideEffect, +def TT_FpToFp : TT_Op<"fp_to_fp", [SameOperandsAndResultShape, + SameOperandsAndResultEncoding, + NoSideEffect, /*DeclareOpInterfaceMethods*/]> { let summary = "Floating point casting for custom types"; @@ -67,11 +78,33 @@ def TT_FpToFp : TT_Op<"fp_to_fp", [SameOperandsAndResultShape, NoSideEffect, // TODO: We need a verifier here. } +// +// Pointer Arith Ops +// + +def TT_AddPtrOp : TT_Op<"addptr", + [NoSideEffect, + SameOperandsAndResultShape, + SameOperandsAndResultEncoding, + TypesMatchWith<"result type matches ptr type", + "result", "ptr", "$_self">, + TypesMatchWith<"result shape matches offset shape", + "result", "offset", + "getI32SameShape($_self)">]> { + let arguments = (ins TT_PtrLike:$ptr, TT_I32Like:$offset); + + let results = (outs TT_PtrLike:$result); + + let assemblyFormat = "$ptr `,` $offset attr-dict `:` type($result)"; +} + + // // Load/Store Ops // def TT_LoadOp : TT_Op<"load", [SameOperandsAndResultShape, + SameOperandsAndResultEncoding, SameVariadicOperandSize, MemoryEffects<[MemRead]>, TypesMatchWith<"infer ptr type from result type", @@ -110,6 +143,7 @@ def TT_LoadOp : TT_Op<"load", def TT_StoreOp : TT_Op<"store", [SameOperandsShape, + SameOperandsEncoding, MemoryEffects<[MemWrite]>, TypesMatchWith<"infer ptr type from value type", "value", "ptr", @@ -133,130 +167,11 @@ def TT_StoreOp : TT_Op<"store", let hasCanonicalizer = 1; } -def TT_AddPtrOp : TT_Op<"addptr", - [NoSideEffect, SameOperandsAndResultShape, - TypesMatchWith<"result type matches ptr type", - "result", "ptr", "$_self">, - TypesMatchWith<"result shape matches offset shape", - "result", "offset", - "getI32SameShape($_self)">]> { - let arguments = (ins TT_PtrLike:$ptr, TT_I32Like:$offset); - - let results = (outs TT_PtrLike:$result); - - let assemblyFormat = "$ptr `,` $offset attr-dict `:` type($result)"; -} - - // -// Shape Manipulation Ops +// Atomic Op // -def TT_ExpandDimsOp : TT_Op<"expand_dims", [NoSideEffect, SameOperandsAndResultElementType]> { - let summary = "expand_dims"; - - let arguments = (ins TT_Tensor:$src, I32Attr:$axis); - - let results = (outs TT_Tensor:$result); - - let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)"; -} - -def TT_ViewOp : TT_Op<"view", [NoSideEffect, SameOperandsAndResultElementType]> { - let summary = "view"; - - let arguments = (ins TT_Tensor:$src); - - let results = (outs TT_Tensor:$result); - - let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)"; -} - -def TT_SplatOp : TT_Op<"splat", [NoSideEffect, SameOperandsAndResultElementType]> { - let summary = "splat"; - - let arguments = (ins TT_Type:$src); - - let results = (outs TT_Tensor:$result); - - let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)"; - - let hasFolder = 1; -} - -def TT_BroadcastOp : TT_Op<"broadcast", [NoSideEffect, SameOperandsAndResultElementType]> { - let summary = "broadcast. No left-padding as of now."; - - let arguments = (ins TT_Type:$src); - - let results = (outs TT_Type:$result); - - let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)"; - - let hasFolder = 1; -} - -def TT_CatOp : TT_Op<"cat", [NoSideEffect, SameOperandsAndResultElementType]> { - let summary = "concatenate 2 tensors"; - - let arguments = (ins TT_Tensor:$lhs, TT_Tensor:$rhs); - - let results = (outs TT_Tensor:$result); - - let assemblyFormat = "$lhs `,` $rhs attr-dict `:` functional-type(operands, results)"; -} - -// -// builtin Ops -// -def TT_GetProgramIdOp : TT_Op<"get_program_id"> { - let arguments = (ins I32Attr:$axis); - - let results = (outs I32:$result); - - let assemblyFormat = "attr-dict `:` type($result)"; -} - -def TT_GetNumProgramsOp : TT_Op<"get_num_programs"> { - let arguments = (ins I32Attr:$axis); - - let results = (outs I32:$result); - - let assemblyFormat = "attr-dict `:` type($result)"; -} - -def TT_DotOp : TT_Op<"dot", [NoSideEffect, - TypesMatchWith<"result's type matches accumulator's type", - "d", "c", "$_self">]> { - let summary = "dot"; - - let description = [{ - $d = matrix_multiply($a, $b) + $c - }]; - - let arguments = (ins TT_FpIntTensor:$a, TT_FpIntTensor:$b, TT_FpIntTensor:$c, BoolAttr:$allowTF32); - - let results = (outs TT_FpIntTensor:$d); - - let assemblyFormat = "$a`,` $b`,` $c attr-dict `:` type($a) `*` type($b) `->` type($d)"; - - // let hasCanonicalizer = 1; -} - -def TT_ReduceOp : TT_Op<"reduce", [NoSideEffect]> { - let summary = "reduce"; - - let arguments = (ins TT_RedOpAttr:$redOp, TT_Tensor:$operand, I32Attr:$axis); - - let results = (outs TT_Tensor:$result); - - let builders = [ - OpBuilder<(ins "triton::RedOp":$redOp, "Value":$operand, "int":$axis)>, - ]; - - let assemblyFormat = "$operand attr-dict `:` type($operand) `->` type($result)"; -} - -def TT_AtomicRMWOp : TT_Op<"atomic_rmw"> { +def TT_AtomicRMWOp : TT_Op<"atomic_rmw", [SameOperandsAndResultShape, + SameOperandsAndResultEncoding]> { let summary = "atomic rmw"; let description = [{ @@ -271,7 +186,8 @@ def TT_AtomicRMWOp : TT_Op<"atomic_rmw"> { let results = (outs TT_Type:$result); } -def TT_AtomicCASOp : TT_Op<"atomic_cas"> { +def TT_AtomicCASOp : TT_Op<"atomic_cas", [SameOperandsAndResultShape, + SameOperandsAndResultEncoding]> { let summary = "atomic cas"; let description = [{ @@ -289,10 +205,133 @@ def TT_AtomicCASOp : TT_Op<"atomic_cas"> { let results = (outs TT_Type:$result); } + // -// External Function Ops +// Shape Manipulation Ops +// +def TT_SplatOp : TT_Op<"splat", [NoSideEffect, + SameOperandsAndResultElementType]> { + let summary = "splat"; + + let arguments = (ins TT_Type:$src); + + let results = (outs TT_Tensor:$result); + + let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)"; + + let hasFolder = 1; +} + +def TT_ExpandDimsOp : TT_Op<"expand_dims", [NoSideEffect, + DeclareOpInterfaceMethods, + SameOperandsAndResultElementType]> { + let summary = "expand_dims"; + + let arguments = (ins TT_Tensor:$src, I32Attr:$axis); + + let results = (outs TT_Tensor:$result); + + let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)"; +} + +def TT_ViewOp : TT_Op<"view", [NoSideEffect, + SameOperandsAndResultElementType]> { + let summary = "view"; + + let arguments = (ins TT_Tensor:$src); + + let results = (outs TT_Tensor:$result); + + let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)"; + +} + +def TT_BroadcastOp : TT_Op<"broadcast", [NoSideEffect, + SameOperandsAndResultElementType]> { + let summary = "broadcast. No left-padding as of now."; + + let arguments = (ins TT_Type:$src); + + let results = (outs TT_Type:$result); + + let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)"; + + let hasFolder = 1; +} + +def TT_CatOp : TT_Op<"cat", [NoSideEffect, + SameOperandsAndResultElementType]> { + let summary = "concatenate 2 tensors"; + + let arguments = (ins TT_Tensor:$lhs, TT_Tensor:$rhs); + + let results = (outs TT_Tensor:$result); + + let assemblyFormat = "$lhs `,` $rhs attr-dict `:` functional-type(operands, results)"; +} + +// +// SPMD Ops +// +def TT_GetProgramIdOp : TT_Op<"get_program_id"> { + let arguments = (ins I32Attr:$axis); + + let results = (outs I32:$result); + + let assemblyFormat = "attr-dict `:` type($result)"; +} + +def TT_GetNumProgramsOp : TT_Op<"get_num_programs"> { + let arguments = (ins I32Attr:$axis); + + let results = (outs I32:$result); + + let assemblyFormat = "attr-dict `:` type($result)"; +} + +// +// Dot Op +// +def TT_DotOp : TT_Op<"dot", [NoSideEffect, + DeclareOpInterfaceMethods, + TypesMatchWith<"result's type matches accumulator's type", + "d", "c", "$_self">]> { + let summary = "dot"; + + let description = [{ + $d = matrix_multiply($a, $b) + $c + }]; + + let arguments = (ins TT_FpIntTensor:$a, TT_FpIntTensor:$b, TT_FpIntTensor:$c, BoolAttr:$allowTF32); + + let results = (outs TT_FpIntTensor:$d); + + let assemblyFormat = "$a`,` $b`,` $c attr-dict `:` type($a) `*` type($b) `->` type($d)"; +} + +// +// Reduce Op +// +def TT_ReduceOp : TT_Op<"reduce", [NoSideEffect, + DeclareOpInterfaceMethods]> { + let summary = "reduce"; + + let arguments = (ins TT_RedOpAttr:$redOp, TT_Tensor:$operand, I32Attr:$axis); + + let results = (outs TT_Tensor:$result); + + let builders = [ + OpBuilder<(ins "triton::RedOp":$redOp, "Value":$operand, "int":$axis)>, + ]; + + let assemblyFormat = "$operand attr-dict `:` type($operand) `->` type($result)"; +} + +// +// External elementwise op // def TT_ExtElemwiseOp : TT_Op<"ext_elemwise", [NoSideEffect, Elementwise, SameOperandsAndResultShape, + SameOperandsAndResultEncoding, SameVariadicOperandSize]> { let summary = "ext_elemwise"; @@ -307,10 +346,11 @@ def TT_ExtElemwiseOp : TT_Op<"ext_elemwise", [NoSideEffect, Elementwise, SameOpe let results = (outs TT_Tensor:$result); let assemblyFormat = "operands attr-dict `:` type(operands) `->` type($result)"; + } // -// Intrinsics +// Make Range Op // // TODO: should have ConstantLike as Trait def TT_MakeRangeOp : TT_Op<"make_range", [NoSideEffect]> { diff --git a/include/triton/Dialect/TritonGPU/IR/Dialect.h b/include/triton/Dialect/TritonGPU/IR/Dialect.h index cc62c0f3a..13b488052 100644 --- a/include/triton/Dialect/TritonGPU/IR/Dialect.h +++ b/include/triton/Dialect/TritonGPU/IR/Dialect.h @@ -11,6 +11,7 @@ #include "triton/Dialect/TritonGPU/IR/Dialect.h.inc" #define GET_ATTRDEF_CLASSES +#include "triton/Dialect/Triton/IR/AttrInterfaces.h.inc" #include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.h.inc" #define GET_OP_CLASSES diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td index 930b35dac..5964c7c37 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td @@ -2,6 +2,7 @@ #define TRITONGPU_ATTRDEFS include "triton/Dialect/TritonGPU/IR/TritonGPUDialect.td" +include "triton/Dialect/Triton/IR/TritonInterfaces.td" //===----------------------------------------------------------------------===// // TritonGPU Attribute Definitions @@ -34,6 +35,7 @@ Right now, Triton implements two classes of layouts: shared, and distributed. code extraBaseClassDeclaration = [{ unsigned getElemsPerThread(ArrayRef shape) const; + ::mlir::LogicalResult verifyLayoutForArg(::mlir::Operation* op, unsigned argNo) const; }]; } @@ -301,15 +303,15 @@ def SliceEncodingAttr : DistributedEncoding<"SliceEncoding"> { TODO: improve docs A = [x x x x x x x x] - [x x x x x x x x] - L_parent = [0 1 2 3 ] - [4 5 6 7 ] - [8 9 10 11] - [12 13 14 15] + + parent = [0 1 2 3 ] + [4 5 6 7 ] + [8 9 10 11] + [12 13 14 15] dim = 0 Then the data of A would be distributed as follow between the 16 CUDA threads: - L(A) = [ {0,4,8,12} , {1,5,9,13} , ... {3,7,11,15} ] + L(A) = [ {0,4,8,12} , {1,5,9,13} , ... {3,7,11,15}, {0,4,8,12} , ..., {3,7,11,15} ] This is useful for constructing the inverse layout of an expand_dims operation during some optimization passes. diff --git a/include/triton/Dialect/TritonGPU/Transforms/Passes.td b/include/triton/Dialect/TritonGPU/Transforms/Passes.td index c03ce486c..493f9afd7 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/Passes.td +++ b/include/triton/Dialect/TritonGPU/Transforms/Passes.td @@ -76,15 +76,4 @@ def TritonGPUCanonicalizeLoops: Pass<"tritongpu-canonicalize-loops", "mlir::Modu let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect"]; } -def TritonGPUVerifier : Pass<"tritongpu-verifier", "mlir::ModuleOp"> { - let summary = "verify TritonGPU IR"; - - let description = [{}]; - - let constructor = "mlir::createTritonGPUVerifier()"; - - let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", - "mlir::gpu::GPUDialect"]; -} - #endif diff --git a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp index de0c409e4..788d20eaa 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp @@ -185,9 +185,16 @@ struct TritonExpandDimsPattern // return type RankedTensorType retType = RankedTensorType::get(retShape, argType.getElementType(), retEncoding); + // convert operand to slice of return type + Attribute newArgEncoding = triton::gpu::SliceEncodingAttr::get( + getContext(), op.axis(), retEncoding); + RankedTensorType newArgType = RankedTensorType::get( + argType.getShape(), argType.getElementType(), newArgEncoding); // construct new op - rewriter.replaceOpWithNewOp( - op, retType, adaptor.src(), adaptor.axis()); + auto newSrc = rewriter.create( + op.getLoc(), newArgType, adaptor.src()); + rewriter.replaceOpWithNewOp(op, newSrc, + adaptor.axis()); return success(); } }; @@ -310,9 +317,8 @@ struct TritonReducePattern : public OpConversionPattern { LogicalResult matchAndRewrite(triton::ReduceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - Type retType = this->getTypeConverter()->convertType(op.getType()); auto newOp = rewriter.replaceOpWithNewOp( - op, retType, adaptor.redOp(), adaptor.operand(), adaptor.axis()); + op, adaptor.redOp(), adaptor.operand(), adaptor.axis()); return success(); } }; diff --git a/lib/Dialect/Triton/IR/CMakeLists.txt b/lib/Dialect/Triton/IR/CMakeLists.txt index 2fa15a9b9..2d679b21f 100644 --- a/lib/Dialect/Triton/IR/CMakeLists.txt +++ b/lib/Dialect/Triton/IR/CMakeLists.txt @@ -1,7 +1,9 @@ add_mlir_dialect_library(TritonIR + Interfaces.cpp Dialect.cpp Ops.cpp Types.cpp + Traits.cpp DEPENDS TritonTableGen diff --git a/lib/Dialect/Triton/IR/Dialect.cpp b/lib/Dialect/Triton/IR/Dialect.cpp index 4b286e5b8..14fb30c21 100644 --- a/lib/Dialect/Triton/IR/Dialect.cpp +++ b/lib/Dialect/Triton/IR/Dialect.cpp @@ -1,6 +1,7 @@ #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Types.h" +#include "triton/Dialect/Triton/IR/AttrInterfaces.h.inc" #include "llvm/ADT/StringSwitch.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/raw_ostream.h" diff --git a/lib/Dialect/Triton/IR/Interfaces.cpp b/lib/Dialect/Triton/IR/Interfaces.cpp new file mode 100644 index 000000000..e69de29bb diff --git a/lib/Dialect/Triton/IR/Ops.cpp b/lib/Dialect/Triton/IR/Ops.cpp index 4628f5d1e..9a47829c5 100644 --- a/lib/Dialect/Triton/IR/Ops.cpp +++ b/lib/Dialect/Triton/IR/Ops.cpp @@ -160,18 +160,88 @@ void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state, } //-- DotOp -- +mlir::LogicalResult mlir::triton::DotOp::inferReturnTypes( + MLIRContext *context, Optional location, ValueRange operands, + DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + // type is the same as the accumulator + auto accTy = operands[2].getType().cast(); + inferredReturnTypes.push_back(accTy); + return mlir::success(); +} + +//-- ReduceOp -- +mlir::LogicalResult mlir::triton::ReduceOp::inferReturnTypes( + MLIRContext *context, Optional location, ValueRange operands, + DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + // infer shape + Value arg = operands[0]; + auto argTy = arg.getType().cast(); + auto retShape = argTy.getShape().vec(); + int axis = attributes.get("axis").cast().getInt(); + retShape.erase(retShape.begin() + axis); + // infer encoding + Attribute argEncoding = argTy.getEncoding(); + Attribute retEncoding; + if (argEncoding) { + Dialect &dialect = argEncoding.getDialect(); + auto inferLayoutInterface = dyn_cast(&dialect); + if (inferLayoutInterface + ->inferReduceOpEncoding(argEncoding, axis, retEncoding) + .failed()) { + llvm::report_fatal_error("failed to infer layout for ReduceOp"); + return mlir::failure(); + } + } + // create type + auto argEltTy = argTy.getElementType(); + inferredReturnTypes.push_back( + RankedTensorType::get(retShape, argEltTy, retEncoding)); + return mlir::success(); +} //-- SplatOp -- OpFoldResult SplatOp::fold(ArrayRef operands) { auto constOperand = src().getDefiningOp(); if (!constOperand) return {}; - auto shapedType = getType().cast(); auto ret = SplatElementsAttr::get(shapedType, {constOperand.getValue()}); return ret; } +//-- ExpandDimsOp -- +mlir::LogicalResult mlir::triton::ExpandDimsOp::inferReturnTypes( + MLIRContext *context, Optional location, ValueRange operands, + DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + // infer shape + auto arg = operands[0]; + auto argTy = arg.getType().cast(); + auto retShape = argTy.getShape().vec(); + int axis = attributes.get("axis").cast().getInt(); + retShape.insert(retShape.begin() + axis, 1); + // infer encoding + Attribute argEncoding = argTy.getEncoding(); + Attribute retEncoding; + if (argEncoding) { + Dialect &dialect = argEncoding.getDialect(); + auto inferLayoutInterface = dyn_cast(&dialect); + if (inferLayoutInterface + ->inferExpandDimsOpEncoding(argEncoding, axis, retEncoding) + .failed()) { + llvm::report_fatal_error("failed to infer layout for ExpandDimsOp"); + return mlir::failure(); + } + } + // create type + auto argEltTy = argTy.getElementType(); + inferredReturnTypes.push_back( + RankedTensorType::get(retShape, argEltTy, retEncoding)); + return mlir::success(); +} + //-- BroadcastOp -- OpFoldResult BroadcastOp::fold(ArrayRef operands) { auto constOperand = src().getDefiningOp(); diff --git a/lib/Dialect/Triton/IR/Traits.cpp b/lib/Dialect/Triton/IR/Traits.cpp new file mode 100644 index 000000000..c3ce7e9d4 --- /dev/null +++ b/lib/Dialect/Triton/IR/Traits.cpp @@ -0,0 +1,69 @@ +#include "triton/Dialect/Triton/IR/Traits.h" + +static mlir::LogicalResult verifySameEncoding(mlir::Type tyA, mlir::Type tyB) { + using namespace mlir; + auto encA = tyA.dyn_cast(); + auto encB = tyA.dyn_cast(); + if (!encA || !encB) + return success(); + return encA.getEncoding() == encB.getEncoding() ? success() : failure(); +} + +mlir::LogicalResult +mlir::OpTrait::impl::verifySameOperandsAndResultEncoding(Operation *op) { + if (failed(verifyAtLeastNOperands(op, 1)) || + failed(verifyAtLeastNResults(op, 1))) + return failure(); + + auto type = op->getOperand(0).getType(); + for (auto resultType : op->getResultTypes()) + if (failed(verifySameEncoding(resultType, type))) + return op->emitOpError() + << "requires the same shape for all operands and results"; + return verifySameOperandsEncoding(op); +} + +mlir::LogicalResult +mlir::OpTrait::impl::verifySameOperandsEncoding(Operation *op) { + if (failed(verifyAtLeastNOperands(op, 1))) + return failure(); + + auto type = op->getOperand(0).getType(); + for (auto opType : llvm::drop_begin(op->getOperandTypes(), 1)) + if (failed(verifySameEncoding(opType, type))) + return op->emitOpError() << "requires the same encoding for all operands"; + + return success(); +} + +mlir::LogicalResult mlir::OpTrait::impl::verifyTensorSize(Operation *op) { + for (auto opType : op->getOperandTypes()) { + if (auto tensorType = opType.dyn_cast()) { + int64_t numElements = 1; + for (int64_t s : tensorType.getShape()) + numElements *= s; + if (numElements > maxTensorNumElements) + return op->emitError("Maximum allowed number of elements is ") + << maxTensorNumElements << ", but " << *op + << " has more than that"; + if ((numElements & (numElements - 1)) != 0) + return op->emitError("Number of elements must be power-of-two, but ") + << *op << " doesn't follow the rule"; + } + } + for (auto opType : op->getResultTypes()) { + if (auto tensorType = opType.dyn_cast()) { + int64_t numElements = 1; + for (int64_t s : tensorType.getShape()) + numElements *= s; + if (numElements > maxTensorNumElements) + return op->emitError("Maximum allowed number of elements is ") + << maxTensorNumElements << ", but " << *op + << " has more than that"; + if ((numElements & (numElements - 1)) != 0) + return op->emitError("Number of elements must be power-of-two, but ") + << *op << " doesn't follow the rule"; + } + } + return success(); +} \ No newline at end of file diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 4ee053b19..f9da54a64 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -3,6 +3,7 @@ #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/OpImplementation.h" #include "triton/Analysis/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "llvm/ADT/TypeSwitch.h" @@ -288,8 +289,9 @@ Attribute BlockedEncodingAttr::parse(AsmParser &parser, Type type) { } } - return parser.getChecked( + auto ret = parser.getChecked( parser.getContext(), sizePerThread, threadsPerWarp, warpsPerCTA, order); + return ret; } void BlockedEncodingAttr::print(mlir::AsmPrinter &printer) const { @@ -346,27 +348,13 @@ void MmaEncodingAttr::print(AsmPrinter &printer) const { Attribute SliceEncodingAttr::parse(AsmParser &parser, Type type) { if (parser.parseLess().failed()) return {}; - // Parse the data as a dictionary - DictionaryAttr dict; - if (parser.parseAttribute(dict).failed()) + NamedAttrList attrs; + if (parser.parseOptionalAttrDict(attrs).failed()) return {}; if (parser.parseGreater().failed()) return {}; - - unsigned dim = 0; - Attribute parent; - - for (const NamedAttribute &attr : dict) { - if (attr.getName() == "dim") { - if (parseUInt(parser, attr, dim, "dim").failed()) - return {}; - } - if (attr.getName() == "parent") { - if (parser.parseAttribute(parent).failed()) - return {}; - } - } - + unsigned dim = attrs.get("dim").cast().getInt(); + Attribute parent = attrs.get("parent"); return parser.getChecked(parser.getContext(), dim, parent); } @@ -522,6 +510,35 @@ public: } }; +struct TritonGPUInferLayoutInterface + : public triton::DialectInferLayoutInterface { + using DialectInferLayoutInterface::DialectInferLayoutInterface; + + LogicalResult inferReduceOpEncoding(Attribute operandEncoding, int axis, + Attribute &resultEncoding) const { + resultEncoding = SliceEncodingAttr::get(getDialect()->getContext(), axis, + operandEncoding); + return success(); + } + + LogicalResult inferExpandDimsOpEncoding(Attribute operandEncoding, int axis, + Attribute &resultEncoding) const { + auto sliceEncoding = operandEncoding.dyn_cast(); + if (!sliceEncoding) { + llvm::report_fatal_error( + "ExpandDimsOp operand encoding must be SliceEncodingAttr"); + return failure(); + } + if (sliceEncoding.getDim() != axis) { + llvm::report_fatal_error( + "Incompatible slice dimension for ExpandDimsOp operand"); + return failure(); + } + resultEncoding = sliceEncoding.getParent(); + return success(); + } +}; + void TritonGPUDialect::initialize() { addAttributes< #define GET_ATTRDEF_LIST @@ -532,6 +549,7 @@ void TritonGPUDialect::initialize() { #include "triton/Dialect/TritonGPU/IR/Ops.cpp.inc" >(); addInterfaces(); + addInterfaces(); } //===----------------------------------------------------------------------===// @@ -568,4 +586,4 @@ LogicalResult TritonGPUDialect::verifyOperationAttribute(Operation *op, NamedAttribute attr) { // TODO: fill this. return success(); -} +} \ No newline at end of file diff --git a/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt b/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt index 9f15374ef..464a1aabf 100644 --- a/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt +++ b/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt @@ -8,7 +8,6 @@ add_mlir_dialect_library(TritonGPUTransforms Combine.cpp Pipeline.cpp Swizzle.cpp - Verifier.cpp TritonGPUConversion.cpp DEPENDS diff --git a/lib/Dialect/TritonGPU/Transforms/Verifier.cpp b/lib/Dialect/TritonGPU/Transforms/Verifier.cpp deleted file mode 100644 index 1acdf915c..000000000 --- a/lib/Dialect/TritonGPU/Transforms/Verifier.cpp +++ /dev/null @@ -1,106 +0,0 @@ -#include "triton/Dialect/TritonGPU/IR/Dialect.h" -#include "triton/Dialect/TritonGPU/Transforms/Passes.h" - -#include - -using namespace mlir; - -#define GEN_PASS_CLASSES -#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" - -class TritonGPUVerifier : public TritonGPUVerifierBase { -public: - void runOnOperation() override { - MLIRContext *context = &getContext(); - ModuleOp m = getOperation(); - - // The idea is similar to mlir/lib/IR/Verifier.cpp - verifyImpl(m.getOperation()); - } - -private: - LogicalResult verifySingleOp(Operation *op) { - if (auto dotOp = llvm::dyn_cast(op)) { - Type aType = dotOp.a().getType(); - Type bType = dotOp.b().getType(); - Type cType = dotOp.c().getType(); - Type dType = dotOp.d().getType(); - for (auto it : llvm::zip(llvm::SmallVector{aType, bType}, - llvm::SmallVector{'a', 'b'})) { - Type type = std::get<0>(it); - char name = std::get<1>(it); - if (auto tensorType = type.dyn_cast()) { - Attribute encoding = tensorType.getEncoding(); - if (!encoding) - return dotOp.emitError() << name << " should have encoding"; - if (!encoding.isa()) - return dotOp.emitError() << name << " should be of shared layout"; - } else - return dotOp.emitError() - << name << "'s type should be of RankedTensorType"; - } - - Attribute cLayout; - for (auto it : llvm::zip(llvm::SmallVector{cType, dType}, - llvm::SmallVector{'c', 'd'})) { - Type type = std::get<0>(it); - char name = std::get<1>(it); - if (auto tensorType = type.dyn_cast()) { - Attribute encoding = tensorType.getEncoding(); - if (!encoding) - return dotOp.emitError() << name << " should have encoding"; - if (!encoding.isa() && - !encoding.isa()) - return dotOp.emitError() - << name << " should be of distributed layout"; - if (name == 'c') - cLayout = encoding; - else if (encoding != cLayout) - return dotOp.emitError() << "d & c should have the same layout"; - } else - return dotOp.emitError() - << name << "'s type should be of RankedTensorType"; - } - - // signalPassFailure(); - } - if (auto loadOp = llvm::dyn_cast(op)) { - // TODO: fill this - } - if (auto storeOp = llvm::dyn_cast(op)) { - // TODO: fill this - } - if (auto addptrOp = llvm::dyn_cast(op)) { - // TODO: fill this - } - // Triton builtin Ops - if (llvm::isa(op)) { - // TODO: fill this - } - if (auto atomicRmw = llvm::dyn_cast(op)) { - // TODO: fill this - } - if (auto atomicCas = llvm::dyn_cast(op)) { - // TODO: fill this - } - - // TODO: Arithmetic, SCF, TritonGPU ops - return success(); - } - - void verifyImpl(Operation *op) { - if (verifySingleOp(op).failed()) - signalPassFailure(); - - // verify that all child regions are ok - for (Region ®ion : op->getRegions()) - for (Block &block : region) - for (Operation &childOp : block) - verifyImpl(&childOp); - } -}; - -std::unique_ptr mlir::createTritonGPUVerifier() { - return std::make_unique(); -} diff --git a/python/src/triton.cc b/python/src/triton.cc index 6512546c8..e3947a3f8 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -1182,10 +1182,6 @@ void init_triton_ir(py::module &&m) { [](mlir::PassManager &self) { self.addPass(mlir::createTritonGPUCombineOpsPass()); }) - .def("add_triton_gpu_verifier_pass", - [](mlir::PassManager &self) { - self.addPass(mlir::createTritonGPUVerifier()); - }) .def("add_triton_gpu_to_llvm", [](mlir::PassManager &self) { self.addPass(mlir::triton::createConvertTritonGPUToLLVMPass()); diff --git a/python/triton/compiler.py b/python/triton/compiler.py index 46ca7fd16..6a3ccdb97 100644 --- a/python/triton/compiler.py +++ b/python/triton/compiler.py @@ -861,7 +861,6 @@ def optimize_tritongpu_ir(mod, num_stages): pm.add_cse_pass() pm.add_coalesce_pass() pm.add_triton_gpu_combine_pass() - pm.add_triton_gpu_verifier_pass() pm.run(mod) return mod diff --git a/test/Analysis/test-allocation.mlir b/test/Analysis/test-allocation.mlir index 19d403dbb..8b88f1787 100644 --- a/test/Analysis/test-allocation.mlir +++ b/test/Analysis/test-allocation.mlir @@ -1,6 +1,7 @@ // RUN: triton-opt %s -split-input-file --mlir-disable-threading -test-print-allocation 2>&1 | FileCheck %s #AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#sliceAd0 = #triton_gpu.slice<{dim = 0, parent = #AL}> #BL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> #A = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> #B = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> @@ -164,7 +165,7 @@ func @alloc(%A : !tt.ptr) { func @scratch() { %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> // CHECK: scratch offset = 0, size = 512 - %b = tt.reduce %cst0 {redOp = 1 : i32, axis = 0 : i32} : tensor<16x16xf16, #AL> -> tensor<16xf16, #AL> + %b = tt.reduce %cst0 {redOp = 1 : i32, axis = 0 : i32} : tensor<16x16xf16, #AL> -> tensor<16xf16, #sliceAd0> return // CHECK-NEXT: size = 512 } diff --git a/test/Analysis/test-membar.mlir b/test/Analysis/test-membar.mlir index 1f7449c78..e4caf6294 100644 --- a/test/Analysis/test-membar.mlir +++ b/test/Analysis/test-membar.mlir @@ -1,6 +1,7 @@ // RUN: triton-opt %s -split-input-file --mlir-disable-threading -test-print-membar 2>&1 | FileCheck %s #AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#sliceAd0 = #triton_gpu.slice<{dim = 0, parent = #AL}> #BL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> #A = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> #B = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> @@ -69,7 +70,8 @@ func @scratch() { // CHECK: Membar 1 %a = tt.cat %cst0, %cst0 {axis = 0} : (tensor<16x16xf16, #A>, tensor<16x16xf16, #A>) -> tensor<32x16xf16, #A> // CHECK-NEXT: Membar 3 - %b = tt.reduce %a {redOp = 1 : i32, axis = 0 : i32} : tensor<32x16xf16, #A> -> tensor<16xf16, #A> + %aa = triton_gpu.convert_layout %a : (tensor<32x16xf16, #A>) -> tensor<32x16xf16, #AL> + %b = tt.reduce %aa {redOp = 1 : i32, axis = 0 : i32} : tensor<32x16xf16, #AL> -> tensor<16xf16, #sliceAd0> return } diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index 2c8583fed..1c182e79a 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -348,15 +348,17 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { #block1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [8], warpsPerCTA = [4], order = [0]}> #block2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 1], warpsPerCTA = [4, 1], order = [1, 0]}> #block3 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 8], warpsPerCTA = [1, 4], order = [1, 0]}> +#slice2d1 = #triton_gpu.slice<{dim = 1, parent=#block2}> +#slice3d0 = #triton_gpu.slice<{dim = 0, parent=#block3}> #AL = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> #A = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 4, order = [1, 0]}> module attributes {"triton_gpu.num-warps" = 4 : i32} { // CHECK-LABEL: basic_insert_slice_async_v4 func @basic_insert_slice_async_v4(%arg0: !tt.ptr {tt.divisibility = 4 : i32}) { - %off0_ = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #block0> - %off1_ = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<64xi32, #block1> - %off0 = tt.expand_dims %off0_ {axis = 1 : i32} : (tensor<16xi32, #block0>) -> tensor<16x1xi32, #block2> - %off1 = tt.expand_dims %off1_ {axis = 0 : i32} : (tensor<64xi32, #block1>) -> tensor<1x64xi32, #block3> + %off0_ = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #slice2d1> + %off1_ = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<64xi32, #slice3d0> + %off0 = tt.expand_dims %off0_ {axis = 1 : i32} : (tensor<16xi32, #slice2d1>) -> tensor<16x1xi32, #block2> + %off1 = tt.expand_dims %off1_ {axis = 0 : i32} : (tensor<64xi32, #slice3d0>) -> tensor<1x64xi32, #block3> %broadcast_off0_scalar = tt.broadcast %off0 : (tensor<16x1xi32, #block2>) -> tensor<16x64xi32, #block2> %cst_scalar = arith.constant 64 : i32 %cst = tt.splat %cst_scalar : (i32) -> tensor<16x64xi32, #block2> @@ -387,15 +389,17 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { #block1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [8], warpsPerCTA = [4], order = [0]}> #block2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 1], warpsPerCTA = [4, 1], order = [1, 0]}> #block3 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 8], warpsPerCTA = [1, 4], order = [1, 0]}> +#slice2d1 = #triton_gpu.slice<{dim = 1, parent=#block2}> +#slice3d0 = #triton_gpu.slice<{dim = 0, parent=#block3}> #AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> #A = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 4, order = [1, 0]}> module attributes {"triton_gpu.num-warps" = 4 : i32} { // CHECK-LABEL: basic_insert_slice_async_v1 func @basic_insert_slice_async_v1(%arg0: !tt.ptr {tt.divisibility = 4 : i32}) { - %off0_ = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #block0> - %off1_ = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #block1> - %off0 = tt.expand_dims %off0_ {axis = 1 : i32} : (tensor<16xi32, #block0>) -> tensor<16x1xi32, #block2> - %off1 = tt.expand_dims %off1_ {axis = 0 : i32} : (tensor<32xi32, #block1>) -> tensor<1x32xi32, #block3> + %off0_ = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #slice2d1> + %off1_ = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #slice3d0> + %off0 = tt.expand_dims %off0_ {axis = 1 : i32} : (tensor<16xi32, #slice2d1>) -> tensor<16x1xi32, #block2> + %off1 = tt.expand_dims %off1_ {axis = 0 : i32} : (tensor<32xi32, #slice3d0>) -> tensor<1x32xi32, #block3> %broadcast_off0_scalar = tt.broadcast %off0 : (tensor<16x1xi32, #block2>) -> tensor<16x32xi32, #block2> %cst_scalar = arith.constant 32 : i32 %cst = tt.splat %cst_scalar : (i32) -> tensor<16x32xi32, #block2> @@ -429,15 +433,17 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { #block0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [8], warpsPerCTA = [4], order = [0]}> #block2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 1], warpsPerCTA = [4, 1], order = [1, 0]}> #block3 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 8], warpsPerCTA = [1, 4], order = [1, 0]}> +#slice2d1 = #triton_gpu.slice<{dim = 1, parent=#block2}> +#slice3d0 = #triton_gpu.slice<{dim = 0, parent=#block3}> #AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> #A = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 4, order = [1, 0]}> module attributes {"triton_gpu.num-warps" = 4 : i32} { // CHECK-LABEL: basic_insert_slice_async_v1_multictas func @basic_insert_slice_async_v1_multictas(%arg0: !tt.ptr {tt.divisibility = 4 : i32}) { - %off0_ = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #block0> - %off1_ = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #block0> - %off0 = tt.expand_dims %off0_ {axis = 1 : i32} : (tensor<32xi32, #block0>) -> tensor<32x1xi32, #block2> - %off1 = tt.expand_dims %off1_ {axis = 0 : i32} : (tensor<32xi32, #block0>) -> tensor<1x32xi32, #block3> + %off0_ = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #slice2d1> + %off1_ = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #slice3d0> + %off0 = tt.expand_dims %off0_ {axis = 1 : i32} : (tensor<32xi32, #slice2d1>) -> tensor<32x1xi32, #block2> + %off1 = tt.expand_dims %off1_ {axis = 0 : i32} : (tensor<32xi32, #slice3d0>) -> tensor<1x32xi32, #block3> %broadcast_off0_scalar = tt.broadcast %off0 : (tensor<32x1xi32, #block2>) -> tensor<32x32xi32, #block2> %cst_scalar = arith.constant 32 : i32 %cst = tt.splat %cst_scalar : (i32) -> tensor<32x32xi32, #block2> diff --git a/test/TritonGPU/coalesce.mlir b/test/TritonGPU/coalesce.mlir index f34f10003..e6d137e71 100644 --- a/test/TritonGPU/coalesce.mlir +++ b/test/TritonGPU/coalesce.mlir @@ -1,8 +1,10 @@ -// RUN: triton-opt %s -split-input-file -tritongpu-coalesce -canonicalize -tritongpu-verifier | FileCheck %s +// RUN: triton-opt %s -split-input-file -tritongpu-coalesce -canonicalize | FileCheck %s #blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> #blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> #blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1]}> +#slice1dim1 = #triton_gpu.slice<{dim = 1, parent = #blocked1}> +#slice2dim0 = #triton_gpu.slice<{dim = 0, parent = #blocked2}> module attributes {"triton_gpu.num-warps" = 4 : i32} { @@ -23,13 +25,14 @@ func @transpose(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}) { %cst = arith.constant dense : tensor<64x64xi1, #blocked1> %cst_0 = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #blocked1> - %0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #blocked0> - %1 = tt.expand_dims %0 {axis = 1 : i32} : (tensor<64xi32, #blocked0>) -> tensor<64x1xi32, #blocked1> + %00 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #slice1dim1> + %01 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #slice2dim0> + %1 = tt.expand_dims %00 {axis = 1 : i32} : (tensor<64xi32, #slice1dim1>) -> tensor<64x1xi32, #blocked1> %2 = tt.splat %arg1 : (i32) -> tensor<64x1xi32, #blocked1> %3 = arith.muli %1, %2 : tensor<64x1xi32, #blocked1> %4 = tt.splat %arg0 : (!tt.ptr) -> tensor<64x1x!tt.ptr, #blocked1> %5 = tt.addptr %4, %3 : tensor<64x1x!tt.ptr, #blocked1> - %6 = tt.expand_dims %0 {axis = 0 : i32} : (tensor<64xi32, #blocked0>) -> tensor<1x64xi32, #blocked2> + %6 = tt.expand_dims %01 {axis = 0 : i32} : (tensor<64xi32, #slice2dim0>) -> tensor<1x64xi32, #blocked2> %7 = tt.broadcast %5 : (tensor<64x1x!tt.ptr, #blocked1>) -> tensor<64x64x!tt.ptr, #blocked1> %8 = tt.broadcast %6 : (tensor<1x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked2> %9 = triton_gpu.convert_layout %8 : (tensor<64x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked1> diff --git a/test/TritonGPU/combine.mlir b/test/TritonGPU/combine.mlir index 8a0aabf67..91ad2703b 100644 --- a/test/TritonGPU/combine.mlir +++ b/test/TritonGPU/combine.mlir @@ -53,7 +53,9 @@ func @remat(%arg0: i32) -> tensor<1024xi32, #layout1> { #blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> #blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> +#slice1dim1 = #triton_gpu.slice<{dim = 1, parent = #blocked1}> #blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1]}> +#slice2dim0 = #triton_gpu.slice<{dim = 0, parent = #blocked2}> #blocked3 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}> #blocked4 = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [0, 1]}> @@ -90,13 +92,14 @@ func @transpose(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: i32 {tt // CHECK: return %cst = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #blocked1> %cst_0 = arith.constant dense : tensor<64x64xi1, #blocked1> - %0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #blocked0> - %1 = tt.expand_dims %0 {axis = 1 : i32} : (tensor<64xi32, #blocked0>) -> tensor<64x1xi32, #blocked1> + %00 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #slice1dim1> + %01 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #slice2dim0> + %1 = tt.expand_dims %00 {axis = 1 : i32} : (tensor<64xi32, #slice1dim1>) -> tensor<64x1xi32, #blocked1> %2 = tt.splat %arg1 : (i32) -> tensor<64x1xi32, #blocked1> %3 = arith.muli %1, %2 : tensor<64x1xi32, #blocked1> %4 = tt.splat %arg0 : (!tt.ptr) -> tensor<64x1x!tt.ptr, #blocked1> %5 = tt.addptr %4, %3 : tensor<64x1x!tt.ptr, #blocked1> - %6 = tt.expand_dims %0 {axis = 0 : i32} : (tensor<64xi32, #blocked0>) -> tensor<1x64xi32, #blocked2> + %6 = tt.expand_dims %01 {axis = 0 : i32} : (tensor<64xi32, #slice2dim0>) -> tensor<1x64xi32, #blocked2> %7 = tt.broadcast %5 : (tensor<64x1x!tt.ptr, #blocked1>) -> tensor<64x64x!tt.ptr, #blocked1> %8 = tt.broadcast %6 : (tensor<1x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked2> %9 = triton_gpu.convert_layout %8 : (tensor<64x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked1> @@ -138,13 +141,14 @@ func @loop(%arg0: !tt.ptr, %arg1: i32, %arg2: !tt.ptr, %arg3: i32, %ar %c32 = arith.constant 32 : index %c0 = arith.constant 0 : index %cst_1 = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #blocked1> - %0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #blocked0> - %1 = tt.expand_dims %0 {axis = 1 : i32} : (tensor<64xi32, #blocked0>) -> tensor<64x1xi32, #blocked1> + %00 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #slice1dim1> + %01 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #slice2dim0> + %1 = tt.expand_dims %00 {axis = 1 : i32} : (tensor<64xi32, #slice1dim1>) -> tensor<64x1xi32, #blocked1> %2 = tt.splat %arg1 : (i32) -> tensor<64x1xi32, #blocked1> %3 = arith.muli %1, %2 : tensor<64x1xi32, #blocked1> %4 = tt.splat %arg0 : (!tt.ptr) -> tensor<64x1x!tt.ptr, #blocked1> %5 = tt.addptr %4, %3 : tensor<64x1x!tt.ptr, #blocked1> - %6 = tt.expand_dims %0 {axis = 0 : i32} : (tensor<64xi32, #blocked0>) -> tensor<1x64xi32, #blocked2> + %6 = tt.expand_dims %01 {axis = 0 : i32} : (tensor<64xi32, #slice2dim0>) -> tensor<1x64xi32, #blocked2> %7 = tt.broadcast %5 : (tensor<64x1x!tt.ptr, #blocked1>) -> tensor<64x64x!tt.ptr, #blocked1> %8 = tt.broadcast %6 : (tensor<1x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked2> %9 = triton_gpu.convert_layout %8 : (tensor<64x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked1> diff --git a/test/TritonGPU/loop-pipeline.mlir b/test/TritonGPU/loop-pipeline.mlir index 320e2116a..f24dd67f1 100644 --- a/test/TritonGPU/loop-pipeline.mlir +++ b/test/TritonGPU/loop-pipeline.mlir @@ -1,4 +1,4 @@ -// RUN: triton-opt %s -split-input-file -tritongpu-pipeline=num-stages=3 -canonicalize -tritongpu-verifier | FileCheck %s +// RUN: triton-opt %s -split-input-file -tritongpu-pipeline=num-stages=3 -canonicalize | FileCheck %s // 4 warps // matmul: 128x32 @ 32x128 -> 128x128