special encoding for broadcast

This commit is contained in:
Yan Da
2022-06-18 21:16:45 +08:00
parent 53cf93ce6a
commit 9d1b5e3f79
6 changed files with 248 additions and 72 deletions

View File

@@ -4,8 +4,9 @@
include "triton/Dialect/TritonGPU/IR/TritonGPUDialect.td"
// include "mlir/IR/TensorEncoding.td"
class TritonGPU_Attr<string name, list<Trait> traits = []>
: AttrDef<TritonGPU_Dialect, name, traits>;
class TritonGPU_Attr<string name, list<Trait> traits = [],
string baseCppClass = "::mlir::Attribute">
: AttrDef<TritonGPU_Dialect, name, traits, baseCppClass>;
def TritonGPUSharedEncodingAttr : TritonGPU_Attr<"TritonGPUSharedEncoding"> {
let mnemonic = "shared_layout";
@@ -104,7 +105,8 @@ And the associated TritonGPU MLIR
ArrayRefParameter<
"unsigned",
"order of axes by the rate of changing"
>:$order
>:$order,
ArrayRefParameter<"unsigned">:$broadcastAxis
// "AffineMap":$threadOrdering,
// "AffineMap":warpOrdering,
// "AffineMap":$blockOrdering,
@@ -114,6 +116,28 @@ And the associated TritonGPU MLIR
// let genVerifyDecl = 1;
}
def TritonGPUBlockedMulticastEncodingAttr
: TritonGPU_Attr<"TritonGPUBlockedMulticastEncoding"> {
let mnemonic = "blocked_multicast_layout";
let description = [{
to be broadcasted to blocked_layout
}];
// This needs to be synced with BlockedEncoding
let parameters = (
ins
ArrayRefParameter<"unsigned">:$threadTileSize,
ArrayRefParameter<"unsigned">:$warpTileSize,
ArrayRefParameter<"unsigned">:$blockTileSize,
ArrayRefParameter<"unsigned">:$order,
// unique to broadcasted layout
ArrayRefParameter<"unsigned">:$broadcastAxis
);
// let genVerifyDecl = 1;
}
def TritonGPUMmaEncodingAttr : TritonGPU_Attr<"TritonGPUMmaEncoding"> {
let mnemonic = "mma_layout";
@@ -131,7 +155,8 @@ def TritonGPUMmaEncodingAttr : TritonGPU_Attr<"TritonGPUMmaEncoding"> {
ArrayRefParameter<"unsigned">:$shapePerTile,
// TODO: should Distributed layout also
ArrayRefParameter<"unsigned">:$repetitions,
ArrayRefParameter<"unsigned">:$contigPerThread
ArrayRefParameter<"unsigned">:$contigPerThread,
ArrayRefParameter<"unsigned">:$broadcastAxis
// "AffineMap":$warpOrdering,
// "AffineMap":$blockOrdering
);
@@ -139,4 +164,28 @@ def TritonGPUMmaEncodingAttr : TritonGPU_Attr<"TritonGPUMmaEncoding"> {
// let genVerifyDecl = 1;
}
def TritonGPUMmaMulticastEncodingAttr
: TritonGPU_Attr<"TritonGPUMmaMulticastEncoding"> {
let mnemonic = "mma_multicast_layout";
let description = [{
To be broadcasted to mma.
}];
// This needs to be synced with MmaEncoding
let parameters = (
ins
ArrayRefParameter<"unsigned">:$fragmentPerWarp,
ArrayRefParameter<"unsigned">:$shapePerWarp,
ArrayRefParameter<"unsigned">:$warpPerTile,
ArrayRefParameter<"unsigned">:$shapePerTile,
ArrayRefParameter<"unsigned">:$repetitions,
ArrayRefParameter<"unsigned">:$contigPerThread,
// unique to broadcasted layout
ArrayRefParameter<"unsigned">:$broadcastAxis
);
// let genVerifyDecl = 1;
}
#endif

View File

@@ -23,6 +23,9 @@ class TritonGPUConversionTarget : public ConversionTarget {
TritonGPUTypeConverter &typeConverter;
public:
explicit TritonGPUConversionTarget(MLIRContext &ctx, TritonGPUTypeConverter &typeConverter);
/// update layouts & insert ConvertLayoutOp if necessary
LogicalResult refineLayouts(ModuleOp mod, int numThreads);
};
} // namespace mlir