[TritonGPU] Improved documentation and semantics of layout encodings (#30)

This commit is contained in:
Philippe Tillet
2022-07-31 13:59:44 -07:00
committed by GitHub
parent e02c82c765
commit d1593e6ca8
17 changed files with 399 additions and 566 deletions

View File

@@ -10,7 +10,7 @@ jobs:
Integration-Tests:
runs-on: ubuntu-20.04
runs-on: ubuntu-latest
steps:

View File

@@ -18,16 +18,14 @@ namespace mlir {
/// Axis information is represented by a std::map<int, int>
class AxisInfo {
public:
typedef std::vector<int> ContiguityT;
typedef std::vector<int> DivisibilityT;
typedef std::vector<int> ConstancyT;
typedef SmallVector<int, 4> DimVectorT;
public:
// Default constructor
AxisInfo() : AxisInfo({}, {}, {}) {}
// Construct contiguity info with known contiguity
AxisInfo(ContiguityT knownContiguity, DivisibilityT knownDivisibility,
ConstancyT knownConstancy)
AxisInfo(DimVectorT knownContiguity, DimVectorT knownDivisibility,
DimVectorT knownConstancy)
: contiguity(knownContiguity), divisibility(knownDivisibility),
constancy(knownConstancy), rank(contiguity.size()) {
assert(knownDivisibility.size() == rank);
@@ -36,13 +34,13 @@ public:
// Accessors
int getContiguity(size_t d) const { return contiguity[d]; }
const ContiguityT &getContiguity() const { return contiguity; }
const DimVectorT &getContiguity() const { return contiguity; }
int getDivisibility(size_t d) const { return divisibility[d]; }
const DivisibilityT &getDivisibility() const { return divisibility; }
const DimVectorT &getDivisibility() const { return divisibility; }
int getConstancy(size_t d) const { return constancy[d]; }
const ConstancyT &getConstancy() const { return constancy; }
const DimVectorT &getConstancy() const { return constancy; }
int getRank() const { return rank; }
@@ -78,7 +76,7 @@ private:
/// [18, 22, 26, 30]
/// [19, 23, 27, 31]
/// Would have contiguity [2, 1].
ContiguityT contiguity;
DimVectorT contiguity;
/// The _divisibility_ information maps the `d`-th
/// dimension to the largest power-of-two that
@@ -93,7 +91,7 @@ private:
/// [14, 18, 22, 26]
/// [15, 19, 23, 27]
// would have divisibility [4, 1]
DivisibilityT divisibility;
DimVectorT divisibility;
/// The _constancy_ information maps the `d`-th
/// dimension to the length of the shortest
@@ -104,7 +102,7 @@ private:
/// [8, 8, 8, 8, 12, 12, 12, 12]
/// [16, 16, 16, 16, 20, 20, 20, 20]
/// would have constancy [1, 4]
ConstancyT constancy;
DimVectorT constancy;
// number of dimensions of the lattice
int rank;

View File

@@ -133,8 +133,8 @@ def TT_GEPOp : TT_Op<"getelementptr",
//
// Shape Manipulation Ops
//
def TT_ReshapeOp : TT_Op<"reshape", [NoSideEffect, SameOperandsAndResultElementType]> {
let summary = "reshape";
def TT_ViewOp : TT_Op<"view", [NoSideEffect, SameOperandsAndResultElementType]> {
let summary = "view";
let arguments = (ins TT_Tensor:$src);

View File

@@ -2,43 +2,60 @@
#define TRITONGPU_ATTRDEFS
include "triton/Dialect/TritonGPU/IR/TritonGPUDialect.td"
// include "mlir/IR/TensorEncoding.td"
//===----------------------------------------------------------------------===//
// TritonGPU Attribute Definitions
//===----------------------------------------------------------------------===//
class TritonGPU_Attr<string name, list<Trait> traits = [],
string baseCppClass = "::mlir::Attribute">
: AttrDef<TritonGPU_Dialect, name, traits, baseCppClass>;
: AttrDef<TritonGPU_Dialect, name, traits, baseCppClass> {
let description = [{
TritonGPU Tensors differ from usual tensors in that they contain a _layout_ attribute which determines
how the data should be partitioned across CUDA threads. Formally speaking, we define a layout as a function
\mathcal{L} that maps a multi-dimensional tensor index $i \in \mathbb{Z}^d$ to a set of integers T corresponding
to the indices of the CUDA threads allowed to access some data at index $i$.
For example, let us consider the layout function:
\mathcal{L}(0, 0) = {0, 4}
\mathcal{L}(0, 1) = {1, 5}
\mathcal{L}(1, 0) = {2, 6}
\mathcal{L}(1, 1) = {3, 7}
Then, attaching $\mathcal{L} to a tensor $T$ would mean that:
- T[0,0] is owned by both cuda thread 0 and 4
- T[0,1] is owned by both cuda thread 1 and 5
- T[1,0] is owned by both cuda thread 2 and 6
- T[1,1] is owned by both cuda thread 3 and 7
Right now, Triton implements two classes of layouts: shared, and distributed.
}];
}
//===----------------------------------------------------------------------===//
// Shared Layout Encoding
//===----------------------------------------------------------------------===//
def TritonGPUSharedEncodingAttr : TritonGPU_Attr<"TritonGPUSharedEncoding"> {
let mnemonic = "shared_layout";
let mnemonic = "shared";
let description = [{
An encoding for tensors whose elements may be simultaneously accessed by
different warps in the programs, via shared memory.
different cuda threads in the programs, via shared memory. In other words,
for all indices i \in R^d, \mathcal{L}(i) = {0, 1, ..., 32*num_warps - 1}.
In order to avoid shared memory bank conflicts, elements may be stored in a
swizzled layout.
For example, a swizzled row-major layout stores would store data as follows:
In order to avoid shared memory bank conflicts, elements may be swizzled
in memory. For example, a swizzled row-major layout could store its data
as follows:
A_{0, 0} A_{0, 1} A_{0, 2} A_{0, 3} ... [phase 0] \ per_phase = 2
A_{1, 0} A_{0, 1} A_{1, 2} A_{1, 3} ... [phase 0] /
groups of vec=2 elements
are stored contiguously
_ _ _ _ /\_ _ _ _
A_{2, 2} A_{2, 3} A_{2, 0} A_{2, 1} ... [phase 1] \ per phase = 2
A_{3, 2} A_{3, 3} A_{3, 0} A_{3, 1} ... [phase 1] /
And the associated TritonGPU MLIR
```mlir
#SMEM = #triton_gpu.shared_layout<{
vec = 2,
perPhase = 2,
maxPhase = 4,
order = [1, 0]
}>
```
}];
let parameters = (
@@ -49,143 +66,222 @@ And the associated TritonGPU MLIR
);
}
//===----------------------------------------------------------------------===//
// Distributed Layout Encoding
//===----------------------------------------------------------------------===//
class TritonGPUDistributedEncodingAttr : TritonGPU_Attr<"TritonGPUDistributedEncoding"> {
let mnemonic = "distributed";
let description = [{
Distributed encodings have a layout function that is entirely characterized
by a d-dimensional tensor L. Note that L doesn't need to have the same shape
(or even the same rank) as the tensor it is encoding.
The layout function \mathcal{L} of this layout is then defined, for an
index `i` \in R^D, as follows:
\mathcal{L}(A)[i_d] = L[(i_d + k_d*A.shape[d]) % L.shape[d]] \forall k_d such as i_d + k_d*A.shape[d] < L.shape[d]
For example, for a tensor/layout pair
A = [x x x x x x x x]
[x x x x x x x x]
L = [0 1 2 3 ]
[4 5 6 7 ]
[8 9 10 11]
[12 13 14 15]
Then the data of A would be distributed as follow between the 16 CUDA threads:
L(A) = [ {0,8} , {1,9} , {2,10}, {3,11}, {0,8} , {1, 9} , {2, 10}, {3, 11},
{4,12}, {5,13}, {6,14}, {7,15}, {4,12}, {5, 13}, {6, 14}, {7, 15} ]
}];
}
//===----------------------------------------------------------------------===//
// Blocked Layout Encoding
//===----------------------------------------------------------------------===//
def TritonGPUBlockedEncodingAttr : TritonGPU_Attr<"TritonGPUBlockedEncoding"> {
let mnemonic = "blocked_layout";
let mnemonic = "blocked";
let description = [{
An encoding where each warp owns a contiguous portion of the target tensor. This is typically the kind of data layout
consumed (and returned) by LoadInst.
For example, a row-major coalesced layout may distribute a 64x16 tensor over 2 warps (i.e. 64 threads) as follows:
used to promote memory coalescing in LoadInst and StoreInst.
It is characterized by three tuples -- thread tile size, warp tile size, and block tile size -- which
specify the amount of elements owned by each CUDA thread, warp and CTA respectively.
thread tile size 2
- - - - - - /\ - - - - - -
warp | thread || A_{0, 0}[T0] A_{0, 1}[T0] ... A_{0, 6}[T3] A_{0, 7}[T3] A_{0, 8}[T0] A_{0, 9}[T0] ... A_{0, 14}[T3] A_{0, 15}[T3]
tile | tile size 2 || A_{1, 0}[T0] A_{1, 1}[T0] ... A_{1, 6}[T3] A_{1, 7}[T3] A_{1, 8}[T0] A_{1, 9}[T0] ... A_{1, 14}[T3] A_{1, 15}[T3]
size } ....
32 | A_{30, 0}[T60] A_{14, 1}[T60] ... A_{14, 6}[T63] A_{14, 7}[T63] A_{14, 8}[T60] A_{14, 9}[T60] ... A_{14, 14}[T63] A_{14, 15}[T63]
| A_{31, 0}[T60] A_{15, 1}[T60] ... A_{15, 6}[T63] A_{15, 7}[T63] A_{15, 8}[T60] A_{15, 9}[T60] ... A_{15, 14}[T63] A_{15, 15}[T63]
-----------------------------/\-----------------------------------
warp tile size 8
For example, a row-major coalesced layout may partition a 16x16 tensor over 2 warps (i.e. 64 threads) as follows.
A_{32, 0}[T0] A_{32, 1}[T0] ... A_{32, 6}[T3] A_{32, 7}[T3] A_{32, 8}[T0] A_{32, 9}[T0] ... A_{32, 14}[T3] A_{32, 15}[T3]
A_{33, 0}[T0] A_{33, 1}[T0] ... A_{33, 6}[T3] A_{33, 7}[T3] A_{33, 8}[T0] A_{33, 9}[T0] ... A_{33, 14}[T3] A_{33, 15}[T3]
[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ]
[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ]
[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ]
[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ]
...
[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ]
[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ]
A_{62, 0}[T60] A_{62, 1}[T60] ... A_{62, 6}[T63] A_{62, 7}[T63] A_{62, 8}[T60] A_{62, 9}[T60] ... A_{62, 14}[T63] A_{62, 15}[T63]
A_{63, 0}[T60] A_{63, 1}[T60] ... A_{63, 6}[T63] A_{63, 7}[T63] A_{63, 8}[T60] A_{63, 9}[T60] ... A_{63, 14}[T63] A_{63, 15}[T63]
for
And the associated TritonGPU MLIR
#LAYOUT = #triton_gpu.blocked_layout<{
threadTileSize = {2, 2}
blockTileSize = {32, 8}
#triton_gpu.blocked_layout<{
sizePerThread = {2, 2}
threadsPerWarp = {8, 4}
warpsPerCTA = {1, 2}
}>
// note to Da: In current Triton codebase, `nanoTileSize = threadTileSize`, and `macro-tile size = blockTileSize / threadTileSize`
probably clearer to have easier semantics (i.e., size of each tile owned by a thread or a block)
}];
let builders = [
// Custom builder initializes sizePerWarp and sizePerCTA automatically
// TODO: compiles on MacOS but not linux?
// AttrBuilder<(ins "ArrayRef<unsigned>":$sizePerThread,
// "ArrayRef<unsigned>":$threadsPerWarp,
// "ArrayRef<unsigned>":$warpsPerCTA,
// "ArrayRef<unsigned>":$order), [{
// int rank = threadsPerWarp.size();
// SmallVector<unsigned, 4> sizePerWarp(rank);
// SmallVector<unsigned, 4> sizePerCTA(rank);
// for (unsigned i = 0; i < rank; i++) {
// sizePerWarp.push_back(sizePerThread[i] * threadsPerWarp[i]);
// sizePerCTA.push_back(sizePerWarp[i] * warpsPerCTA[i]);
// }
// return $_get(context, sizePerThread, threadsPerWarp, warpsPerCTA, order, sizePerWarp, sizePerCTA);
// }]>,
// Custom builder initializes sizePerWarp and sizePerCTA automatically
// Default builder takes sizePerThread, order and numWarps, and tries to
// pack numWarps*32 threads in the provided order for use in a type
// of the given shape.
AttrBuilder<(ins "ArrayRef<int64_t>":$shape,
"ArrayRef<unsigned>":$sizePerThread,
"ArrayRef<unsigned>":$order,
"unsigned":$numWarps), [{
int rank = sizePerThread.size();
int remainingWarps = numWarps;
int remainingLanes = 32;
SmallVector<unsigned, 4> threadsPerWarp(rank);
SmallVector<unsigned, 4> warpsPerCTA(rank);
for (int _dim = 0; _dim < rank; ++_dim) {
int dim = order[_dim];
int maxNumThreads = int(shape[dim]) / sizePerThread[dim];
warpsPerCTA[dim] = std::clamp(remainingWarps, 1, maxNumThreads);
threadsPerWarp[dim] = std::clamp(remainingLanes, 1, maxNumThreads);
remainingWarps /= warpsPerCTA[dim];
remainingLanes /= threadsPerWarp[dim];
}
return $_get(context, sizePerThread, threadsPerWarp, warpsPerCTA, order);
}]>
];
let parameters = (
ins
// TODO: should we rename this as laneTileSize?
ArrayRefParameter<
"unsigned",
/*desc*/"size of a tile that is holded by a thread"
>:$threadTileSize,
ArrayRefParameter<
"unsigned",
"size of the a tile that is holded by a warp"
>:$warpTileSize,
ArrayRefParameter<
"unsigned",
"size of a tile that is holded by a thread block"
>:$blockTileSize,
// // TODO: It seems that we don't need this (because we can re-compute this)
// ArrayRefParameter<"unsigned">:$reptitions,
ArrayRefParameter<"unsigned">:$sizePerThread,
ArrayRefParameter<"unsigned">:$threadsPerWarp,
ArrayRefParameter<"unsigned">:$warpsPerCTA,
// fastest-changing axis first
ArrayRefParameter<
"unsigned",
"order of axes by the rate of changing"
>:$order,
ArrayRefParameter<"unsigned">:$broadcastAxis
// "AffineMap":$threadOrdering,
// "AffineMap":warpOrdering,
// "AffineMap":$blockOrdering,
>:$order
// These attributes can be inferred from the rest
// ArrayRefParameter<"unsigned">:$sizePerWarp,
// ArrayRefParameter<"unsigned">:$sizePerCTA
);
// let genVerifyDecl = 1;
}
def TritonGPUBlockedMulticastEncodingAttr
: TritonGPU_Attr<"TritonGPUBlockedMulticastEncoding"> {
let mnemonic = "blocked_multicast_layout";
let description = [{
to be broadcasted to blocked_layout
}];
// This needs to be synced with BlockedEncoding
let parameters = (
ins
ArrayRefParameter<"unsigned">:$threadTileSize,
ArrayRefParameter<"unsigned">:$warpTileSize,
ArrayRefParameter<"unsigned">:$blockTileSize,
ArrayRefParameter<"unsigned">:$order,
// unique to broadcasted layout
ArrayRefParameter<"unsigned">:$broadcastAxis
);
// let genVerifyDecl = 1;
}
//===----------------------------------------------------------------------===//
// MMA Layout Encoding
//===----------------------------------------------------------------------===//
// TODO: MMAv1 and MMAv2 should be two instances of the same class
def TritonGPUMmaEncodingAttr : TritonGPU_Attr<"TritonGPUMmaEncoding"> {
let mnemonic = "mma_layout";
let description = [{TODO: I think we may be able to implement it as a special-case of Distributed encoding with maybe one more warpTileSize attribute!}];
let parameters = (
ins
// only used by Volta mma.884
ArrayRefParameter<"unsigned">:$fragmentPerWarp,
// aka shapeOfInstr (e.g., {16,8,16})
ArrayRefParameter<"unsigned">:$shapePerWarp,
// TODO: should we rename this as warpTileSize? (consistent naming with Distributed layout)
ArrayRefParameter<"unsigned">:$warpPerTile,
// TODO: should we rename this as blockTileSize? (consistent naming with Distributed layout)
ArrayRefParameter<"unsigned">:$shapePerTile,
// TODO: should Distributed layout also
ArrayRefParameter<"unsigned">:$repetitions,
ArrayRefParameter<"unsigned">:$contigPerThread,
ArrayRefParameter<"unsigned">:$broadcastAxis
// "AffineMap":$warpOrdering,
// "AffineMap":$blockOrdering
);
// let genVerifyDecl = 1;
}
def TritonGPUMmaMulticastEncodingAttr
: TritonGPU_Attr<"TritonGPUMmaMulticastEncoding"> {
let mnemonic = "mma_multicast_layout";
let mnemonic = "mma";
let description = [{
To be broadcasted to mma.
}];
An encoding for tensors that have been produced by tensor cores.
It is characterized by two parameters:
- A 'version' which specifies the generation the tensor cores
whose output is being partitioned: 1 for first-gen tensor cores (Volta),
and 2 for second-gen tensor cores (Turing/Ampere).
- A `blockTileSize` to indicate how data should be
partitioned between warps.
// -------------------------------- version = 1 --------------------------- //
For first-gen tensor cores, the implicit warpTileSize is [16, 16].
Information about this layout can be found in the official PTX documentation
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html
(mma.884 section, FP32 accumulator).
For example, the matrix L corresponding to blockTileSize=[32,16] is:
warp 0
--------------------------------/\-------------------------------
[ 0 0 2 2 0 0 2 2 4 4 6 6 4 4 6 6 ]
[ 1 1 3 3 1 1 3 3 5 5 7 7 5 5 7 7 ]
[ 0 0 2 2 0 0 2 2 4 4 6 6 4 4 6 6 ]
[ 1 1 3 3 1 1 3 3 5 5 7 7 5 5 7 7 ]
[ 16 16 18 18 16 16 18 18 20 20 22 22 20 20 22 22]
[ 17 17 19 19 17 17 19 19 21 21 23 23 21 21 23 23]
[ 16 16 18 18 16 16 18 18 20 20 22 22 20 20 22 22]
[ 17 17 19 19 17 17 19 19 21 21 23 23 21 21 23 23]
[ 8 8 10 10 8 8 10 10 12 12 14 14 12 12 14 14]
[ 9 9 11 11 9 9 11 11 13 13 15 15 13 13 15 15]
[ ..............................................................
[ ..............................................................
[ 24 24 26 26 24 24 26 26 28 28 30 30 28 28 30 30]
[ 25 25 27 27 25 25 27 27 29 29 31 31 29 29 31 31]
warp 1 = warp0 + 32
--------------------------------/\-------------------------------
[ 32 32 34 34 32 32 34 34 36 36 38 38 36 36 38 38]
[ 33 33 35 35 33 33 35 35 37 37 39 39 37 37 39 39]
[ ..............................................................
[ ..............................................................
[ 56 56 58 58 56 56 58 58 60 60 62 62 60 60 62 62]
[ 57 57 59 59 57 57 59 59 61 61 63 63 61 61 63 63]
// -------------------------------- version = 2 --------------------------- //
For second-gen tensor cores, the implicit warpTileSize is [16, 8].
Information about this layout can be found in the official PTX documentation
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html
(mma.16816 section, FP32 accumulator).
For example, the matrix L corresponding to blockTileSize=[32,16] is:
warp 0 warp 1
-----------------/\------------- ----------------/\-------------
[ 0 0 1 1 2 2 3 3 32 32 33 33 34 34 35 35
[ 4 4 5 5 6 6 7 7 36 36 37 37 38 38 39 39
[ .............................. ..............................
[ 28 28 29 29 30 30 31 31 60 60 61 61 62 62 63 63
[ 0 0 1 1 2 2 3 3 32 32 33 33 34 34 35 35
[ 4 4 5 5 6 6 7 7 36 36 37 37 38 38 39 39
[ .............................. ..............................
[ 28 28 29 29 30 30 31 31 60 60 61 61 62 62 63 63
warp 3 warp 4
----------------/\------------- ----------------/\-------------
[ 64 64 65 65 66 66 67 67 96 96 97 97 98 98 99 99
[ 68 68 69 69 70 70 71 71 100 100 101 101 102 102 103 103
[ .............................. ...............................
[ 92 92 93 93 94 94 95 95 124 124 125 125 126 126 127 127
[ 64 64 65 65 66 66 67 67 96 96 97 97 98 98 99 99
[ 68 68 69 69 70 70 71 71 100 100 101 101 102 102 103 103
[ .............................. ...............................
[ 92 92 93 93 94 94 95 95 124 124 125 125 126 126 127 127
}];
// This needs to be synced with MmaEncoding
let parameters = (
ins
ArrayRefParameter<"unsigned">:$fragmentPerWarp,
ArrayRefParameter<"unsigned">:$shapePerWarp,
ArrayRefParameter<"unsigned">:$warpPerTile,
ArrayRefParameter<"unsigned">:$shapePerTile,
ArrayRefParameter<"unsigned">:$repetitions,
ArrayRefParameter<"unsigned">:$contigPerThread,
// unique to broadcasted layout
ArrayRefParameter<"unsigned">:$broadcastAxis
"unsigned":$version,
ArrayRefParameter<"unsigned">:$warpsPerCTA
);
// let genVerifyDecl = 1;
}
#endif

View File

@@ -7,16 +7,7 @@ def TritonGPUPipeline : Pass<"tritongpu-pipeline", "mlir::ModuleOp"> {
let summary = "pipeline";
let description = [{
scf.for() {
%a = load %a_ptr;
%b = load %b_ptr;
%d = dot %a, %b, %c;
}
=>
...
TODO
}];
let constructor = "mlir::createTritonGPUPipelinePass()";

View File

@@ -13,11 +13,11 @@ namespace mlir {
class TritonGPUTypeConverter : public TypeConverter {
public:
TritonGPUTypeConverter(MLIRContext *context, int numThreads);
TritonGPUTypeConverter(MLIRContext *context, int numWarps);
private:
MLIRContext *context;
int numThreads;
int numWarps;
};
class TritonGPUConversionTarget : public ConversionTarget {
@@ -26,9 +26,6 @@ class TritonGPUConversionTarget : public ConversionTarget {
public:
explicit TritonGPUConversionTarget(MLIRContext &ctx,
TritonGPUTypeConverter &typeConverter);
/// update layouts & insert ConvertLayoutOp if necessary
LogicalResult refineLayouts(ModuleOp mod, int numThreads);
};
} // namespace mlir

View File

@@ -48,17 +48,17 @@ AxisInfo AxisInfo::getPessimisticValueState(Value value) {
divHint = attr.cast<IntegerAttr>().getValue().getZExtValue();
}
}
ContiguityT contiguity(rank, 1);
DivisibilityT divisibility(rank, divHint);
ConstancyT constancy(rank, 1);
DimVectorT contiguity(rank, 1);
DimVectorT divisibility(rank, divHint);
DimVectorT constancy(rank, 1);
return AxisInfo(contiguity, divisibility, constancy);
}
// The gcd of both arguments for each dimension
AxisInfo AxisInfo::join(const AxisInfo &lhs, const AxisInfo &rhs) {
ContiguityT retContiguity;
DivisibilityT retDivisibility;
ConstancyT retConstancy;
DimVectorT retContiguity;
DimVectorT retDivisibility;
DimVectorT retConstancy;
for (size_t d = 0; d < lhs.getRank(); d++) {
retContiguity.push_back(gcd(lhs.getContiguity(d), rhs.getContiguity(d)));
retDivisibility.push_back(
@@ -78,9 +78,9 @@ AxisInfo AxisInfoAnalysis::visitBinaryOp(
const std::function<int(AxisInfo, AxisInfo, int)> &getDivisibility,
const std::function<int(AxisInfo, AxisInfo, int)> &getConstancy) {
int rank = lhsInfo.getRank();
AxisInfo::ContiguityT newContiguity;
AxisInfo::DivisibilityT newDivisibility;
AxisInfo::ConstancyT newConstancy;
AxisInfo::DimVectorT newContiguity;
AxisInfo::DimVectorT newDivisibility;
AxisInfo::DimVectorT newConstancy;
for (size_t d = 0; d < rank; d++) {
newContiguity.push_back(getContiguity(lhsInfo, rhsInfo, d));
newDivisibility.push_back(getDivisibility(lhsInfo, rhsInfo, d));
@@ -101,9 +101,9 @@ ChangeResult AxisInfoAnalysis::visitOperation(
llvm::dyn_cast<triton::MakeRangeOp>(op)) {
int start = make_range.start();
int end = make_range.end();
AxisInfo::ContiguityT contiguity = {end - start};
AxisInfo::DivisibilityT divisibility = {highestPowOf2Divisor(start)};
AxisInfo::ConstancyT constancy = {1};
AxisInfo::DimVectorT contiguity = {end - start};
AxisInfo::DimVectorT divisibility = {highestPowOf2Divisor(start)};
AxisInfo::DimVectorT constancy = {1};
curr = AxisInfo(contiguity, divisibility, constancy);
}
// Constant
@@ -119,9 +119,9 @@ ChangeResult AxisInfoAnalysis::visitOperation(
auto value = splatAttr.getSplatValue<int>();
TensorType ty = splatAttr.getType().cast<TensorType>();
curr = AxisInfo(
AxisInfo::ContiguityT(ty.getRank(), 1),
AxisInfo::DivisibilityT(ty.getRank(), highestPowOf2Divisor(value)),
AxisInfo::ConstancyT(ty.getShape().begin(), ty.getShape().end()));
AxisInfo::DimVectorT(ty.getRank(), 1),
AxisInfo::DimVectorT(ty.getRank(), highestPowOf2Divisor(value)),
AxisInfo::DimVectorT(ty.getShape().begin(), ty.getShape().end()));
}
}
// Addition
@@ -156,9 +156,9 @@ ChangeResult AxisInfoAnalysis::visitOperation(
Type _retTy = *op->result_type_begin();
TensorType retTy = _retTy.cast<TensorType>();
AxisInfo opInfo = operands[0]->getValue();
AxisInfo::ContiguityT contiguity;
AxisInfo::DivisibilityT divisibility;
AxisInfo::ConstancyT constancy;
AxisInfo::DimVectorT contiguity;
AxisInfo::DimVectorT divisibility;
AxisInfo::DimVectorT constancy;
for (size_t d = 0; d < retTy.getRank(); d++) {
contiguity.push_back(1);
divisibility.push_back(opInfo.getDivisibility(0));
@@ -167,8 +167,7 @@ ChangeResult AxisInfoAnalysis::visitOperation(
curr = AxisInfo(contiguity, divisibility, constancy);
}
// Reshape
// TODO: Replace by `unsqueeze`
if (llvm::isa<triton::ReshapeOp>(op)) {
if (llvm::isa<triton::ViewOp>(op)) {
Type _retTy = *op->result_type_begin();
Type _opTy = *op->operand_type_begin();
TensorType retTy = _retTy.cast<TensorType>();
@@ -176,9 +175,9 @@ ChangeResult AxisInfoAnalysis::visitOperation(
ArrayRef<int64_t> retShape = retTy.getShape();
ArrayRef<int64_t> opShape = opTy.getShape();
AxisInfo opInfo = operands[0]->getValue();
AxisInfo::ContiguityT contiguity;
AxisInfo::DivisibilityT divisibility;
AxisInfo::ConstancyT constancy;
AxisInfo::DimVectorT contiguity;
AxisInfo::DimVectorT divisibility;
AxisInfo::DimVectorT constancy;
bool is_skewed = false;
size_t current = 0;
for (size_t d = 0; d < retTy.getRank(); d++) {
@@ -209,9 +208,9 @@ ChangeResult AxisInfoAnalysis::visitOperation(
ArrayRef<int64_t> retShape = retTy.getShape();
ArrayRef<int64_t> opShape = opTy.getShape();
AxisInfo opInfo = operands[0]->getValue();
AxisInfo::ContiguityT contiguity;
AxisInfo::DivisibilityT divisibility;
AxisInfo::ConstancyT constancy;
AxisInfo::DimVectorT contiguity;
AxisInfo::DimVectorT divisibility;
AxisInfo::DimVectorT constancy;
for (size_t d = 0; d < retTy.getRank(); d++) {
contiguity.push_back(opShape[d] == 1 ? 1 : opInfo.getContiguity(d));
divisibility.push_back(opInfo.getDivisibility(d));

View File

@@ -218,6 +218,28 @@ struct TritonGenericPattern : public OpConversionPattern<Op> {
}
};
struct TritonBroadcastPattern
: public OpConversionPattern<triton::BroadcastOp> {
using OpConversionPattern<triton::BroadcastOp>::OpConversionPattern;
// This creates a tensor with the new shape but the argument's layout
LogicalResult
matchAndRewrite(BroadcastOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto srcType = adaptor.src().getType().cast<RankedTensorType>();
auto srcEncoding = srcType.getEncoding();
if (!srcEncoding)
return failure();
auto opType = op.getType().cast<RankedTensorType>();
Type retType = RankedTensorType::get(opType.getShape(),
opType.getElementType(), srcEncoding);
// Type retType = this->getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<triton::BroadcastOp>(op, retType,
adaptor.getOperands());
return success();
}
};
struct TritonReducePattern : public OpConversionPattern<triton::ReduceOp> {
using OpConversionPattern<triton::ReduceOp>::OpConversionPattern;
@@ -234,12 +256,12 @@ struct TritonReducePattern : public OpConversionPattern<triton::ReduceOp> {
void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
RewritePatternSet &patterns) {
MLIRContext *context = patterns.getContext();
patterns.add<TritonGenericPattern<triton::ReshapeOp>,
TritonGenericPattern<triton::SplatOp>,
TritonGenericPattern<triton::BroadcastOp>,
TritonGenericPattern<triton::GEPOp>, TritonReducePattern,
TritonMakeRangePattern, TritonDotPattern, TritonLoadPattern,
TritonStorePattern>(typeConverter, context);
patterns.add< // TODO: view should have custom pattern that views the layout
TritonGenericPattern<triton::ViewOp>,
TritonGenericPattern<triton::SplatOp>, TritonBroadcastPattern,
TritonGenericPattern<triton::GEPOp>, TritonReducePattern,
TritonMakeRangePattern, TritonDotPattern, TritonLoadPattern,
TritonStorePattern>(typeConverter, context);
}
//
@@ -317,9 +339,8 @@ public:
void runOnOperation() override {
MLIRContext *context = &getContext();
ModuleOp mod = getOperation();
int numThreads = numWarps * 32;
// type converter
TritonGPUTypeConverter typeConverter(context, numThreads);
TritonGPUTypeConverter typeConverter(context, numWarps);
TritonGPUConversionTarget target(*context, typeConverter);
// rewrite patterns
RewritePatternSet patterns(context);
@@ -335,8 +356,8 @@ public:
// update layouts
// broadcast src => multicast, dst => broadcasted
if (failed(target.refineLayouts(mod, numWarps)))
return signalPassFailure();
// if (failed(target.refineLayouts(mod, numWarps)))
// return signalPassFailure();
}
};

View File

@@ -30,13 +30,28 @@ static LogicalResult parseIntArrayAttr(AsmParser &parser,
return success();
};
static LogicalResult parseUInt(AsmParser &parser, const NamedAttribute &attr,
unsigned &value, StringRef desc) {
auto intAttr = attr.getValue().dyn_cast<IntegerAttr>();
if (!intAttr) {
parser.emitError(parser.getNameLoc(), "expected an integer ") << desc;
return failure();
}
value = intAttr.getUInt();
return success();
};
//===----------------------------------------------------------------------===//
// Attribute methods
//===----------------------------------------------------------------------===//
#define GET_ATTRDEF_CLASSES
#include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.cpp.inc"
static Attribute parseBlocked(AsmParser &parser, Type type) {
//===----------------------------------------------------------------------===//
// Blocked Encoding
//===----------------------------------------------------------------------===//
Attribute TritonGPUBlockedEncodingAttr::parse(AsmParser &parser, Type type) {
if (parser.parseLess().failed())
return {};
// Parse the data as a dictionary
@@ -46,32 +61,30 @@ static Attribute parseBlocked(AsmParser &parser, Type type) {
if (parser.parseGreater().failed())
return {};
SmallVector<unsigned, 2> threadTileSize;
SmallVector<unsigned, 2> warpTileSize;
SmallVector<unsigned, 2> blockTileSize;
SmallVector<unsigned, 2> sizePerThread;
SmallVector<unsigned, 2> threadsPerWarp;
SmallVector<unsigned, 2> warpsPerCTA;
SmallVector<unsigned, 2> order;
SmallVector<unsigned, 2> broadcastAxis;
for (const NamedAttribute &attr : dict) {
if (attr.getName() == "threadTileSize") {
if (parseIntArrayAttr(parser, attr, threadTileSize, "thread tile size")
if (attr.getName() == "sizePerThread") {
if (parseIntArrayAttr(parser, attr, sizePerThread,
"number of elements per thread")
.failed())
return {};
} else if (attr.getName() == "warpTileSize") {
if (parseIntArrayAttr(parser, attr, warpTileSize, "warp tile size")
} else if (attr.getName() == "threadsPerWarp") {
if (parseIntArrayAttr(parser, attr, threadsPerWarp,
"number of threads per warp")
.failed())
return {};
} else if (attr.getName() == "blockTileSize") {
if (parseIntArrayAttr(parser, attr, blockTileSize, "block tile size")
} else if (attr.getName() == "warpsPerCTA") {
if (parseIntArrayAttr(parser, attr, warpsPerCTA,
"number of warps per CTA")
.failed())
return {};
} else if (attr.getName() == "order") {
if (parseIntArrayAttr(parser, attr, order, "order").failed())
return {};
} else if (attr.getName() == "broadcastAxis") {
if (parseIntArrayAttr(parser, attr, broadcastAxis, "broadcastAxis")
.failed())
return {};
} else {
parser.emitError(parser.getNameLoc(), "unexpected key: ")
<< attr.getName().strref();
@@ -80,39 +93,23 @@ static Attribute parseBlocked(AsmParser &parser, Type type) {
}
return parser.getChecked<TritonGPUBlockedEncodingAttr>(
parser.getContext(), threadTileSize, warpTileSize, blockTileSize, order,
broadcastAxis);
}
template <class T>
static void printBlocked(AsmPrinter &printer, const T *attr) {
printer << "<{"
<< "threadTileSize = [" << attr->getThreadTileSize() << "]"
<< ", warpTileSize = [" << attr->getWarpTileSize() << "]"
<< ", blockTileSize = [" << attr->getBlockTileSize() << "]"
<< ", order = [" << attr->getOrder() << "]"
<< ", broadcastAxis = [" << attr->getBroadcastAxis() << "]"
<< "}>";
}
Attribute TritonGPUBlockedEncodingAttr::parse(AsmParser &parser, Type type) {
return parseBlocked(parser, type);
parser.getContext(), sizePerThread, threadsPerWarp, warpsPerCTA, order);
}
void TritonGPUBlockedEncodingAttr::print(mlir::AsmPrinter &printer) const {
printBlocked(printer, this);
printer << "<{"
<< "sizePerThread = [" << getSizePerThread() << "]"
<< ", threadsPerWarp = [" << getThreadsPerWarp() << "]"
<< ", warpsPerCTA = [" << getWarpsPerCTA() << "]"
<< ", order = [" << getOrder() << "]"
<< "}>";
}
Attribute TritonGPUBlockedMulticastEncodingAttr::parse(AsmParser &parser,
Type type) {
return parseBlocked(parser, type);
}
//===----------------------------------------------------------------------===//
// MMA encoding
//===----------------------------------------------------------------------===//
void TritonGPUBlockedMulticastEncodingAttr::print(AsmPrinter &printer) const {
printBlocked(printer, this);
}
static Attribute parseMma(AsmParser &parser, Type type) {
Attribute TritonGPUMmaEncodingAttr::parse(AsmParser &parser, Type type) {
if (parser.parseLess().failed())
return {};
DictionaryAttr dict;
@@ -121,76 +118,34 @@ static Attribute parseMma(AsmParser &parser, Type type) {
if (parser.parseGreater().failed())
return {};
SmallVector<unsigned, 2> fragmentPerWarp;
SmallVector<unsigned, 2> shapePerWarp;
SmallVector<unsigned, 2> warpPerTile;
SmallVector<unsigned, 2> shapePerTile;
SmallVector<unsigned, 2> repetitions;
SmallVector<unsigned, 2> contigPerThread;
SmallVector<unsigned, 2> broadcastAxis;
unsigned version = 0;
SmallVector<unsigned, 2> warpsPerCTA;
for (const NamedAttribute &attr : dict) {
if (attr.getName() == "fragmentPerWarp") {
if (parseIntArrayAttr(parser, attr, fragmentPerWarp, "fragmentPerWarp")
.failed())
if (attr.getName() == "version") {
if (parseUInt(parser, attr, version, "version").failed())
return {};
} else if (attr.getName() == "shapePerWarp") {
if (parseIntArrayAttr(parser, attr, shapePerWarp, "shapePerWarp")
.failed())
}
if (attr.getName() == "warpsPerCTA") {
if (parseIntArrayAttr(parser, attr, warpsPerCTA, "warpsPerCTA").failed())
return {};
} else if (attr.getName() == "warpPerTile") {
if (parseIntArrayAttr(parser, attr, warpPerTile, "warpPerTile").failed())
return {};
} else if (attr.getName() == "shapePerTile") {
if (parseIntArrayAttr(parser, attr, shapePerTile, "shapePerTile")
.failed())
return {};
} else if (attr.getName() == "repetitions") {
if (parseIntArrayAttr(parser, attr, repetitions, "repetitions").failed())
return {};
} else if (attr.getName() == "contigPerThread") {
if (parseIntArrayAttr(parser, attr, contigPerThread, "contigPerThread")
.failed())
return {};
} else {
parser.emitError(parser.getNameLoc(), "unexpected key: ")
<< attr.getName().strref();
return {};
}
}
return parser.getChecked<TritonGPUMmaEncodingAttr>(
parser.getContext(), fragmentPerWarp, shapePerWarp, warpPerTile,
shapePerTile, repetitions, contigPerThread, broadcastAxis);
}
template <class T> static void printMma(AsmPrinter &printer, T *attr) {
printer << "<{"
<< "fragmentPerWarp = [" << attr->getFragmentPerWarp() << "]"
<< ", shapePerWarp = [" << attr->getShapePerWarp() << "]"
<< ", warpPerTile = [" << attr->getWarpPerTile() << "]"
<< ", shapePerTile = [" << attr->getShapePerTile() << "]"
<< ", repetitions = [" << attr->getRepetitions() << "]"
<< ", contigPerThread = [" << attr->getContigPerThread() << "]"
<< "}>";
}
Attribute TritonGPUMmaEncodingAttr::parse(AsmParser &parser, Type type) {
return parseMma(parser, type);
return parser.getChecked<TritonGPUMmaEncodingAttr>(parser.getContext(),
version, warpsPerCTA);
}
void TritonGPUMmaEncodingAttr::print(AsmPrinter &printer) const {
printMma(printer, this);
printer << "<{"
<< "version = " << getVersion() << ", "
<< "warpsPerCTA = [" << getWarpsPerCTA() << "]"
<< "}>";
}
Attribute TritonGPUMmaMulticastEncodingAttr::parse(AsmParser &parser,
Type type) {
return parseMma(parser, type);
}
void TritonGPUMmaMulticastEncodingAttr::print(AsmPrinter &printer) const {
printMma(printer, this);
}
//===----------------------------------------------------------------------===//
// Shared encoding
//===----------------------------------------------------------------------===//
Attribute TritonGPUSharedEncodingAttr::parse(AsmParser &parser, Type type) {
if (parser.parseLess().failed())
@@ -207,26 +162,15 @@ Attribute TritonGPUSharedEncodingAttr::parse(AsmParser &parser, Type type) {
unsigned maxPhase = 0;
SmallVector<unsigned, 2> order;
auto parseUInt = [&parser](const NamedAttribute &attr, unsigned &value,
StringRef desc) -> LogicalResult {
auto intAttr = attr.getValue().dyn_cast<IntegerAttr>();
if (!intAttr) {
parser.emitError(parser.getNameLoc(), "expected an integer ") << desc;
return failure();
}
value = intAttr.getUInt();
return success();
};
for (const NamedAttribute &attr : dict) {
if (attr.getName() == "vec") {
if (parseUInt(attr, vec, "vec").failed())
if (parseUInt(parser, attr, vec, "vec").failed())
return {};
} else if (attr.getName() == "perPhase") {
if (parseUInt(attr, perPhase, "perPhase").failed())
if (parseUInt(parser, attr, perPhase, "perPhase").failed())
return {};
} else if (attr.getName() == "maxPhase") {
if (parseUInt(attr, maxPhase, "maxPhase").failed())
if (parseUInt(parser, attr, maxPhase, "maxPhase").failed())
return {};
} else if (attr.getName() == "order") {
if (parseIntArrayAttr(parser, attr, order, "order").failed())
@@ -250,6 +194,10 @@ void TritonGPUSharedEncodingAttr::print(AsmPrinter &printer) const {
<< "}>";
}
//===----------------------------------------------------------------------===//
// ASM Interface (i.e.: alias)
//===----------------------------------------------------------------------===//
class TritonGPUOpAsmInterface : public OpAsmDialectInterface {
public:
using OpAsmDialectInterface::OpAsmDialectInterface;
@@ -257,72 +205,18 @@ public:
AliasResult getAlias(Attribute attr, raw_ostream &os) const override {
if (auto mmaAttr = attr.dyn_cast<TritonGPUMmaEncodingAttr>()) {
os << "mma";
TritonGPUOpAsmInterface::printMma(mmaAttr, os);
return AliasResult::FinalAlias;
} else if (auto mmaMulticastAttr =
attr.dyn_cast<TritonGPUMmaMulticastEncodingAttr>()) {
os << "mma_multicast";
TritonGPUOpAsmInterface::printMma(mmaAttr, os);
return AliasResult::FinalAlias;
} else if (auto sharedAttr = attr.dyn_cast<TritonGPUSharedEncodingAttr>()) {
os << "shared";
TritonGPUOpAsmInterface::printShared(sharedAttr, os);
return AliasResult::FinalAlias;
} else if (auto blockedAttr =
attr.dyn_cast<TritonGPUBlockedEncodingAttr>()) {
os << "blocked";
TritonGPUOpAsmInterface::printBlocked(blockedAttr, os);
return AliasResult::FinalAlias;
} else if (auto blockedMulticastAttr =
attr.dyn_cast<TritonGPUBlockedMulticastEncodingAttr>()) {
os << "blocked_multicast";
TritonGPUOpAsmInterface::printBlocked(blockedMulticastAttr, os);
return AliasResult::FinalAlias;
}
OpAsmDialectInterface::getAlias(attr, os);
return AliasResult::FinalAlias;
}
private:
static void printMma(const TritonGPUMmaEncodingAttr &attr, raw_ostream &os) {
TritonGPUOpAsmInterface::printArray(attr.getFragmentPerWarp(), os);
TritonGPUOpAsmInterface::printArray(attr.getShapePerWarp(), os);
TritonGPUOpAsmInterface::printArray(attr.getWarpPerTile(), os);
TritonGPUOpAsmInterface::printArray(attr.getShapePerTile(), os);
TritonGPUOpAsmInterface::printArray(attr.getRepetitions(), os);
TritonGPUOpAsmInterface::printArray(attr.getContigPerThread(), os);
}
static void printShared(const TritonGPUSharedEncodingAttr &attr,
raw_ostream &os) {
os << "_" << attr.getVec();
os << "_" << attr.getPerPhase();
os << "_" << attr.getMaxPhase();
TritonGPUOpAsmInterface::printArray(attr.getOrder(), os);
}
template <class T> static void printBlocked(const T &attr, raw_ostream &os) {
TritonGPUOpAsmInterface::printArray(attr.getThreadTileSize(), os);
TritonGPUOpAsmInterface::printArray(attr.getWarpTileSize(), os);
TritonGPUOpAsmInterface::printArray(attr.getBlockTileSize(), os);
TritonGPUOpAsmInterface::printArray(attr.getOrder(), os);
TritonGPUOpAsmInterface::printArray(attr.getBroadcastAxis(), os);
}
static void printArray(const ArrayRef<unsigned> &array, raw_ostream &os,
const std::string &delimiter = "x") {
os << "_";
if (array.empty()) {
os << "none";
return;
}
for (unsigned i = 0; i < array.size(); i++) {
os << array[i];
if (i != array.size() - 1) {
os << delimiter;
}
}
}
};
void TritonGPUDialect::initialize() {

View File

@@ -3,6 +3,7 @@
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include <algorithm>
#include <numeric>
using namespace mlir;
using namespace mlir::triton::gpu;
@@ -11,54 +12,26 @@ using namespace mlir::triton::gpu;
// TypeConverter
//
TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
int numThreads)
: context(context), numThreads(numThreads) {
int numWarps)
: context(context), numWarps(numWarps) {
// TODO: how does MLIR pick the right conversion?
addConversion([](Type type) { return type; });
addConversion([this](RankedTensorType tensorType) -> RankedTensorType {
MLIRContext *context = this->context;
int numThreads = this->numThreads;
llvm::ArrayRef<int64_t> shape = tensorType.getShape();
Type elementType = tensorType.getElementType();
int64_t rank = tensorType.getRank();
int64_t numElements = tensorType.getNumElements();
// TODO: are there any better ways to raise this error?
if (!(numElements >= numThreads)) {
SmallVector<char> buffer;
llvm::raw_svector_ostream os(buffer);
os << tensorType << " has " << numElements << " numElements "
<< " smaller than numThreads (" << numThreads << ")\n"
<< "consider using smaller num-warps\n";
llvm::report_fatal_error(os.str());
}
assert(numElements % numThreads == 0);
// or assert no encoding?
// Now we assume:
// contiguous = 1, order = 0, 1, 2, ...,
llvm::SmallVector<unsigned> threadTileSize(rank, 1); // naive layout
llvm::SmallVector<unsigned> warpTileSize(rank, 1);
llvm::SmallVector<unsigned> blockTileSize(rank);
// types with encoding are already in the right format
// TODO: check for layout encodings specifically
if (tensorType.getEncoding())
return tensorType;
// pessimistic values for attributes:
// - 1 element per thread
// - order = arange(rank)
ArrayRef<int64_t> shape = tensorType.getShape();
int rank = shape.size();
llvm::SmallVector<unsigned> order(rank);
llvm::SmallVector<unsigned> broadcastAxis;
int remainingThreads = numThreads;
int remainingLanes = /*warp size*/ 32;
for (int64_t dim = 0; dim < rank; ++dim) {
blockTileSize[dim] = std::clamp(remainingThreads, 1, int(shape[dim]));
warpTileSize[dim] = std::clamp(remainingLanes, 1, int(shape[dim]));
order[dim] = dim;
remainingThreads /= blockTileSize[dim];
remainingLanes /= warpTileSize[dim];
// TODO: will we need repetition?
}
std::iota(order.begin(), order.end(), 0);
llvm::SmallVector<unsigned> sizePerThread(rank, 1);
Attribute encoding = triton::gpu::TritonGPUBlockedEncodingAttr::get(
context, threadTileSize, warpTileSize, blockTileSize, order,
broadcastAxis);
return RankedTensorType::get(shape, elementType, encoding);
this->context, shape, sizePerThread, order, this->numWarps);
return RankedTensorType::get(shape, tensorType.getElementType(), encoding);
});
//
@@ -86,8 +59,12 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
// NOTE: only for remapped values.
addTargetMaterialization([&](OpBuilder &builder, RankedTensorType tensorType,
ValueRange inputs, Location loc) {
llvm_unreachable("Not implemented");
return llvm::None;
auto cast =
builder.create<triton::gpu::ConvertLayoutOp>(loc, tensorType, inputs);
return Optional<Value>(cast.getResult());
// return Optional<Value>(cast.getResult(0));
// llvm_unreachable("Not implemented");
// return llvm::None;
});
}
@@ -122,87 +99,6 @@ TritonGPUConversionTarget::TritonGPUConversionTarget(
aEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>() &&
bEncoding && bEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>())
return true;
// // TODO: we should delete this
// if (this->typeConverter.isLegal(dotOp))
// return true;
return false;
});
}
// %dst = tt.broadcast %src
// =>
// %newSrc = convert_layout %src
// %bcst = tt.broadcast %newSrc
// %dst = convert_layout %bcst
LogicalResult TritonGPUConversionTarget::refineLayouts(ModuleOp mod,
int numThreads) {
// collect broadcasts
SmallVector<triton::BroadcastOp> broadcasts;
mod.walk([&](triton::BroadcastOp op) { broadcasts.push_back(op); });
BlockAndValueMapping mapping;
for (auto broadcast : broadcasts) {
OpBuilder builder(broadcast);
Value src = mapping.lookupOrDefault(broadcast.src());
Type originSrcType = src.getType();
Type originDstType = broadcast.getType();
auto originDstTensorType = originDstType.dyn_cast<RankedTensorType>();
unsigned dstRank = originDstTensorType.getRank();
// compute newSrcType & broadcastAxis
Type newSrcType;
SmallVector<unsigned> broadcastAxis;
bool isSrcScalar = false;
if (auto tensorType = originSrcType.dyn_cast<RankedTensorType>()) {
assert(tensorType.getRank() == dstRank &&
"src & dst should have same rank (verifier should catch this)");
for (unsigned ax = 0; ax < dstRank; ++ax)
if (tensorType.getShape()[ax] < originDstTensorType.getShape()[ax])
broadcastAxis.push_back(ax);
Attribute originSrcEnc = tensorType.getEncoding();
if (auto blockedEnc =
originSrcEnc.dyn_cast<TritonGPUBlockedEncodingAttr>()) {
auto newSrcEnc = TritonGPUBlockedMulticastEncodingAttr::get(
blockedEnc.getContext(), blockedEnc.getThreadTileSize(),
blockedEnc.getWarpTileSize(), blockedEnc.getBlockTileSize(),
blockedEnc.getOrder(), broadcastAxis);
newSrcType = RankedTensorType::get(
tensorType.getShape(), tensorType.getElementType(), newSrcEnc);
} else
llvm_unreachable("src of broadcast should have blocked encoding");
} else {
for (unsigned ax = 0; ax < dstRank; ++ax)
broadcastAxis.push_back(ax);
newSrcType = originSrcType;
isSrcScalar = true;
}
// create new src
if (!isSrcScalar) // we don't need to convert layout for scalar values
src = builder.create<triton::gpu::ConvertLayoutOp>(src.getLoc(),
newSrcType, src);
// create new broadcast
// compute new type (encoding)
auto originDstEnc = originDstTensorType.getEncoding()
.dyn_cast<TritonGPUBlockedEncodingAttr>();
auto newEnc = TritonGPUBlockedEncodingAttr::get(
originDstEnc.getContext(), originDstEnc.getThreadTileSize(),
originDstEnc.getWarpTileSize(), originDstEnc.getBlockTileSize(),
originDstEnc.getOrder(), broadcastAxis);
auto newType =
RankedTensorType::get(originDstTensorType.getShape(),
originDstTensorType.getElementType(), newEnc);
Value newBroadcast =
builder.create<triton::BroadcastOp>(broadcast.getLoc(), newType, src);
// we don't want to change the encoding of the result
Value newDst = builder.create<triton::gpu::ConvertLayoutOp>(
broadcast.getLoc(), originDstType, newBroadcast);
broadcast.replaceAllUsesWith(newDst);
mapping.map(broadcast, newDst);
}
return success();
}
}

View File

@@ -10,10 +10,10 @@ def kernel(X, stride_xm, stride_xn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
off_m = tl.arange(0, BLOCK_M)
off_n = tl.arange(0, BLOCK_N)
Xs = X + off_m[:, None] * stride_xm + off_n[None, :] * stride_xn
Zs = Z + off_m[:, None] * stride_zm + off_n[None, :] * stride_zn
Xs = X + off_m[:, None] * stride_xm + off_n[None, :] * 1
Zs = Z + off_m[:, None] * 1 + off_n[None, :] * stride_zn
tl.store(Zs, tl.load(Xs))
ret = triton.compile(kernel, "*fp32,i32,i32,*fp32,i32,i32", constants={"BLOCK_M": 128, "BLOCK_N": 128}, output="ttgir")
ret = triton.compile(kernel, "*fp32,i32,i32,*fp32,i32,i32", constants={"BLOCK_M": 64, "BLOCK_N": 64}, output="ttgir")
print(ret)

View File

@@ -1471,14 +1471,14 @@ void init_triton_ir(py::module &&m) {
self.create<mlir::triton::StoreOp>(loc, ptrs, val, mask);
})
// Block instruction
.def("create_reshape",
.def("create_view",
[](mlir::OpBuilder &self, mlir::Value &arg,
std::vector<int64_t> &shape) -> mlir::Value {
auto loc = self.getUnknownLoc();
auto argType = arg.getType()
.dyn_cast<mlir::RankedTensorType>()
.getElementType();
return self.create<mlir::triton::ReshapeOp>(
return self.create<mlir::triton::ViewOp>(
loc, mlir::RankedTensorType::get(shape, argType), arg);
})
.def("create_cat",

View File

@@ -565,7 +565,7 @@ class tensor:
elif sl == slice(None, None, None):
dst_shape.append(src_shape[curr].value)
curr += 1
ret = semantic.reshape(self, dst_shape, _builder)
ret = semantic.view(self, dst_shape, _builder)
return ret
@builtin

View File

@@ -451,16 +451,16 @@ def zeros(shape: List[int], dtype: tl.dtype, builder: ir.builder) -> tl.tensor:
# ===----------------------------------------------------------------------===//
def reshape(input: tl.tensor,
dst_shape: List[int],
builder: ir.builder) -> tl.tensor:
def view(input: tl.tensor,
dst_shape: List[int],
builder: ir.builder) -> tl.tensor:
numel = 1
for s in dst_shape:
numel *= s
if input.type.numel != numel:
raise ValueError("cannot reshape block of different shape")
ret_ty = tl.block_type(input.type.scalar, dst_shape)
return tl.tensor(builder.create_reshape(input.handle, dst_shape), ret_ty)
return tl.tensor(builder.create_view(input.handle, dst_shape), ret_ty)
def cat(lhs: tl.tensor, rhs: tl.tensor, builder: ir.builder) -> tl.tensor:

View File

@@ -10,7 +10,7 @@ func @permute_2d(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {t
// CHECK-NEXT: Contiguity: [128] ; Divisibility: [65536] ; Constancy: [1]
%1 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
// CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [65536, 1] ; Constancy: [1, 1]
%2 = tt.reshape %0 : (tensor<128xi32>) -> tensor<128x1xi32>
%2 = tt.view %0 : (tensor<128xi32>) -> tensor<128x1xi32>
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [128, 1]
%3 = tt.splat %arg1 : (i32) -> tensor<128x1xi32>
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [1048576, 16] ; Constancy: [1, 1]
@@ -20,7 +20,7 @@ func @permute_2d(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {t
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [1, 1]
%6 = tt.getelementptr %5, %4 : tensor<128x1x!tt.ptr<f32>>
// CHECK-NEXT: Contiguity: [1, 128] ; Divisibility: [1, 65536] ; Constancy: [1, 1]
%7 = tt.reshape %1 : (tensor<128xi32>) -> tensor<1x128xi32>
%7 = tt.view %1 : (tensor<128xi32>) -> tensor<1x128xi32>
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [1, 128]
%8 = tt.broadcast %6 : (tensor<128x1x!tt.ptr<f32>>) -> tensor<128x128x!tt.ptr<f32>>
// CHECK-NEXT: Contiguity: [1, 128] ; Divisibility: [1, 65536] ; Constancy: [128, 1]
@@ -28,13 +28,13 @@ func @permute_2d(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {t
// CHECK-NEXT: Contiguity: [1, 128] ; Divisibility: [1, 16] ; Constancy: [1, 1]
%10 = tt.getelementptr %8, %9 : tensor<128x128x!tt.ptr<f32>>
// CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [65536, 1] ; Constancy: [1, 1]
%11 = tt.reshape %0 : (tensor<128xi32>) -> tensor<128x1xi32>
%11 = tt.view %0 : (tensor<128xi32>) -> tensor<128x1xi32>
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [128, 1]
%12 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<128x1x!tt.ptr<f32>>
// CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [16, 1] ; Constancy: [1, 1]
%13 = tt.getelementptr %12, %11 : tensor<128x1x!tt.ptr<f32>>
// CHECK-NEXT: Contiguity: [1, 128] ; Divisibility: [1, 65536] ; Constancy: [1, 1]
%14 = tt.reshape %1 : (tensor<128xi32>) -> tensor<1x128xi32>
%14 = tt.view %1 : (tensor<128xi32>) -> tensor<1x128xi32>
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [1, 128]
%15 = tt.splat %arg3 : (i32) -> tensor<1x128xi32>
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 1048576] ; Constancy: [1, 1]

View File

@@ -1,26 +0,0 @@
// RUN: triton-opt %s -split-input-file -verify-diagnostics
#reg = #triton_gpu.blocked_layout<{
threadTileSize = [1, 1],
warpTileSize = [32, 1],
blockTileSize = [64, 1],
order = [0]
}>
#reg2 = #triton_gpu.blocked_layout<{
threadTileSize = [2, 1],
warpTileSize = [64, 1],
blockTileSize = [128, 1],
order = [0]
}>
func @add(%arg0: tensor<256xi32, #reg>, %arg1: tensor<256xi32, #reg>) {
%0 = arith.addi %arg0, %arg1 : tensor<256xi32, #reg>
return
}
func @add(%arg0: tensor<256xi32, #reg>, %arg1: tensor<256xi32, #reg>) { // expected-note {{prior use here}}
// expected-error @+1 {{use of value '%arg0' expects different type than prior uses}}
%0 = arith.addi %arg0, %arg1 : tensor<256xi32, #reg2>
return
}

View File

@@ -1,45 +1,12 @@
// RUN: triton-opt %s -split-input-file -tritongpu-pipeline=num-stages=3 -canonicalize -tritongpu-verifier
// RUN: triton-opt %s -split-input-file -tritongpu-pipeline=num-stages=3 -canonicalize -tritongpu-verifier | FileCheck %s
// 4 warps
// matmul: 128x32 @ 32x128 -> 128x128
#AL = #triton_gpu.blocked_layout<{
threadTileSize = [1, 4],
warpTileSize = [4, 32],
blockTileSize = [16, 32],
order = [1, 0]
}>
#BL = #triton_gpu.blocked_layout<{
threadTileSize = [1, 4],
warpTileSize = [1, 128],
blockTileSize = [4, 128],
order = [1, 0]
}>
#A = #triton_gpu.shared_layout<{
vec = 2,
perPhase = 2,
maxPhase = 4,
order = [1, 0]
}>
#B = #triton_gpu.shared_layout<{
vec = 2,
perPhase = 2,
maxPhase = 4,
order = [1, 0]
}>
// TODO: check this
#C = #triton_gpu.mma_layout<{
fragmentPerWarp = [1, 1],
shapePerWarp = [16, 8],
warpPerTile = [2, 2],
shapePerTile = [32, 16],
repetitions = [4, 4],
contigPerThread = [1, 8]
}>
#AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
#BL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
#A = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>
#B = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>
#C = #triton_gpu.mma<{version = 2, warpsPerCTA = [4, 1]}>
// CHECK: func @matmul_loop
// CHECK: %[[A0:.*]] = triton_gpu.copy_async