more progress on the definition of layouts
This commit is contained in:
@@ -8,12 +8,14 @@ class TritonGPU_Attr<string name, list<Trait> traits = []>
|
||||
: AttrDef<TritonGPU_Dialect, name, traits>;
|
||||
|
||||
def TritonGPUSharedEncodingAttr : TritonGPU_Attr<"TritonGPUSharedEncoding"> {
|
||||
let mnemonic = "shared (memory) encoding";
|
||||
let mnemonic = "shared_layout";
|
||||
|
||||
let description = [{
|
||||
An encoding for tensors whose elements may be simultaneously accessed by different warps in the programs, via shared memory.
|
||||
An encoding for tensors whose elements may be simultaneously accessed by
|
||||
different warps in the programs, via shared memory.
|
||||
|
||||
In order to avoid shared memory bank conflicts, elements may be stored in a swizzled layout.
|
||||
In order to avoid shared memory bank conflicts, elements may be stored in a
|
||||
swizzled layout.
|
||||
For example, a swizzled row-major layout stores would store data as follows:
|
||||
|
||||
A_{0, 0} A_{0, 1} A_{0, 2} A_{0, 3} ... [phase 0] \ per_phase = 2
|
||||
@@ -29,10 +31,11 @@ A_{3, 2} A_{3, 3} A_{3, 0} A_{3, 1} ... [phase 1] /
|
||||
And the associated TritonGPU MLIR
|
||||
|
||||
```mlir
|
||||
#SMEM = #triton_gpu.encoding<{
|
||||
#SMEM = #triton_gpu.shared_layout<{
|
||||
vec = 2,
|
||||
perPhase = 2,
|
||||
maxPhase = 4
|
||||
maxPhase = 4,
|
||||
order = [1, 0]
|
||||
}>
|
||||
```
|
||||
}];
|
||||
@@ -40,12 +43,13 @@ And the associated TritonGPU MLIR
|
||||
let parameters = (
|
||||
ins
|
||||
// swizzle info
|
||||
"unsigned":$vec, "unsigned":$perPhase, "unsigned":$maxPhase
|
||||
"unsigned":$vec, "unsigned":$perPhase, "unsigned":$maxPhase,
|
||||
ArrayRefParameter<"unsigned", "order of axes by the rate of changing">:$order
|
||||
);
|
||||
}
|
||||
|
||||
def TritonGPUDistributedEncodingAttr : TritonGPU_Attr<"TritonGPUDistributedEncoding"> {
|
||||
let mnemonic = "coalesced encoding";
|
||||
def TritonGPUShardedEncodingAttr : TritonGPU_Attr<"TritonGPUShardedEncoding"> {
|
||||
let mnemonic = "sharded_layout";
|
||||
|
||||
let description = [{
|
||||
An encoding where each warp owns a contiguous portion of the target tensor. This is typically the kind of data layout
|
||||
@@ -70,7 +74,7 @@ size } ....
|
||||
A_{63, 0}[T60] A_{63, 1}[T60] ... A_{63, 6}[T63] A_{63, 7}[T63] A_{63, 8}[T60] A_{63, 9}[T60] ... A_{63, 14}[T63] A_{63, 15}[T63]
|
||||
|
||||
And the associated TritonGPU MLIR
|
||||
#SMEM = #triton_gpu.encoding<{
|
||||
#LAYOUT = #triton_gpu.sharded_layout<{
|
||||
threadTileSize = {2, 2}
|
||||
blockTileSize = {32, 8}
|
||||
}>
|
||||
@@ -81,28 +85,55 @@ And the associated TritonGPU MLIR
|
||||
|
||||
let parameters = (
|
||||
ins
|
||||
ArrayRefParameter<"unsigned">:$threadTileSize,
|
||||
ArrayRefParameter<"unsigned">:$blockTileSize,
|
||||
// TODO: should we rename this as laneTileSize?
|
||||
ArrayRefParameter<
|
||||
"unsigned",
|
||||
/*desc*/"size of a tile that is holded by a thread"
|
||||
>:$threadTileSize,
|
||||
ArrayRefParameter<
|
||||
"unsigned",
|
||||
"size of the a tile that is holded by a warp"
|
||||
>:$warpTileSize,
|
||||
ArrayRefParameter<
|
||||
"unsigned",
|
||||
"size of a tile that is holded by a thread block"
|
||||
>:$blockTileSize,
|
||||
// // TODO: It seems that we don't need this (because we can re-compute this)
|
||||
// ArrayRefParameter<"unsigned">:$reptitions,
|
||||
// fastest-changing axis first
|
||||
ArrayRefParameter<"unsigned">:$order
|
||||
ArrayRefParameter<
|
||||
"unsigned",
|
||||
"order of axes by the rate of changing"
|
||||
>:$order
|
||||
// "AffineMap":$threadOrdering,
|
||||
// "AffineMap":warpOrdering,
|
||||
// "AffineMap":$blockOrdering,
|
||||
|
||||
);
|
||||
|
||||
// let genVerifyDecl = 1;
|
||||
}
|
||||
|
||||
def TritonGPUMmaEncodingAttr : TritonGPU_Attr<"TritonGPUMmaEncoding"> {
|
||||
let mnemonic = "mma encoding";
|
||||
let mnemonic = "mma_layout";
|
||||
|
||||
let description = [{TODO: I think we may be able to implement it as a special-case of Distributed encoding with maybe one more warpTileSize attribute!}];
|
||||
|
||||
let parameters = (
|
||||
ins
|
||||
// only used by Volta mma.884
|
||||
ArrayRefParameter<"unsigned">:$fragmentPerWarp,
|
||||
// aka shapeOfInstr (e.g., {16,8,16})
|
||||
ArrayRefParameter<"unsigned">:$shapePerWarp,
|
||||
// TODO: should we rename this as warpTileSize? (consistent naming with Distributed layout)
|
||||
ArrayRefParameter<"unsigned">:$warpPerTile,
|
||||
// TODO: should we rename this as blockTileSize? (consistent naming with Distributed layout)
|
||||
ArrayRefParameter<"unsigned">:$shapePerTile,
|
||||
// TODO: should Distributed layout also
|
||||
ArrayRefParameter<"unsigned">:$reptitions,
|
||||
ArrayRefParameter<"unsigned">:$contigPerThread
|
||||
// "AffineMap":$warpOrdering,
|
||||
// "AffineMap":$blockOrdering
|
||||
);
|
||||
|
||||
// let genVerifyDecl = 1;
|
||||
|
Reference in New Issue
Block a user