[TritonGPU] Improved documentation and semantics of layout encodings (#30)
This commit is contained in:
2
.github/workflows/integration-tests.yml
vendored
2
.github/workflows/integration-tests.yml
vendored
@@ -10,7 +10,7 @@ jobs:
|
||||
|
||||
Integration-Tests:
|
||||
|
||||
runs-on: ubuntu-20.04
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
|
||||
|
@@ -18,16 +18,14 @@ namespace mlir {
|
||||
/// Axis information is represented by a std::map<int, int>
|
||||
class AxisInfo {
|
||||
public:
|
||||
typedef std::vector<int> ContiguityT;
|
||||
typedef std::vector<int> DivisibilityT;
|
||||
typedef std::vector<int> ConstancyT;
|
||||
typedef SmallVector<int, 4> DimVectorT;
|
||||
|
||||
public:
|
||||
// Default constructor
|
||||
AxisInfo() : AxisInfo({}, {}, {}) {}
|
||||
// Construct contiguity info with known contiguity
|
||||
AxisInfo(ContiguityT knownContiguity, DivisibilityT knownDivisibility,
|
||||
ConstancyT knownConstancy)
|
||||
AxisInfo(DimVectorT knownContiguity, DimVectorT knownDivisibility,
|
||||
DimVectorT knownConstancy)
|
||||
: contiguity(knownContiguity), divisibility(knownDivisibility),
|
||||
constancy(knownConstancy), rank(contiguity.size()) {
|
||||
assert(knownDivisibility.size() == rank);
|
||||
@@ -36,13 +34,13 @@ public:
|
||||
|
||||
// Accessors
|
||||
int getContiguity(size_t d) const { return contiguity[d]; }
|
||||
const ContiguityT &getContiguity() const { return contiguity; }
|
||||
const DimVectorT &getContiguity() const { return contiguity; }
|
||||
|
||||
int getDivisibility(size_t d) const { return divisibility[d]; }
|
||||
const DivisibilityT &getDivisibility() const { return divisibility; }
|
||||
const DimVectorT &getDivisibility() const { return divisibility; }
|
||||
|
||||
int getConstancy(size_t d) const { return constancy[d]; }
|
||||
const ConstancyT &getConstancy() const { return constancy; }
|
||||
const DimVectorT &getConstancy() const { return constancy; }
|
||||
|
||||
int getRank() const { return rank; }
|
||||
|
||||
@@ -78,7 +76,7 @@ private:
|
||||
/// [18, 22, 26, 30]
|
||||
/// [19, 23, 27, 31]
|
||||
/// Would have contiguity [2, 1].
|
||||
ContiguityT contiguity;
|
||||
DimVectorT contiguity;
|
||||
|
||||
/// The _divisibility_ information maps the `d`-th
|
||||
/// dimension to the largest power-of-two that
|
||||
@@ -93,7 +91,7 @@ private:
|
||||
/// [14, 18, 22, 26]
|
||||
/// [15, 19, 23, 27]
|
||||
// would have divisibility [4, 1]
|
||||
DivisibilityT divisibility;
|
||||
DimVectorT divisibility;
|
||||
|
||||
/// The _constancy_ information maps the `d`-th
|
||||
/// dimension to the length of the shortest
|
||||
@@ -104,7 +102,7 @@ private:
|
||||
/// [8, 8, 8, 8, 12, 12, 12, 12]
|
||||
/// [16, 16, 16, 16, 20, 20, 20, 20]
|
||||
/// would have constancy [1, 4]
|
||||
ConstancyT constancy;
|
||||
DimVectorT constancy;
|
||||
|
||||
// number of dimensions of the lattice
|
||||
int rank;
|
||||
|
@@ -133,8 +133,8 @@ def TT_GEPOp : TT_Op<"getelementptr",
|
||||
//
|
||||
// Shape Manipulation Ops
|
||||
//
|
||||
def TT_ReshapeOp : TT_Op<"reshape", [NoSideEffect, SameOperandsAndResultElementType]> {
|
||||
let summary = "reshape";
|
||||
def TT_ViewOp : TT_Op<"view", [NoSideEffect, SameOperandsAndResultElementType]> {
|
||||
let summary = "view";
|
||||
|
||||
let arguments = (ins TT_Tensor:$src);
|
||||
|
||||
|
@@ -2,43 +2,60 @@
|
||||
#define TRITONGPU_ATTRDEFS
|
||||
|
||||
include "triton/Dialect/TritonGPU/IR/TritonGPUDialect.td"
|
||||
// include "mlir/IR/TensorEncoding.td"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TritonGPU Attribute Definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class TritonGPU_Attr<string name, list<Trait> traits = [],
|
||||
string baseCppClass = "::mlir::Attribute">
|
||||
: AttrDef<TritonGPU_Dialect, name, traits, baseCppClass>;
|
||||
: AttrDef<TritonGPU_Dialect, name, traits, baseCppClass> {
|
||||
|
||||
let description = [{
|
||||
TritonGPU Tensors differ from usual tensors in that they contain a _layout_ attribute which determines
|
||||
how the data should be partitioned across CUDA threads. Formally speaking, we define a layout as a function
|
||||
\mathcal{L} that maps a multi-dimensional tensor index $i \in \mathbb{Z}^d$ to a set of integers T corresponding
|
||||
to the indices of the CUDA threads allowed to access some data at index $i$.
|
||||
|
||||
For example, let us consider the layout function:
|
||||
\mathcal{L}(0, 0) = {0, 4}
|
||||
\mathcal{L}(0, 1) = {1, 5}
|
||||
\mathcal{L}(1, 0) = {2, 6}
|
||||
\mathcal{L}(1, 1) = {3, 7}
|
||||
|
||||
Then, attaching $\mathcal{L} to a tensor $T$ would mean that:
|
||||
- T[0,0] is owned by both cuda thread 0 and 4
|
||||
- T[0,1] is owned by both cuda thread 1 and 5
|
||||
- T[1,0] is owned by both cuda thread 2 and 6
|
||||
- T[1,1] is owned by both cuda thread 3 and 7
|
||||
|
||||
Right now, Triton implements two classes of layouts: shared, and distributed.
|
||||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Shared Layout Encoding
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def TritonGPUSharedEncodingAttr : TritonGPU_Attr<"TritonGPUSharedEncoding"> {
|
||||
let mnemonic = "shared_layout";
|
||||
let mnemonic = "shared";
|
||||
|
||||
let description = [{
|
||||
An encoding for tensors whose elements may be simultaneously accessed by
|
||||
different warps in the programs, via shared memory.
|
||||
different cuda threads in the programs, via shared memory. In other words,
|
||||
for all indices i \in R^d, \mathcal{L}(i) = {0, 1, ..., 32*num_warps - 1}.
|
||||
|
||||
In order to avoid shared memory bank conflicts, elements may be stored in a
|
||||
swizzled layout.
|
||||
For example, a swizzled row-major layout stores would store data as follows:
|
||||
In order to avoid shared memory bank conflicts, elements may be swizzled
|
||||
in memory. For example, a swizzled row-major layout could store its data
|
||||
as follows:
|
||||
|
||||
A_{0, 0} A_{0, 1} A_{0, 2} A_{0, 3} ... [phase 0] \ per_phase = 2
|
||||
A_{1, 0} A_{0, 1} A_{1, 2} A_{1, 3} ... [phase 0] /
|
||||
|
||||
groups of vec=2 elements
|
||||
are stored contiguously
|
||||
_ _ _ _ /\_ _ _ _
|
||||
A_{2, 2} A_{2, 3} A_{2, 0} A_{2, 1} ... [phase 1] \ per phase = 2
|
||||
A_{3, 2} A_{3, 3} A_{3, 0} A_{3, 1} ... [phase 1] /
|
||||
|
||||
|
||||
And the associated TritonGPU MLIR
|
||||
|
||||
```mlir
|
||||
#SMEM = #triton_gpu.shared_layout<{
|
||||
vec = 2,
|
||||
perPhase = 2,
|
||||
maxPhase = 4,
|
||||
order = [1, 0]
|
||||
}>
|
||||
```
|
||||
}];
|
||||
|
||||
let parameters = (
|
||||
@@ -49,143 +66,222 @@ And the associated TritonGPU MLIR
|
||||
);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Distributed Layout Encoding
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class TritonGPUDistributedEncodingAttr : TritonGPU_Attr<"TritonGPUDistributedEncoding"> {
|
||||
let mnemonic = "distributed";
|
||||
|
||||
let description = [{
|
||||
Distributed encodings have a layout function that is entirely characterized
|
||||
by a d-dimensional tensor L. Note that L doesn't need to have the same shape
|
||||
(or even the same rank) as the tensor it is encoding.
|
||||
|
||||
The layout function \mathcal{L} of this layout is then defined, for an
|
||||
index `i` \in R^D, as follows:
|
||||
|
||||
\mathcal{L}(A)[i_d] = L[(i_d + k_d*A.shape[d]) % L.shape[d]] \forall k_d such as i_d + k_d*A.shape[d] < L.shape[d]
|
||||
|
||||
For example, for a tensor/layout pair
|
||||
A = [x x x x x x x x]
|
||||
[x x x x x x x x]
|
||||
L = [0 1 2 3 ]
|
||||
[4 5 6 7 ]
|
||||
[8 9 10 11]
|
||||
[12 13 14 15]
|
||||
|
||||
Then the data of A would be distributed as follow between the 16 CUDA threads:
|
||||
L(A) = [ {0,8} , {1,9} , {2,10}, {3,11}, {0,8} , {1, 9} , {2, 10}, {3, 11},
|
||||
{4,12}, {5,13}, {6,14}, {7,15}, {4,12}, {5, 13}, {6, 14}, {7, 15} ]
|
||||
}];
|
||||
}
|
||||
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Blocked Layout Encoding
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def TritonGPUBlockedEncodingAttr : TritonGPU_Attr<"TritonGPUBlockedEncoding"> {
|
||||
let mnemonic = "blocked_layout";
|
||||
let mnemonic = "blocked";
|
||||
|
||||
let description = [{
|
||||
An encoding where each warp owns a contiguous portion of the target tensor. This is typically the kind of data layout
|
||||
consumed (and returned) by LoadInst.
|
||||
For example, a row-major coalesced layout may distribute a 64x16 tensor over 2 warps (i.e. 64 threads) as follows:
|
||||
used to promote memory coalescing in LoadInst and StoreInst.
|
||||
It is characterized by three tuples -- thread tile size, warp tile size, and block tile size -- which
|
||||
specify the amount of elements owned by each CUDA thread, warp and CTA respectively.
|
||||
|
||||
thread tile size 2
|
||||
- - - - - - /\ - - - - - -
|
||||
warp | thread || A_{0, 0}[T0] A_{0, 1}[T0] ... A_{0, 6}[T3] A_{0, 7}[T3] A_{0, 8}[T0] A_{0, 9}[T0] ... A_{0, 14}[T3] A_{0, 15}[T3]
|
||||
tile | tile size 2 || A_{1, 0}[T0] A_{1, 1}[T0] ... A_{1, 6}[T3] A_{1, 7}[T3] A_{1, 8}[T0] A_{1, 9}[T0] ... A_{1, 14}[T3] A_{1, 15}[T3]
|
||||
size } ....
|
||||
32 | A_{30, 0}[T60] A_{14, 1}[T60] ... A_{14, 6}[T63] A_{14, 7}[T63] A_{14, 8}[T60] A_{14, 9}[T60] ... A_{14, 14}[T63] A_{14, 15}[T63]
|
||||
| A_{31, 0}[T60] A_{15, 1}[T60] ... A_{15, 6}[T63] A_{15, 7}[T63] A_{15, 8}[T60] A_{15, 9}[T60] ... A_{15, 14}[T63] A_{15, 15}[T63]
|
||||
-----------------------------/\-----------------------------------
|
||||
warp tile size 8
|
||||
For example, a row-major coalesced layout may partition a 16x16 tensor over 2 warps (i.e. 64 threads) as follows.
|
||||
|
||||
|
||||
A_{32, 0}[T0] A_{32, 1}[T0] ... A_{32, 6}[T3] A_{32, 7}[T3] A_{32, 8}[T0] A_{32, 9}[T0] ... A_{32, 14}[T3] A_{32, 15}[T3]
|
||||
A_{33, 0}[T0] A_{33, 1}[T0] ... A_{33, 6}[T3] A_{33, 7}[T3] A_{33, 8}[T0] A_{33, 9}[T0] ... A_{33, 14}[T3] A_{33, 15}[T3]
|
||||
[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ]
|
||||
[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ]
|
||||
[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ]
|
||||
[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ]
|
||||
...
|
||||
[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ]
|
||||
[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ]
|
||||
|
||||
A_{62, 0}[T60] A_{62, 1}[T60] ... A_{62, 6}[T63] A_{62, 7}[T63] A_{62, 8}[T60] A_{62, 9}[T60] ... A_{62, 14}[T63] A_{62, 15}[T63]
|
||||
A_{63, 0}[T60] A_{63, 1}[T60] ... A_{63, 6}[T63] A_{63, 7}[T63] A_{63, 8}[T60] A_{63, 9}[T60] ... A_{63, 14}[T63] A_{63, 15}[T63]
|
||||
for
|
||||
|
||||
And the associated TritonGPU MLIR
|
||||
#LAYOUT = #triton_gpu.blocked_layout<{
|
||||
threadTileSize = {2, 2}
|
||||
blockTileSize = {32, 8}
|
||||
#triton_gpu.blocked_layout<{
|
||||
sizePerThread = {2, 2}
|
||||
threadsPerWarp = {8, 4}
|
||||
warpsPerCTA = {1, 2}
|
||||
}>
|
||||
|
||||
// note to Da: In current Triton codebase, `nanoTileSize = threadTileSize`, and `macro-tile size = blockTileSize / threadTileSize`
|
||||
probably clearer to have easier semantics (i.e., size of each tile owned by a thread or a block)
|
||||
}];
|
||||
|
||||
|
||||
let builders = [
|
||||
// Custom builder initializes sizePerWarp and sizePerCTA automatically
|
||||
// TODO: compiles on MacOS but not linux?
|
||||
// AttrBuilder<(ins "ArrayRef<unsigned>":$sizePerThread,
|
||||
// "ArrayRef<unsigned>":$threadsPerWarp,
|
||||
// "ArrayRef<unsigned>":$warpsPerCTA,
|
||||
// "ArrayRef<unsigned>":$order), [{
|
||||
// int rank = threadsPerWarp.size();
|
||||
// SmallVector<unsigned, 4> sizePerWarp(rank);
|
||||
// SmallVector<unsigned, 4> sizePerCTA(rank);
|
||||
// for (unsigned i = 0; i < rank; i++) {
|
||||
// sizePerWarp.push_back(sizePerThread[i] * threadsPerWarp[i]);
|
||||
// sizePerCTA.push_back(sizePerWarp[i] * warpsPerCTA[i]);
|
||||
// }
|
||||
// return $_get(context, sizePerThread, threadsPerWarp, warpsPerCTA, order, sizePerWarp, sizePerCTA);
|
||||
// }]>,
|
||||
// Custom builder initializes sizePerWarp and sizePerCTA automatically
|
||||
// Default builder takes sizePerThread, order and numWarps, and tries to
|
||||
// pack numWarps*32 threads in the provided order for use in a type
|
||||
// of the given shape.
|
||||
AttrBuilder<(ins "ArrayRef<int64_t>":$shape,
|
||||
"ArrayRef<unsigned>":$sizePerThread,
|
||||
"ArrayRef<unsigned>":$order,
|
||||
"unsigned":$numWarps), [{
|
||||
int rank = sizePerThread.size();
|
||||
int remainingWarps = numWarps;
|
||||
int remainingLanes = 32;
|
||||
SmallVector<unsigned, 4> threadsPerWarp(rank);
|
||||
SmallVector<unsigned, 4> warpsPerCTA(rank);
|
||||
for (int _dim = 0; _dim < rank; ++_dim) {
|
||||
int dim = order[_dim];
|
||||
int maxNumThreads = int(shape[dim]) / sizePerThread[dim];
|
||||
warpsPerCTA[dim] = std::clamp(remainingWarps, 1, maxNumThreads);
|
||||
threadsPerWarp[dim] = std::clamp(remainingLanes, 1, maxNumThreads);
|
||||
remainingWarps /= warpsPerCTA[dim];
|
||||
remainingLanes /= threadsPerWarp[dim];
|
||||
}
|
||||
|
||||
return $_get(context, sizePerThread, threadsPerWarp, warpsPerCTA, order);
|
||||
|
||||
}]>
|
||||
];
|
||||
|
||||
|
||||
let parameters = (
|
||||
ins
|
||||
// TODO: should we rename this as laneTileSize?
|
||||
ArrayRefParameter<
|
||||
"unsigned",
|
||||
/*desc*/"size of a tile that is holded by a thread"
|
||||
>:$threadTileSize,
|
||||
ArrayRefParameter<
|
||||
"unsigned",
|
||||
"size of the a tile that is holded by a warp"
|
||||
>:$warpTileSize,
|
||||
ArrayRefParameter<
|
||||
"unsigned",
|
||||
"size of a tile that is holded by a thread block"
|
||||
>:$blockTileSize,
|
||||
// // TODO: It seems that we don't need this (because we can re-compute this)
|
||||
// ArrayRefParameter<"unsigned">:$reptitions,
|
||||
ArrayRefParameter<"unsigned">:$sizePerThread,
|
||||
ArrayRefParameter<"unsigned">:$threadsPerWarp,
|
||||
ArrayRefParameter<"unsigned">:$warpsPerCTA,
|
||||
// fastest-changing axis first
|
||||
ArrayRefParameter<
|
||||
"unsigned",
|
||||
"order of axes by the rate of changing"
|
||||
>:$order,
|
||||
ArrayRefParameter<"unsigned">:$broadcastAxis
|
||||
// "AffineMap":$threadOrdering,
|
||||
// "AffineMap":warpOrdering,
|
||||
// "AffineMap":$blockOrdering,
|
||||
|
||||
>:$order
|
||||
// These attributes can be inferred from the rest
|
||||
// ArrayRefParameter<"unsigned">:$sizePerWarp,
|
||||
// ArrayRefParameter<"unsigned">:$sizePerCTA
|
||||
);
|
||||
|
||||
// let genVerifyDecl = 1;
|
||||
}
|
||||
|
||||
def TritonGPUBlockedMulticastEncodingAttr
|
||||
: TritonGPU_Attr<"TritonGPUBlockedMulticastEncoding"> {
|
||||
let mnemonic = "blocked_multicast_layout";
|
||||
|
||||
let description = [{
|
||||
to be broadcasted to blocked_layout
|
||||
}];
|
||||
|
||||
// This needs to be synced with BlockedEncoding
|
||||
let parameters = (
|
||||
ins
|
||||
ArrayRefParameter<"unsigned">:$threadTileSize,
|
||||
ArrayRefParameter<"unsigned">:$warpTileSize,
|
||||
ArrayRefParameter<"unsigned">:$blockTileSize,
|
||||
ArrayRefParameter<"unsigned">:$order,
|
||||
// unique to broadcasted layout
|
||||
ArrayRefParameter<"unsigned">:$broadcastAxis
|
||||
);
|
||||
|
||||
// let genVerifyDecl = 1;
|
||||
}
|
||||
//===----------------------------------------------------------------------===//
|
||||
// MMA Layout Encoding
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TODO: MMAv1 and MMAv2 should be two instances of the same class
|
||||
|
||||
def TritonGPUMmaEncodingAttr : TritonGPU_Attr<"TritonGPUMmaEncoding"> {
|
||||
let mnemonic = "mma_layout";
|
||||
|
||||
let description = [{TODO: I think we may be able to implement it as a special-case of Distributed encoding with maybe one more warpTileSize attribute!}];
|
||||
|
||||
let parameters = (
|
||||
ins
|
||||
// only used by Volta mma.884
|
||||
ArrayRefParameter<"unsigned">:$fragmentPerWarp,
|
||||
// aka shapeOfInstr (e.g., {16,8,16})
|
||||
ArrayRefParameter<"unsigned">:$shapePerWarp,
|
||||
// TODO: should we rename this as warpTileSize? (consistent naming with Distributed layout)
|
||||
ArrayRefParameter<"unsigned">:$warpPerTile,
|
||||
// TODO: should we rename this as blockTileSize? (consistent naming with Distributed layout)
|
||||
ArrayRefParameter<"unsigned">:$shapePerTile,
|
||||
// TODO: should Distributed layout also
|
||||
ArrayRefParameter<"unsigned">:$repetitions,
|
||||
ArrayRefParameter<"unsigned">:$contigPerThread,
|
||||
ArrayRefParameter<"unsigned">:$broadcastAxis
|
||||
// "AffineMap":$warpOrdering,
|
||||
// "AffineMap":$blockOrdering
|
||||
);
|
||||
|
||||
// let genVerifyDecl = 1;
|
||||
}
|
||||
|
||||
def TritonGPUMmaMulticastEncodingAttr
|
||||
: TritonGPU_Attr<"TritonGPUMmaMulticastEncoding"> {
|
||||
let mnemonic = "mma_multicast_layout";
|
||||
let mnemonic = "mma";
|
||||
|
||||
let description = [{
|
||||
To be broadcasted to mma.
|
||||
}];
|
||||
An encoding for tensors that have been produced by tensor cores.
|
||||
It is characterized by two parameters:
|
||||
- A 'version' which specifies the generation the tensor cores
|
||||
whose output is being partitioned: 1 for first-gen tensor cores (Volta),
|
||||
and 2 for second-gen tensor cores (Turing/Ampere).
|
||||
- A `blockTileSize` to indicate how data should be
|
||||
partitioned between warps.
|
||||
|
||||
// -------------------------------- version = 1 --------------------------- //
|
||||
|
||||
For first-gen tensor cores, the implicit warpTileSize is [16, 16].
|
||||
Information about this layout can be found in the official PTX documentation
|
||||
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html
|
||||
(mma.884 section, FP32 accumulator).
|
||||
|
||||
For example, the matrix L corresponding to blockTileSize=[32,16] is:
|
||||
|
||||
warp 0
|
||||
--------------------------------/\-------------------------------
|
||||
[ 0 0 2 2 0 0 2 2 4 4 6 6 4 4 6 6 ]
|
||||
[ 1 1 3 3 1 1 3 3 5 5 7 7 5 5 7 7 ]
|
||||
[ 0 0 2 2 0 0 2 2 4 4 6 6 4 4 6 6 ]
|
||||
[ 1 1 3 3 1 1 3 3 5 5 7 7 5 5 7 7 ]
|
||||
[ 16 16 18 18 16 16 18 18 20 20 22 22 20 20 22 22]
|
||||
[ 17 17 19 19 17 17 19 19 21 21 23 23 21 21 23 23]
|
||||
[ 16 16 18 18 16 16 18 18 20 20 22 22 20 20 22 22]
|
||||
[ 17 17 19 19 17 17 19 19 21 21 23 23 21 21 23 23]
|
||||
[ 8 8 10 10 8 8 10 10 12 12 14 14 12 12 14 14]
|
||||
[ 9 9 11 11 9 9 11 11 13 13 15 15 13 13 15 15]
|
||||
[ ..............................................................
|
||||
[ ..............................................................
|
||||
[ 24 24 26 26 24 24 26 26 28 28 30 30 28 28 30 30]
|
||||
[ 25 25 27 27 25 25 27 27 29 29 31 31 29 29 31 31]
|
||||
|
||||
warp 1 = warp0 + 32
|
||||
--------------------------------/\-------------------------------
|
||||
[ 32 32 34 34 32 32 34 34 36 36 38 38 36 36 38 38]
|
||||
[ 33 33 35 35 33 33 35 35 37 37 39 39 37 37 39 39]
|
||||
[ ..............................................................
|
||||
[ ..............................................................
|
||||
[ 56 56 58 58 56 56 58 58 60 60 62 62 60 60 62 62]
|
||||
[ 57 57 59 59 57 57 59 59 61 61 63 63 61 61 63 63]
|
||||
|
||||
// -------------------------------- version = 2 --------------------------- //
|
||||
|
||||
For second-gen tensor cores, the implicit warpTileSize is [16, 8].
|
||||
Information about this layout can be found in the official PTX documentation
|
||||
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html
|
||||
(mma.16816 section, FP32 accumulator).
|
||||
|
||||
For example, the matrix L corresponding to blockTileSize=[32,16] is:
|
||||
warp 0 warp 1
|
||||
-----------------/\------------- ----------------/\-------------
|
||||
[ 0 0 1 1 2 2 3 3 32 32 33 33 34 34 35 35
|
||||
[ 4 4 5 5 6 6 7 7 36 36 37 37 38 38 39 39
|
||||
[ .............................. ..............................
|
||||
[ 28 28 29 29 30 30 31 31 60 60 61 61 62 62 63 63
|
||||
[ 0 0 1 1 2 2 3 3 32 32 33 33 34 34 35 35
|
||||
[ 4 4 5 5 6 6 7 7 36 36 37 37 38 38 39 39
|
||||
[ .............................. ..............................
|
||||
[ 28 28 29 29 30 30 31 31 60 60 61 61 62 62 63 63
|
||||
|
||||
warp 3 warp 4
|
||||
----------------/\------------- ----------------/\-------------
|
||||
[ 64 64 65 65 66 66 67 67 96 96 97 97 98 98 99 99
|
||||
[ 68 68 69 69 70 70 71 71 100 100 101 101 102 102 103 103
|
||||
[ .............................. ...............................
|
||||
[ 92 92 93 93 94 94 95 95 124 124 125 125 126 126 127 127
|
||||
[ 64 64 65 65 66 66 67 67 96 96 97 97 98 98 99 99
|
||||
[ 68 68 69 69 70 70 71 71 100 100 101 101 102 102 103 103
|
||||
[ .............................. ...............................
|
||||
[ 92 92 93 93 94 94 95 95 124 124 125 125 126 126 127 127
|
||||
|
||||
}];
|
||||
|
||||
// This needs to be synced with MmaEncoding
|
||||
let parameters = (
|
||||
ins
|
||||
ArrayRefParameter<"unsigned">:$fragmentPerWarp,
|
||||
ArrayRefParameter<"unsigned">:$shapePerWarp,
|
||||
ArrayRefParameter<"unsigned">:$warpPerTile,
|
||||
ArrayRefParameter<"unsigned">:$shapePerTile,
|
||||
ArrayRefParameter<"unsigned">:$repetitions,
|
||||
ArrayRefParameter<"unsigned">:$contigPerThread,
|
||||
// unique to broadcasted layout
|
||||
ArrayRefParameter<"unsigned">:$broadcastAxis
|
||||
"unsigned":$version,
|
||||
ArrayRefParameter<"unsigned">:$warpsPerCTA
|
||||
);
|
||||
|
||||
// let genVerifyDecl = 1;
|
||||
}
|
||||
|
||||
|
||||
#endif
|
||||
|
@@ -7,16 +7,7 @@ def TritonGPUPipeline : Pass<"tritongpu-pipeline", "mlir::ModuleOp"> {
|
||||
let summary = "pipeline";
|
||||
|
||||
let description = [{
|
||||
scf.for() {
|
||||
%a = load %a_ptr;
|
||||
%b = load %b_ptr;
|
||||
|
||||
%d = dot %a, %b, %c;
|
||||
}
|
||||
|
||||
=>
|
||||
|
||||
...
|
||||
TODO
|
||||
}];
|
||||
|
||||
let constructor = "mlir::createTritonGPUPipelinePass()";
|
||||
|
@@ -13,11 +13,11 @@ namespace mlir {
|
||||
|
||||
class TritonGPUTypeConverter : public TypeConverter {
|
||||
public:
|
||||
TritonGPUTypeConverter(MLIRContext *context, int numThreads);
|
||||
TritonGPUTypeConverter(MLIRContext *context, int numWarps);
|
||||
|
||||
private:
|
||||
MLIRContext *context;
|
||||
int numThreads;
|
||||
int numWarps;
|
||||
};
|
||||
|
||||
class TritonGPUConversionTarget : public ConversionTarget {
|
||||
@@ -26,9 +26,6 @@ class TritonGPUConversionTarget : public ConversionTarget {
|
||||
public:
|
||||
explicit TritonGPUConversionTarget(MLIRContext &ctx,
|
||||
TritonGPUTypeConverter &typeConverter);
|
||||
|
||||
/// update layouts & insert ConvertLayoutOp if necessary
|
||||
LogicalResult refineLayouts(ModuleOp mod, int numThreads);
|
||||
};
|
||||
|
||||
} // namespace mlir
|
||||
|
@@ -48,17 +48,17 @@ AxisInfo AxisInfo::getPessimisticValueState(Value value) {
|
||||
divHint = attr.cast<IntegerAttr>().getValue().getZExtValue();
|
||||
}
|
||||
}
|
||||
ContiguityT contiguity(rank, 1);
|
||||
DivisibilityT divisibility(rank, divHint);
|
||||
ConstancyT constancy(rank, 1);
|
||||
DimVectorT contiguity(rank, 1);
|
||||
DimVectorT divisibility(rank, divHint);
|
||||
DimVectorT constancy(rank, 1);
|
||||
return AxisInfo(contiguity, divisibility, constancy);
|
||||
}
|
||||
|
||||
// The gcd of both arguments for each dimension
|
||||
AxisInfo AxisInfo::join(const AxisInfo &lhs, const AxisInfo &rhs) {
|
||||
ContiguityT retContiguity;
|
||||
DivisibilityT retDivisibility;
|
||||
ConstancyT retConstancy;
|
||||
DimVectorT retContiguity;
|
||||
DimVectorT retDivisibility;
|
||||
DimVectorT retConstancy;
|
||||
for (size_t d = 0; d < lhs.getRank(); d++) {
|
||||
retContiguity.push_back(gcd(lhs.getContiguity(d), rhs.getContiguity(d)));
|
||||
retDivisibility.push_back(
|
||||
@@ -78,9 +78,9 @@ AxisInfo AxisInfoAnalysis::visitBinaryOp(
|
||||
const std::function<int(AxisInfo, AxisInfo, int)> &getDivisibility,
|
||||
const std::function<int(AxisInfo, AxisInfo, int)> &getConstancy) {
|
||||
int rank = lhsInfo.getRank();
|
||||
AxisInfo::ContiguityT newContiguity;
|
||||
AxisInfo::DivisibilityT newDivisibility;
|
||||
AxisInfo::ConstancyT newConstancy;
|
||||
AxisInfo::DimVectorT newContiguity;
|
||||
AxisInfo::DimVectorT newDivisibility;
|
||||
AxisInfo::DimVectorT newConstancy;
|
||||
for (size_t d = 0; d < rank; d++) {
|
||||
newContiguity.push_back(getContiguity(lhsInfo, rhsInfo, d));
|
||||
newDivisibility.push_back(getDivisibility(lhsInfo, rhsInfo, d));
|
||||
@@ -101,9 +101,9 @@ ChangeResult AxisInfoAnalysis::visitOperation(
|
||||
llvm::dyn_cast<triton::MakeRangeOp>(op)) {
|
||||
int start = make_range.start();
|
||||
int end = make_range.end();
|
||||
AxisInfo::ContiguityT contiguity = {end - start};
|
||||
AxisInfo::DivisibilityT divisibility = {highestPowOf2Divisor(start)};
|
||||
AxisInfo::ConstancyT constancy = {1};
|
||||
AxisInfo::DimVectorT contiguity = {end - start};
|
||||
AxisInfo::DimVectorT divisibility = {highestPowOf2Divisor(start)};
|
||||
AxisInfo::DimVectorT constancy = {1};
|
||||
curr = AxisInfo(contiguity, divisibility, constancy);
|
||||
}
|
||||
// Constant
|
||||
@@ -119,9 +119,9 @@ ChangeResult AxisInfoAnalysis::visitOperation(
|
||||
auto value = splatAttr.getSplatValue<int>();
|
||||
TensorType ty = splatAttr.getType().cast<TensorType>();
|
||||
curr = AxisInfo(
|
||||
AxisInfo::ContiguityT(ty.getRank(), 1),
|
||||
AxisInfo::DivisibilityT(ty.getRank(), highestPowOf2Divisor(value)),
|
||||
AxisInfo::ConstancyT(ty.getShape().begin(), ty.getShape().end()));
|
||||
AxisInfo::DimVectorT(ty.getRank(), 1),
|
||||
AxisInfo::DimVectorT(ty.getRank(), highestPowOf2Divisor(value)),
|
||||
AxisInfo::DimVectorT(ty.getShape().begin(), ty.getShape().end()));
|
||||
}
|
||||
}
|
||||
// Addition
|
||||
@@ -156,9 +156,9 @@ ChangeResult AxisInfoAnalysis::visitOperation(
|
||||
Type _retTy = *op->result_type_begin();
|
||||
TensorType retTy = _retTy.cast<TensorType>();
|
||||
AxisInfo opInfo = operands[0]->getValue();
|
||||
AxisInfo::ContiguityT contiguity;
|
||||
AxisInfo::DivisibilityT divisibility;
|
||||
AxisInfo::ConstancyT constancy;
|
||||
AxisInfo::DimVectorT contiguity;
|
||||
AxisInfo::DimVectorT divisibility;
|
||||
AxisInfo::DimVectorT constancy;
|
||||
for (size_t d = 0; d < retTy.getRank(); d++) {
|
||||
contiguity.push_back(1);
|
||||
divisibility.push_back(opInfo.getDivisibility(0));
|
||||
@@ -167,8 +167,7 @@ ChangeResult AxisInfoAnalysis::visitOperation(
|
||||
curr = AxisInfo(contiguity, divisibility, constancy);
|
||||
}
|
||||
// Reshape
|
||||
// TODO: Replace by `unsqueeze`
|
||||
if (llvm::isa<triton::ReshapeOp>(op)) {
|
||||
if (llvm::isa<triton::ViewOp>(op)) {
|
||||
Type _retTy = *op->result_type_begin();
|
||||
Type _opTy = *op->operand_type_begin();
|
||||
TensorType retTy = _retTy.cast<TensorType>();
|
||||
@@ -176,9 +175,9 @@ ChangeResult AxisInfoAnalysis::visitOperation(
|
||||
ArrayRef<int64_t> retShape = retTy.getShape();
|
||||
ArrayRef<int64_t> opShape = opTy.getShape();
|
||||
AxisInfo opInfo = operands[0]->getValue();
|
||||
AxisInfo::ContiguityT contiguity;
|
||||
AxisInfo::DivisibilityT divisibility;
|
||||
AxisInfo::ConstancyT constancy;
|
||||
AxisInfo::DimVectorT contiguity;
|
||||
AxisInfo::DimVectorT divisibility;
|
||||
AxisInfo::DimVectorT constancy;
|
||||
bool is_skewed = false;
|
||||
size_t current = 0;
|
||||
for (size_t d = 0; d < retTy.getRank(); d++) {
|
||||
@@ -209,9 +208,9 @@ ChangeResult AxisInfoAnalysis::visitOperation(
|
||||
ArrayRef<int64_t> retShape = retTy.getShape();
|
||||
ArrayRef<int64_t> opShape = opTy.getShape();
|
||||
AxisInfo opInfo = operands[0]->getValue();
|
||||
AxisInfo::ContiguityT contiguity;
|
||||
AxisInfo::DivisibilityT divisibility;
|
||||
AxisInfo::ConstancyT constancy;
|
||||
AxisInfo::DimVectorT contiguity;
|
||||
AxisInfo::DimVectorT divisibility;
|
||||
AxisInfo::DimVectorT constancy;
|
||||
for (size_t d = 0; d < retTy.getRank(); d++) {
|
||||
contiguity.push_back(opShape[d] == 1 ? 1 : opInfo.getContiguity(d));
|
||||
divisibility.push_back(opInfo.getDivisibility(d));
|
||||
|
@@ -218,6 +218,28 @@ struct TritonGenericPattern : public OpConversionPattern<Op> {
|
||||
}
|
||||
};
|
||||
|
||||
struct TritonBroadcastPattern
|
||||
: public OpConversionPattern<triton::BroadcastOp> {
|
||||
using OpConversionPattern<triton::BroadcastOp>::OpConversionPattern;
|
||||
|
||||
// This creates a tensor with the new shape but the argument's layout
|
||||
LogicalResult
|
||||
matchAndRewrite(BroadcastOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto srcType = adaptor.src().getType().cast<RankedTensorType>();
|
||||
auto srcEncoding = srcType.getEncoding();
|
||||
if (!srcEncoding)
|
||||
return failure();
|
||||
auto opType = op.getType().cast<RankedTensorType>();
|
||||
Type retType = RankedTensorType::get(opType.getShape(),
|
||||
opType.getElementType(), srcEncoding);
|
||||
// Type retType = this->getTypeConverter()->convertType(op.getType());
|
||||
rewriter.replaceOpWithNewOp<triton::BroadcastOp>(op, retType,
|
||||
adaptor.getOperands());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct TritonReducePattern : public OpConversionPattern<triton::ReduceOp> {
|
||||
using OpConversionPattern<triton::ReduceOp>::OpConversionPattern;
|
||||
|
||||
@@ -234,12 +256,12 @@ struct TritonReducePattern : public OpConversionPattern<triton::ReduceOp> {
|
||||
void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns) {
|
||||
MLIRContext *context = patterns.getContext();
|
||||
patterns.add<TritonGenericPattern<triton::ReshapeOp>,
|
||||
TritonGenericPattern<triton::SplatOp>,
|
||||
TritonGenericPattern<triton::BroadcastOp>,
|
||||
TritonGenericPattern<triton::GEPOp>, TritonReducePattern,
|
||||
TritonMakeRangePattern, TritonDotPattern, TritonLoadPattern,
|
||||
TritonStorePattern>(typeConverter, context);
|
||||
patterns.add< // TODO: view should have custom pattern that views the layout
|
||||
TritonGenericPattern<triton::ViewOp>,
|
||||
TritonGenericPattern<triton::SplatOp>, TritonBroadcastPattern,
|
||||
TritonGenericPattern<triton::GEPOp>, TritonReducePattern,
|
||||
TritonMakeRangePattern, TritonDotPattern, TritonLoadPattern,
|
||||
TritonStorePattern>(typeConverter, context);
|
||||
}
|
||||
|
||||
//
|
||||
@@ -317,9 +339,8 @@ public:
|
||||
void runOnOperation() override {
|
||||
MLIRContext *context = &getContext();
|
||||
ModuleOp mod = getOperation();
|
||||
int numThreads = numWarps * 32;
|
||||
// type converter
|
||||
TritonGPUTypeConverter typeConverter(context, numThreads);
|
||||
TritonGPUTypeConverter typeConverter(context, numWarps);
|
||||
TritonGPUConversionTarget target(*context, typeConverter);
|
||||
// rewrite patterns
|
||||
RewritePatternSet patterns(context);
|
||||
@@ -335,8 +356,8 @@ public:
|
||||
|
||||
// update layouts
|
||||
// broadcast src => multicast, dst => broadcasted
|
||||
if (failed(target.refineLayouts(mod, numWarps)))
|
||||
return signalPassFailure();
|
||||
// if (failed(target.refineLayouts(mod, numWarps)))
|
||||
// return signalPassFailure();
|
||||
}
|
||||
};
|
||||
|
||||
|
@@ -30,13 +30,28 @@ static LogicalResult parseIntArrayAttr(AsmParser &parser,
|
||||
return success();
|
||||
};
|
||||
|
||||
static LogicalResult parseUInt(AsmParser &parser, const NamedAttribute &attr,
|
||||
unsigned &value, StringRef desc) {
|
||||
auto intAttr = attr.getValue().dyn_cast<IntegerAttr>();
|
||||
if (!intAttr) {
|
||||
parser.emitError(parser.getNameLoc(), "expected an integer ") << desc;
|
||||
return failure();
|
||||
}
|
||||
value = intAttr.getUInt();
|
||||
return success();
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Attribute methods
|
||||
//===----------------------------------------------------------------------===//
|
||||
#define GET_ATTRDEF_CLASSES
|
||||
#include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.cpp.inc"
|
||||
|
||||
static Attribute parseBlocked(AsmParser &parser, Type type) {
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Blocked Encoding
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
Attribute TritonGPUBlockedEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
if (parser.parseLess().failed())
|
||||
return {};
|
||||
// Parse the data as a dictionary
|
||||
@@ -46,32 +61,30 @@ static Attribute parseBlocked(AsmParser &parser, Type type) {
|
||||
if (parser.parseGreater().failed())
|
||||
return {};
|
||||
|
||||
SmallVector<unsigned, 2> threadTileSize;
|
||||
SmallVector<unsigned, 2> warpTileSize;
|
||||
SmallVector<unsigned, 2> blockTileSize;
|
||||
SmallVector<unsigned, 2> sizePerThread;
|
||||
SmallVector<unsigned, 2> threadsPerWarp;
|
||||
SmallVector<unsigned, 2> warpsPerCTA;
|
||||
SmallVector<unsigned, 2> order;
|
||||
SmallVector<unsigned, 2> broadcastAxis;
|
||||
|
||||
for (const NamedAttribute &attr : dict) {
|
||||
if (attr.getName() == "threadTileSize") {
|
||||
if (parseIntArrayAttr(parser, attr, threadTileSize, "thread tile size")
|
||||
if (attr.getName() == "sizePerThread") {
|
||||
if (parseIntArrayAttr(parser, attr, sizePerThread,
|
||||
"number of elements per thread")
|
||||
.failed())
|
||||
return {};
|
||||
} else if (attr.getName() == "warpTileSize") {
|
||||
if (parseIntArrayAttr(parser, attr, warpTileSize, "warp tile size")
|
||||
} else if (attr.getName() == "threadsPerWarp") {
|
||||
if (parseIntArrayAttr(parser, attr, threadsPerWarp,
|
||||
"number of threads per warp")
|
||||
.failed())
|
||||
return {};
|
||||
} else if (attr.getName() == "blockTileSize") {
|
||||
if (parseIntArrayAttr(parser, attr, blockTileSize, "block tile size")
|
||||
} else if (attr.getName() == "warpsPerCTA") {
|
||||
if (parseIntArrayAttr(parser, attr, warpsPerCTA,
|
||||
"number of warps per CTA")
|
||||
.failed())
|
||||
return {};
|
||||
} else if (attr.getName() == "order") {
|
||||
if (parseIntArrayAttr(parser, attr, order, "order").failed())
|
||||
return {};
|
||||
} else if (attr.getName() == "broadcastAxis") {
|
||||
if (parseIntArrayAttr(parser, attr, broadcastAxis, "broadcastAxis")
|
||||
.failed())
|
||||
return {};
|
||||
} else {
|
||||
parser.emitError(parser.getNameLoc(), "unexpected key: ")
|
||||
<< attr.getName().strref();
|
||||
@@ -80,39 +93,23 @@ static Attribute parseBlocked(AsmParser &parser, Type type) {
|
||||
}
|
||||
|
||||
return parser.getChecked<TritonGPUBlockedEncodingAttr>(
|
||||
parser.getContext(), threadTileSize, warpTileSize, blockTileSize, order,
|
||||
broadcastAxis);
|
||||
}
|
||||
|
||||
template <class T>
|
||||
static void printBlocked(AsmPrinter &printer, const T *attr) {
|
||||
printer << "<{"
|
||||
<< "threadTileSize = [" << attr->getThreadTileSize() << "]"
|
||||
<< ", warpTileSize = [" << attr->getWarpTileSize() << "]"
|
||||
<< ", blockTileSize = [" << attr->getBlockTileSize() << "]"
|
||||
<< ", order = [" << attr->getOrder() << "]"
|
||||
<< ", broadcastAxis = [" << attr->getBroadcastAxis() << "]"
|
||||
<< "}>";
|
||||
}
|
||||
|
||||
Attribute TritonGPUBlockedEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
return parseBlocked(parser, type);
|
||||
parser.getContext(), sizePerThread, threadsPerWarp, warpsPerCTA, order);
|
||||
}
|
||||
|
||||
void TritonGPUBlockedEncodingAttr::print(mlir::AsmPrinter &printer) const {
|
||||
printBlocked(printer, this);
|
||||
printer << "<{"
|
||||
<< "sizePerThread = [" << getSizePerThread() << "]"
|
||||
<< ", threadsPerWarp = [" << getThreadsPerWarp() << "]"
|
||||
<< ", warpsPerCTA = [" << getWarpsPerCTA() << "]"
|
||||
<< ", order = [" << getOrder() << "]"
|
||||
<< "}>";
|
||||
}
|
||||
|
||||
Attribute TritonGPUBlockedMulticastEncodingAttr::parse(AsmParser &parser,
|
||||
Type type) {
|
||||
return parseBlocked(parser, type);
|
||||
}
|
||||
//===----------------------------------------------------------------------===//
|
||||
// MMA encoding
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void TritonGPUBlockedMulticastEncodingAttr::print(AsmPrinter &printer) const {
|
||||
printBlocked(printer, this);
|
||||
}
|
||||
|
||||
static Attribute parseMma(AsmParser &parser, Type type) {
|
||||
Attribute TritonGPUMmaEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
if (parser.parseLess().failed())
|
||||
return {};
|
||||
DictionaryAttr dict;
|
||||
@@ -121,76 +118,34 @@ static Attribute parseMma(AsmParser &parser, Type type) {
|
||||
if (parser.parseGreater().failed())
|
||||
return {};
|
||||
|
||||
SmallVector<unsigned, 2> fragmentPerWarp;
|
||||
SmallVector<unsigned, 2> shapePerWarp;
|
||||
SmallVector<unsigned, 2> warpPerTile;
|
||||
SmallVector<unsigned, 2> shapePerTile;
|
||||
SmallVector<unsigned, 2> repetitions;
|
||||
SmallVector<unsigned, 2> contigPerThread;
|
||||
SmallVector<unsigned, 2> broadcastAxis;
|
||||
unsigned version = 0;
|
||||
SmallVector<unsigned, 2> warpsPerCTA;
|
||||
|
||||
for (const NamedAttribute &attr : dict) {
|
||||
if (attr.getName() == "fragmentPerWarp") {
|
||||
if (parseIntArrayAttr(parser, attr, fragmentPerWarp, "fragmentPerWarp")
|
||||
.failed())
|
||||
if (attr.getName() == "version") {
|
||||
if (parseUInt(parser, attr, version, "version").failed())
|
||||
return {};
|
||||
} else if (attr.getName() == "shapePerWarp") {
|
||||
if (parseIntArrayAttr(parser, attr, shapePerWarp, "shapePerWarp")
|
||||
.failed())
|
||||
}
|
||||
if (attr.getName() == "warpsPerCTA") {
|
||||
if (parseIntArrayAttr(parser, attr, warpsPerCTA, "warpsPerCTA").failed())
|
||||
return {};
|
||||
} else if (attr.getName() == "warpPerTile") {
|
||||
if (parseIntArrayAttr(parser, attr, warpPerTile, "warpPerTile").failed())
|
||||
return {};
|
||||
} else if (attr.getName() == "shapePerTile") {
|
||||
if (parseIntArrayAttr(parser, attr, shapePerTile, "shapePerTile")
|
||||
.failed())
|
||||
return {};
|
||||
} else if (attr.getName() == "repetitions") {
|
||||
if (parseIntArrayAttr(parser, attr, repetitions, "repetitions").failed())
|
||||
return {};
|
||||
} else if (attr.getName() == "contigPerThread") {
|
||||
if (parseIntArrayAttr(parser, attr, contigPerThread, "contigPerThread")
|
||||
.failed())
|
||||
return {};
|
||||
} else {
|
||||
parser.emitError(parser.getNameLoc(), "unexpected key: ")
|
||||
<< attr.getName().strref();
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
return parser.getChecked<TritonGPUMmaEncodingAttr>(
|
||||
parser.getContext(), fragmentPerWarp, shapePerWarp, warpPerTile,
|
||||
shapePerTile, repetitions, contigPerThread, broadcastAxis);
|
||||
}
|
||||
|
||||
template <class T> static void printMma(AsmPrinter &printer, T *attr) {
|
||||
printer << "<{"
|
||||
<< "fragmentPerWarp = [" << attr->getFragmentPerWarp() << "]"
|
||||
<< ", shapePerWarp = [" << attr->getShapePerWarp() << "]"
|
||||
<< ", warpPerTile = [" << attr->getWarpPerTile() << "]"
|
||||
<< ", shapePerTile = [" << attr->getShapePerTile() << "]"
|
||||
<< ", repetitions = [" << attr->getRepetitions() << "]"
|
||||
<< ", contigPerThread = [" << attr->getContigPerThread() << "]"
|
||||
<< "}>";
|
||||
}
|
||||
|
||||
Attribute TritonGPUMmaEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
return parseMma(parser, type);
|
||||
return parser.getChecked<TritonGPUMmaEncodingAttr>(parser.getContext(),
|
||||
version, warpsPerCTA);
|
||||
}
|
||||
|
||||
void TritonGPUMmaEncodingAttr::print(AsmPrinter &printer) const {
|
||||
printMma(printer, this);
|
||||
printer << "<{"
|
||||
<< "version = " << getVersion() << ", "
|
||||
<< "warpsPerCTA = [" << getWarpsPerCTA() << "]"
|
||||
<< "}>";
|
||||
}
|
||||
|
||||
Attribute TritonGPUMmaMulticastEncodingAttr::parse(AsmParser &parser,
|
||||
Type type) {
|
||||
return parseMma(parser, type);
|
||||
}
|
||||
|
||||
void TritonGPUMmaMulticastEncodingAttr::print(AsmPrinter &printer) const {
|
||||
printMma(printer, this);
|
||||
}
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Shared encoding
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
Attribute TritonGPUSharedEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
if (parser.parseLess().failed())
|
||||
@@ -207,26 +162,15 @@ Attribute TritonGPUSharedEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
unsigned maxPhase = 0;
|
||||
SmallVector<unsigned, 2> order;
|
||||
|
||||
auto parseUInt = [&parser](const NamedAttribute &attr, unsigned &value,
|
||||
StringRef desc) -> LogicalResult {
|
||||
auto intAttr = attr.getValue().dyn_cast<IntegerAttr>();
|
||||
if (!intAttr) {
|
||||
parser.emitError(parser.getNameLoc(), "expected an integer ") << desc;
|
||||
return failure();
|
||||
}
|
||||
value = intAttr.getUInt();
|
||||
return success();
|
||||
};
|
||||
|
||||
for (const NamedAttribute &attr : dict) {
|
||||
if (attr.getName() == "vec") {
|
||||
if (parseUInt(attr, vec, "vec").failed())
|
||||
if (parseUInt(parser, attr, vec, "vec").failed())
|
||||
return {};
|
||||
} else if (attr.getName() == "perPhase") {
|
||||
if (parseUInt(attr, perPhase, "perPhase").failed())
|
||||
if (parseUInt(parser, attr, perPhase, "perPhase").failed())
|
||||
return {};
|
||||
} else if (attr.getName() == "maxPhase") {
|
||||
if (parseUInt(attr, maxPhase, "maxPhase").failed())
|
||||
if (parseUInt(parser, attr, maxPhase, "maxPhase").failed())
|
||||
return {};
|
||||
} else if (attr.getName() == "order") {
|
||||
if (parseIntArrayAttr(parser, attr, order, "order").failed())
|
||||
@@ -250,6 +194,10 @@ void TritonGPUSharedEncodingAttr::print(AsmPrinter &printer) const {
|
||||
<< "}>";
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ASM Interface (i.e.: alias)
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class TritonGPUOpAsmInterface : public OpAsmDialectInterface {
|
||||
public:
|
||||
using OpAsmDialectInterface::OpAsmDialectInterface;
|
||||
@@ -257,72 +205,18 @@ public:
|
||||
AliasResult getAlias(Attribute attr, raw_ostream &os) const override {
|
||||
if (auto mmaAttr = attr.dyn_cast<TritonGPUMmaEncodingAttr>()) {
|
||||
os << "mma";
|
||||
TritonGPUOpAsmInterface::printMma(mmaAttr, os);
|
||||
return AliasResult::FinalAlias;
|
||||
} else if (auto mmaMulticastAttr =
|
||||
attr.dyn_cast<TritonGPUMmaMulticastEncodingAttr>()) {
|
||||
os << "mma_multicast";
|
||||
TritonGPUOpAsmInterface::printMma(mmaAttr, os);
|
||||
return AliasResult::FinalAlias;
|
||||
} else if (auto sharedAttr = attr.dyn_cast<TritonGPUSharedEncodingAttr>()) {
|
||||
os << "shared";
|
||||
TritonGPUOpAsmInterface::printShared(sharedAttr, os);
|
||||
return AliasResult::FinalAlias;
|
||||
} else if (auto blockedAttr =
|
||||
attr.dyn_cast<TritonGPUBlockedEncodingAttr>()) {
|
||||
os << "blocked";
|
||||
TritonGPUOpAsmInterface::printBlocked(blockedAttr, os);
|
||||
return AliasResult::FinalAlias;
|
||||
} else if (auto blockedMulticastAttr =
|
||||
attr.dyn_cast<TritonGPUBlockedMulticastEncodingAttr>()) {
|
||||
os << "blocked_multicast";
|
||||
TritonGPUOpAsmInterface::printBlocked(blockedMulticastAttr, os);
|
||||
return AliasResult::FinalAlias;
|
||||
}
|
||||
OpAsmDialectInterface::getAlias(attr, os);
|
||||
return AliasResult::FinalAlias;
|
||||
}
|
||||
|
||||
private:
|
||||
static void printMma(const TritonGPUMmaEncodingAttr &attr, raw_ostream &os) {
|
||||
TritonGPUOpAsmInterface::printArray(attr.getFragmentPerWarp(), os);
|
||||
TritonGPUOpAsmInterface::printArray(attr.getShapePerWarp(), os);
|
||||
TritonGPUOpAsmInterface::printArray(attr.getWarpPerTile(), os);
|
||||
TritonGPUOpAsmInterface::printArray(attr.getShapePerTile(), os);
|
||||
TritonGPUOpAsmInterface::printArray(attr.getRepetitions(), os);
|
||||
TritonGPUOpAsmInterface::printArray(attr.getContigPerThread(), os);
|
||||
}
|
||||
|
||||
static void printShared(const TritonGPUSharedEncodingAttr &attr,
|
||||
raw_ostream &os) {
|
||||
os << "_" << attr.getVec();
|
||||
os << "_" << attr.getPerPhase();
|
||||
os << "_" << attr.getMaxPhase();
|
||||
TritonGPUOpAsmInterface::printArray(attr.getOrder(), os);
|
||||
}
|
||||
|
||||
template <class T> static void printBlocked(const T &attr, raw_ostream &os) {
|
||||
TritonGPUOpAsmInterface::printArray(attr.getThreadTileSize(), os);
|
||||
TritonGPUOpAsmInterface::printArray(attr.getWarpTileSize(), os);
|
||||
TritonGPUOpAsmInterface::printArray(attr.getBlockTileSize(), os);
|
||||
TritonGPUOpAsmInterface::printArray(attr.getOrder(), os);
|
||||
TritonGPUOpAsmInterface::printArray(attr.getBroadcastAxis(), os);
|
||||
}
|
||||
|
||||
static void printArray(const ArrayRef<unsigned> &array, raw_ostream &os,
|
||||
const std::string &delimiter = "x") {
|
||||
os << "_";
|
||||
if (array.empty()) {
|
||||
os << "none";
|
||||
return;
|
||||
}
|
||||
for (unsigned i = 0; i < array.size(); i++) {
|
||||
os << array[i];
|
||||
if (i != array.size() - 1) {
|
||||
os << delimiter;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
void TritonGPUDialect::initialize() {
|
||||
|
@@ -3,6 +3,7 @@
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include <algorithm>
|
||||
#include <numeric>
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton::gpu;
|
||||
@@ -11,54 +12,26 @@ using namespace mlir::triton::gpu;
|
||||
// TypeConverter
|
||||
//
|
||||
TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
|
||||
int numThreads)
|
||||
: context(context), numThreads(numThreads) {
|
||||
int numWarps)
|
||||
: context(context), numWarps(numWarps) {
|
||||
// TODO: how does MLIR pick the right conversion?
|
||||
addConversion([](Type type) { return type; });
|
||||
addConversion([this](RankedTensorType tensorType) -> RankedTensorType {
|
||||
MLIRContext *context = this->context;
|
||||
int numThreads = this->numThreads;
|
||||
|
||||
llvm::ArrayRef<int64_t> shape = tensorType.getShape();
|
||||
Type elementType = tensorType.getElementType();
|
||||
int64_t rank = tensorType.getRank();
|
||||
int64_t numElements = tensorType.getNumElements();
|
||||
|
||||
// TODO: are there any better ways to raise this error?
|
||||
if (!(numElements >= numThreads)) {
|
||||
SmallVector<char> buffer;
|
||||
llvm::raw_svector_ostream os(buffer);
|
||||
os << tensorType << " has " << numElements << " numElements "
|
||||
<< " smaller than numThreads (" << numThreads << ")\n"
|
||||
<< "consider using smaller num-warps\n";
|
||||
llvm::report_fatal_error(os.str());
|
||||
}
|
||||
assert(numElements % numThreads == 0);
|
||||
|
||||
// or assert no encoding?
|
||||
|
||||
// Now we assume:
|
||||
// contiguous = 1, order = 0, 1, 2, ...,
|
||||
llvm::SmallVector<unsigned> threadTileSize(rank, 1); // naive layout
|
||||
llvm::SmallVector<unsigned> warpTileSize(rank, 1);
|
||||
llvm::SmallVector<unsigned> blockTileSize(rank);
|
||||
// types with encoding are already in the right format
|
||||
// TODO: check for layout encodings specifically
|
||||
if (tensorType.getEncoding())
|
||||
return tensorType;
|
||||
// pessimistic values for attributes:
|
||||
// - 1 element per thread
|
||||
// - order = arange(rank)
|
||||
ArrayRef<int64_t> shape = tensorType.getShape();
|
||||
int rank = shape.size();
|
||||
llvm::SmallVector<unsigned> order(rank);
|
||||
llvm::SmallVector<unsigned> broadcastAxis;
|
||||
int remainingThreads = numThreads;
|
||||
int remainingLanes = /*warp size*/ 32;
|
||||
for (int64_t dim = 0; dim < rank; ++dim) {
|
||||
blockTileSize[dim] = std::clamp(remainingThreads, 1, int(shape[dim]));
|
||||
warpTileSize[dim] = std::clamp(remainingLanes, 1, int(shape[dim]));
|
||||
order[dim] = dim;
|
||||
|
||||
remainingThreads /= blockTileSize[dim];
|
||||
remainingLanes /= warpTileSize[dim];
|
||||
// TODO: will we need repetition?
|
||||
}
|
||||
std::iota(order.begin(), order.end(), 0);
|
||||
llvm::SmallVector<unsigned> sizePerThread(rank, 1);
|
||||
Attribute encoding = triton::gpu::TritonGPUBlockedEncodingAttr::get(
|
||||
context, threadTileSize, warpTileSize, blockTileSize, order,
|
||||
broadcastAxis);
|
||||
return RankedTensorType::get(shape, elementType, encoding);
|
||||
this->context, shape, sizePerThread, order, this->numWarps);
|
||||
return RankedTensorType::get(shape, tensorType.getElementType(), encoding);
|
||||
});
|
||||
|
||||
//
|
||||
@@ -86,8 +59,12 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
|
||||
// NOTE: only for remapped values.
|
||||
addTargetMaterialization([&](OpBuilder &builder, RankedTensorType tensorType,
|
||||
ValueRange inputs, Location loc) {
|
||||
llvm_unreachable("Not implemented");
|
||||
return llvm::None;
|
||||
auto cast =
|
||||
builder.create<triton::gpu::ConvertLayoutOp>(loc, tensorType, inputs);
|
||||
return Optional<Value>(cast.getResult());
|
||||
// return Optional<Value>(cast.getResult(0));
|
||||
// llvm_unreachable("Not implemented");
|
||||
// return llvm::None;
|
||||
});
|
||||
}
|
||||
|
||||
@@ -122,87 +99,6 @@ TritonGPUConversionTarget::TritonGPUConversionTarget(
|
||||
aEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>() &&
|
||||
bEncoding && bEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>())
|
||||
return true;
|
||||
// // TODO: we should delete this
|
||||
// if (this->typeConverter.isLegal(dotOp))
|
||||
// return true;
|
||||
return false;
|
||||
});
|
||||
}
|
||||
|
||||
// %dst = tt.broadcast %src
|
||||
// =>
|
||||
// %newSrc = convert_layout %src
|
||||
// %bcst = tt.broadcast %newSrc
|
||||
// %dst = convert_layout %bcst
|
||||
LogicalResult TritonGPUConversionTarget::refineLayouts(ModuleOp mod,
|
||||
int numThreads) {
|
||||
// collect broadcasts
|
||||
SmallVector<triton::BroadcastOp> broadcasts;
|
||||
mod.walk([&](triton::BroadcastOp op) { broadcasts.push_back(op); });
|
||||
|
||||
BlockAndValueMapping mapping;
|
||||
for (auto broadcast : broadcasts) {
|
||||
OpBuilder builder(broadcast);
|
||||
Value src = mapping.lookupOrDefault(broadcast.src());
|
||||
Type originSrcType = src.getType();
|
||||
Type originDstType = broadcast.getType();
|
||||
auto originDstTensorType = originDstType.dyn_cast<RankedTensorType>();
|
||||
unsigned dstRank = originDstTensorType.getRank();
|
||||
|
||||
// compute newSrcType & broadcastAxis
|
||||
Type newSrcType;
|
||||
SmallVector<unsigned> broadcastAxis;
|
||||
bool isSrcScalar = false;
|
||||
if (auto tensorType = originSrcType.dyn_cast<RankedTensorType>()) {
|
||||
assert(tensorType.getRank() == dstRank &&
|
||||
"src & dst should have same rank (verifier should catch this)");
|
||||
for (unsigned ax = 0; ax < dstRank; ++ax)
|
||||
if (tensorType.getShape()[ax] < originDstTensorType.getShape()[ax])
|
||||
broadcastAxis.push_back(ax);
|
||||
|
||||
Attribute originSrcEnc = tensorType.getEncoding();
|
||||
if (auto blockedEnc =
|
||||
originSrcEnc.dyn_cast<TritonGPUBlockedEncodingAttr>()) {
|
||||
auto newSrcEnc = TritonGPUBlockedMulticastEncodingAttr::get(
|
||||
blockedEnc.getContext(), blockedEnc.getThreadTileSize(),
|
||||
blockedEnc.getWarpTileSize(), blockedEnc.getBlockTileSize(),
|
||||
blockedEnc.getOrder(), broadcastAxis);
|
||||
newSrcType = RankedTensorType::get(
|
||||
tensorType.getShape(), tensorType.getElementType(), newSrcEnc);
|
||||
} else
|
||||
llvm_unreachable("src of broadcast should have blocked encoding");
|
||||
} else {
|
||||
for (unsigned ax = 0; ax < dstRank; ++ax)
|
||||
broadcastAxis.push_back(ax);
|
||||
newSrcType = originSrcType;
|
||||
isSrcScalar = true;
|
||||
}
|
||||
|
||||
// create new src
|
||||
if (!isSrcScalar) // we don't need to convert layout for scalar values
|
||||
src = builder.create<triton::gpu::ConvertLayoutOp>(src.getLoc(),
|
||||
newSrcType, src);
|
||||
|
||||
// create new broadcast
|
||||
// compute new type (encoding)
|
||||
auto originDstEnc = originDstTensorType.getEncoding()
|
||||
.dyn_cast<TritonGPUBlockedEncodingAttr>();
|
||||
auto newEnc = TritonGPUBlockedEncodingAttr::get(
|
||||
originDstEnc.getContext(), originDstEnc.getThreadTileSize(),
|
||||
originDstEnc.getWarpTileSize(), originDstEnc.getBlockTileSize(),
|
||||
originDstEnc.getOrder(), broadcastAxis);
|
||||
auto newType =
|
||||
RankedTensorType::get(originDstTensorType.getShape(),
|
||||
originDstTensorType.getElementType(), newEnc);
|
||||
Value newBroadcast =
|
||||
builder.create<triton::BroadcastOp>(broadcast.getLoc(), newType, src);
|
||||
// we don't want to change the encoding of the result
|
||||
Value newDst = builder.create<triton::gpu::ConvertLayoutOp>(
|
||||
broadcast.getLoc(), originDstType, newBroadcast);
|
||||
|
||||
broadcast.replaceAllUsesWith(newDst);
|
||||
mapping.map(broadcast, newDst);
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
}
|
@@ -10,10 +10,10 @@ def kernel(X, stride_xm, stride_xn,
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
|
||||
off_m = tl.arange(0, BLOCK_M)
|
||||
off_n = tl.arange(0, BLOCK_N)
|
||||
Xs = X + off_m[:, None] * stride_xm + off_n[None, :] * stride_xn
|
||||
Zs = Z + off_m[:, None] * stride_zm + off_n[None, :] * stride_zn
|
||||
Xs = X + off_m[:, None] * stride_xm + off_n[None, :] * 1
|
||||
Zs = Z + off_m[:, None] * 1 + off_n[None, :] * stride_zn
|
||||
tl.store(Zs, tl.load(Xs))
|
||||
|
||||
|
||||
ret = triton.compile(kernel, "*fp32,i32,i32,*fp32,i32,i32", constants={"BLOCK_M": 128, "BLOCK_N": 128}, output="ttgir")
|
||||
ret = triton.compile(kernel, "*fp32,i32,i32,*fp32,i32,i32", constants={"BLOCK_M": 64, "BLOCK_N": 64}, output="ttgir")
|
||||
print(ret)
|
||||
|
@@ -1471,14 +1471,14 @@ void init_triton_ir(py::module &&m) {
|
||||
self.create<mlir::triton::StoreOp>(loc, ptrs, val, mask);
|
||||
})
|
||||
// Block instruction
|
||||
.def("create_reshape",
|
||||
.def("create_view",
|
||||
[](mlir::OpBuilder &self, mlir::Value &arg,
|
||||
std::vector<int64_t> &shape) -> mlir::Value {
|
||||
auto loc = self.getUnknownLoc();
|
||||
auto argType = arg.getType()
|
||||
.dyn_cast<mlir::RankedTensorType>()
|
||||
.getElementType();
|
||||
return self.create<mlir::triton::ReshapeOp>(
|
||||
return self.create<mlir::triton::ViewOp>(
|
||||
loc, mlir::RankedTensorType::get(shape, argType), arg);
|
||||
})
|
||||
.def("create_cat",
|
||||
|
@@ -565,7 +565,7 @@ class tensor:
|
||||
elif sl == slice(None, None, None):
|
||||
dst_shape.append(src_shape[curr].value)
|
||||
curr += 1
|
||||
ret = semantic.reshape(self, dst_shape, _builder)
|
||||
ret = semantic.view(self, dst_shape, _builder)
|
||||
return ret
|
||||
|
||||
@builtin
|
||||
|
@@ -451,16 +451,16 @@ def zeros(shape: List[int], dtype: tl.dtype, builder: ir.builder) -> tl.tensor:
|
||||
# ===----------------------------------------------------------------------===//
|
||||
|
||||
|
||||
def reshape(input: tl.tensor,
|
||||
dst_shape: List[int],
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
def view(input: tl.tensor,
|
||||
dst_shape: List[int],
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
numel = 1
|
||||
for s in dst_shape:
|
||||
numel *= s
|
||||
if input.type.numel != numel:
|
||||
raise ValueError("cannot reshape block of different shape")
|
||||
ret_ty = tl.block_type(input.type.scalar, dst_shape)
|
||||
return tl.tensor(builder.create_reshape(input.handle, dst_shape), ret_ty)
|
||||
return tl.tensor(builder.create_view(input.handle, dst_shape), ret_ty)
|
||||
|
||||
|
||||
def cat(lhs: tl.tensor, rhs: tl.tensor, builder: ir.builder) -> tl.tensor:
|
||||
|
@@ -10,7 +10,7 @@ func @permute_2d(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {t
|
||||
// CHECK-NEXT: Contiguity: [128] ; Divisibility: [65536] ; Constancy: [1]
|
||||
%1 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [65536, 1] ; Constancy: [1, 1]
|
||||
%2 = tt.reshape %0 : (tensor<128xi32>) -> tensor<128x1xi32>
|
||||
%2 = tt.view %0 : (tensor<128xi32>) -> tensor<128x1xi32>
|
||||
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [128, 1]
|
||||
%3 = tt.splat %arg1 : (i32) -> tensor<128x1xi32>
|
||||
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [1048576, 16] ; Constancy: [1, 1]
|
||||
@@ -20,7 +20,7 @@ func @permute_2d(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {t
|
||||
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [1, 1]
|
||||
%6 = tt.getelementptr %5, %4 : tensor<128x1x!tt.ptr<f32>>
|
||||
// CHECK-NEXT: Contiguity: [1, 128] ; Divisibility: [1, 65536] ; Constancy: [1, 1]
|
||||
%7 = tt.reshape %1 : (tensor<128xi32>) -> tensor<1x128xi32>
|
||||
%7 = tt.view %1 : (tensor<128xi32>) -> tensor<1x128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [1, 128]
|
||||
%8 = tt.broadcast %6 : (tensor<128x1x!tt.ptr<f32>>) -> tensor<128x128x!tt.ptr<f32>>
|
||||
// CHECK-NEXT: Contiguity: [1, 128] ; Divisibility: [1, 65536] ; Constancy: [128, 1]
|
||||
@@ -28,13 +28,13 @@ func @permute_2d(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {t
|
||||
// CHECK-NEXT: Contiguity: [1, 128] ; Divisibility: [1, 16] ; Constancy: [1, 1]
|
||||
%10 = tt.getelementptr %8, %9 : tensor<128x128x!tt.ptr<f32>>
|
||||
// CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [65536, 1] ; Constancy: [1, 1]
|
||||
%11 = tt.reshape %0 : (tensor<128xi32>) -> tensor<128x1xi32>
|
||||
%11 = tt.view %0 : (tensor<128xi32>) -> tensor<128x1xi32>
|
||||
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [128, 1]
|
||||
%12 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<128x1x!tt.ptr<f32>>
|
||||
// CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [16, 1] ; Constancy: [1, 1]
|
||||
%13 = tt.getelementptr %12, %11 : tensor<128x1x!tt.ptr<f32>>
|
||||
// CHECK-NEXT: Contiguity: [1, 128] ; Divisibility: [1, 65536] ; Constancy: [1, 1]
|
||||
%14 = tt.reshape %1 : (tensor<128xi32>) -> tensor<1x128xi32>
|
||||
%14 = tt.view %1 : (tensor<128xi32>) -> tensor<1x128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [1, 128]
|
||||
%15 = tt.splat %arg3 : (i32) -> tensor<1x128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 1048576] ; Constancy: [1, 1]
|
||||
|
@@ -1,26 +0,0 @@
|
||||
// RUN: triton-opt %s -split-input-file -verify-diagnostics
|
||||
|
||||
#reg = #triton_gpu.blocked_layout<{
|
||||
threadTileSize = [1, 1],
|
||||
warpTileSize = [32, 1],
|
||||
blockTileSize = [64, 1],
|
||||
order = [0]
|
||||
}>
|
||||
|
||||
#reg2 = #triton_gpu.blocked_layout<{
|
||||
threadTileSize = [2, 1],
|
||||
warpTileSize = [64, 1],
|
||||
blockTileSize = [128, 1],
|
||||
order = [0]
|
||||
}>
|
||||
|
||||
func @add(%arg0: tensor<256xi32, #reg>, %arg1: tensor<256xi32, #reg>) {
|
||||
%0 = arith.addi %arg0, %arg1 : tensor<256xi32, #reg>
|
||||
return
|
||||
}
|
||||
|
||||
func @add(%arg0: tensor<256xi32, #reg>, %arg1: tensor<256xi32, #reg>) { // expected-note {{prior use here}}
|
||||
// expected-error @+1 {{use of value '%arg0' expects different type than prior uses}}
|
||||
%0 = arith.addi %arg0, %arg1 : tensor<256xi32, #reg2>
|
||||
return
|
||||
}
|
@@ -1,45 +1,12 @@
|
||||
// RUN: triton-opt %s -split-input-file -tritongpu-pipeline=num-stages=3 -canonicalize -tritongpu-verifier
|
||||
// RUN: triton-opt %s -split-input-file -tritongpu-pipeline=num-stages=3 -canonicalize -tritongpu-verifier | FileCheck %s
|
||||
|
||||
// 4 warps
|
||||
// matmul: 128x32 @ 32x128 -> 128x128
|
||||
#AL = #triton_gpu.blocked_layout<{
|
||||
threadTileSize = [1, 4],
|
||||
warpTileSize = [4, 32],
|
||||
blockTileSize = [16, 32],
|
||||
order = [1, 0]
|
||||
}>
|
||||
|
||||
#BL = #triton_gpu.blocked_layout<{
|
||||
threadTileSize = [1, 4],
|
||||
warpTileSize = [1, 128],
|
||||
blockTileSize = [4, 128],
|
||||
order = [1, 0]
|
||||
}>
|
||||
|
||||
#A = #triton_gpu.shared_layout<{
|
||||
vec = 2,
|
||||
perPhase = 2,
|
||||
maxPhase = 4,
|
||||
order = [1, 0]
|
||||
}>
|
||||
|
||||
#B = #triton_gpu.shared_layout<{
|
||||
vec = 2,
|
||||
perPhase = 2,
|
||||
maxPhase = 4,
|
||||
order = [1, 0]
|
||||
}>
|
||||
|
||||
// TODO: check this
|
||||
#C = #triton_gpu.mma_layout<{
|
||||
fragmentPerWarp = [1, 1],
|
||||
shapePerWarp = [16, 8],
|
||||
warpPerTile = [2, 2],
|
||||
shapePerTile = [32, 16],
|
||||
repetitions = [4, 4],
|
||||
contigPerThread = [1, 8]
|
||||
}>
|
||||
#AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
|
||||
#BL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
|
||||
#A = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>
|
||||
#B = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>
|
||||
#C = #triton_gpu.mma<{version = 2, warpsPerCTA = [4, 1]}>
|
||||
|
||||
// CHECK: func @matmul_loop
|
||||
// CHECK: %[[A0:.*]] = triton_gpu.copy_async
|
||||
|
Reference in New Issue
Block a user