[TritonGPU] Improved documentation and semantics of layout encodings (#30)

This commit is contained in:
Philippe Tillet
2022-07-31 13:59:44 -07:00
committed by GitHub
parent e02c82c765
commit d1593e6ca8
17 changed files with 399 additions and 566 deletions

View File

@@ -10,7 +10,7 @@ jobs:
Integration-Tests: Integration-Tests:
runs-on: ubuntu-20.04 runs-on: ubuntu-latest
steps: steps:

View File

@@ -18,16 +18,14 @@ namespace mlir {
/// Axis information is represented by a std::map<int, int> /// Axis information is represented by a std::map<int, int>
class AxisInfo { class AxisInfo {
public: public:
typedef std::vector<int> ContiguityT; typedef SmallVector<int, 4> DimVectorT;
typedef std::vector<int> DivisibilityT;
typedef std::vector<int> ConstancyT;
public: public:
// Default constructor // Default constructor
AxisInfo() : AxisInfo({}, {}, {}) {} AxisInfo() : AxisInfo({}, {}, {}) {}
// Construct contiguity info with known contiguity // Construct contiguity info with known contiguity
AxisInfo(ContiguityT knownContiguity, DivisibilityT knownDivisibility, AxisInfo(DimVectorT knownContiguity, DimVectorT knownDivisibility,
ConstancyT knownConstancy) DimVectorT knownConstancy)
: contiguity(knownContiguity), divisibility(knownDivisibility), : contiguity(knownContiguity), divisibility(knownDivisibility),
constancy(knownConstancy), rank(contiguity.size()) { constancy(knownConstancy), rank(contiguity.size()) {
assert(knownDivisibility.size() == rank); assert(knownDivisibility.size() == rank);
@@ -36,13 +34,13 @@ public:
// Accessors // Accessors
int getContiguity(size_t d) const { return contiguity[d]; } int getContiguity(size_t d) const { return contiguity[d]; }
const ContiguityT &getContiguity() const { return contiguity; } const DimVectorT &getContiguity() const { return contiguity; }
int getDivisibility(size_t d) const { return divisibility[d]; } int getDivisibility(size_t d) const { return divisibility[d]; }
const DivisibilityT &getDivisibility() const { return divisibility; } const DimVectorT &getDivisibility() const { return divisibility; }
int getConstancy(size_t d) const { return constancy[d]; } int getConstancy(size_t d) const { return constancy[d]; }
const ConstancyT &getConstancy() const { return constancy; } const DimVectorT &getConstancy() const { return constancy; }
int getRank() const { return rank; } int getRank() const { return rank; }
@@ -78,7 +76,7 @@ private:
/// [18, 22, 26, 30] /// [18, 22, 26, 30]
/// [19, 23, 27, 31] /// [19, 23, 27, 31]
/// Would have contiguity [2, 1]. /// Would have contiguity [2, 1].
ContiguityT contiguity; DimVectorT contiguity;
/// The _divisibility_ information maps the `d`-th /// The _divisibility_ information maps the `d`-th
/// dimension to the largest power-of-two that /// dimension to the largest power-of-two that
@@ -93,7 +91,7 @@ private:
/// [14, 18, 22, 26] /// [14, 18, 22, 26]
/// [15, 19, 23, 27] /// [15, 19, 23, 27]
// would have divisibility [4, 1] // would have divisibility [4, 1]
DivisibilityT divisibility; DimVectorT divisibility;
/// The _constancy_ information maps the `d`-th /// The _constancy_ information maps the `d`-th
/// dimension to the length of the shortest /// dimension to the length of the shortest
@@ -104,7 +102,7 @@ private:
/// [8, 8, 8, 8, 12, 12, 12, 12] /// [8, 8, 8, 8, 12, 12, 12, 12]
/// [16, 16, 16, 16, 20, 20, 20, 20] /// [16, 16, 16, 16, 20, 20, 20, 20]
/// would have constancy [1, 4] /// would have constancy [1, 4]
ConstancyT constancy; DimVectorT constancy;
// number of dimensions of the lattice // number of dimensions of the lattice
int rank; int rank;

View File

@@ -133,8 +133,8 @@ def TT_GEPOp : TT_Op<"getelementptr",
// //
// Shape Manipulation Ops // Shape Manipulation Ops
// //
def TT_ReshapeOp : TT_Op<"reshape", [NoSideEffect, SameOperandsAndResultElementType]> { def TT_ViewOp : TT_Op<"view", [NoSideEffect, SameOperandsAndResultElementType]> {
let summary = "reshape"; let summary = "view";
let arguments = (ins TT_Tensor:$src); let arguments = (ins TT_Tensor:$src);

View File

@@ -2,43 +2,60 @@
#define TRITONGPU_ATTRDEFS #define TRITONGPU_ATTRDEFS
include "triton/Dialect/TritonGPU/IR/TritonGPUDialect.td" include "triton/Dialect/TritonGPU/IR/TritonGPUDialect.td"
// include "mlir/IR/TensorEncoding.td"
//===----------------------------------------------------------------------===//
// TritonGPU Attribute Definitions
//===----------------------------------------------------------------------===//
class TritonGPU_Attr<string name, list<Trait> traits = [], class TritonGPU_Attr<string name, list<Trait> traits = [],
string baseCppClass = "::mlir::Attribute"> string baseCppClass = "::mlir::Attribute">
: AttrDef<TritonGPU_Dialect, name, traits, baseCppClass>; : AttrDef<TritonGPU_Dialect, name, traits, baseCppClass> {
let description = [{
TritonGPU Tensors differ from usual tensors in that they contain a _layout_ attribute which determines
how the data should be partitioned across CUDA threads. Formally speaking, we define a layout as a function
\mathcal{L} that maps a multi-dimensional tensor index $i \in \mathbb{Z}^d$ to a set of integers T corresponding
to the indices of the CUDA threads allowed to access some data at index $i$.
For example, let us consider the layout function:
\mathcal{L}(0, 0) = {0, 4}
\mathcal{L}(0, 1) = {1, 5}
\mathcal{L}(1, 0) = {2, 6}
\mathcal{L}(1, 1) = {3, 7}
Then, attaching $\mathcal{L} to a tensor $T$ would mean that:
- T[0,0] is owned by both cuda thread 0 and 4
- T[0,1] is owned by both cuda thread 1 and 5
- T[1,0] is owned by both cuda thread 2 and 6
- T[1,1] is owned by both cuda thread 3 and 7
Right now, Triton implements two classes of layouts: shared, and distributed.
}];
}
//===----------------------------------------------------------------------===//
// Shared Layout Encoding
//===----------------------------------------------------------------------===//
def TritonGPUSharedEncodingAttr : TritonGPU_Attr<"TritonGPUSharedEncoding"> { def TritonGPUSharedEncodingAttr : TritonGPU_Attr<"TritonGPUSharedEncoding"> {
let mnemonic = "shared_layout"; let mnemonic = "shared";
let description = [{ let description = [{
An encoding for tensors whose elements may be simultaneously accessed by An encoding for tensors whose elements may be simultaneously accessed by
different warps in the programs, via shared memory. different cuda threads in the programs, via shared memory. In other words,
for all indices i \in R^d, \mathcal{L}(i) = {0, 1, ..., 32*num_warps - 1}.
In order to avoid shared memory bank conflicts, elements may be stored in a In order to avoid shared memory bank conflicts, elements may be swizzled
swizzled layout. in memory. For example, a swizzled row-major layout could store its data
For example, a swizzled row-major layout stores would store data as follows: as follows:
A_{0, 0} A_{0, 1} A_{0, 2} A_{0, 3} ... [phase 0] \ per_phase = 2 A_{0, 0} A_{0, 1} A_{0, 2} A_{0, 3} ... [phase 0] \ per_phase = 2
A_{1, 0} A_{0, 1} A_{1, 2} A_{1, 3} ... [phase 0] / A_{1, 0} A_{0, 1} A_{1, 2} A_{1, 3} ... [phase 0] /
groups of vec=2 elements groups of vec=2 elements
are stored contiguously are stored contiguously
_ _ _ _ /\_ _ _ _ _ _ _ _ /\_ _ _ _
A_{2, 2} A_{2, 3} A_{2, 0} A_{2, 1} ... [phase 1] \ per phase = 2 A_{2, 2} A_{2, 3} A_{2, 0} A_{2, 1} ... [phase 1] \ per phase = 2
A_{3, 2} A_{3, 3} A_{3, 0} A_{3, 1} ... [phase 1] / A_{3, 2} A_{3, 3} A_{3, 0} A_{3, 1} ... [phase 1] /
And the associated TritonGPU MLIR
```mlir
#SMEM = #triton_gpu.shared_layout<{
vec = 2,
perPhase = 2,
maxPhase = 4,
order = [1, 0]
}>
```
}]; }];
let parameters = ( let parameters = (
@@ -49,143 +66,222 @@ And the associated TritonGPU MLIR
); );
} }
//===----------------------------------------------------------------------===//
// Distributed Layout Encoding
//===----------------------------------------------------------------------===//
class TritonGPUDistributedEncodingAttr : TritonGPU_Attr<"TritonGPUDistributedEncoding"> {
let mnemonic = "distributed";
let description = [{
Distributed encodings have a layout function that is entirely characterized
by a d-dimensional tensor L. Note that L doesn't need to have the same shape
(or even the same rank) as the tensor it is encoding.
The layout function \mathcal{L} of this layout is then defined, for an
index `i` \in R^D, as follows:
\mathcal{L}(A)[i_d] = L[(i_d + k_d*A.shape[d]) % L.shape[d]] \forall k_d such as i_d + k_d*A.shape[d] < L.shape[d]
For example, for a tensor/layout pair
A = [x x x x x x x x]
[x x x x x x x x]
L = [0 1 2 3 ]
[4 5 6 7 ]
[8 9 10 11]
[12 13 14 15]
Then the data of A would be distributed as follow between the 16 CUDA threads:
L(A) = [ {0,8} , {1,9} , {2,10}, {3,11}, {0,8} , {1, 9} , {2, 10}, {3, 11},
{4,12}, {5,13}, {6,14}, {7,15}, {4,12}, {5, 13}, {6, 14}, {7, 15} ]
}];
}
//===----------------------------------------------------------------------===//
// Blocked Layout Encoding
//===----------------------------------------------------------------------===//
def TritonGPUBlockedEncodingAttr : TritonGPU_Attr<"TritonGPUBlockedEncoding"> { def TritonGPUBlockedEncodingAttr : TritonGPU_Attr<"TritonGPUBlockedEncoding"> {
let mnemonic = "blocked_layout"; let mnemonic = "blocked";
let description = [{ let description = [{
An encoding where each warp owns a contiguous portion of the target tensor. This is typically the kind of data layout An encoding where each warp owns a contiguous portion of the target tensor. This is typically the kind of data layout
consumed (and returned) by LoadInst. used to promote memory coalescing in LoadInst and StoreInst.
For example, a row-major coalesced layout may distribute a 64x16 tensor over 2 warps (i.e. 64 threads) as follows: It is characterized by three tuples -- thread tile size, warp tile size, and block tile size -- which
specify the amount of elements owned by each CUDA thread, warp and CTA respectively.
thread tile size 2 For example, a row-major coalesced layout may partition a 16x16 tensor over 2 warps (i.e. 64 threads) as follows.
- - - - - - /\ - - - - - -
warp | thread || A_{0, 0}[T0] A_{0, 1}[T0] ... A_{0, 6}[T3] A_{0, 7}[T3] A_{0, 8}[T0] A_{0, 9}[T0] ... A_{0, 14}[T3] A_{0, 15}[T3]
tile | tile size 2 || A_{1, 0}[T0] A_{1, 1}[T0] ... A_{1, 6}[T3] A_{1, 7}[T3] A_{1, 8}[T0] A_{1, 9}[T0] ... A_{1, 14}[T3] A_{1, 15}[T3]
size } ....
32 | A_{30, 0}[T60] A_{14, 1}[T60] ... A_{14, 6}[T63] A_{14, 7}[T63] A_{14, 8}[T60] A_{14, 9}[T60] ... A_{14, 14}[T63] A_{14, 15}[T63]
| A_{31, 0}[T60] A_{15, 1}[T60] ... A_{15, 6}[T63] A_{15, 7}[T63] A_{15, 8}[T60] A_{15, 9}[T60] ... A_{15, 14}[T63] A_{15, 15}[T63]
-----------------------------/\-----------------------------------
warp tile size 8
[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ]
[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ]
[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ]
[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ]
...
[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ]
[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ]
A_{32, 0}[T0] A_{32, 1}[T0] ... A_{32, 6}[T3] A_{32, 7}[T3] A_{32, 8}[T0] A_{32, 9}[T0] ... A_{32, 14}[T3] A_{32, 15}[T3] for
A_{33, 0}[T0] A_{33, 1}[T0] ... A_{33, 6}[T3] A_{33, 7}[T3] A_{33, 8}[T0] A_{33, 9}[T0] ... A_{33, 14}[T3] A_{33, 15}[T3]
A_{62, 0}[T60] A_{62, 1}[T60] ... A_{62, 6}[T63] A_{62, 7}[T63] A_{62, 8}[T60] A_{62, 9}[T60] ... A_{62, 14}[T63] A_{62, 15}[T63] #triton_gpu.blocked_layout<{
A_{63, 0}[T60] A_{63, 1}[T60] ... A_{63, 6}[T63] A_{63, 7}[T63] A_{63, 8}[T60] A_{63, 9}[T60] ... A_{63, 14}[T63] A_{63, 15}[T63] sizePerThread = {2, 2}
threadsPerWarp = {8, 4}
And the associated TritonGPU MLIR warpsPerCTA = {1, 2}
#LAYOUT = #triton_gpu.blocked_layout<{
threadTileSize = {2, 2}
blockTileSize = {32, 8}
}> }>
// note to Da: In current Triton codebase, `nanoTileSize = threadTileSize`, and `macro-tile size = blockTileSize / threadTileSize`
probably clearer to have easier semantics (i.e., size of each tile owned by a thread or a block)
}]; }];
let builders = [
// Custom builder initializes sizePerWarp and sizePerCTA automatically
// TODO: compiles on MacOS but not linux?
// AttrBuilder<(ins "ArrayRef<unsigned>":$sizePerThread,
// "ArrayRef<unsigned>":$threadsPerWarp,
// "ArrayRef<unsigned>":$warpsPerCTA,
// "ArrayRef<unsigned>":$order), [{
// int rank = threadsPerWarp.size();
// SmallVector<unsigned, 4> sizePerWarp(rank);
// SmallVector<unsigned, 4> sizePerCTA(rank);
// for (unsigned i = 0; i < rank; i++) {
// sizePerWarp.push_back(sizePerThread[i] * threadsPerWarp[i]);
// sizePerCTA.push_back(sizePerWarp[i] * warpsPerCTA[i]);
// }
// return $_get(context, sizePerThread, threadsPerWarp, warpsPerCTA, order, sizePerWarp, sizePerCTA);
// }]>,
// Custom builder initializes sizePerWarp and sizePerCTA automatically
// Default builder takes sizePerThread, order and numWarps, and tries to
// pack numWarps*32 threads in the provided order for use in a type
// of the given shape.
AttrBuilder<(ins "ArrayRef<int64_t>":$shape,
"ArrayRef<unsigned>":$sizePerThread,
"ArrayRef<unsigned>":$order,
"unsigned":$numWarps), [{
int rank = sizePerThread.size();
int remainingWarps = numWarps;
int remainingLanes = 32;
SmallVector<unsigned, 4> threadsPerWarp(rank);
SmallVector<unsigned, 4> warpsPerCTA(rank);
for (int _dim = 0; _dim < rank; ++_dim) {
int dim = order[_dim];
int maxNumThreads = int(shape[dim]) / sizePerThread[dim];
warpsPerCTA[dim] = std::clamp(remainingWarps, 1, maxNumThreads);
threadsPerWarp[dim] = std::clamp(remainingLanes, 1, maxNumThreads);
remainingWarps /= warpsPerCTA[dim];
remainingLanes /= threadsPerWarp[dim];
}
return $_get(context, sizePerThread, threadsPerWarp, warpsPerCTA, order);
}]>
];
let parameters = ( let parameters = (
ins ins
// TODO: should we rename this as laneTileSize? ArrayRefParameter<"unsigned">:$sizePerThread,
ArrayRefParameter< ArrayRefParameter<"unsigned">:$threadsPerWarp,
"unsigned", ArrayRefParameter<"unsigned">:$warpsPerCTA,
/*desc*/"size of a tile that is holded by a thread"
>:$threadTileSize,
ArrayRefParameter<
"unsigned",
"size of the a tile that is holded by a warp"
>:$warpTileSize,
ArrayRefParameter<
"unsigned",
"size of a tile that is holded by a thread block"
>:$blockTileSize,
// // TODO: It seems that we don't need this (because we can re-compute this)
// ArrayRefParameter<"unsigned">:$reptitions,
// fastest-changing axis first // fastest-changing axis first
ArrayRefParameter< ArrayRefParameter<
"unsigned", "unsigned",
"order of axes by the rate of changing" "order of axes by the rate of changing"
>:$order, >:$order
ArrayRefParameter<"unsigned">:$broadcastAxis // These attributes can be inferred from the rest
// "AffineMap":$threadOrdering, // ArrayRefParameter<"unsigned">:$sizePerWarp,
// "AffineMap":warpOrdering, // ArrayRefParameter<"unsigned">:$sizePerCTA
// "AffineMap":$blockOrdering,
); );
// let genVerifyDecl = 1;
} }
def TritonGPUBlockedMulticastEncodingAttr //===----------------------------------------------------------------------===//
: TritonGPU_Attr<"TritonGPUBlockedMulticastEncoding"> { // MMA Layout Encoding
let mnemonic = "blocked_multicast_layout"; //===----------------------------------------------------------------------===//
// TODO: MMAv1 and MMAv2 should be two instances of the same class
let description = [{
to be broadcasted to blocked_layout
}];
// This needs to be synced with BlockedEncoding
let parameters = (
ins
ArrayRefParameter<"unsigned">:$threadTileSize,
ArrayRefParameter<"unsigned">:$warpTileSize,
ArrayRefParameter<"unsigned">:$blockTileSize,
ArrayRefParameter<"unsigned">:$order,
// unique to broadcasted layout
ArrayRefParameter<"unsigned">:$broadcastAxis
);
// let genVerifyDecl = 1;
}
def TritonGPUMmaEncodingAttr : TritonGPU_Attr<"TritonGPUMmaEncoding"> { def TritonGPUMmaEncodingAttr : TritonGPU_Attr<"TritonGPUMmaEncoding"> {
let mnemonic = "mma_layout"; let mnemonic = "mma";
let description = [{TODO: I think we may be able to implement it as a special-case of Distributed encoding with maybe one more warpTileSize attribute!}];
let parameters = (
ins
// only used by Volta mma.884
ArrayRefParameter<"unsigned">:$fragmentPerWarp,
// aka shapeOfInstr (e.g., {16,8,16})
ArrayRefParameter<"unsigned">:$shapePerWarp,
// TODO: should we rename this as warpTileSize? (consistent naming with Distributed layout)
ArrayRefParameter<"unsigned">:$warpPerTile,
// TODO: should we rename this as blockTileSize? (consistent naming with Distributed layout)
ArrayRefParameter<"unsigned">:$shapePerTile,
// TODO: should Distributed layout also
ArrayRefParameter<"unsigned">:$repetitions,
ArrayRefParameter<"unsigned">:$contigPerThread,
ArrayRefParameter<"unsigned">:$broadcastAxis
// "AffineMap":$warpOrdering,
// "AffineMap":$blockOrdering
);
// let genVerifyDecl = 1;
}
def TritonGPUMmaMulticastEncodingAttr
: TritonGPU_Attr<"TritonGPUMmaMulticastEncoding"> {
let mnemonic = "mma_multicast_layout";
let description = [{ let description = [{
To be broadcasted to mma. An encoding for tensors that have been produced by tensor cores.
}]; It is characterized by two parameters:
- A 'version' which specifies the generation the tensor cores
whose output is being partitioned: 1 for first-gen tensor cores (Volta),
and 2 for second-gen tensor cores (Turing/Ampere).
- A `blockTileSize` to indicate how data should be
partitioned between warps.
// -------------------------------- version = 1 --------------------------- //
For first-gen tensor cores, the implicit warpTileSize is [16, 16].
Information about this layout can be found in the official PTX documentation
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html
(mma.884 section, FP32 accumulator).
For example, the matrix L corresponding to blockTileSize=[32,16] is:
warp 0
--------------------------------/\-------------------------------
[ 0 0 2 2 0 0 2 2 4 4 6 6 4 4 6 6 ]
[ 1 1 3 3 1 1 3 3 5 5 7 7 5 5 7 7 ]
[ 0 0 2 2 0 0 2 2 4 4 6 6 4 4 6 6 ]
[ 1 1 3 3 1 1 3 3 5 5 7 7 5 5 7 7 ]
[ 16 16 18 18 16 16 18 18 20 20 22 22 20 20 22 22]
[ 17 17 19 19 17 17 19 19 21 21 23 23 21 21 23 23]
[ 16 16 18 18 16 16 18 18 20 20 22 22 20 20 22 22]
[ 17 17 19 19 17 17 19 19 21 21 23 23 21 21 23 23]
[ 8 8 10 10 8 8 10 10 12 12 14 14 12 12 14 14]
[ 9 9 11 11 9 9 11 11 13 13 15 15 13 13 15 15]
[ ..............................................................
[ ..............................................................
[ 24 24 26 26 24 24 26 26 28 28 30 30 28 28 30 30]
[ 25 25 27 27 25 25 27 27 29 29 31 31 29 29 31 31]
warp 1 = warp0 + 32
--------------------------------/\-------------------------------
[ 32 32 34 34 32 32 34 34 36 36 38 38 36 36 38 38]
[ 33 33 35 35 33 33 35 35 37 37 39 39 37 37 39 39]
[ ..............................................................
[ ..............................................................
[ 56 56 58 58 56 56 58 58 60 60 62 62 60 60 62 62]
[ 57 57 59 59 57 57 59 59 61 61 63 63 61 61 63 63]
// -------------------------------- version = 2 --------------------------- //
For second-gen tensor cores, the implicit warpTileSize is [16, 8].
Information about this layout can be found in the official PTX documentation
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html
(mma.16816 section, FP32 accumulator).
For example, the matrix L corresponding to blockTileSize=[32,16] is:
warp 0 warp 1
-----------------/\------------- ----------------/\-------------
[ 0 0 1 1 2 2 3 3 32 32 33 33 34 34 35 35
[ 4 4 5 5 6 6 7 7 36 36 37 37 38 38 39 39
[ .............................. ..............................
[ 28 28 29 29 30 30 31 31 60 60 61 61 62 62 63 63
[ 0 0 1 1 2 2 3 3 32 32 33 33 34 34 35 35
[ 4 4 5 5 6 6 7 7 36 36 37 37 38 38 39 39
[ .............................. ..............................
[ 28 28 29 29 30 30 31 31 60 60 61 61 62 62 63 63
warp 3 warp 4
----------------/\------------- ----------------/\-------------
[ 64 64 65 65 66 66 67 67 96 96 97 97 98 98 99 99
[ 68 68 69 69 70 70 71 71 100 100 101 101 102 102 103 103
[ .............................. ...............................
[ 92 92 93 93 94 94 95 95 124 124 125 125 126 126 127 127
[ 64 64 65 65 66 66 67 67 96 96 97 97 98 98 99 99
[ 68 68 69 69 70 70 71 71 100 100 101 101 102 102 103 103
[ .............................. ...............................
[ 92 92 93 93 94 94 95 95 124 124 125 125 126 126 127 127
}];
// This needs to be synced with MmaEncoding
let parameters = ( let parameters = (
ins ins
ArrayRefParameter<"unsigned">:$fragmentPerWarp, "unsigned":$version,
ArrayRefParameter<"unsigned">:$shapePerWarp, ArrayRefParameter<"unsigned">:$warpsPerCTA
ArrayRefParameter<"unsigned">:$warpPerTile,
ArrayRefParameter<"unsigned">:$shapePerTile,
ArrayRefParameter<"unsigned">:$repetitions,
ArrayRefParameter<"unsigned">:$contigPerThread,
// unique to broadcasted layout
ArrayRefParameter<"unsigned">:$broadcastAxis
); );
// let genVerifyDecl = 1;
} }
#endif #endif

View File

@@ -7,16 +7,7 @@ def TritonGPUPipeline : Pass<"tritongpu-pipeline", "mlir::ModuleOp"> {
let summary = "pipeline"; let summary = "pipeline";
let description = [{ let description = [{
scf.for() { TODO
%a = load %a_ptr;
%b = load %b_ptr;
%d = dot %a, %b, %c;
}
=>
...
}]; }];
let constructor = "mlir::createTritonGPUPipelinePass()"; let constructor = "mlir::createTritonGPUPipelinePass()";

View File

@@ -13,11 +13,11 @@ namespace mlir {
class TritonGPUTypeConverter : public TypeConverter { class TritonGPUTypeConverter : public TypeConverter {
public: public:
TritonGPUTypeConverter(MLIRContext *context, int numThreads); TritonGPUTypeConverter(MLIRContext *context, int numWarps);
private: private:
MLIRContext *context; MLIRContext *context;
int numThreads; int numWarps;
}; };
class TritonGPUConversionTarget : public ConversionTarget { class TritonGPUConversionTarget : public ConversionTarget {
@@ -26,9 +26,6 @@ class TritonGPUConversionTarget : public ConversionTarget {
public: public:
explicit TritonGPUConversionTarget(MLIRContext &ctx, explicit TritonGPUConversionTarget(MLIRContext &ctx,
TritonGPUTypeConverter &typeConverter); TritonGPUTypeConverter &typeConverter);
/// update layouts & insert ConvertLayoutOp if necessary
LogicalResult refineLayouts(ModuleOp mod, int numThreads);
}; };
} // namespace mlir } // namespace mlir

View File

@@ -48,17 +48,17 @@ AxisInfo AxisInfo::getPessimisticValueState(Value value) {
divHint = attr.cast<IntegerAttr>().getValue().getZExtValue(); divHint = attr.cast<IntegerAttr>().getValue().getZExtValue();
} }
} }
ContiguityT contiguity(rank, 1); DimVectorT contiguity(rank, 1);
DivisibilityT divisibility(rank, divHint); DimVectorT divisibility(rank, divHint);
ConstancyT constancy(rank, 1); DimVectorT constancy(rank, 1);
return AxisInfo(contiguity, divisibility, constancy); return AxisInfo(contiguity, divisibility, constancy);
} }
// The gcd of both arguments for each dimension // The gcd of both arguments for each dimension
AxisInfo AxisInfo::join(const AxisInfo &lhs, const AxisInfo &rhs) { AxisInfo AxisInfo::join(const AxisInfo &lhs, const AxisInfo &rhs) {
ContiguityT retContiguity; DimVectorT retContiguity;
DivisibilityT retDivisibility; DimVectorT retDivisibility;
ConstancyT retConstancy; DimVectorT retConstancy;
for (size_t d = 0; d < lhs.getRank(); d++) { for (size_t d = 0; d < lhs.getRank(); d++) {
retContiguity.push_back(gcd(lhs.getContiguity(d), rhs.getContiguity(d))); retContiguity.push_back(gcd(lhs.getContiguity(d), rhs.getContiguity(d)));
retDivisibility.push_back( retDivisibility.push_back(
@@ -78,9 +78,9 @@ AxisInfo AxisInfoAnalysis::visitBinaryOp(
const std::function<int(AxisInfo, AxisInfo, int)> &getDivisibility, const std::function<int(AxisInfo, AxisInfo, int)> &getDivisibility,
const std::function<int(AxisInfo, AxisInfo, int)> &getConstancy) { const std::function<int(AxisInfo, AxisInfo, int)> &getConstancy) {
int rank = lhsInfo.getRank(); int rank = lhsInfo.getRank();
AxisInfo::ContiguityT newContiguity; AxisInfo::DimVectorT newContiguity;
AxisInfo::DivisibilityT newDivisibility; AxisInfo::DimVectorT newDivisibility;
AxisInfo::ConstancyT newConstancy; AxisInfo::DimVectorT newConstancy;
for (size_t d = 0; d < rank; d++) { for (size_t d = 0; d < rank; d++) {
newContiguity.push_back(getContiguity(lhsInfo, rhsInfo, d)); newContiguity.push_back(getContiguity(lhsInfo, rhsInfo, d));
newDivisibility.push_back(getDivisibility(lhsInfo, rhsInfo, d)); newDivisibility.push_back(getDivisibility(lhsInfo, rhsInfo, d));
@@ -101,9 +101,9 @@ ChangeResult AxisInfoAnalysis::visitOperation(
llvm::dyn_cast<triton::MakeRangeOp>(op)) { llvm::dyn_cast<triton::MakeRangeOp>(op)) {
int start = make_range.start(); int start = make_range.start();
int end = make_range.end(); int end = make_range.end();
AxisInfo::ContiguityT contiguity = {end - start}; AxisInfo::DimVectorT contiguity = {end - start};
AxisInfo::DivisibilityT divisibility = {highestPowOf2Divisor(start)}; AxisInfo::DimVectorT divisibility = {highestPowOf2Divisor(start)};
AxisInfo::ConstancyT constancy = {1}; AxisInfo::DimVectorT constancy = {1};
curr = AxisInfo(contiguity, divisibility, constancy); curr = AxisInfo(contiguity, divisibility, constancy);
} }
// Constant // Constant
@@ -119,9 +119,9 @@ ChangeResult AxisInfoAnalysis::visitOperation(
auto value = splatAttr.getSplatValue<int>(); auto value = splatAttr.getSplatValue<int>();
TensorType ty = splatAttr.getType().cast<TensorType>(); TensorType ty = splatAttr.getType().cast<TensorType>();
curr = AxisInfo( curr = AxisInfo(
AxisInfo::ContiguityT(ty.getRank(), 1), AxisInfo::DimVectorT(ty.getRank(), 1),
AxisInfo::DivisibilityT(ty.getRank(), highestPowOf2Divisor(value)), AxisInfo::DimVectorT(ty.getRank(), highestPowOf2Divisor(value)),
AxisInfo::ConstancyT(ty.getShape().begin(), ty.getShape().end())); AxisInfo::DimVectorT(ty.getShape().begin(), ty.getShape().end()));
} }
} }
// Addition // Addition
@@ -156,9 +156,9 @@ ChangeResult AxisInfoAnalysis::visitOperation(
Type _retTy = *op->result_type_begin(); Type _retTy = *op->result_type_begin();
TensorType retTy = _retTy.cast<TensorType>(); TensorType retTy = _retTy.cast<TensorType>();
AxisInfo opInfo = operands[0]->getValue(); AxisInfo opInfo = operands[0]->getValue();
AxisInfo::ContiguityT contiguity; AxisInfo::DimVectorT contiguity;
AxisInfo::DivisibilityT divisibility; AxisInfo::DimVectorT divisibility;
AxisInfo::ConstancyT constancy; AxisInfo::DimVectorT constancy;
for (size_t d = 0; d < retTy.getRank(); d++) { for (size_t d = 0; d < retTy.getRank(); d++) {
contiguity.push_back(1); contiguity.push_back(1);
divisibility.push_back(opInfo.getDivisibility(0)); divisibility.push_back(opInfo.getDivisibility(0));
@@ -167,8 +167,7 @@ ChangeResult AxisInfoAnalysis::visitOperation(
curr = AxisInfo(contiguity, divisibility, constancy); curr = AxisInfo(contiguity, divisibility, constancy);
} }
// Reshape // Reshape
// TODO: Replace by `unsqueeze` if (llvm::isa<triton::ViewOp>(op)) {
if (llvm::isa<triton::ReshapeOp>(op)) {
Type _retTy = *op->result_type_begin(); Type _retTy = *op->result_type_begin();
Type _opTy = *op->operand_type_begin(); Type _opTy = *op->operand_type_begin();
TensorType retTy = _retTy.cast<TensorType>(); TensorType retTy = _retTy.cast<TensorType>();
@@ -176,9 +175,9 @@ ChangeResult AxisInfoAnalysis::visitOperation(
ArrayRef<int64_t> retShape = retTy.getShape(); ArrayRef<int64_t> retShape = retTy.getShape();
ArrayRef<int64_t> opShape = opTy.getShape(); ArrayRef<int64_t> opShape = opTy.getShape();
AxisInfo opInfo = operands[0]->getValue(); AxisInfo opInfo = operands[0]->getValue();
AxisInfo::ContiguityT contiguity; AxisInfo::DimVectorT contiguity;
AxisInfo::DivisibilityT divisibility; AxisInfo::DimVectorT divisibility;
AxisInfo::ConstancyT constancy; AxisInfo::DimVectorT constancy;
bool is_skewed = false; bool is_skewed = false;
size_t current = 0; size_t current = 0;
for (size_t d = 0; d < retTy.getRank(); d++) { for (size_t d = 0; d < retTy.getRank(); d++) {
@@ -209,9 +208,9 @@ ChangeResult AxisInfoAnalysis::visitOperation(
ArrayRef<int64_t> retShape = retTy.getShape(); ArrayRef<int64_t> retShape = retTy.getShape();
ArrayRef<int64_t> opShape = opTy.getShape(); ArrayRef<int64_t> opShape = opTy.getShape();
AxisInfo opInfo = operands[0]->getValue(); AxisInfo opInfo = operands[0]->getValue();
AxisInfo::ContiguityT contiguity; AxisInfo::DimVectorT contiguity;
AxisInfo::DivisibilityT divisibility; AxisInfo::DimVectorT divisibility;
AxisInfo::ConstancyT constancy; AxisInfo::DimVectorT constancy;
for (size_t d = 0; d < retTy.getRank(); d++) { for (size_t d = 0; d < retTy.getRank(); d++) {
contiguity.push_back(opShape[d] == 1 ? 1 : opInfo.getContiguity(d)); contiguity.push_back(opShape[d] == 1 ? 1 : opInfo.getContiguity(d));
divisibility.push_back(opInfo.getDivisibility(d)); divisibility.push_back(opInfo.getDivisibility(d));

View File

@@ -218,6 +218,28 @@ struct TritonGenericPattern : public OpConversionPattern<Op> {
} }
}; };
struct TritonBroadcastPattern
: public OpConversionPattern<triton::BroadcastOp> {
using OpConversionPattern<triton::BroadcastOp>::OpConversionPattern;
// This creates a tensor with the new shape but the argument's layout
LogicalResult
matchAndRewrite(BroadcastOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto srcType = adaptor.src().getType().cast<RankedTensorType>();
auto srcEncoding = srcType.getEncoding();
if (!srcEncoding)
return failure();
auto opType = op.getType().cast<RankedTensorType>();
Type retType = RankedTensorType::get(opType.getShape(),
opType.getElementType(), srcEncoding);
// Type retType = this->getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<triton::BroadcastOp>(op, retType,
adaptor.getOperands());
return success();
}
};
struct TritonReducePattern : public OpConversionPattern<triton::ReduceOp> { struct TritonReducePattern : public OpConversionPattern<triton::ReduceOp> {
using OpConversionPattern<triton::ReduceOp>::OpConversionPattern; using OpConversionPattern<triton::ReduceOp>::OpConversionPattern;
@@ -234,9 +256,9 @@ struct TritonReducePattern : public OpConversionPattern<triton::ReduceOp> {
void populateTritonPatterns(TritonGPUTypeConverter &typeConverter, void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
RewritePatternSet &patterns) { RewritePatternSet &patterns) {
MLIRContext *context = patterns.getContext(); MLIRContext *context = patterns.getContext();
patterns.add<TritonGenericPattern<triton::ReshapeOp>, patterns.add< // TODO: view should have custom pattern that views the layout
TritonGenericPattern<triton::SplatOp>, TritonGenericPattern<triton::ViewOp>,
TritonGenericPattern<triton::BroadcastOp>, TritonGenericPattern<triton::SplatOp>, TritonBroadcastPattern,
TritonGenericPattern<triton::GEPOp>, TritonReducePattern, TritonGenericPattern<triton::GEPOp>, TritonReducePattern,
TritonMakeRangePattern, TritonDotPattern, TritonLoadPattern, TritonMakeRangePattern, TritonDotPattern, TritonLoadPattern,
TritonStorePattern>(typeConverter, context); TritonStorePattern>(typeConverter, context);
@@ -317,9 +339,8 @@ public:
void runOnOperation() override { void runOnOperation() override {
MLIRContext *context = &getContext(); MLIRContext *context = &getContext();
ModuleOp mod = getOperation(); ModuleOp mod = getOperation();
int numThreads = numWarps * 32;
// type converter // type converter
TritonGPUTypeConverter typeConverter(context, numThreads); TritonGPUTypeConverter typeConverter(context, numWarps);
TritonGPUConversionTarget target(*context, typeConverter); TritonGPUConversionTarget target(*context, typeConverter);
// rewrite patterns // rewrite patterns
RewritePatternSet patterns(context); RewritePatternSet patterns(context);
@@ -335,8 +356,8 @@ public:
// update layouts // update layouts
// broadcast src => multicast, dst => broadcasted // broadcast src => multicast, dst => broadcasted
if (failed(target.refineLayouts(mod, numWarps))) // if (failed(target.refineLayouts(mod, numWarps)))
return signalPassFailure(); // return signalPassFailure();
} }
}; };

View File

@@ -30,13 +30,28 @@ static LogicalResult parseIntArrayAttr(AsmParser &parser,
return success(); return success();
}; };
static LogicalResult parseUInt(AsmParser &parser, const NamedAttribute &attr,
unsigned &value, StringRef desc) {
auto intAttr = attr.getValue().dyn_cast<IntegerAttr>();
if (!intAttr) {
parser.emitError(parser.getNameLoc(), "expected an integer ") << desc;
return failure();
}
value = intAttr.getUInt();
return success();
};
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Attribute methods // Attribute methods
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#define GET_ATTRDEF_CLASSES #define GET_ATTRDEF_CLASSES
#include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.cpp.inc" #include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.cpp.inc"
static Attribute parseBlocked(AsmParser &parser, Type type) { //===----------------------------------------------------------------------===//
// Blocked Encoding
//===----------------------------------------------------------------------===//
Attribute TritonGPUBlockedEncodingAttr::parse(AsmParser &parser, Type type) {
if (parser.parseLess().failed()) if (parser.parseLess().failed())
return {}; return {};
// Parse the data as a dictionary // Parse the data as a dictionary
@@ -46,32 +61,30 @@ static Attribute parseBlocked(AsmParser &parser, Type type) {
if (parser.parseGreater().failed()) if (parser.parseGreater().failed())
return {}; return {};
SmallVector<unsigned, 2> threadTileSize; SmallVector<unsigned, 2> sizePerThread;
SmallVector<unsigned, 2> warpTileSize; SmallVector<unsigned, 2> threadsPerWarp;
SmallVector<unsigned, 2> blockTileSize; SmallVector<unsigned, 2> warpsPerCTA;
SmallVector<unsigned, 2> order; SmallVector<unsigned, 2> order;
SmallVector<unsigned, 2> broadcastAxis;
for (const NamedAttribute &attr : dict) { for (const NamedAttribute &attr : dict) {
if (attr.getName() == "threadTileSize") { if (attr.getName() == "sizePerThread") {
if (parseIntArrayAttr(parser, attr, threadTileSize, "thread tile size") if (parseIntArrayAttr(parser, attr, sizePerThread,
"number of elements per thread")
.failed()) .failed())
return {}; return {};
} else if (attr.getName() == "warpTileSize") { } else if (attr.getName() == "threadsPerWarp") {
if (parseIntArrayAttr(parser, attr, warpTileSize, "warp tile size") if (parseIntArrayAttr(parser, attr, threadsPerWarp,
"number of threads per warp")
.failed()) .failed())
return {}; return {};
} else if (attr.getName() == "blockTileSize") { } else if (attr.getName() == "warpsPerCTA") {
if (parseIntArrayAttr(parser, attr, blockTileSize, "block tile size") if (parseIntArrayAttr(parser, attr, warpsPerCTA,
"number of warps per CTA")
.failed()) .failed())
return {}; return {};
} else if (attr.getName() == "order") { } else if (attr.getName() == "order") {
if (parseIntArrayAttr(parser, attr, order, "order").failed()) if (parseIntArrayAttr(parser, attr, order, "order").failed())
return {}; return {};
} else if (attr.getName() == "broadcastAxis") {
if (parseIntArrayAttr(parser, attr, broadcastAxis, "broadcastAxis")
.failed())
return {};
} else { } else {
parser.emitError(parser.getNameLoc(), "unexpected key: ") parser.emitError(parser.getNameLoc(), "unexpected key: ")
<< attr.getName().strref(); << attr.getName().strref();
@@ -80,39 +93,23 @@ static Attribute parseBlocked(AsmParser &parser, Type type) {
} }
return parser.getChecked<TritonGPUBlockedEncodingAttr>( return parser.getChecked<TritonGPUBlockedEncodingAttr>(
parser.getContext(), threadTileSize, warpTileSize, blockTileSize, order, parser.getContext(), sizePerThread, threadsPerWarp, warpsPerCTA, order);
broadcastAxis);
}
template <class T>
static void printBlocked(AsmPrinter &printer, const T *attr) {
printer << "<{"
<< "threadTileSize = [" << attr->getThreadTileSize() << "]"
<< ", warpTileSize = [" << attr->getWarpTileSize() << "]"
<< ", blockTileSize = [" << attr->getBlockTileSize() << "]"
<< ", order = [" << attr->getOrder() << "]"
<< ", broadcastAxis = [" << attr->getBroadcastAxis() << "]"
<< "}>";
}
Attribute TritonGPUBlockedEncodingAttr::parse(AsmParser &parser, Type type) {
return parseBlocked(parser, type);
} }
void TritonGPUBlockedEncodingAttr::print(mlir::AsmPrinter &printer) const { void TritonGPUBlockedEncodingAttr::print(mlir::AsmPrinter &printer) const {
printBlocked(printer, this); printer << "<{"
<< "sizePerThread = [" << getSizePerThread() << "]"
<< ", threadsPerWarp = [" << getThreadsPerWarp() << "]"
<< ", warpsPerCTA = [" << getWarpsPerCTA() << "]"
<< ", order = [" << getOrder() << "]"
<< "}>";
} }
Attribute TritonGPUBlockedMulticastEncodingAttr::parse(AsmParser &parser, //===----------------------------------------------------------------------===//
Type type) { // MMA encoding
return parseBlocked(parser, type); //===----------------------------------------------------------------------===//
}
void TritonGPUBlockedMulticastEncodingAttr::print(AsmPrinter &printer) const { Attribute TritonGPUMmaEncodingAttr::parse(AsmParser &parser, Type type) {
printBlocked(printer, this);
}
static Attribute parseMma(AsmParser &parser, Type type) {
if (parser.parseLess().failed()) if (parser.parseLess().failed())
return {}; return {};
DictionaryAttr dict; DictionaryAttr dict;
@@ -121,76 +118,34 @@ static Attribute parseMma(AsmParser &parser, Type type) {
if (parser.parseGreater().failed()) if (parser.parseGreater().failed())
return {}; return {};
SmallVector<unsigned, 2> fragmentPerWarp; unsigned version = 0;
SmallVector<unsigned, 2> shapePerWarp; SmallVector<unsigned, 2> warpsPerCTA;
SmallVector<unsigned, 2> warpPerTile;
SmallVector<unsigned, 2> shapePerTile;
SmallVector<unsigned, 2> repetitions;
SmallVector<unsigned, 2> contigPerThread;
SmallVector<unsigned, 2> broadcastAxis;
for (const NamedAttribute &attr : dict) { for (const NamedAttribute &attr : dict) {
if (attr.getName() == "fragmentPerWarp") { if (attr.getName() == "version") {
if (parseIntArrayAttr(parser, attr, fragmentPerWarp, "fragmentPerWarp") if (parseUInt(parser, attr, version, "version").failed())
.failed())
return {}; return {};
} else if (attr.getName() == "shapePerWarp") { }
if (parseIntArrayAttr(parser, attr, shapePerWarp, "shapePerWarp") if (attr.getName() == "warpsPerCTA") {
.failed()) if (parseIntArrayAttr(parser, attr, warpsPerCTA, "warpsPerCTA").failed())
return {};
} else if (attr.getName() == "warpPerTile") {
if (parseIntArrayAttr(parser, attr, warpPerTile, "warpPerTile").failed())
return {};
} else if (attr.getName() == "shapePerTile") {
if (parseIntArrayAttr(parser, attr, shapePerTile, "shapePerTile")
.failed())
return {};
} else if (attr.getName() == "repetitions") {
if (parseIntArrayAttr(parser, attr, repetitions, "repetitions").failed())
return {};
} else if (attr.getName() == "contigPerThread") {
if (parseIntArrayAttr(parser, attr, contigPerThread, "contigPerThread")
.failed())
return {};
} else {
parser.emitError(parser.getNameLoc(), "unexpected key: ")
<< attr.getName().strref();
return {}; return {};
} }
} }
return parser.getChecked<TritonGPUMmaEncodingAttr>( return parser.getChecked<TritonGPUMmaEncodingAttr>(parser.getContext(),
parser.getContext(), fragmentPerWarp, shapePerWarp, warpPerTile, version, warpsPerCTA);
shapePerTile, repetitions, contigPerThread, broadcastAxis);
}
template <class T> static void printMma(AsmPrinter &printer, T *attr) {
printer << "<{"
<< "fragmentPerWarp = [" << attr->getFragmentPerWarp() << "]"
<< ", shapePerWarp = [" << attr->getShapePerWarp() << "]"
<< ", warpPerTile = [" << attr->getWarpPerTile() << "]"
<< ", shapePerTile = [" << attr->getShapePerTile() << "]"
<< ", repetitions = [" << attr->getRepetitions() << "]"
<< ", contigPerThread = [" << attr->getContigPerThread() << "]"
<< "}>";
}
Attribute TritonGPUMmaEncodingAttr::parse(AsmParser &parser, Type type) {
return parseMma(parser, type);
} }
void TritonGPUMmaEncodingAttr::print(AsmPrinter &printer) const { void TritonGPUMmaEncodingAttr::print(AsmPrinter &printer) const {
printMma(printer, this); printer << "<{"
<< "version = " << getVersion() << ", "
<< "warpsPerCTA = [" << getWarpsPerCTA() << "]"
<< "}>";
} }
Attribute TritonGPUMmaMulticastEncodingAttr::parse(AsmParser &parser, //===----------------------------------------------------------------------===//
Type type) { // Shared encoding
return parseMma(parser, type); //===----------------------------------------------------------------------===//
}
void TritonGPUMmaMulticastEncodingAttr::print(AsmPrinter &printer) const {
printMma(printer, this);
}
Attribute TritonGPUSharedEncodingAttr::parse(AsmParser &parser, Type type) { Attribute TritonGPUSharedEncodingAttr::parse(AsmParser &parser, Type type) {
if (parser.parseLess().failed()) if (parser.parseLess().failed())
@@ -207,26 +162,15 @@ Attribute TritonGPUSharedEncodingAttr::parse(AsmParser &parser, Type type) {
unsigned maxPhase = 0; unsigned maxPhase = 0;
SmallVector<unsigned, 2> order; SmallVector<unsigned, 2> order;
auto parseUInt = [&parser](const NamedAttribute &attr, unsigned &value,
StringRef desc) -> LogicalResult {
auto intAttr = attr.getValue().dyn_cast<IntegerAttr>();
if (!intAttr) {
parser.emitError(parser.getNameLoc(), "expected an integer ") << desc;
return failure();
}
value = intAttr.getUInt();
return success();
};
for (const NamedAttribute &attr : dict) { for (const NamedAttribute &attr : dict) {
if (attr.getName() == "vec") { if (attr.getName() == "vec") {
if (parseUInt(attr, vec, "vec").failed()) if (parseUInt(parser, attr, vec, "vec").failed())
return {}; return {};
} else if (attr.getName() == "perPhase") { } else if (attr.getName() == "perPhase") {
if (parseUInt(attr, perPhase, "perPhase").failed()) if (parseUInt(parser, attr, perPhase, "perPhase").failed())
return {}; return {};
} else if (attr.getName() == "maxPhase") { } else if (attr.getName() == "maxPhase") {
if (parseUInt(attr, maxPhase, "maxPhase").failed()) if (parseUInt(parser, attr, maxPhase, "maxPhase").failed())
return {}; return {};
} else if (attr.getName() == "order") { } else if (attr.getName() == "order") {
if (parseIntArrayAttr(parser, attr, order, "order").failed()) if (parseIntArrayAttr(parser, attr, order, "order").failed())
@@ -250,6 +194,10 @@ void TritonGPUSharedEncodingAttr::print(AsmPrinter &printer) const {
<< "}>"; << "}>";
} }
//===----------------------------------------------------------------------===//
// ASM Interface (i.e.: alias)
//===----------------------------------------------------------------------===//
class TritonGPUOpAsmInterface : public OpAsmDialectInterface { class TritonGPUOpAsmInterface : public OpAsmDialectInterface {
public: public:
using OpAsmDialectInterface::OpAsmDialectInterface; using OpAsmDialectInterface::OpAsmDialectInterface;
@@ -257,72 +205,18 @@ public:
AliasResult getAlias(Attribute attr, raw_ostream &os) const override { AliasResult getAlias(Attribute attr, raw_ostream &os) const override {
if (auto mmaAttr = attr.dyn_cast<TritonGPUMmaEncodingAttr>()) { if (auto mmaAttr = attr.dyn_cast<TritonGPUMmaEncodingAttr>()) {
os << "mma"; os << "mma";
TritonGPUOpAsmInterface::printMma(mmaAttr, os);
return AliasResult::FinalAlias;
} else if (auto mmaMulticastAttr =
attr.dyn_cast<TritonGPUMmaMulticastEncodingAttr>()) {
os << "mma_multicast";
TritonGPUOpAsmInterface::printMma(mmaAttr, os);
return AliasResult::FinalAlias; return AliasResult::FinalAlias;
} else if (auto sharedAttr = attr.dyn_cast<TritonGPUSharedEncodingAttr>()) { } else if (auto sharedAttr = attr.dyn_cast<TritonGPUSharedEncodingAttr>()) {
os << "shared"; os << "shared";
TritonGPUOpAsmInterface::printShared(sharedAttr, os);
return AliasResult::FinalAlias; return AliasResult::FinalAlias;
} else if (auto blockedAttr = } else if (auto blockedAttr =
attr.dyn_cast<TritonGPUBlockedEncodingAttr>()) { attr.dyn_cast<TritonGPUBlockedEncodingAttr>()) {
os << "blocked"; os << "blocked";
TritonGPUOpAsmInterface::printBlocked(blockedAttr, os);
return AliasResult::FinalAlias;
} else if (auto blockedMulticastAttr =
attr.dyn_cast<TritonGPUBlockedMulticastEncodingAttr>()) {
os << "blocked_multicast";
TritonGPUOpAsmInterface::printBlocked(blockedMulticastAttr, os);
return AliasResult::FinalAlias; return AliasResult::FinalAlias;
} }
OpAsmDialectInterface::getAlias(attr, os); OpAsmDialectInterface::getAlias(attr, os);
return AliasResult::FinalAlias; return AliasResult::FinalAlias;
} }
private:
static void printMma(const TritonGPUMmaEncodingAttr &attr, raw_ostream &os) {
TritonGPUOpAsmInterface::printArray(attr.getFragmentPerWarp(), os);
TritonGPUOpAsmInterface::printArray(attr.getShapePerWarp(), os);
TritonGPUOpAsmInterface::printArray(attr.getWarpPerTile(), os);
TritonGPUOpAsmInterface::printArray(attr.getShapePerTile(), os);
TritonGPUOpAsmInterface::printArray(attr.getRepetitions(), os);
TritonGPUOpAsmInterface::printArray(attr.getContigPerThread(), os);
}
static void printShared(const TritonGPUSharedEncodingAttr &attr,
raw_ostream &os) {
os << "_" << attr.getVec();
os << "_" << attr.getPerPhase();
os << "_" << attr.getMaxPhase();
TritonGPUOpAsmInterface::printArray(attr.getOrder(), os);
}
template <class T> static void printBlocked(const T &attr, raw_ostream &os) {
TritonGPUOpAsmInterface::printArray(attr.getThreadTileSize(), os);
TritonGPUOpAsmInterface::printArray(attr.getWarpTileSize(), os);
TritonGPUOpAsmInterface::printArray(attr.getBlockTileSize(), os);
TritonGPUOpAsmInterface::printArray(attr.getOrder(), os);
TritonGPUOpAsmInterface::printArray(attr.getBroadcastAxis(), os);
}
static void printArray(const ArrayRef<unsigned> &array, raw_ostream &os,
const std::string &delimiter = "x") {
os << "_";
if (array.empty()) {
os << "none";
return;
}
for (unsigned i = 0; i < array.size(); i++) {
os << array[i];
if (i != array.size() - 1) {
os << delimiter;
}
}
}
}; };
void TritonGPUDialect::initialize() { void TritonGPUDialect::initialize() {

View File

@@ -3,6 +3,7 @@
#include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include <algorithm> #include <algorithm>
#include <numeric>
using namespace mlir; using namespace mlir;
using namespace mlir::triton::gpu; using namespace mlir::triton::gpu;
@@ -11,54 +12,26 @@ using namespace mlir::triton::gpu;
// TypeConverter // TypeConverter
// //
TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context, TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
int numThreads) int numWarps)
: context(context), numThreads(numThreads) { : context(context), numWarps(numWarps) {
// TODO: how does MLIR pick the right conversion? // TODO: how does MLIR pick the right conversion?
addConversion([](Type type) { return type; }); addConversion([](Type type) { return type; });
addConversion([this](RankedTensorType tensorType) -> RankedTensorType { addConversion([this](RankedTensorType tensorType) -> RankedTensorType {
MLIRContext *context = this->context; // types with encoding are already in the right format
int numThreads = this->numThreads; // TODO: check for layout encodings specifically
if (tensorType.getEncoding())
llvm::ArrayRef<int64_t> shape = tensorType.getShape(); return tensorType;
Type elementType = tensorType.getElementType(); // pessimistic values for attributes:
int64_t rank = tensorType.getRank(); // - 1 element per thread
int64_t numElements = tensorType.getNumElements(); // - order = arange(rank)
ArrayRef<int64_t> shape = tensorType.getShape();
// TODO: are there any better ways to raise this error? int rank = shape.size();
if (!(numElements >= numThreads)) {
SmallVector<char> buffer;
llvm::raw_svector_ostream os(buffer);
os << tensorType << " has " << numElements << " numElements "
<< " smaller than numThreads (" << numThreads << ")\n"
<< "consider using smaller num-warps\n";
llvm::report_fatal_error(os.str());
}
assert(numElements % numThreads == 0);
// or assert no encoding?
// Now we assume:
// contiguous = 1, order = 0, 1, 2, ...,
llvm::SmallVector<unsigned> threadTileSize(rank, 1); // naive layout
llvm::SmallVector<unsigned> warpTileSize(rank, 1);
llvm::SmallVector<unsigned> blockTileSize(rank);
llvm::SmallVector<unsigned> order(rank); llvm::SmallVector<unsigned> order(rank);
llvm::SmallVector<unsigned> broadcastAxis; std::iota(order.begin(), order.end(), 0);
int remainingThreads = numThreads; llvm::SmallVector<unsigned> sizePerThread(rank, 1);
int remainingLanes = /*warp size*/ 32;
for (int64_t dim = 0; dim < rank; ++dim) {
blockTileSize[dim] = std::clamp(remainingThreads, 1, int(shape[dim]));
warpTileSize[dim] = std::clamp(remainingLanes, 1, int(shape[dim]));
order[dim] = dim;
remainingThreads /= blockTileSize[dim];
remainingLanes /= warpTileSize[dim];
// TODO: will we need repetition?
}
Attribute encoding = triton::gpu::TritonGPUBlockedEncodingAttr::get( Attribute encoding = triton::gpu::TritonGPUBlockedEncodingAttr::get(
context, threadTileSize, warpTileSize, blockTileSize, order, this->context, shape, sizePerThread, order, this->numWarps);
broadcastAxis); return RankedTensorType::get(shape, tensorType.getElementType(), encoding);
return RankedTensorType::get(shape, elementType, encoding);
}); });
// //
@@ -86,8 +59,12 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
// NOTE: only for remapped values. // NOTE: only for remapped values.
addTargetMaterialization([&](OpBuilder &builder, RankedTensorType tensorType, addTargetMaterialization([&](OpBuilder &builder, RankedTensorType tensorType,
ValueRange inputs, Location loc) { ValueRange inputs, Location loc) {
llvm_unreachable("Not implemented"); auto cast =
return llvm::None; builder.create<triton::gpu::ConvertLayoutOp>(loc, tensorType, inputs);
return Optional<Value>(cast.getResult());
// return Optional<Value>(cast.getResult(0));
// llvm_unreachable("Not implemented");
// return llvm::None;
}); });
} }
@@ -122,87 +99,6 @@ TritonGPUConversionTarget::TritonGPUConversionTarget(
aEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>() && aEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>() &&
bEncoding && bEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>()) bEncoding && bEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>())
return true; return true;
// // TODO: we should delete this
// if (this->typeConverter.isLegal(dotOp))
// return true;
return false; return false;
}); });
} }
// %dst = tt.broadcast %src
// =>
// %newSrc = convert_layout %src
// %bcst = tt.broadcast %newSrc
// %dst = convert_layout %bcst
LogicalResult TritonGPUConversionTarget::refineLayouts(ModuleOp mod,
int numThreads) {
// collect broadcasts
SmallVector<triton::BroadcastOp> broadcasts;
mod.walk([&](triton::BroadcastOp op) { broadcasts.push_back(op); });
BlockAndValueMapping mapping;
for (auto broadcast : broadcasts) {
OpBuilder builder(broadcast);
Value src = mapping.lookupOrDefault(broadcast.src());
Type originSrcType = src.getType();
Type originDstType = broadcast.getType();
auto originDstTensorType = originDstType.dyn_cast<RankedTensorType>();
unsigned dstRank = originDstTensorType.getRank();
// compute newSrcType & broadcastAxis
Type newSrcType;
SmallVector<unsigned> broadcastAxis;
bool isSrcScalar = false;
if (auto tensorType = originSrcType.dyn_cast<RankedTensorType>()) {
assert(tensorType.getRank() == dstRank &&
"src & dst should have same rank (verifier should catch this)");
for (unsigned ax = 0; ax < dstRank; ++ax)
if (tensorType.getShape()[ax] < originDstTensorType.getShape()[ax])
broadcastAxis.push_back(ax);
Attribute originSrcEnc = tensorType.getEncoding();
if (auto blockedEnc =
originSrcEnc.dyn_cast<TritonGPUBlockedEncodingAttr>()) {
auto newSrcEnc = TritonGPUBlockedMulticastEncodingAttr::get(
blockedEnc.getContext(), blockedEnc.getThreadTileSize(),
blockedEnc.getWarpTileSize(), blockedEnc.getBlockTileSize(),
blockedEnc.getOrder(), broadcastAxis);
newSrcType = RankedTensorType::get(
tensorType.getShape(), tensorType.getElementType(), newSrcEnc);
} else
llvm_unreachable("src of broadcast should have blocked encoding");
} else {
for (unsigned ax = 0; ax < dstRank; ++ax)
broadcastAxis.push_back(ax);
newSrcType = originSrcType;
isSrcScalar = true;
}
// create new src
if (!isSrcScalar) // we don't need to convert layout for scalar values
src = builder.create<triton::gpu::ConvertLayoutOp>(src.getLoc(),
newSrcType, src);
// create new broadcast
// compute new type (encoding)
auto originDstEnc = originDstTensorType.getEncoding()
.dyn_cast<TritonGPUBlockedEncodingAttr>();
auto newEnc = TritonGPUBlockedEncodingAttr::get(
originDstEnc.getContext(), originDstEnc.getThreadTileSize(),
originDstEnc.getWarpTileSize(), originDstEnc.getBlockTileSize(),
originDstEnc.getOrder(), broadcastAxis);
auto newType =
RankedTensorType::get(originDstTensorType.getShape(),
originDstTensorType.getElementType(), newEnc);
Value newBroadcast =
builder.create<triton::BroadcastOp>(broadcast.getLoc(), newType, src);
// we don't want to change the encoding of the result
Value newDst = builder.create<triton::gpu::ConvertLayoutOp>(
broadcast.getLoc(), originDstType, newBroadcast);
broadcast.replaceAllUsesWith(newDst);
mapping.map(broadcast, newDst);
}
return success();
}

View File

@@ -10,10 +10,10 @@ def kernel(X, stride_xm, stride_xn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
off_m = tl.arange(0, BLOCK_M) off_m = tl.arange(0, BLOCK_M)
off_n = tl.arange(0, BLOCK_N) off_n = tl.arange(0, BLOCK_N)
Xs = X + off_m[:, None] * stride_xm + off_n[None, :] * stride_xn Xs = X + off_m[:, None] * stride_xm + off_n[None, :] * 1
Zs = Z + off_m[:, None] * stride_zm + off_n[None, :] * stride_zn Zs = Z + off_m[:, None] * 1 + off_n[None, :] * stride_zn
tl.store(Zs, tl.load(Xs)) tl.store(Zs, tl.load(Xs))
ret = triton.compile(kernel, "*fp32,i32,i32,*fp32,i32,i32", constants={"BLOCK_M": 128, "BLOCK_N": 128}, output="ttgir") ret = triton.compile(kernel, "*fp32,i32,i32,*fp32,i32,i32", constants={"BLOCK_M": 64, "BLOCK_N": 64}, output="ttgir")
print(ret) print(ret)

View File

@@ -1471,14 +1471,14 @@ void init_triton_ir(py::module &&m) {
self.create<mlir::triton::StoreOp>(loc, ptrs, val, mask); self.create<mlir::triton::StoreOp>(loc, ptrs, val, mask);
}) })
// Block instruction // Block instruction
.def("create_reshape", .def("create_view",
[](mlir::OpBuilder &self, mlir::Value &arg, [](mlir::OpBuilder &self, mlir::Value &arg,
std::vector<int64_t> &shape) -> mlir::Value { std::vector<int64_t> &shape) -> mlir::Value {
auto loc = self.getUnknownLoc(); auto loc = self.getUnknownLoc();
auto argType = arg.getType() auto argType = arg.getType()
.dyn_cast<mlir::RankedTensorType>() .dyn_cast<mlir::RankedTensorType>()
.getElementType(); .getElementType();
return self.create<mlir::triton::ReshapeOp>( return self.create<mlir::triton::ViewOp>(
loc, mlir::RankedTensorType::get(shape, argType), arg); loc, mlir::RankedTensorType::get(shape, argType), arg);
}) })
.def("create_cat", .def("create_cat",

View File

@@ -565,7 +565,7 @@ class tensor:
elif sl == slice(None, None, None): elif sl == slice(None, None, None):
dst_shape.append(src_shape[curr].value) dst_shape.append(src_shape[curr].value)
curr += 1 curr += 1
ret = semantic.reshape(self, dst_shape, _builder) ret = semantic.view(self, dst_shape, _builder)
return ret return ret
@builtin @builtin

View File

@@ -451,7 +451,7 @@ def zeros(shape: List[int], dtype: tl.dtype, builder: ir.builder) -> tl.tensor:
# ===----------------------------------------------------------------------===// # ===----------------------------------------------------------------------===//
def reshape(input: tl.tensor, def view(input: tl.tensor,
dst_shape: List[int], dst_shape: List[int],
builder: ir.builder) -> tl.tensor: builder: ir.builder) -> tl.tensor:
numel = 1 numel = 1
@@ -460,7 +460,7 @@ def reshape(input: tl.tensor,
if input.type.numel != numel: if input.type.numel != numel:
raise ValueError("cannot reshape block of different shape") raise ValueError("cannot reshape block of different shape")
ret_ty = tl.block_type(input.type.scalar, dst_shape) ret_ty = tl.block_type(input.type.scalar, dst_shape)
return tl.tensor(builder.create_reshape(input.handle, dst_shape), ret_ty) return tl.tensor(builder.create_view(input.handle, dst_shape), ret_ty)
def cat(lhs: tl.tensor, rhs: tl.tensor, builder: ir.builder) -> tl.tensor: def cat(lhs: tl.tensor, rhs: tl.tensor, builder: ir.builder) -> tl.tensor:

View File

@@ -10,7 +10,7 @@ func @permute_2d(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {t
// CHECK-NEXT: Contiguity: [128] ; Divisibility: [65536] ; Constancy: [1] // CHECK-NEXT: Contiguity: [128] ; Divisibility: [65536] ; Constancy: [1]
%1 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> %1 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
// CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [65536, 1] ; Constancy: [1, 1] // CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [65536, 1] ; Constancy: [1, 1]
%2 = tt.reshape %0 : (tensor<128xi32>) -> tensor<128x1xi32> %2 = tt.view %0 : (tensor<128xi32>) -> tensor<128x1xi32>
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [128, 1] // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [128, 1]
%3 = tt.splat %arg1 : (i32) -> tensor<128x1xi32> %3 = tt.splat %arg1 : (i32) -> tensor<128x1xi32>
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [1048576, 16] ; Constancy: [1, 1] // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [1048576, 16] ; Constancy: [1, 1]
@@ -20,7 +20,7 @@ func @permute_2d(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {t
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [1, 1] // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [1, 1]
%6 = tt.getelementptr %5, %4 : tensor<128x1x!tt.ptr<f32>> %6 = tt.getelementptr %5, %4 : tensor<128x1x!tt.ptr<f32>>
// CHECK-NEXT: Contiguity: [1, 128] ; Divisibility: [1, 65536] ; Constancy: [1, 1] // CHECK-NEXT: Contiguity: [1, 128] ; Divisibility: [1, 65536] ; Constancy: [1, 1]
%7 = tt.reshape %1 : (tensor<128xi32>) -> tensor<1x128xi32> %7 = tt.view %1 : (tensor<128xi32>) -> tensor<1x128xi32>
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [1, 128] // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [1, 128]
%8 = tt.broadcast %6 : (tensor<128x1x!tt.ptr<f32>>) -> tensor<128x128x!tt.ptr<f32>> %8 = tt.broadcast %6 : (tensor<128x1x!tt.ptr<f32>>) -> tensor<128x128x!tt.ptr<f32>>
// CHECK-NEXT: Contiguity: [1, 128] ; Divisibility: [1, 65536] ; Constancy: [128, 1] // CHECK-NEXT: Contiguity: [1, 128] ; Divisibility: [1, 65536] ; Constancy: [128, 1]
@@ -28,13 +28,13 @@ func @permute_2d(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {t
// CHECK-NEXT: Contiguity: [1, 128] ; Divisibility: [1, 16] ; Constancy: [1, 1] // CHECK-NEXT: Contiguity: [1, 128] ; Divisibility: [1, 16] ; Constancy: [1, 1]
%10 = tt.getelementptr %8, %9 : tensor<128x128x!tt.ptr<f32>> %10 = tt.getelementptr %8, %9 : tensor<128x128x!tt.ptr<f32>>
// CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [65536, 1] ; Constancy: [1, 1] // CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [65536, 1] ; Constancy: [1, 1]
%11 = tt.reshape %0 : (tensor<128xi32>) -> tensor<128x1xi32> %11 = tt.view %0 : (tensor<128xi32>) -> tensor<128x1xi32>
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [128, 1] // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [128, 1]
%12 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<128x1x!tt.ptr<f32>> %12 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<128x1x!tt.ptr<f32>>
// CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [16, 1] ; Constancy: [1, 1] // CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [16, 1] ; Constancy: [1, 1]
%13 = tt.getelementptr %12, %11 : tensor<128x1x!tt.ptr<f32>> %13 = tt.getelementptr %12, %11 : tensor<128x1x!tt.ptr<f32>>
// CHECK-NEXT: Contiguity: [1, 128] ; Divisibility: [1, 65536] ; Constancy: [1, 1] // CHECK-NEXT: Contiguity: [1, 128] ; Divisibility: [1, 65536] ; Constancy: [1, 1]
%14 = tt.reshape %1 : (tensor<128xi32>) -> tensor<1x128xi32> %14 = tt.view %1 : (tensor<128xi32>) -> tensor<1x128xi32>
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [1, 128] // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [1, 128]
%15 = tt.splat %arg3 : (i32) -> tensor<1x128xi32> %15 = tt.splat %arg3 : (i32) -> tensor<1x128xi32>
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 1048576] ; Constancy: [1, 1] // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 1048576] ; Constancy: [1, 1]

View File

@@ -1,26 +0,0 @@
// RUN: triton-opt %s -split-input-file -verify-diagnostics
#reg = #triton_gpu.blocked_layout<{
threadTileSize = [1, 1],
warpTileSize = [32, 1],
blockTileSize = [64, 1],
order = [0]
}>
#reg2 = #triton_gpu.blocked_layout<{
threadTileSize = [2, 1],
warpTileSize = [64, 1],
blockTileSize = [128, 1],
order = [0]
}>
func @add(%arg0: tensor<256xi32, #reg>, %arg1: tensor<256xi32, #reg>) {
%0 = arith.addi %arg0, %arg1 : tensor<256xi32, #reg>
return
}
func @add(%arg0: tensor<256xi32, #reg>, %arg1: tensor<256xi32, #reg>) { // expected-note {{prior use here}}
// expected-error @+1 {{use of value '%arg0' expects different type than prior uses}}
%0 = arith.addi %arg0, %arg1 : tensor<256xi32, #reg2>
return
}

View File

@@ -1,45 +1,12 @@
// RUN: triton-opt %s -split-input-file -tritongpu-pipeline=num-stages=3 -canonicalize -tritongpu-verifier
// RUN: triton-opt %s -split-input-file -tritongpu-pipeline=num-stages=3 -canonicalize -tritongpu-verifier | FileCheck %s // RUN: triton-opt %s -split-input-file -tritongpu-pipeline=num-stages=3 -canonicalize -tritongpu-verifier | FileCheck %s
// 4 warps // 4 warps
// matmul: 128x32 @ 32x128 -> 128x128 // matmul: 128x32 @ 32x128 -> 128x128
#AL = #triton_gpu.blocked_layout<{ #AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
threadTileSize = [1, 4], #BL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
warpTileSize = [4, 32], #A = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>
blockTileSize = [16, 32], #B = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>
order = [1, 0] #C = #triton_gpu.mma<{version = 2, warpsPerCTA = [4, 1]}>
}>
#BL = #triton_gpu.blocked_layout<{
threadTileSize = [1, 4],
warpTileSize = [1, 128],
blockTileSize = [4, 128],
order = [1, 0]
}>
#A = #triton_gpu.shared_layout<{
vec = 2,
perPhase = 2,
maxPhase = 4,
order = [1, 0]
}>
#B = #triton_gpu.shared_layout<{
vec = 2,
perPhase = 2,
maxPhase = 4,
order = [1, 0]
}>
// TODO: check this
#C = #triton_gpu.mma_layout<{
fragmentPerWarp = [1, 1],
shapePerWarp = [16, 8],
warpPerTile = [2, 2],
shapePerTile = [32, 16],
repetitions = [4, 4],
contigPerThread = [1, 8]
}>
// CHECK: func @matmul_loop // CHECK: func @matmul_loop
// CHECK: %[[A0:.*]] = triton_gpu.copy_async // CHECK: %[[A0:.*]] = triton_gpu.copy_async