special encoding for broadcast
This commit is contained in:
@@ -4,8 +4,9 @@
|
||||
include "triton/Dialect/TritonGPU/IR/TritonGPUDialect.td"
|
||||
// include "mlir/IR/TensorEncoding.td"
|
||||
|
||||
class TritonGPU_Attr<string name, list<Trait> traits = []>
|
||||
: AttrDef<TritonGPU_Dialect, name, traits>;
|
||||
class TritonGPU_Attr<string name, list<Trait> traits = [],
|
||||
string baseCppClass = "::mlir::Attribute">
|
||||
: AttrDef<TritonGPU_Dialect, name, traits, baseCppClass>;
|
||||
|
||||
def TritonGPUSharedEncodingAttr : TritonGPU_Attr<"TritonGPUSharedEncoding"> {
|
||||
let mnemonic = "shared_layout";
|
||||
@@ -104,7 +105,8 @@ And the associated TritonGPU MLIR
|
||||
ArrayRefParameter<
|
||||
"unsigned",
|
||||
"order of axes by the rate of changing"
|
||||
>:$order
|
||||
>:$order,
|
||||
ArrayRefParameter<"unsigned">:$broadcastAxis
|
||||
// "AffineMap":$threadOrdering,
|
||||
// "AffineMap":warpOrdering,
|
||||
// "AffineMap":$blockOrdering,
|
||||
@@ -114,6 +116,28 @@ And the associated TritonGPU MLIR
|
||||
// let genVerifyDecl = 1;
|
||||
}
|
||||
|
||||
def TritonGPUBlockedMulticastEncodingAttr
|
||||
: TritonGPU_Attr<"TritonGPUBlockedMulticastEncoding"> {
|
||||
let mnemonic = "blocked_multicast_layout";
|
||||
|
||||
let description = [{
|
||||
to be broadcasted to blocked_layout
|
||||
}];
|
||||
|
||||
// This needs to be synced with BlockedEncoding
|
||||
let parameters = (
|
||||
ins
|
||||
ArrayRefParameter<"unsigned">:$threadTileSize,
|
||||
ArrayRefParameter<"unsigned">:$warpTileSize,
|
||||
ArrayRefParameter<"unsigned">:$blockTileSize,
|
||||
ArrayRefParameter<"unsigned">:$order,
|
||||
// unique to broadcasted layout
|
||||
ArrayRefParameter<"unsigned">:$broadcastAxis
|
||||
);
|
||||
|
||||
// let genVerifyDecl = 1;
|
||||
}
|
||||
|
||||
def TritonGPUMmaEncodingAttr : TritonGPU_Attr<"TritonGPUMmaEncoding"> {
|
||||
let mnemonic = "mma_layout";
|
||||
|
||||
@@ -131,7 +155,8 @@ def TritonGPUMmaEncodingAttr : TritonGPU_Attr<"TritonGPUMmaEncoding"> {
|
||||
ArrayRefParameter<"unsigned">:$shapePerTile,
|
||||
// TODO: should Distributed layout also
|
||||
ArrayRefParameter<"unsigned">:$repetitions,
|
||||
ArrayRefParameter<"unsigned">:$contigPerThread
|
||||
ArrayRefParameter<"unsigned">:$contigPerThread,
|
||||
ArrayRefParameter<"unsigned">:$broadcastAxis
|
||||
// "AffineMap":$warpOrdering,
|
||||
// "AffineMap":$blockOrdering
|
||||
);
|
||||
@@ -139,4 +164,28 @@ def TritonGPUMmaEncodingAttr : TritonGPU_Attr<"TritonGPUMmaEncoding"> {
|
||||
// let genVerifyDecl = 1;
|
||||
}
|
||||
|
||||
def TritonGPUMmaMulticastEncodingAttr
|
||||
: TritonGPU_Attr<"TritonGPUMmaMulticastEncoding"> {
|
||||
let mnemonic = "mma_multicast_layout";
|
||||
|
||||
let description = [{
|
||||
To be broadcasted to mma.
|
||||
}];
|
||||
|
||||
// This needs to be synced with MmaEncoding
|
||||
let parameters = (
|
||||
ins
|
||||
ArrayRefParameter<"unsigned">:$fragmentPerWarp,
|
||||
ArrayRefParameter<"unsigned">:$shapePerWarp,
|
||||
ArrayRefParameter<"unsigned">:$warpPerTile,
|
||||
ArrayRefParameter<"unsigned">:$shapePerTile,
|
||||
ArrayRefParameter<"unsigned">:$repetitions,
|
||||
ArrayRefParameter<"unsigned">:$contigPerThread,
|
||||
// unique to broadcasted layout
|
||||
ArrayRefParameter<"unsigned">:$broadcastAxis
|
||||
);
|
||||
|
||||
// let genVerifyDecl = 1;
|
||||
}
|
||||
|
||||
#endif
|
||||
|
@@ -23,6 +23,9 @@ class TritonGPUConversionTarget : public ConversionTarget {
|
||||
TritonGPUTypeConverter &typeConverter;
|
||||
public:
|
||||
explicit TritonGPUConversionTarget(MLIRContext &ctx, TritonGPUTypeConverter &typeConverter);
|
||||
|
||||
/// update layouts & insert ConvertLayoutOp if necessary
|
||||
LogicalResult refineLayouts(ModuleOp mod, int numThreads);
|
||||
};
|
||||
|
||||
} // namespace mlir
|
||||
|
Reference in New Issue
Block a user