|
|
|
@@ -2,43 +2,60 @@
|
|
|
|
|
#define TRITONGPU_ATTRDEFS
|
|
|
|
|
|
|
|
|
|
include "triton/Dialect/TritonGPU/IR/TritonGPUDialect.td"
|
|
|
|
|
// include "mlir/IR/TensorEncoding.td"
|
|
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
// TritonGPU Attribute Definitions
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
|
|
class TritonGPU_Attr<string name, list<Trait> traits = [],
|
|
|
|
|
string baseCppClass = "::mlir::Attribute">
|
|
|
|
|
: AttrDef<TritonGPU_Dialect, name, traits, baseCppClass>;
|
|
|
|
|
: AttrDef<TritonGPU_Dialect, name, traits, baseCppClass> {
|
|
|
|
|
|
|
|
|
|
let description = [{
|
|
|
|
|
TritonGPU Tensors differ from usual tensors in that they contain a _layout_ attribute which determines
|
|
|
|
|
how the data should be partitioned across CUDA threads. Formally speaking, we define a layout as a function
|
|
|
|
|
\mathcal{L} that maps a multi-dimensional tensor index $i \in \mathbb{Z}^d$ to a set of integers T corresponding
|
|
|
|
|
to the indices of the CUDA threads allowed to access some data at index $i$.
|
|
|
|
|
|
|
|
|
|
For example, let us consider the layout function:
|
|
|
|
|
\mathcal{L}(0, 0) = {0, 4}
|
|
|
|
|
\mathcal{L}(0, 1) = {1, 5}
|
|
|
|
|
\mathcal{L}(1, 0) = {2, 6}
|
|
|
|
|
\mathcal{L}(1, 1) = {3, 7}
|
|
|
|
|
|
|
|
|
|
Then, attaching $\mathcal{L} to a tensor $T$ would mean that:
|
|
|
|
|
- T[0,0] is owned by both cuda thread 0 and 4
|
|
|
|
|
- T[0,1] is owned by both cuda thread 1 and 5
|
|
|
|
|
- T[1,0] is owned by both cuda thread 2 and 6
|
|
|
|
|
- T[1,1] is owned by both cuda thread 3 and 7
|
|
|
|
|
|
|
|
|
|
Right now, Triton implements two classes of layouts: shared, and distributed.
|
|
|
|
|
}];
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
// Shared Layout Encoding
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
|
|
def TritonGPUSharedEncodingAttr : TritonGPU_Attr<"TritonGPUSharedEncoding"> {
|
|
|
|
|
let mnemonic = "shared_layout";
|
|
|
|
|
let mnemonic = "shared";
|
|
|
|
|
|
|
|
|
|
let description = [{
|
|
|
|
|
An encoding for tensors whose elements may be simultaneously accessed by
|
|
|
|
|
different warps in the programs, via shared memory.
|
|
|
|
|
different cuda threads in the programs, via shared memory. In other words,
|
|
|
|
|
for all indices i \in R^d, \mathcal{L}(i) = {0, 1, ..., 32*num_warps - 1}.
|
|
|
|
|
|
|
|
|
|
In order to avoid shared memory bank conflicts, elements may be stored in a
|
|
|
|
|
swizzled layout.
|
|
|
|
|
For example, a swizzled row-major layout stores would store data as follows:
|
|
|
|
|
In order to avoid shared memory bank conflicts, elements may be swizzled
|
|
|
|
|
in memory. For example, a swizzled row-major layout could store its data
|
|
|
|
|
as follows:
|
|
|
|
|
|
|
|
|
|
A_{0, 0} A_{0, 1} A_{0, 2} A_{0, 3} ... [phase 0] \ per_phase = 2
|
|
|
|
|
A_{1, 0} A_{0, 1} A_{1, 2} A_{1, 3} ... [phase 0] /
|
|
|
|
|
|
|
|
|
|
groups of vec=2 elements
|
|
|
|
|
are stored contiguously
|
|
|
|
|
_ _ _ _ /\_ _ _ _
|
|
|
|
|
A_{2, 2} A_{2, 3} A_{2, 0} A_{2, 1} ... [phase 1] \ per phase = 2
|
|
|
|
|
A_{3, 2} A_{3, 3} A_{3, 0} A_{3, 1} ... [phase 1] /
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
And the associated TritonGPU MLIR
|
|
|
|
|
|
|
|
|
|
```mlir
|
|
|
|
|
#SMEM = #triton_gpu.shared_layout<{
|
|
|
|
|
vec = 2,
|
|
|
|
|
perPhase = 2,
|
|
|
|
|
maxPhase = 4,
|
|
|
|
|
order = [1, 0]
|
|
|
|
|
}>
|
|
|
|
|
```
|
|
|
|
|
}];
|
|
|
|
|
|
|
|
|
|
let parameters = (
|
|
|
|
@@ -49,143 +66,222 @@ And the associated TritonGPU MLIR
|
|
|
|
|
);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
// Distributed Layout Encoding
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
|
|
class TritonGPUDistributedEncodingAttr : TritonGPU_Attr<"TritonGPUDistributedEncoding"> {
|
|
|
|
|
let mnemonic = "distributed";
|
|
|
|
|
|
|
|
|
|
let description = [{
|
|
|
|
|
Distributed encodings have a layout function that is entirely characterized
|
|
|
|
|
by a d-dimensional tensor L. Note that L doesn't need to have the same shape
|
|
|
|
|
(or even the same rank) as the tensor it is encoding.
|
|
|
|
|
|
|
|
|
|
The layout function \mathcal{L} of this layout is then defined, for an
|
|
|
|
|
index `i` \in R^D, as follows:
|
|
|
|
|
|
|
|
|
|
\mathcal{L}(A)[i_d] = L[(i_d + k_d*A.shape[d]) % L.shape[d]] \forall k_d such as i_d + k_d*A.shape[d] < L.shape[d]
|
|
|
|
|
|
|
|
|
|
For example, for a tensor/layout pair
|
|
|
|
|
A = [x x x x x x x x]
|
|
|
|
|
[x x x x x x x x]
|
|
|
|
|
L = [0 1 2 3 ]
|
|
|
|
|
[4 5 6 7 ]
|
|
|
|
|
[8 9 10 11]
|
|
|
|
|
[12 13 14 15]
|
|
|
|
|
|
|
|
|
|
Then the data of A would be distributed as follow between the 16 CUDA threads:
|
|
|
|
|
L(A) = [ {0,8} , {1,9} , {2,10}, {3,11}, {0,8} , {1, 9} , {2, 10}, {3, 11},
|
|
|
|
|
{4,12}, {5,13}, {6,14}, {7,15}, {4,12}, {5, 13}, {6, 14}, {7, 15} ]
|
|
|
|
|
}];
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
// Blocked Layout Encoding
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
|
|
def TritonGPUBlockedEncodingAttr : TritonGPU_Attr<"TritonGPUBlockedEncoding"> {
|
|
|
|
|
let mnemonic = "blocked_layout";
|
|
|
|
|
let mnemonic = "blocked";
|
|
|
|
|
|
|
|
|
|
let description = [{
|
|
|
|
|
An encoding where each warp owns a contiguous portion of the target tensor. This is typically the kind of data layout
|
|
|
|
|
consumed (and returned) by LoadInst.
|
|
|
|
|
For example, a row-major coalesced layout may distribute a 64x16 tensor over 2 warps (i.e. 64 threads) as follows:
|
|
|
|
|
used to promote memory coalescing in LoadInst and StoreInst.
|
|
|
|
|
It is characterized by three tuples -- thread tile size, warp tile size, and block tile size -- which
|
|
|
|
|
specify the amount of elements owned by each CUDA thread, warp and CTA respectively.
|
|
|
|
|
|
|
|
|
|
thread tile size 2
|
|
|
|
|
- - - - - - /\ - - - - - -
|
|
|
|
|
warp | thread || A_{0, 0}[T0] A_{0, 1}[T0] ... A_{0, 6}[T3] A_{0, 7}[T3] A_{0, 8}[T0] A_{0, 9}[T0] ... A_{0, 14}[T3] A_{0, 15}[T3]
|
|
|
|
|
tile | tile size 2 || A_{1, 0}[T0] A_{1, 1}[T0] ... A_{1, 6}[T3] A_{1, 7}[T3] A_{1, 8}[T0] A_{1, 9}[T0] ... A_{1, 14}[T3] A_{1, 15}[T3]
|
|
|
|
|
size } ....
|
|
|
|
|
32 | A_{30, 0}[T60] A_{14, 1}[T60] ... A_{14, 6}[T63] A_{14, 7}[T63] A_{14, 8}[T60] A_{14, 9}[T60] ... A_{14, 14}[T63] A_{14, 15}[T63]
|
|
|
|
|
| A_{31, 0}[T60] A_{15, 1}[T60] ... A_{15, 6}[T63] A_{15, 7}[T63] A_{15, 8}[T60] A_{15, 9}[T60] ... A_{15, 14}[T63] A_{15, 15}[T63]
|
|
|
|
|
-----------------------------/\-----------------------------------
|
|
|
|
|
warp tile size 8
|
|
|
|
|
For example, a row-major coalesced layout may partition a 16x16 tensor over 2 warps (i.e. 64 threads) as follows.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
A_{32, 0}[T0] A_{32, 1}[T0] ... A_{32, 6}[T3] A_{32, 7}[T3] A_{32, 8}[T0] A_{32, 9}[T0] ... A_{32, 14}[T3] A_{32, 15}[T3]
|
|
|
|
|
A_{33, 0}[T0] A_{33, 1}[T0] ... A_{33, 6}[T3] A_{33, 7}[T3] A_{33, 8}[T0] A_{33, 9}[T0] ... A_{33, 14}[T3] A_{33, 15}[T3]
|
|
|
|
|
[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ]
|
|
|
|
|
[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ]
|
|
|
|
|
[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ]
|
|
|
|
|
[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ]
|
|
|
|
|
...
|
|
|
|
|
[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ]
|
|
|
|
|
[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ]
|
|
|
|
|
|
|
|
|
|
A_{62, 0}[T60] A_{62, 1}[T60] ... A_{62, 6}[T63] A_{62, 7}[T63] A_{62, 8}[T60] A_{62, 9}[T60] ... A_{62, 14}[T63] A_{62, 15}[T63]
|
|
|
|
|
A_{63, 0}[T60] A_{63, 1}[T60] ... A_{63, 6}[T63] A_{63, 7}[T63] A_{63, 8}[T60] A_{63, 9}[T60] ... A_{63, 14}[T63] A_{63, 15}[T63]
|
|
|
|
|
for
|
|
|
|
|
|
|
|
|
|
And the associated TritonGPU MLIR
|
|
|
|
|
#LAYOUT = #triton_gpu.blocked_layout<{
|
|
|
|
|
threadTileSize = {2, 2}
|
|
|
|
|
blockTileSize = {32, 8}
|
|
|
|
|
#triton_gpu.blocked_layout<{
|
|
|
|
|
sizePerThread = {2, 2}
|
|
|
|
|
threadsPerWarp = {8, 4}
|
|
|
|
|
warpsPerCTA = {1, 2}
|
|
|
|
|
}>
|
|
|
|
|
|
|
|
|
|
// note to Da: In current Triton codebase, `nanoTileSize = threadTileSize`, and `macro-tile size = blockTileSize / threadTileSize`
|
|
|
|
|
probably clearer to have easier semantics (i.e., size of each tile owned by a thread or a block)
|
|
|
|
|
}];
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
let builders = [
|
|
|
|
|
// Custom builder initializes sizePerWarp and sizePerCTA automatically
|
|
|
|
|
// TODO: compiles on MacOS but not linux?
|
|
|
|
|
// AttrBuilder<(ins "ArrayRef<unsigned>":$sizePerThread,
|
|
|
|
|
// "ArrayRef<unsigned>":$threadsPerWarp,
|
|
|
|
|
// "ArrayRef<unsigned>":$warpsPerCTA,
|
|
|
|
|
// "ArrayRef<unsigned>":$order), [{
|
|
|
|
|
// int rank = threadsPerWarp.size();
|
|
|
|
|
// SmallVector<unsigned, 4> sizePerWarp(rank);
|
|
|
|
|
// SmallVector<unsigned, 4> sizePerCTA(rank);
|
|
|
|
|
// for (unsigned i = 0; i < rank; i++) {
|
|
|
|
|
// sizePerWarp.push_back(sizePerThread[i] * threadsPerWarp[i]);
|
|
|
|
|
// sizePerCTA.push_back(sizePerWarp[i] * warpsPerCTA[i]);
|
|
|
|
|
// }
|
|
|
|
|
// return $_get(context, sizePerThread, threadsPerWarp, warpsPerCTA, order, sizePerWarp, sizePerCTA);
|
|
|
|
|
// }]>,
|
|
|
|
|
// Custom builder initializes sizePerWarp and sizePerCTA automatically
|
|
|
|
|
// Default builder takes sizePerThread, order and numWarps, and tries to
|
|
|
|
|
// pack numWarps*32 threads in the provided order for use in a type
|
|
|
|
|
// of the given shape.
|
|
|
|
|
AttrBuilder<(ins "ArrayRef<int64_t>":$shape,
|
|
|
|
|
"ArrayRef<unsigned>":$sizePerThread,
|
|
|
|
|
"ArrayRef<unsigned>":$order,
|
|
|
|
|
"unsigned":$numWarps), [{
|
|
|
|
|
int rank = sizePerThread.size();
|
|
|
|
|
int remainingWarps = numWarps;
|
|
|
|
|
int remainingLanes = 32;
|
|
|
|
|
SmallVector<unsigned, 4> threadsPerWarp(rank);
|
|
|
|
|
SmallVector<unsigned, 4> warpsPerCTA(rank);
|
|
|
|
|
for (int _dim = 0; _dim < rank; ++_dim) {
|
|
|
|
|
int dim = order[_dim];
|
|
|
|
|
int maxNumThreads = int(shape[dim]) / sizePerThread[dim];
|
|
|
|
|
warpsPerCTA[dim] = std::clamp(remainingWarps, 1, maxNumThreads);
|
|
|
|
|
threadsPerWarp[dim] = std::clamp(remainingLanes, 1, maxNumThreads);
|
|
|
|
|
remainingWarps /= warpsPerCTA[dim];
|
|
|
|
|
remainingLanes /= threadsPerWarp[dim];
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return $_get(context, sizePerThread, threadsPerWarp, warpsPerCTA, order);
|
|
|
|
|
|
|
|
|
|
}]>
|
|
|
|
|
];
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
let parameters = (
|
|
|
|
|
ins
|
|
|
|
|
// TODO: should we rename this as laneTileSize?
|
|
|
|
|
ArrayRefParameter<
|
|
|
|
|
"unsigned",
|
|
|
|
|
/*desc*/"size of a tile that is holded by a thread"
|
|
|
|
|
>:$threadTileSize,
|
|
|
|
|
ArrayRefParameter<
|
|
|
|
|
"unsigned",
|
|
|
|
|
"size of the a tile that is holded by a warp"
|
|
|
|
|
>:$warpTileSize,
|
|
|
|
|
ArrayRefParameter<
|
|
|
|
|
"unsigned",
|
|
|
|
|
"size of a tile that is holded by a thread block"
|
|
|
|
|
>:$blockTileSize,
|
|
|
|
|
// // TODO: It seems that we don't need this (because we can re-compute this)
|
|
|
|
|
// ArrayRefParameter<"unsigned">:$reptitions,
|
|
|
|
|
ArrayRefParameter<"unsigned">:$sizePerThread,
|
|
|
|
|
ArrayRefParameter<"unsigned">:$threadsPerWarp,
|
|
|
|
|
ArrayRefParameter<"unsigned">:$warpsPerCTA,
|
|
|
|
|
// fastest-changing axis first
|
|
|
|
|
ArrayRefParameter<
|
|
|
|
|
"unsigned",
|
|
|
|
|
"order of axes by the rate of changing"
|
|
|
|
|
>:$order,
|
|
|
|
|
ArrayRefParameter<"unsigned">:$broadcastAxis
|
|
|
|
|
// "AffineMap":$threadOrdering,
|
|
|
|
|
// "AffineMap":warpOrdering,
|
|
|
|
|
// "AffineMap":$blockOrdering,
|
|
|
|
|
|
|
|
|
|
>:$order
|
|
|
|
|
// These attributes can be inferred from the rest
|
|
|
|
|
// ArrayRefParameter<"unsigned">:$sizePerWarp,
|
|
|
|
|
// ArrayRefParameter<"unsigned">:$sizePerCTA
|
|
|
|
|
);
|
|
|
|
|
|
|
|
|
|
// let genVerifyDecl = 1;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
def TritonGPUBlockedMulticastEncodingAttr
|
|
|
|
|
: TritonGPU_Attr<"TritonGPUBlockedMulticastEncoding"> {
|
|
|
|
|
let mnemonic = "blocked_multicast_layout";
|
|
|
|
|
|
|
|
|
|
let description = [{
|
|
|
|
|
to be broadcasted to blocked_layout
|
|
|
|
|
}];
|
|
|
|
|
|
|
|
|
|
// This needs to be synced with BlockedEncoding
|
|
|
|
|
let parameters = (
|
|
|
|
|
ins
|
|
|
|
|
ArrayRefParameter<"unsigned">:$threadTileSize,
|
|
|
|
|
ArrayRefParameter<"unsigned">:$warpTileSize,
|
|
|
|
|
ArrayRefParameter<"unsigned">:$blockTileSize,
|
|
|
|
|
ArrayRefParameter<"unsigned">:$order,
|
|
|
|
|
// unique to broadcasted layout
|
|
|
|
|
ArrayRefParameter<"unsigned">:$broadcastAxis
|
|
|
|
|
);
|
|
|
|
|
|
|
|
|
|
// let genVerifyDecl = 1;
|
|
|
|
|
}
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
// MMA Layout Encoding
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
// TODO: MMAv1 and MMAv2 should be two instances of the same class
|
|
|
|
|
|
|
|
|
|
def TritonGPUMmaEncodingAttr : TritonGPU_Attr<"TritonGPUMmaEncoding"> {
|
|
|
|
|
let mnemonic = "mma_layout";
|
|
|
|
|
|
|
|
|
|
let description = [{TODO: I think we may be able to implement it as a special-case of Distributed encoding with maybe one more warpTileSize attribute!}];
|
|
|
|
|
|
|
|
|
|
let parameters = (
|
|
|
|
|
ins
|
|
|
|
|
// only used by Volta mma.884
|
|
|
|
|
ArrayRefParameter<"unsigned">:$fragmentPerWarp,
|
|
|
|
|
// aka shapeOfInstr (e.g., {16,8,16})
|
|
|
|
|
ArrayRefParameter<"unsigned">:$shapePerWarp,
|
|
|
|
|
// TODO: should we rename this as warpTileSize? (consistent naming with Distributed layout)
|
|
|
|
|
ArrayRefParameter<"unsigned">:$warpPerTile,
|
|
|
|
|
// TODO: should we rename this as blockTileSize? (consistent naming with Distributed layout)
|
|
|
|
|
ArrayRefParameter<"unsigned">:$shapePerTile,
|
|
|
|
|
// TODO: should Distributed layout also
|
|
|
|
|
ArrayRefParameter<"unsigned">:$repetitions,
|
|
|
|
|
ArrayRefParameter<"unsigned">:$contigPerThread,
|
|
|
|
|
ArrayRefParameter<"unsigned">:$broadcastAxis
|
|
|
|
|
// "AffineMap":$warpOrdering,
|
|
|
|
|
// "AffineMap":$blockOrdering
|
|
|
|
|
);
|
|
|
|
|
|
|
|
|
|
// let genVerifyDecl = 1;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
def TritonGPUMmaMulticastEncodingAttr
|
|
|
|
|
: TritonGPU_Attr<"TritonGPUMmaMulticastEncoding"> {
|
|
|
|
|
let mnemonic = "mma_multicast_layout";
|
|
|
|
|
let mnemonic = "mma";
|
|
|
|
|
|
|
|
|
|
let description = [{
|
|
|
|
|
To be broadcasted to mma.
|
|
|
|
|
}];
|
|
|
|
|
An encoding for tensors that have been produced by tensor cores.
|
|
|
|
|
It is characterized by two parameters:
|
|
|
|
|
- A 'version' which specifies the generation the tensor cores
|
|
|
|
|
whose output is being partitioned: 1 for first-gen tensor cores (Volta),
|
|
|
|
|
and 2 for second-gen tensor cores (Turing/Ampere).
|
|
|
|
|
- A `blockTileSize` to indicate how data should be
|
|
|
|
|
partitioned between warps.
|
|
|
|
|
|
|
|
|
|
// -------------------------------- version = 1 --------------------------- //
|
|
|
|
|
|
|
|
|
|
For first-gen tensor cores, the implicit warpTileSize is [16, 16].
|
|
|
|
|
Information about this layout can be found in the official PTX documentation
|
|
|
|
|
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html
|
|
|
|
|
(mma.884 section, FP32 accumulator).
|
|
|
|
|
|
|
|
|
|
For example, the matrix L corresponding to blockTileSize=[32,16] is:
|
|
|
|
|
|
|
|
|
|
warp 0
|
|
|
|
|
--------------------------------/\-------------------------------
|
|
|
|
|
[ 0 0 2 2 0 0 2 2 4 4 6 6 4 4 6 6 ]
|
|
|
|
|
[ 1 1 3 3 1 1 3 3 5 5 7 7 5 5 7 7 ]
|
|
|
|
|
[ 0 0 2 2 0 0 2 2 4 4 6 6 4 4 6 6 ]
|
|
|
|
|
[ 1 1 3 3 1 1 3 3 5 5 7 7 5 5 7 7 ]
|
|
|
|
|
[ 16 16 18 18 16 16 18 18 20 20 22 22 20 20 22 22]
|
|
|
|
|
[ 17 17 19 19 17 17 19 19 21 21 23 23 21 21 23 23]
|
|
|
|
|
[ 16 16 18 18 16 16 18 18 20 20 22 22 20 20 22 22]
|
|
|
|
|
[ 17 17 19 19 17 17 19 19 21 21 23 23 21 21 23 23]
|
|
|
|
|
[ 8 8 10 10 8 8 10 10 12 12 14 14 12 12 14 14]
|
|
|
|
|
[ 9 9 11 11 9 9 11 11 13 13 15 15 13 13 15 15]
|
|
|
|
|
[ ..............................................................
|
|
|
|
|
[ ..............................................................
|
|
|
|
|
[ 24 24 26 26 24 24 26 26 28 28 30 30 28 28 30 30]
|
|
|
|
|
[ 25 25 27 27 25 25 27 27 29 29 31 31 29 29 31 31]
|
|
|
|
|
|
|
|
|
|
warp 1 = warp0 + 32
|
|
|
|
|
--------------------------------/\-------------------------------
|
|
|
|
|
[ 32 32 34 34 32 32 34 34 36 36 38 38 36 36 38 38]
|
|
|
|
|
[ 33 33 35 35 33 33 35 35 37 37 39 39 37 37 39 39]
|
|
|
|
|
[ ..............................................................
|
|
|
|
|
[ ..............................................................
|
|
|
|
|
[ 56 56 58 58 56 56 58 58 60 60 62 62 60 60 62 62]
|
|
|
|
|
[ 57 57 59 59 57 57 59 59 61 61 63 63 61 61 63 63]
|
|
|
|
|
|
|
|
|
|
// -------------------------------- version = 2 --------------------------- //
|
|
|
|
|
|
|
|
|
|
For second-gen tensor cores, the implicit warpTileSize is [16, 8].
|
|
|
|
|
Information about this layout can be found in the official PTX documentation
|
|
|
|
|
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html
|
|
|
|
|
(mma.16816 section, FP32 accumulator).
|
|
|
|
|
|
|
|
|
|
For example, the matrix L corresponding to blockTileSize=[32,16] is:
|
|
|
|
|
warp 0 warp 1
|
|
|
|
|
-----------------/\------------- ----------------/\-------------
|
|
|
|
|
[ 0 0 1 1 2 2 3 3 32 32 33 33 34 34 35 35
|
|
|
|
|
[ 4 4 5 5 6 6 7 7 36 36 37 37 38 38 39 39
|
|
|
|
|
[ .............................. ..............................
|
|
|
|
|
[ 28 28 29 29 30 30 31 31 60 60 61 61 62 62 63 63
|
|
|
|
|
[ 0 0 1 1 2 2 3 3 32 32 33 33 34 34 35 35
|
|
|
|
|
[ 4 4 5 5 6 6 7 7 36 36 37 37 38 38 39 39
|
|
|
|
|
[ .............................. ..............................
|
|
|
|
|
[ 28 28 29 29 30 30 31 31 60 60 61 61 62 62 63 63
|
|
|
|
|
|
|
|
|
|
warp 3 warp 4
|
|
|
|
|
----------------/\------------- ----------------/\-------------
|
|
|
|
|
[ 64 64 65 65 66 66 67 67 96 96 97 97 98 98 99 99
|
|
|
|
|
[ 68 68 69 69 70 70 71 71 100 100 101 101 102 102 103 103
|
|
|
|
|
[ .............................. ...............................
|
|
|
|
|
[ 92 92 93 93 94 94 95 95 124 124 125 125 126 126 127 127
|
|
|
|
|
[ 64 64 65 65 66 66 67 67 96 96 97 97 98 98 99 99
|
|
|
|
|
[ 68 68 69 69 70 70 71 71 100 100 101 101 102 102 103 103
|
|
|
|
|
[ .............................. ...............................
|
|
|
|
|
[ 92 92 93 93 94 94 95 95 124 124 125 125 126 126 127 127
|
|
|
|
|
|
|
|
|
|
}];
|
|
|
|
|
|
|
|
|
|
// This needs to be synced with MmaEncoding
|
|
|
|
|
let parameters = (
|
|
|
|
|
ins
|
|
|
|
|
ArrayRefParameter<"unsigned">:$fragmentPerWarp,
|
|
|
|
|
ArrayRefParameter<"unsigned">:$shapePerWarp,
|
|
|
|
|
ArrayRefParameter<"unsigned">:$warpPerTile,
|
|
|
|
|
ArrayRefParameter<"unsigned">:$shapePerTile,
|
|
|
|
|
ArrayRefParameter<"unsigned">:$repetitions,
|
|
|
|
|
ArrayRefParameter<"unsigned">:$contigPerThread,
|
|
|
|
|
// unique to broadcasted layout
|
|
|
|
|
ArrayRefParameter<"unsigned">:$broadcastAxis
|
|
|
|
|
"unsigned":$version,
|
|
|
|
|
ArrayRefParameter<"unsigned">:$warpsPerCTA
|
|
|
|
|
);
|
|
|
|
|
|
|
|
|
|
// let genVerifyDecl = 1;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#endif
|
|
|
|
|