diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index cabbdba90..e96dfbafe 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -10,7 +10,7 @@ jobs: Integration-Tests: - runs-on: ubuntu-20.04 + runs-on: ubuntu-latest steps: diff --git a/include/triton/Analysis/AxisInfo.h b/include/triton/Analysis/AxisInfo.h index c9be250fc..7dfc6a08f 100644 --- a/include/triton/Analysis/AxisInfo.h +++ b/include/triton/Analysis/AxisInfo.h @@ -18,16 +18,14 @@ namespace mlir { /// Axis information is represented by a std::map class AxisInfo { public: - typedef std::vector ContiguityT; - typedef std::vector DivisibilityT; - typedef std::vector ConstancyT; + typedef SmallVector DimVectorT; public: // Default constructor AxisInfo() : AxisInfo({}, {}, {}) {} // Construct contiguity info with known contiguity - AxisInfo(ContiguityT knownContiguity, DivisibilityT knownDivisibility, - ConstancyT knownConstancy) + AxisInfo(DimVectorT knownContiguity, DimVectorT knownDivisibility, + DimVectorT knownConstancy) : contiguity(knownContiguity), divisibility(knownDivisibility), constancy(knownConstancy), rank(contiguity.size()) { assert(knownDivisibility.size() == rank); @@ -36,13 +34,13 @@ public: // Accessors int getContiguity(size_t d) const { return contiguity[d]; } - const ContiguityT &getContiguity() const { return contiguity; } + const DimVectorT &getContiguity() const { return contiguity; } int getDivisibility(size_t d) const { return divisibility[d]; } - const DivisibilityT &getDivisibility() const { return divisibility; } + const DimVectorT &getDivisibility() const { return divisibility; } int getConstancy(size_t d) const { return constancy[d]; } - const ConstancyT &getConstancy() const { return constancy; } + const DimVectorT &getConstancy() const { return constancy; } int getRank() const { return rank; } @@ -78,7 +76,7 @@ private: /// [18, 22, 26, 30] /// [19, 23, 27, 31] /// Would have contiguity [2, 1]. - ContiguityT contiguity; + DimVectorT contiguity; /// The _divisibility_ information maps the `d`-th /// dimension to the largest power-of-two that @@ -93,7 +91,7 @@ private: /// [14, 18, 22, 26] /// [15, 19, 23, 27] // would have divisibility [4, 1] - DivisibilityT divisibility; + DimVectorT divisibility; /// The _constancy_ information maps the `d`-th /// dimension to the length of the shortest @@ -104,7 +102,7 @@ private: /// [8, 8, 8, 8, 12, 12, 12, 12] /// [16, 16, 16, 16, 20, 20, 20, 20] /// would have constancy [1, 4] - ConstancyT constancy; + DimVectorT constancy; // number of dimensions of the lattice int rank; diff --git a/include/triton/Dialect/Triton/IR/TritonOps.td b/include/triton/Dialect/Triton/IR/TritonOps.td index d8e51ed06..ba65e11fe 100644 --- a/include/triton/Dialect/Triton/IR/TritonOps.td +++ b/include/triton/Dialect/Triton/IR/TritonOps.td @@ -133,8 +133,8 @@ def TT_GEPOp : TT_Op<"getelementptr", // // Shape Manipulation Ops // -def TT_ReshapeOp : TT_Op<"reshape", [NoSideEffect, SameOperandsAndResultElementType]> { - let summary = "reshape"; +def TT_ViewOp : TT_Op<"view", [NoSideEffect, SameOperandsAndResultElementType]> { + let summary = "view"; let arguments = (ins TT_Tensor:$src); diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td index 0f0e570ae..f5f5d5904 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td @@ -2,43 +2,60 @@ #define TRITONGPU_ATTRDEFS include "triton/Dialect/TritonGPU/IR/TritonGPUDialect.td" -// include "mlir/IR/TensorEncoding.td" + +//===----------------------------------------------------------------------===// +// TritonGPU Attribute Definitions +//===----------------------------------------------------------------------===// class TritonGPU_Attr traits = [], string baseCppClass = "::mlir::Attribute"> - : AttrDef; + : AttrDef { + + let description = [{ +TritonGPU Tensors differ from usual tensors in that they contain a _layout_ attribute which determines +how the data should be partitioned across CUDA threads. Formally speaking, we define a layout as a function +\mathcal{L} that maps a multi-dimensional tensor index $i \in \mathbb{Z}^d$ to a set of integers T corresponding +to the indices of the CUDA threads allowed to access some data at index $i$. + +For example, let us consider the layout function: +\mathcal{L}(0, 0) = {0, 4} +\mathcal{L}(0, 1) = {1, 5} +\mathcal{L}(1, 0) = {2, 6} +\mathcal{L}(1, 1) = {3, 7} + +Then, attaching $\mathcal{L} to a tensor $T$ would mean that: +- T[0,0] is owned by both cuda thread 0 and 4 +- T[0,1] is owned by both cuda thread 1 and 5 +- T[1,0] is owned by both cuda thread 2 and 6 +- T[1,1] is owned by both cuda thread 3 and 7 + +Right now, Triton implements two classes of layouts: shared, and distributed. + }]; +} + +//===----------------------------------------------------------------------===// +// Shared Layout Encoding +//===----------------------------------------------------------------------===// def TritonGPUSharedEncodingAttr : TritonGPU_Attr<"TritonGPUSharedEncoding"> { - let mnemonic = "shared_layout"; + let mnemonic = "shared"; let description = [{ An encoding for tensors whose elements may be simultaneously accessed by -different warps in the programs, via shared memory. +different cuda threads in the programs, via shared memory. In other words, +for all indices i \in R^d, \mathcal{L}(i) = {0, 1, ..., 32*num_warps - 1}. -In order to avoid shared memory bank conflicts, elements may be stored in a -swizzled layout. -For example, a swizzled row-major layout stores would store data as follows: +In order to avoid shared memory bank conflicts, elements may be swizzled +in memory. For example, a swizzled row-major layout could store its data +as follows: A_{0, 0} A_{0, 1} A_{0, 2} A_{0, 3} ... [phase 0] \ per_phase = 2 A_{1, 0} A_{0, 1} A_{1, 2} A_{1, 3} ... [phase 0] / - groups of vec=2 elements are stored contiguously _ _ _ _ /\_ _ _ _ A_{2, 2} A_{2, 3} A_{2, 0} A_{2, 1} ... [phase 1] \ per phase = 2 A_{3, 2} A_{3, 3} A_{3, 0} A_{3, 1} ... [phase 1] / - - -And the associated TritonGPU MLIR - -```mlir -#SMEM = #triton_gpu.shared_layout<{ - vec = 2, - perPhase = 2, - maxPhase = 4, - order = [1, 0] -}> -``` }]; let parameters = ( @@ -49,143 +66,222 @@ And the associated TritonGPU MLIR ); } +//===----------------------------------------------------------------------===// +// Distributed Layout Encoding +//===----------------------------------------------------------------------===// + +class TritonGPUDistributedEncodingAttr : TritonGPU_Attr<"TritonGPUDistributedEncoding"> { + let mnemonic = "distributed"; + + let description = [{ +Distributed encodings have a layout function that is entirely characterized +by a d-dimensional tensor L. Note that L doesn't need to have the same shape +(or even the same rank) as the tensor it is encoding. + +The layout function \mathcal{L} of this layout is then defined, for an +index `i` \in R^D, as follows: + +\mathcal{L}(A)[i_d] = L[(i_d + k_d*A.shape[d]) % L.shape[d]] \forall k_d such as i_d + k_d*A.shape[d] < L.shape[d] + +For example, for a tensor/layout pair +A = [x x x x x x x x] + [x x x x x x x x] +L = [0 1 2 3 ] + [4 5 6 7 ] + [8 9 10 11] + [12 13 14 15] + +Then the data of A would be distributed as follow between the 16 CUDA threads: +L(A) = [ {0,8} , {1,9} , {2,10}, {3,11}, {0,8} , {1, 9} , {2, 10}, {3, 11}, + {4,12}, {5,13}, {6,14}, {7,15}, {4,12}, {5, 13}, {6, 14}, {7, 15} ] + }]; +} + + +//===----------------------------------------------------------------------===// +// Blocked Layout Encoding +//===----------------------------------------------------------------------===// + def TritonGPUBlockedEncodingAttr : TritonGPU_Attr<"TritonGPUBlockedEncoding"> { - let mnemonic = "blocked_layout"; + let mnemonic = "blocked"; let description = [{ An encoding where each warp owns a contiguous portion of the target tensor. This is typically the kind of data layout -consumed (and returned) by LoadInst. -For example, a row-major coalesced layout may distribute a 64x16 tensor over 2 warps (i.e. 64 threads) as follows: +used to promote memory coalescing in LoadInst and StoreInst. +It is characterized by three tuples -- thread tile size, warp tile size, and block tile size -- which +specify the amount of elements owned by each CUDA thread, warp and CTA respectively. - thread tile size 2 - - - - - - - /\ - - - - - - -warp | thread || A_{0, 0}[T0] A_{0, 1}[T0] ... A_{0, 6}[T3] A_{0, 7}[T3] A_{0, 8}[T0] A_{0, 9}[T0] ... A_{0, 14}[T3] A_{0, 15}[T3] -tile | tile size 2 || A_{1, 0}[T0] A_{1, 1}[T0] ... A_{1, 6}[T3] A_{1, 7}[T3] A_{1, 8}[T0] A_{1, 9}[T0] ... A_{1, 14}[T3] A_{1, 15}[T3] -size } .... -32 | A_{30, 0}[T60] A_{14, 1}[T60] ... A_{14, 6}[T63] A_{14, 7}[T63] A_{14, 8}[T60] A_{14, 9}[T60] ... A_{14, 14}[T63] A_{14, 15}[T63] - | A_{31, 0}[T60] A_{15, 1}[T60] ... A_{15, 6}[T63] A_{15, 7}[T63] A_{15, 8}[T60] A_{15, 9}[T60] ... A_{15, 14}[T63] A_{15, 15}[T63] - -----------------------------/\----------------------------------- - warp tile size 8 +For example, a row-major coalesced layout may partition a 16x16 tensor over 2 warps (i.e. 64 threads) as follows. - - A_{32, 0}[T0] A_{32, 1}[T0] ... A_{32, 6}[T3] A_{32, 7}[T3] A_{32, 8}[T0] A_{32, 9}[T0] ... A_{32, 14}[T3] A_{32, 15}[T3] - A_{33, 0}[T0] A_{33, 1}[T0] ... A_{33, 6}[T3] A_{33, 7}[T3] A_{33, 8}[T0] A_{33, 9}[T0] ... A_{33, 14}[T3] A_{33, 15}[T3] +[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] +[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] +[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] +[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] +... +[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] +[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] - A_{62, 0}[T60] A_{62, 1}[T60] ... A_{62, 6}[T63] A_{62, 7}[T63] A_{62, 8}[T60] A_{62, 9}[T60] ... A_{62, 14}[T63] A_{62, 15}[T63] - A_{63, 0}[T60] A_{63, 1}[T60] ... A_{63, 6}[T63] A_{63, 7}[T63] A_{63, 8}[T60] A_{63, 9}[T60] ... A_{63, 14}[T63] A_{63, 15}[T63] +for -And the associated TritonGPU MLIR -#LAYOUT = #triton_gpu.blocked_layout<{ - threadTileSize = {2, 2} - blockTileSize = {32, 8} +#triton_gpu.blocked_layout<{ + sizePerThread = {2, 2} + threadsPerWarp = {8, 4} + warpsPerCTA = {1, 2} }> - -// note to Da: In current Triton codebase, `nanoTileSize = threadTileSize`, and `macro-tile size = blockTileSize / threadTileSize` - probably clearer to have easier semantics (i.e., size of each tile owned by a thread or a block) }]; + + let builders = [ + // Custom builder initializes sizePerWarp and sizePerCTA automatically + // TODO: compiles on MacOS but not linux? + // AttrBuilder<(ins "ArrayRef":$sizePerThread, + // "ArrayRef":$threadsPerWarp, + // "ArrayRef":$warpsPerCTA, + // "ArrayRef":$order), [{ + // int rank = threadsPerWarp.size(); + // SmallVector sizePerWarp(rank); + // SmallVector sizePerCTA(rank); + // for (unsigned i = 0; i < rank; i++) { + // sizePerWarp.push_back(sizePerThread[i] * threadsPerWarp[i]); + // sizePerCTA.push_back(sizePerWarp[i] * warpsPerCTA[i]); + // } + // return $_get(context, sizePerThread, threadsPerWarp, warpsPerCTA, order, sizePerWarp, sizePerCTA); + // }]>, + // Custom builder initializes sizePerWarp and sizePerCTA automatically + // Default builder takes sizePerThread, order and numWarps, and tries to + // pack numWarps*32 threads in the provided order for use in a type + // of the given shape. + AttrBuilder<(ins "ArrayRef":$shape, + "ArrayRef":$sizePerThread, + "ArrayRef":$order, + "unsigned":$numWarps), [{ + int rank = sizePerThread.size(); + int remainingWarps = numWarps; + int remainingLanes = 32; + SmallVector threadsPerWarp(rank); + SmallVector warpsPerCTA(rank); + for (int _dim = 0; _dim < rank; ++_dim) { + int dim = order[_dim]; + int maxNumThreads = int(shape[dim]) / sizePerThread[dim]; + warpsPerCTA[dim] = std::clamp(remainingWarps, 1, maxNumThreads); + threadsPerWarp[dim] = std::clamp(remainingLanes, 1, maxNumThreads); + remainingWarps /= warpsPerCTA[dim]; + remainingLanes /= threadsPerWarp[dim]; + } + + return $_get(context, sizePerThread, threadsPerWarp, warpsPerCTA, order); + + }]> + ]; + + let parameters = ( ins - // TODO: should we rename this as laneTileSize? - ArrayRefParameter< - "unsigned", - /*desc*/"size of a tile that is holded by a thread" - >:$threadTileSize, - ArrayRefParameter< - "unsigned", - "size of the a tile that is holded by a warp" - >:$warpTileSize, - ArrayRefParameter< - "unsigned", - "size of a tile that is holded by a thread block" - >:$blockTileSize, - // // TODO: It seems that we don't need this (because we can re-compute this) - // ArrayRefParameter<"unsigned">:$reptitions, + ArrayRefParameter<"unsigned">:$sizePerThread, + ArrayRefParameter<"unsigned">:$threadsPerWarp, + ArrayRefParameter<"unsigned">:$warpsPerCTA, // fastest-changing axis first ArrayRefParameter< "unsigned", "order of axes by the rate of changing" - >:$order, - ArrayRefParameter<"unsigned">:$broadcastAxis - // "AffineMap":$threadOrdering, - // "AffineMap":warpOrdering, - // "AffineMap":$blockOrdering, - + >:$order + // These attributes can be inferred from the rest + // ArrayRefParameter<"unsigned">:$sizePerWarp, + // ArrayRefParameter<"unsigned">:$sizePerCTA ); - // let genVerifyDecl = 1; } -def TritonGPUBlockedMulticastEncodingAttr - : TritonGPU_Attr<"TritonGPUBlockedMulticastEncoding"> { - let mnemonic = "blocked_multicast_layout"; - - let description = [{ - to be broadcasted to blocked_layout - }]; - - // This needs to be synced with BlockedEncoding - let parameters = ( - ins - ArrayRefParameter<"unsigned">:$threadTileSize, - ArrayRefParameter<"unsigned">:$warpTileSize, - ArrayRefParameter<"unsigned">:$blockTileSize, - ArrayRefParameter<"unsigned">:$order, - // unique to broadcasted layout - ArrayRefParameter<"unsigned">:$broadcastAxis - ); - - // let genVerifyDecl = 1; -} +//===----------------------------------------------------------------------===// +// MMA Layout Encoding +//===----------------------------------------------------------------------===// +// TODO: MMAv1 and MMAv2 should be two instances of the same class def TritonGPUMmaEncodingAttr : TritonGPU_Attr<"TritonGPUMmaEncoding"> { - let mnemonic = "mma_layout"; - - let description = [{TODO: I think we may be able to implement it as a special-case of Distributed encoding with maybe one more warpTileSize attribute!}]; - - let parameters = ( - ins - // only used by Volta mma.884 - ArrayRefParameter<"unsigned">:$fragmentPerWarp, - // aka shapeOfInstr (e.g., {16,8,16}) - ArrayRefParameter<"unsigned">:$shapePerWarp, - // TODO: should we rename this as warpTileSize? (consistent naming with Distributed layout) - ArrayRefParameter<"unsigned">:$warpPerTile, - // TODO: should we rename this as blockTileSize? (consistent naming with Distributed layout) - ArrayRefParameter<"unsigned">:$shapePerTile, - // TODO: should Distributed layout also - ArrayRefParameter<"unsigned">:$repetitions, - ArrayRefParameter<"unsigned">:$contigPerThread, - ArrayRefParameter<"unsigned">:$broadcastAxis - // "AffineMap":$warpOrdering, - // "AffineMap":$blockOrdering - ); - - // let genVerifyDecl = 1; -} - -def TritonGPUMmaMulticastEncodingAttr - : TritonGPU_Attr<"TritonGPUMmaMulticastEncoding"> { - let mnemonic = "mma_multicast_layout"; + let mnemonic = "mma"; let description = [{ - To be broadcasted to mma. - }]; +An encoding for tensors that have been produced by tensor cores. +It is characterized by two parameters: +- A 'version' which specifies the generation the tensor cores +whose output is being partitioned: 1 for first-gen tensor cores (Volta), +and 2 for second-gen tensor cores (Turing/Ampere). +- A `blockTileSize` to indicate how data should be +partitioned between warps. + +// -------------------------------- version = 1 --------------------------- // + +For first-gen tensor cores, the implicit warpTileSize is [16, 16]. +Information about this layout can be found in the official PTX documentation +https://docs.nvidia.com/cuda/parallel-thread-execution/index.html +(mma.884 section, FP32 accumulator). + +For example, the matrix L corresponding to blockTileSize=[32,16] is: + + warp 0 +--------------------------------/\------------------------------- +[ 0 0 2 2 0 0 2 2 4 4 6 6 4 4 6 6 ] +[ 1 1 3 3 1 1 3 3 5 5 7 7 5 5 7 7 ] +[ 0 0 2 2 0 0 2 2 4 4 6 6 4 4 6 6 ] +[ 1 1 3 3 1 1 3 3 5 5 7 7 5 5 7 7 ] +[ 16 16 18 18 16 16 18 18 20 20 22 22 20 20 22 22] +[ 17 17 19 19 17 17 19 19 21 21 23 23 21 21 23 23] +[ 16 16 18 18 16 16 18 18 20 20 22 22 20 20 22 22] +[ 17 17 19 19 17 17 19 19 21 21 23 23 21 21 23 23] +[ 8 8 10 10 8 8 10 10 12 12 14 14 12 12 14 14] +[ 9 9 11 11 9 9 11 11 13 13 15 15 13 13 15 15] +[ .............................................................. +[ .............................................................. +[ 24 24 26 26 24 24 26 26 28 28 30 30 28 28 30 30] +[ 25 25 27 27 25 25 27 27 29 29 31 31 29 29 31 31] + + warp 1 = warp0 + 32 +--------------------------------/\------------------------------- +[ 32 32 34 34 32 32 34 34 36 36 38 38 36 36 38 38] +[ 33 33 35 35 33 33 35 35 37 37 39 39 37 37 39 39] +[ .............................................................. +[ .............................................................. +[ 56 56 58 58 56 56 58 58 60 60 62 62 60 60 62 62] +[ 57 57 59 59 57 57 59 59 61 61 63 63 61 61 63 63] + +// -------------------------------- version = 2 --------------------------- // + +For second-gen tensor cores, the implicit warpTileSize is [16, 8]. +Information about this layout can be found in the official PTX documentation +https://docs.nvidia.com/cuda/parallel-thread-execution/index.html +(mma.16816 section, FP32 accumulator). + +For example, the matrix L corresponding to blockTileSize=[32,16] is: + warp 0 warp 1 +-----------------/\------------- ----------------/\------------- +[ 0 0 1 1 2 2 3 3 32 32 33 33 34 34 35 35 +[ 4 4 5 5 6 6 7 7 36 36 37 37 38 38 39 39 +[ .............................. .............................. +[ 28 28 29 29 30 30 31 31 60 60 61 61 62 62 63 63 +[ 0 0 1 1 2 2 3 3 32 32 33 33 34 34 35 35 +[ 4 4 5 5 6 6 7 7 36 36 37 37 38 38 39 39 +[ .............................. .............................. +[ 28 28 29 29 30 30 31 31 60 60 61 61 62 62 63 63 + + warp 3 warp 4 +----------------/\------------- ----------------/\------------- +[ 64 64 65 65 66 66 67 67 96 96 97 97 98 98 99 99 +[ 68 68 69 69 70 70 71 71 100 100 101 101 102 102 103 103 +[ .............................. ............................... +[ 92 92 93 93 94 94 95 95 124 124 125 125 126 126 127 127 +[ 64 64 65 65 66 66 67 67 96 96 97 97 98 98 99 99 +[ 68 68 69 69 70 70 71 71 100 100 101 101 102 102 103 103 +[ .............................. ............................... +[ 92 92 93 93 94 94 95 95 124 124 125 125 126 126 127 127 + +}]; - // This needs to be synced with MmaEncoding let parameters = ( ins - ArrayRefParameter<"unsigned">:$fragmentPerWarp, - ArrayRefParameter<"unsigned">:$shapePerWarp, - ArrayRefParameter<"unsigned">:$warpPerTile, - ArrayRefParameter<"unsigned">:$shapePerTile, - ArrayRefParameter<"unsigned">:$repetitions, - ArrayRefParameter<"unsigned">:$contigPerThread, - // unique to broadcasted layout - ArrayRefParameter<"unsigned">:$broadcastAxis + "unsigned":$version, + ArrayRefParameter<"unsigned">:$warpsPerCTA ); - - // let genVerifyDecl = 1; } + #endif diff --git a/include/triton/Dialect/TritonGPU/Transforms/Passes.td b/include/triton/Dialect/TritonGPU/Transforms/Passes.td index 74a3413f0..2ad90bc48 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/Passes.td +++ b/include/triton/Dialect/TritonGPU/Transforms/Passes.td @@ -7,16 +7,7 @@ def TritonGPUPipeline : Pass<"tritongpu-pipeline", "mlir::ModuleOp"> { let summary = "pipeline"; let description = [{ - scf.for() { - %a = load %a_ptr; - %b = load %b_ptr; - - %d = dot %a, %b, %c; - } - - => - - ... + TODO }]; let constructor = "mlir::createTritonGPUPipelinePass()"; diff --git a/include/triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h b/include/triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h index 6cb59c327..6669d2d9a 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h +++ b/include/triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h @@ -13,11 +13,11 @@ namespace mlir { class TritonGPUTypeConverter : public TypeConverter { public: - TritonGPUTypeConverter(MLIRContext *context, int numThreads); + TritonGPUTypeConverter(MLIRContext *context, int numWarps); private: MLIRContext *context; - int numThreads; + int numWarps; }; class TritonGPUConversionTarget : public ConversionTarget { @@ -26,9 +26,6 @@ class TritonGPUConversionTarget : public ConversionTarget { public: explicit TritonGPUConversionTarget(MLIRContext &ctx, TritonGPUTypeConverter &typeConverter); - - /// update layouts & insert ConvertLayoutOp if necessary - LogicalResult refineLayouts(ModuleOp mod, int numThreads); }; } // namespace mlir diff --git a/lib/Analysis/AxisInfo.cpp b/lib/Analysis/AxisInfo.cpp index ef926c190..a11f1a970 100644 --- a/lib/Analysis/AxisInfo.cpp +++ b/lib/Analysis/AxisInfo.cpp @@ -48,17 +48,17 @@ AxisInfo AxisInfo::getPessimisticValueState(Value value) { divHint = attr.cast().getValue().getZExtValue(); } } - ContiguityT contiguity(rank, 1); - DivisibilityT divisibility(rank, divHint); - ConstancyT constancy(rank, 1); + DimVectorT contiguity(rank, 1); + DimVectorT divisibility(rank, divHint); + DimVectorT constancy(rank, 1); return AxisInfo(contiguity, divisibility, constancy); } // The gcd of both arguments for each dimension AxisInfo AxisInfo::join(const AxisInfo &lhs, const AxisInfo &rhs) { - ContiguityT retContiguity; - DivisibilityT retDivisibility; - ConstancyT retConstancy; + DimVectorT retContiguity; + DimVectorT retDivisibility; + DimVectorT retConstancy; for (size_t d = 0; d < lhs.getRank(); d++) { retContiguity.push_back(gcd(lhs.getContiguity(d), rhs.getContiguity(d))); retDivisibility.push_back( @@ -78,9 +78,9 @@ AxisInfo AxisInfoAnalysis::visitBinaryOp( const std::function &getDivisibility, const std::function &getConstancy) { int rank = lhsInfo.getRank(); - AxisInfo::ContiguityT newContiguity; - AxisInfo::DivisibilityT newDivisibility; - AxisInfo::ConstancyT newConstancy; + AxisInfo::DimVectorT newContiguity; + AxisInfo::DimVectorT newDivisibility; + AxisInfo::DimVectorT newConstancy; for (size_t d = 0; d < rank; d++) { newContiguity.push_back(getContiguity(lhsInfo, rhsInfo, d)); newDivisibility.push_back(getDivisibility(lhsInfo, rhsInfo, d)); @@ -101,9 +101,9 @@ ChangeResult AxisInfoAnalysis::visitOperation( llvm::dyn_cast(op)) { int start = make_range.start(); int end = make_range.end(); - AxisInfo::ContiguityT contiguity = {end - start}; - AxisInfo::DivisibilityT divisibility = {highestPowOf2Divisor(start)}; - AxisInfo::ConstancyT constancy = {1}; + AxisInfo::DimVectorT contiguity = {end - start}; + AxisInfo::DimVectorT divisibility = {highestPowOf2Divisor(start)}; + AxisInfo::DimVectorT constancy = {1}; curr = AxisInfo(contiguity, divisibility, constancy); } // Constant @@ -119,9 +119,9 @@ ChangeResult AxisInfoAnalysis::visitOperation( auto value = splatAttr.getSplatValue(); TensorType ty = splatAttr.getType().cast(); curr = AxisInfo( - AxisInfo::ContiguityT(ty.getRank(), 1), - AxisInfo::DivisibilityT(ty.getRank(), highestPowOf2Divisor(value)), - AxisInfo::ConstancyT(ty.getShape().begin(), ty.getShape().end())); + AxisInfo::DimVectorT(ty.getRank(), 1), + AxisInfo::DimVectorT(ty.getRank(), highestPowOf2Divisor(value)), + AxisInfo::DimVectorT(ty.getShape().begin(), ty.getShape().end())); } } // Addition @@ -156,9 +156,9 @@ ChangeResult AxisInfoAnalysis::visitOperation( Type _retTy = *op->result_type_begin(); TensorType retTy = _retTy.cast(); AxisInfo opInfo = operands[0]->getValue(); - AxisInfo::ContiguityT contiguity; - AxisInfo::DivisibilityT divisibility; - AxisInfo::ConstancyT constancy; + AxisInfo::DimVectorT contiguity; + AxisInfo::DimVectorT divisibility; + AxisInfo::DimVectorT constancy; for (size_t d = 0; d < retTy.getRank(); d++) { contiguity.push_back(1); divisibility.push_back(opInfo.getDivisibility(0)); @@ -167,8 +167,7 @@ ChangeResult AxisInfoAnalysis::visitOperation( curr = AxisInfo(contiguity, divisibility, constancy); } // Reshape - // TODO: Replace by `unsqueeze` - if (llvm::isa(op)) { + if (llvm::isa(op)) { Type _retTy = *op->result_type_begin(); Type _opTy = *op->operand_type_begin(); TensorType retTy = _retTy.cast(); @@ -176,9 +175,9 @@ ChangeResult AxisInfoAnalysis::visitOperation( ArrayRef retShape = retTy.getShape(); ArrayRef opShape = opTy.getShape(); AxisInfo opInfo = operands[0]->getValue(); - AxisInfo::ContiguityT contiguity; - AxisInfo::DivisibilityT divisibility; - AxisInfo::ConstancyT constancy; + AxisInfo::DimVectorT contiguity; + AxisInfo::DimVectorT divisibility; + AxisInfo::DimVectorT constancy; bool is_skewed = false; size_t current = 0; for (size_t d = 0; d < retTy.getRank(); d++) { @@ -209,9 +208,9 @@ ChangeResult AxisInfoAnalysis::visitOperation( ArrayRef retShape = retTy.getShape(); ArrayRef opShape = opTy.getShape(); AxisInfo opInfo = operands[0]->getValue(); - AxisInfo::ContiguityT contiguity; - AxisInfo::DivisibilityT divisibility; - AxisInfo::ConstancyT constancy; + AxisInfo::DimVectorT contiguity; + AxisInfo::DimVectorT divisibility; + AxisInfo::DimVectorT constancy; for (size_t d = 0; d < retTy.getRank(); d++) { contiguity.push_back(opShape[d] == 1 ? 1 : opInfo.getContiguity(d)); divisibility.push_back(opInfo.getDivisibility(d)); diff --git a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp index c927e766d..4147a6256 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp @@ -218,6 +218,28 @@ struct TritonGenericPattern : public OpConversionPattern { } }; +struct TritonBroadcastPattern + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + // This creates a tensor with the new shape but the argument's layout + LogicalResult + matchAndRewrite(BroadcastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto srcType = adaptor.src().getType().cast(); + auto srcEncoding = srcType.getEncoding(); + if (!srcEncoding) + return failure(); + auto opType = op.getType().cast(); + Type retType = RankedTensorType::get(opType.getShape(), + opType.getElementType(), srcEncoding); + // Type retType = this->getTypeConverter()->convertType(op.getType()); + rewriter.replaceOpWithNewOp(op, retType, + adaptor.getOperands()); + return success(); + } +}; + struct TritonReducePattern : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -234,12 +256,12 @@ struct TritonReducePattern : public OpConversionPattern { void populateTritonPatterns(TritonGPUTypeConverter &typeConverter, RewritePatternSet &patterns) { MLIRContext *context = patterns.getContext(); - patterns.add, - TritonGenericPattern, - TritonGenericPattern, - TritonGenericPattern, TritonReducePattern, - TritonMakeRangePattern, TritonDotPattern, TritonLoadPattern, - TritonStorePattern>(typeConverter, context); + patterns.add< // TODO: view should have custom pattern that views the layout + TritonGenericPattern, + TritonGenericPattern, TritonBroadcastPattern, + TritonGenericPattern, TritonReducePattern, + TritonMakeRangePattern, TritonDotPattern, TritonLoadPattern, + TritonStorePattern>(typeConverter, context); } // @@ -317,9 +339,8 @@ public: void runOnOperation() override { MLIRContext *context = &getContext(); ModuleOp mod = getOperation(); - int numThreads = numWarps * 32; // type converter - TritonGPUTypeConverter typeConverter(context, numThreads); + TritonGPUTypeConverter typeConverter(context, numWarps); TritonGPUConversionTarget target(*context, typeConverter); // rewrite patterns RewritePatternSet patterns(context); @@ -335,8 +356,8 @@ public: // update layouts // broadcast src => multicast, dst => broadcasted - if (failed(target.refineLayouts(mod, numWarps))) - return signalPassFailure(); + // if (failed(target.refineLayouts(mod, numWarps))) + // return signalPassFailure(); } }; diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 1413ea8fb..eefd13633 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -30,13 +30,28 @@ static LogicalResult parseIntArrayAttr(AsmParser &parser, return success(); }; +static LogicalResult parseUInt(AsmParser &parser, const NamedAttribute &attr, + unsigned &value, StringRef desc) { + auto intAttr = attr.getValue().dyn_cast(); + if (!intAttr) { + parser.emitError(parser.getNameLoc(), "expected an integer ") << desc; + return failure(); + } + value = intAttr.getUInt(); + return success(); +}; + //===----------------------------------------------------------------------===// // Attribute methods //===----------------------------------------------------------------------===// #define GET_ATTRDEF_CLASSES #include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.cpp.inc" -static Attribute parseBlocked(AsmParser &parser, Type type) { +//===----------------------------------------------------------------------===// +// Blocked Encoding +//===----------------------------------------------------------------------===// + +Attribute TritonGPUBlockedEncodingAttr::parse(AsmParser &parser, Type type) { if (parser.parseLess().failed()) return {}; // Parse the data as a dictionary @@ -46,32 +61,30 @@ static Attribute parseBlocked(AsmParser &parser, Type type) { if (parser.parseGreater().failed()) return {}; - SmallVector threadTileSize; - SmallVector warpTileSize; - SmallVector blockTileSize; + SmallVector sizePerThread; + SmallVector threadsPerWarp; + SmallVector warpsPerCTA; SmallVector order; - SmallVector broadcastAxis; for (const NamedAttribute &attr : dict) { - if (attr.getName() == "threadTileSize") { - if (parseIntArrayAttr(parser, attr, threadTileSize, "thread tile size") + if (attr.getName() == "sizePerThread") { + if (parseIntArrayAttr(parser, attr, sizePerThread, + "number of elements per thread") .failed()) return {}; - } else if (attr.getName() == "warpTileSize") { - if (parseIntArrayAttr(parser, attr, warpTileSize, "warp tile size") + } else if (attr.getName() == "threadsPerWarp") { + if (parseIntArrayAttr(parser, attr, threadsPerWarp, + "number of threads per warp") .failed()) return {}; - } else if (attr.getName() == "blockTileSize") { - if (parseIntArrayAttr(parser, attr, blockTileSize, "block tile size") + } else if (attr.getName() == "warpsPerCTA") { + if (parseIntArrayAttr(parser, attr, warpsPerCTA, + "number of warps per CTA") .failed()) return {}; } else if (attr.getName() == "order") { if (parseIntArrayAttr(parser, attr, order, "order").failed()) return {}; - } else if (attr.getName() == "broadcastAxis") { - if (parseIntArrayAttr(parser, attr, broadcastAxis, "broadcastAxis") - .failed()) - return {}; } else { parser.emitError(parser.getNameLoc(), "unexpected key: ") << attr.getName().strref(); @@ -80,39 +93,23 @@ static Attribute parseBlocked(AsmParser &parser, Type type) { } return parser.getChecked( - parser.getContext(), threadTileSize, warpTileSize, blockTileSize, order, - broadcastAxis); -} - -template -static void printBlocked(AsmPrinter &printer, const T *attr) { - printer << "<{" - << "threadTileSize = [" << attr->getThreadTileSize() << "]" - << ", warpTileSize = [" << attr->getWarpTileSize() << "]" - << ", blockTileSize = [" << attr->getBlockTileSize() << "]" - << ", order = [" << attr->getOrder() << "]" - << ", broadcastAxis = [" << attr->getBroadcastAxis() << "]" - << "}>"; -} - -Attribute TritonGPUBlockedEncodingAttr::parse(AsmParser &parser, Type type) { - return parseBlocked(parser, type); + parser.getContext(), sizePerThread, threadsPerWarp, warpsPerCTA, order); } void TritonGPUBlockedEncodingAttr::print(mlir::AsmPrinter &printer) const { - printBlocked(printer, this); + printer << "<{" + << "sizePerThread = [" << getSizePerThread() << "]" + << ", threadsPerWarp = [" << getThreadsPerWarp() << "]" + << ", warpsPerCTA = [" << getWarpsPerCTA() << "]" + << ", order = [" << getOrder() << "]" + << "}>"; } -Attribute TritonGPUBlockedMulticastEncodingAttr::parse(AsmParser &parser, - Type type) { - return parseBlocked(parser, type); -} +//===----------------------------------------------------------------------===// +// MMA encoding +//===----------------------------------------------------------------------===// -void TritonGPUBlockedMulticastEncodingAttr::print(AsmPrinter &printer) const { - printBlocked(printer, this); -} - -static Attribute parseMma(AsmParser &parser, Type type) { +Attribute TritonGPUMmaEncodingAttr::parse(AsmParser &parser, Type type) { if (parser.parseLess().failed()) return {}; DictionaryAttr dict; @@ -121,76 +118,34 @@ static Attribute parseMma(AsmParser &parser, Type type) { if (parser.parseGreater().failed()) return {}; - SmallVector fragmentPerWarp; - SmallVector shapePerWarp; - SmallVector warpPerTile; - SmallVector shapePerTile; - SmallVector repetitions; - SmallVector contigPerThread; - SmallVector broadcastAxis; + unsigned version = 0; + SmallVector warpsPerCTA; for (const NamedAttribute &attr : dict) { - if (attr.getName() == "fragmentPerWarp") { - if (parseIntArrayAttr(parser, attr, fragmentPerWarp, "fragmentPerWarp") - .failed()) + if (attr.getName() == "version") { + if (parseUInt(parser, attr, version, "version").failed()) return {}; - } else if (attr.getName() == "shapePerWarp") { - if (parseIntArrayAttr(parser, attr, shapePerWarp, "shapePerWarp") - .failed()) + } + if (attr.getName() == "warpsPerCTA") { + if (parseIntArrayAttr(parser, attr, warpsPerCTA, "warpsPerCTA").failed()) return {}; - } else if (attr.getName() == "warpPerTile") { - if (parseIntArrayAttr(parser, attr, warpPerTile, "warpPerTile").failed()) - return {}; - } else if (attr.getName() == "shapePerTile") { - if (parseIntArrayAttr(parser, attr, shapePerTile, "shapePerTile") - .failed()) - return {}; - } else if (attr.getName() == "repetitions") { - if (parseIntArrayAttr(parser, attr, repetitions, "repetitions").failed()) - return {}; - } else if (attr.getName() == "contigPerThread") { - if (parseIntArrayAttr(parser, attr, contigPerThread, "contigPerThread") - .failed()) - return {}; - } else { - parser.emitError(parser.getNameLoc(), "unexpected key: ") - << attr.getName().strref(); - return {}; } } - return parser.getChecked( - parser.getContext(), fragmentPerWarp, shapePerWarp, warpPerTile, - shapePerTile, repetitions, contigPerThread, broadcastAxis); -} - -template static void printMma(AsmPrinter &printer, T *attr) { - printer << "<{" - << "fragmentPerWarp = [" << attr->getFragmentPerWarp() << "]" - << ", shapePerWarp = [" << attr->getShapePerWarp() << "]" - << ", warpPerTile = [" << attr->getWarpPerTile() << "]" - << ", shapePerTile = [" << attr->getShapePerTile() << "]" - << ", repetitions = [" << attr->getRepetitions() << "]" - << ", contigPerThread = [" << attr->getContigPerThread() << "]" - << "}>"; -} - -Attribute TritonGPUMmaEncodingAttr::parse(AsmParser &parser, Type type) { - return parseMma(parser, type); + return parser.getChecked(parser.getContext(), + version, warpsPerCTA); } void TritonGPUMmaEncodingAttr::print(AsmPrinter &printer) const { - printMma(printer, this); + printer << "<{" + << "version = " << getVersion() << ", " + << "warpsPerCTA = [" << getWarpsPerCTA() << "]" + << "}>"; } -Attribute TritonGPUMmaMulticastEncodingAttr::parse(AsmParser &parser, - Type type) { - return parseMma(parser, type); -} - -void TritonGPUMmaMulticastEncodingAttr::print(AsmPrinter &printer) const { - printMma(printer, this); -} +//===----------------------------------------------------------------------===// +// Shared encoding +//===----------------------------------------------------------------------===// Attribute TritonGPUSharedEncodingAttr::parse(AsmParser &parser, Type type) { if (parser.parseLess().failed()) @@ -207,26 +162,15 @@ Attribute TritonGPUSharedEncodingAttr::parse(AsmParser &parser, Type type) { unsigned maxPhase = 0; SmallVector order; - auto parseUInt = [&parser](const NamedAttribute &attr, unsigned &value, - StringRef desc) -> LogicalResult { - auto intAttr = attr.getValue().dyn_cast(); - if (!intAttr) { - parser.emitError(parser.getNameLoc(), "expected an integer ") << desc; - return failure(); - } - value = intAttr.getUInt(); - return success(); - }; - for (const NamedAttribute &attr : dict) { if (attr.getName() == "vec") { - if (parseUInt(attr, vec, "vec").failed()) + if (parseUInt(parser, attr, vec, "vec").failed()) return {}; } else if (attr.getName() == "perPhase") { - if (parseUInt(attr, perPhase, "perPhase").failed()) + if (parseUInt(parser, attr, perPhase, "perPhase").failed()) return {}; } else if (attr.getName() == "maxPhase") { - if (parseUInt(attr, maxPhase, "maxPhase").failed()) + if (parseUInt(parser, attr, maxPhase, "maxPhase").failed()) return {}; } else if (attr.getName() == "order") { if (parseIntArrayAttr(parser, attr, order, "order").failed()) @@ -250,6 +194,10 @@ void TritonGPUSharedEncodingAttr::print(AsmPrinter &printer) const { << "}>"; } +//===----------------------------------------------------------------------===// +// ASM Interface (i.e.: alias) +//===----------------------------------------------------------------------===// + class TritonGPUOpAsmInterface : public OpAsmDialectInterface { public: using OpAsmDialectInterface::OpAsmDialectInterface; @@ -257,72 +205,18 @@ public: AliasResult getAlias(Attribute attr, raw_ostream &os) const override { if (auto mmaAttr = attr.dyn_cast()) { os << "mma"; - TritonGPUOpAsmInterface::printMma(mmaAttr, os); - return AliasResult::FinalAlias; - } else if (auto mmaMulticastAttr = - attr.dyn_cast()) { - os << "mma_multicast"; - TritonGPUOpAsmInterface::printMma(mmaAttr, os); return AliasResult::FinalAlias; } else if (auto sharedAttr = attr.dyn_cast()) { os << "shared"; - TritonGPUOpAsmInterface::printShared(sharedAttr, os); return AliasResult::FinalAlias; } else if (auto blockedAttr = attr.dyn_cast()) { os << "blocked"; - TritonGPUOpAsmInterface::printBlocked(blockedAttr, os); - return AliasResult::FinalAlias; - } else if (auto blockedMulticastAttr = - attr.dyn_cast()) { - os << "blocked_multicast"; - TritonGPUOpAsmInterface::printBlocked(blockedMulticastAttr, os); return AliasResult::FinalAlias; } OpAsmDialectInterface::getAlias(attr, os); return AliasResult::FinalAlias; } - -private: - static void printMma(const TritonGPUMmaEncodingAttr &attr, raw_ostream &os) { - TritonGPUOpAsmInterface::printArray(attr.getFragmentPerWarp(), os); - TritonGPUOpAsmInterface::printArray(attr.getShapePerWarp(), os); - TritonGPUOpAsmInterface::printArray(attr.getWarpPerTile(), os); - TritonGPUOpAsmInterface::printArray(attr.getShapePerTile(), os); - TritonGPUOpAsmInterface::printArray(attr.getRepetitions(), os); - TritonGPUOpAsmInterface::printArray(attr.getContigPerThread(), os); - } - - static void printShared(const TritonGPUSharedEncodingAttr &attr, - raw_ostream &os) { - os << "_" << attr.getVec(); - os << "_" << attr.getPerPhase(); - os << "_" << attr.getMaxPhase(); - TritonGPUOpAsmInterface::printArray(attr.getOrder(), os); - } - - template static void printBlocked(const T &attr, raw_ostream &os) { - TritonGPUOpAsmInterface::printArray(attr.getThreadTileSize(), os); - TritonGPUOpAsmInterface::printArray(attr.getWarpTileSize(), os); - TritonGPUOpAsmInterface::printArray(attr.getBlockTileSize(), os); - TritonGPUOpAsmInterface::printArray(attr.getOrder(), os); - TritonGPUOpAsmInterface::printArray(attr.getBroadcastAxis(), os); - } - - static void printArray(const ArrayRef &array, raw_ostream &os, - const std::string &delimiter = "x") { - os << "_"; - if (array.empty()) { - os << "none"; - return; - } - for (unsigned i = 0; i < array.size(); i++) { - os << array[i]; - if (i != array.size() - 1) { - os << delimiter; - } - } - } }; void TritonGPUDialect::initialize() { diff --git a/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp b/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp index 091ca05d3..521fa8964 100644 --- a/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp +++ b/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp @@ -3,6 +3,7 @@ #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include +#include using namespace mlir; using namespace mlir::triton::gpu; @@ -11,54 +12,26 @@ using namespace mlir::triton::gpu; // TypeConverter // TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context, - int numThreads) - : context(context), numThreads(numThreads) { + int numWarps) + : context(context), numWarps(numWarps) { // TODO: how does MLIR pick the right conversion? addConversion([](Type type) { return type; }); addConversion([this](RankedTensorType tensorType) -> RankedTensorType { - MLIRContext *context = this->context; - int numThreads = this->numThreads; - - llvm::ArrayRef shape = tensorType.getShape(); - Type elementType = tensorType.getElementType(); - int64_t rank = tensorType.getRank(); - int64_t numElements = tensorType.getNumElements(); - - // TODO: are there any better ways to raise this error? - if (!(numElements >= numThreads)) { - SmallVector buffer; - llvm::raw_svector_ostream os(buffer); - os << tensorType << " has " << numElements << " numElements " - << " smaller than numThreads (" << numThreads << ")\n" - << "consider using smaller num-warps\n"; - llvm::report_fatal_error(os.str()); - } - assert(numElements % numThreads == 0); - - // or assert no encoding? - - // Now we assume: - // contiguous = 1, order = 0, 1, 2, ..., - llvm::SmallVector threadTileSize(rank, 1); // naive layout - llvm::SmallVector warpTileSize(rank, 1); - llvm::SmallVector blockTileSize(rank); + // types with encoding are already in the right format + // TODO: check for layout encodings specifically + if (tensorType.getEncoding()) + return tensorType; + // pessimistic values for attributes: + // - 1 element per thread + // - order = arange(rank) + ArrayRef shape = tensorType.getShape(); + int rank = shape.size(); llvm::SmallVector order(rank); - llvm::SmallVector broadcastAxis; - int remainingThreads = numThreads; - int remainingLanes = /*warp size*/ 32; - for (int64_t dim = 0; dim < rank; ++dim) { - blockTileSize[dim] = std::clamp(remainingThreads, 1, int(shape[dim])); - warpTileSize[dim] = std::clamp(remainingLanes, 1, int(shape[dim])); - order[dim] = dim; - - remainingThreads /= blockTileSize[dim]; - remainingLanes /= warpTileSize[dim]; - // TODO: will we need repetition? - } + std::iota(order.begin(), order.end(), 0); + llvm::SmallVector sizePerThread(rank, 1); Attribute encoding = triton::gpu::TritonGPUBlockedEncodingAttr::get( - context, threadTileSize, warpTileSize, blockTileSize, order, - broadcastAxis); - return RankedTensorType::get(shape, elementType, encoding); + this->context, shape, sizePerThread, order, this->numWarps); + return RankedTensorType::get(shape, tensorType.getElementType(), encoding); }); // @@ -86,8 +59,12 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context, // NOTE: only for remapped values. addTargetMaterialization([&](OpBuilder &builder, RankedTensorType tensorType, ValueRange inputs, Location loc) { - llvm_unreachable("Not implemented"); - return llvm::None; + auto cast = + builder.create(loc, tensorType, inputs); + return Optional(cast.getResult()); + // return Optional(cast.getResult(0)); + // llvm_unreachable("Not implemented"); + // return llvm::None; }); } @@ -122,87 +99,6 @@ TritonGPUConversionTarget::TritonGPUConversionTarget( aEncoding.isa() && bEncoding && bEncoding.isa()) return true; - // // TODO: we should delete this - // if (this->typeConverter.isLegal(dotOp)) - // return true; return false; }); -} - -// %dst = tt.broadcast %src -// => -// %newSrc = convert_layout %src -// %bcst = tt.broadcast %newSrc -// %dst = convert_layout %bcst -LogicalResult TritonGPUConversionTarget::refineLayouts(ModuleOp mod, - int numThreads) { - // collect broadcasts - SmallVector broadcasts; - mod.walk([&](triton::BroadcastOp op) { broadcasts.push_back(op); }); - - BlockAndValueMapping mapping; - for (auto broadcast : broadcasts) { - OpBuilder builder(broadcast); - Value src = mapping.lookupOrDefault(broadcast.src()); - Type originSrcType = src.getType(); - Type originDstType = broadcast.getType(); - auto originDstTensorType = originDstType.dyn_cast(); - unsigned dstRank = originDstTensorType.getRank(); - - // compute newSrcType & broadcastAxis - Type newSrcType; - SmallVector broadcastAxis; - bool isSrcScalar = false; - if (auto tensorType = originSrcType.dyn_cast()) { - assert(tensorType.getRank() == dstRank && - "src & dst should have same rank (verifier should catch this)"); - for (unsigned ax = 0; ax < dstRank; ++ax) - if (tensorType.getShape()[ax] < originDstTensorType.getShape()[ax]) - broadcastAxis.push_back(ax); - - Attribute originSrcEnc = tensorType.getEncoding(); - if (auto blockedEnc = - originSrcEnc.dyn_cast()) { - auto newSrcEnc = TritonGPUBlockedMulticastEncodingAttr::get( - blockedEnc.getContext(), blockedEnc.getThreadTileSize(), - blockedEnc.getWarpTileSize(), blockedEnc.getBlockTileSize(), - blockedEnc.getOrder(), broadcastAxis); - newSrcType = RankedTensorType::get( - tensorType.getShape(), tensorType.getElementType(), newSrcEnc); - } else - llvm_unreachable("src of broadcast should have blocked encoding"); - } else { - for (unsigned ax = 0; ax < dstRank; ++ax) - broadcastAxis.push_back(ax); - newSrcType = originSrcType; - isSrcScalar = true; - } - - // create new src - if (!isSrcScalar) // we don't need to convert layout for scalar values - src = builder.create(src.getLoc(), - newSrcType, src); - - // create new broadcast - // compute new type (encoding) - auto originDstEnc = originDstTensorType.getEncoding() - .dyn_cast(); - auto newEnc = TritonGPUBlockedEncodingAttr::get( - originDstEnc.getContext(), originDstEnc.getThreadTileSize(), - originDstEnc.getWarpTileSize(), originDstEnc.getBlockTileSize(), - originDstEnc.getOrder(), broadcastAxis); - auto newType = - RankedTensorType::get(originDstTensorType.getShape(), - originDstTensorType.getElementType(), newEnc); - Value newBroadcast = - builder.create(broadcast.getLoc(), newType, src); - // we don't want to change the encoding of the result - Value newDst = builder.create( - broadcast.getLoc(), originDstType, newBroadcast); - - broadcast.replaceAllUsesWith(newDst); - mapping.map(broadcast, newDst); - } - - return success(); -} +} \ No newline at end of file diff --git a/python/examples/copy_strided.py b/python/examples/copy_strided.py index 48e9dbb4e..7f95e8f24 100644 --- a/python/examples/copy_strided.py +++ b/python/examples/copy_strided.py @@ -10,10 +10,10 @@ def kernel(X, stride_xm, stride_xn, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): off_m = tl.arange(0, BLOCK_M) off_n = tl.arange(0, BLOCK_N) - Xs = X + off_m[:, None] * stride_xm + off_n[None, :] * stride_xn - Zs = Z + off_m[:, None] * stride_zm + off_n[None, :] * stride_zn + Xs = X + off_m[:, None] * stride_xm + off_n[None, :] * 1 + Zs = Z + off_m[:, None] * 1 + off_n[None, :] * stride_zn tl.store(Zs, tl.load(Xs)) -ret = triton.compile(kernel, "*fp32,i32,i32,*fp32,i32,i32", constants={"BLOCK_M": 128, "BLOCK_N": 128}, output="ttgir") +ret = triton.compile(kernel, "*fp32,i32,i32,*fp32,i32,i32", constants={"BLOCK_M": 64, "BLOCK_N": 64}, output="ttgir") print(ret) diff --git a/python/src/triton.cc b/python/src/triton.cc index b2ecbdd6b..9a7dbbd29 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -1471,14 +1471,14 @@ void init_triton_ir(py::module &&m) { self.create(loc, ptrs, val, mask); }) // Block instruction - .def("create_reshape", + .def("create_view", [](mlir::OpBuilder &self, mlir::Value &arg, std::vector &shape) -> mlir::Value { auto loc = self.getUnknownLoc(); auto argType = arg.getType() .dyn_cast() .getElementType(); - return self.create( + return self.create( loc, mlir::RankedTensorType::get(shape, argType), arg); }) .def("create_cat", diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 282f06db4..477633d2c 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -565,7 +565,7 @@ class tensor: elif sl == slice(None, None, None): dst_shape.append(src_shape[curr].value) curr += 1 - ret = semantic.reshape(self, dst_shape, _builder) + ret = semantic.view(self, dst_shape, _builder) return ret @builtin diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index 0f026b717..a3d0d3385 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -451,16 +451,16 @@ def zeros(shape: List[int], dtype: tl.dtype, builder: ir.builder) -> tl.tensor: # ===----------------------------------------------------------------------===// -def reshape(input: tl.tensor, - dst_shape: List[int], - builder: ir.builder) -> tl.tensor: +def view(input: tl.tensor, + dst_shape: List[int], + builder: ir.builder) -> tl.tensor: numel = 1 for s in dst_shape: numel *= s if input.type.numel != numel: raise ValueError("cannot reshape block of different shape") ret_ty = tl.block_type(input.type.scalar, dst_shape) - return tl.tensor(builder.create_reshape(input.handle, dst_shape), ret_ty) + return tl.tensor(builder.create_view(input.handle, dst_shape), ret_ty) def cat(lhs: tl.tensor, rhs: tl.tensor, builder: ir.builder) -> tl.tensor: diff --git a/test/Analysis/test-alignment.mlir b/test/Analysis/test-alignment.mlir index 3d2b31d58..041b88db2 100644 --- a/test/Analysis/test-alignment.mlir +++ b/test/Analysis/test-alignment.mlir @@ -10,7 +10,7 @@ func @permute_2d(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: i32 {t // CHECK-NEXT: Contiguity: [128] ; Divisibility: [65536] ; Constancy: [1] %1 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> // CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [65536, 1] ; Constancy: [1, 1] - %2 = tt.reshape %0 : (tensor<128xi32>) -> tensor<128x1xi32> + %2 = tt.view %0 : (tensor<128xi32>) -> tensor<128x1xi32> // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [128, 1] %3 = tt.splat %arg1 : (i32) -> tensor<128x1xi32> // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [1048576, 16] ; Constancy: [1, 1] @@ -20,7 +20,7 @@ func @permute_2d(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: i32 {t // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [1, 1] %6 = tt.getelementptr %5, %4 : tensor<128x1x!tt.ptr> // CHECK-NEXT: Contiguity: [1, 128] ; Divisibility: [1, 65536] ; Constancy: [1, 1] - %7 = tt.reshape %1 : (tensor<128xi32>) -> tensor<1x128xi32> + %7 = tt.view %1 : (tensor<128xi32>) -> tensor<1x128xi32> // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [1, 128] %8 = tt.broadcast %6 : (tensor<128x1x!tt.ptr>) -> tensor<128x128x!tt.ptr> // CHECK-NEXT: Contiguity: [1, 128] ; Divisibility: [1, 65536] ; Constancy: [128, 1] @@ -28,13 +28,13 @@ func @permute_2d(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: i32 {t // CHECK-NEXT: Contiguity: [1, 128] ; Divisibility: [1, 16] ; Constancy: [1, 1] %10 = tt.getelementptr %8, %9 : tensor<128x128x!tt.ptr> // CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [65536, 1] ; Constancy: [1, 1] - %11 = tt.reshape %0 : (tensor<128xi32>) -> tensor<128x1xi32> + %11 = tt.view %0 : (tensor<128xi32>) -> tensor<128x1xi32> // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [128, 1] %12 = tt.splat %arg2 : (!tt.ptr) -> tensor<128x1x!tt.ptr> // CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [16, 1] ; Constancy: [1, 1] %13 = tt.getelementptr %12, %11 : tensor<128x1x!tt.ptr> // CHECK-NEXT: Contiguity: [1, 128] ; Divisibility: [1, 65536] ; Constancy: [1, 1] - %14 = tt.reshape %1 : (tensor<128xi32>) -> tensor<1x128xi32> + %14 = tt.view %1 : (tensor<128xi32>) -> tensor<1x128xi32> // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [1, 128] %15 = tt.splat %arg3 : (i32) -> tensor<1x128xi32> // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 1048576] ; Constancy: [1, 1] diff --git a/test/TritonGPU/layout.mlir b/test/TritonGPU/layout.mlir deleted file mode 100644 index e03c018bf..000000000 --- a/test/TritonGPU/layout.mlir +++ /dev/null @@ -1,26 +0,0 @@ -// RUN: triton-opt %s -split-input-file -verify-diagnostics - -#reg = #triton_gpu.blocked_layout<{ - threadTileSize = [1, 1], - warpTileSize = [32, 1], - blockTileSize = [64, 1], - order = [0] -}> - -#reg2 = #triton_gpu.blocked_layout<{ - threadTileSize = [2, 1], - warpTileSize = [64, 1], - blockTileSize = [128, 1], - order = [0] -}> - -func @add(%arg0: tensor<256xi32, #reg>, %arg1: tensor<256xi32, #reg>) { - %0 = arith.addi %arg0, %arg1 : tensor<256xi32, #reg> - return -} - -func @add(%arg0: tensor<256xi32, #reg>, %arg1: tensor<256xi32, #reg>) { // expected-note {{prior use here}} - // expected-error @+1 {{use of value '%arg0' expects different type than prior uses}} - %0 = arith.addi %arg0, %arg1 : tensor<256xi32, #reg2> - return -} \ No newline at end of file diff --git a/test/TritonGPU/loop-pipeline.mlir b/test/TritonGPU/loop-pipeline.mlir index f1697fac6..345e5d971 100644 --- a/test/TritonGPU/loop-pipeline.mlir +++ b/test/TritonGPU/loop-pipeline.mlir @@ -1,45 +1,12 @@ -// RUN: triton-opt %s -split-input-file -tritongpu-pipeline=num-stages=3 -canonicalize -tritongpu-verifier // RUN: triton-opt %s -split-input-file -tritongpu-pipeline=num-stages=3 -canonicalize -tritongpu-verifier | FileCheck %s // 4 warps // matmul: 128x32 @ 32x128 -> 128x128 -#AL = #triton_gpu.blocked_layout<{ - threadTileSize = [1, 4], - warpTileSize = [4, 32], - blockTileSize = [16, 32], - order = [1, 0] -}> - -#BL = #triton_gpu.blocked_layout<{ - threadTileSize = [1, 4], - warpTileSize = [1, 128], - blockTileSize = [4, 128], - order = [1, 0] -}> - -#A = #triton_gpu.shared_layout<{ - vec = 2, - perPhase = 2, - maxPhase = 4, - order = [1, 0] -}> - -#B = #triton_gpu.shared_layout<{ - vec = 2, - perPhase = 2, - maxPhase = 4, - order = [1, 0] -}> - -// TODO: check this -#C = #triton_gpu.mma_layout<{ - fragmentPerWarp = [1, 1], - shapePerWarp = [16, 8], - warpPerTile = [2, 2], - shapePerTile = [32, 16], - repetitions = [4, 4], - contigPerThread = [1, 8] -}> +#AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#BL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#A = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> +#B = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> +#C = #triton_gpu.mma<{version = 2, warpsPerCTA = [4, 1]}> // CHECK: func @matmul_loop // CHECK: %[[A0:.*]] = triton_gpu.copy_async