special encoding for broadcast
This commit is contained in:
@@ -10,7 +10,7 @@ using namespace mlir::triton::gpu;
|
||||
// parse an array of integers
|
||||
static LogicalResult parseIntArrayAttr(AsmParser &parser,
|
||||
const NamedAttribute &attr,
|
||||
SmallVector<unsigned, 2> &res,
|
||||
/*SmallVector<unsigned, 2>*/auto &res,
|
||||
StringRef desc) {
|
||||
auto arrayAttr = attr.getValue().dyn_cast<ArrayAttr>();
|
||||
if (!arrayAttr) {
|
||||
@@ -36,8 +36,7 @@ static LogicalResult parseIntArrayAttr(AsmParser &parser,
|
||||
#define GET_ATTRDEF_CLASSES
|
||||
#include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.cpp.inc"
|
||||
|
||||
Attribute
|
||||
TritonGPUBlockedEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
static Attribute parseBlocked(AsmParser &parser, Type type) {
|
||||
if (parser.parseLess().failed())
|
||||
return {};
|
||||
// Parse the data as a dictionary
|
||||
@@ -51,28 +50,7 @@ TritonGPUBlockedEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
SmallVector<unsigned, 2> warpTileSize;
|
||||
SmallVector<unsigned, 2> blockTileSize;
|
||||
SmallVector<unsigned, 2> order;
|
||||
|
||||
// parse an array of integers
|
||||
// auto parseIntArrayAttr = [&parser](const NamedAttribute &attr,
|
||||
// SmallVector<unsigned, 2> &res,
|
||||
// StringRef desc) -> LogicalResult {
|
||||
// auto arrayAttr = attr.getValue().dyn_cast<ArrayAttr>();
|
||||
// if (!arrayAttr) {
|
||||
// parser.emitError(parser.getNameLoc(), "expected an array for ")
|
||||
// << desc;
|
||||
// return failure();
|
||||
// }
|
||||
// for (Attribute i : arrayAttr) {
|
||||
// auto intAttr = i.dyn_cast<IntegerAttr>();
|
||||
// if (!intAttr) {
|
||||
// parser.emitError(parser.getNameLoc(), "expected an integer value in ")
|
||||
// << desc;
|
||||
// return failure();
|
||||
// }
|
||||
// res.push_back(intAttr.getUInt());
|
||||
// }
|
||||
// return success();
|
||||
// };
|
||||
SmallVector<unsigned, 2> broadcastAxis;
|
||||
|
||||
for (const NamedAttribute &attr : dict) {
|
||||
if (attr.getName() == "threadTileSize") {
|
||||
@@ -98,20 +76,39 @@ TritonGPUBlockedEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
threadTileSize,
|
||||
warpTileSize,
|
||||
blockTileSize,
|
||||
order);
|
||||
order,
|
||||
broadcastAxis);
|
||||
}
|
||||
|
||||
void TritonGPUBlockedEncodingAttr::print(mlir::AsmPrinter &printer) const {
|
||||
static void printBlocked(AsmPrinter &printer, auto *attr) {
|
||||
printer << "<{"
|
||||
<< "threadTileSize = [" << getThreadTileSize() << "]"
|
||||
<< ", warpTileSize = [" << getWarpTileSize() << "]"
|
||||
<< ", blockTileSize = [" << getBlockTileSize() << "]"
|
||||
<< ", order = [" << getOrder() << "]"
|
||||
<< "threadTileSize = [" << attr->getThreadTileSize() << "]"
|
||||
<< ", warpTileSize = [" << attr->getWarpTileSize() << "]"
|
||||
<< ", blockTileSize = [" << attr->getBlockTileSize() << "]"
|
||||
<< ", order = [" << attr->getOrder() << "]"
|
||||
<< ", broadcastAxis = [" << attr->getBroadcastAxis() << "]"
|
||||
<< "}>";
|
||||
}
|
||||
|
||||
Attribute
|
||||
TritonGPUMmaEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
TritonGPUBlockedEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
parseBlocked(parser, type);
|
||||
}
|
||||
|
||||
void TritonGPUBlockedEncodingAttr::print(mlir::AsmPrinter &printer) const {
|
||||
printBlocked(printer, this);
|
||||
}
|
||||
|
||||
Attribute
|
||||
TritonGPUBlockedMulticastEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
parseBlocked(parser, type);
|
||||
}
|
||||
|
||||
void TritonGPUBlockedMulticastEncodingAttr::print(AsmPrinter &printer) const {
|
||||
printBlocked(printer, this);
|
||||
}
|
||||
|
||||
static Attribute parseMma(AsmParser &parser, Type type) {
|
||||
if (parser.parseLess().failed())
|
||||
return {};
|
||||
DictionaryAttr dict;
|
||||
@@ -126,6 +123,7 @@ TritonGPUMmaEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
SmallVector<unsigned, 2> shapePerTile;
|
||||
SmallVector<unsigned, 2> repetitions;
|
||||
SmallVector<unsigned, 2> contigPerThread;
|
||||
SmallVector<unsigned, 2> broadcastAxis;
|
||||
|
||||
for (const NamedAttribute &attr : dict) {
|
||||
if (attr.getName() == "fragmentPerWarp") {
|
||||
@@ -159,18 +157,37 @@ TritonGPUMmaEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
warpPerTile,
|
||||
shapePerTile,
|
||||
repetitions,
|
||||
contigPerThread);
|
||||
contigPerThread,
|
||||
broadcastAxis);
|
||||
}
|
||||
|
||||
static void printMma(AsmPrinter &printer, auto *attr) {
|
||||
printer << "<{"
|
||||
<< "fragmentPerWarp = [" << attr->getFragmentPerWarp() << "]"
|
||||
<< ", shapePerWarp = [" << attr->getShapePerWarp() << "]"
|
||||
<< ", warpPerTile = [" << attr->getWarpPerTile() << "]"
|
||||
<< ", shapePerTile = [" << attr->getShapePerTile() << "]"
|
||||
<< ", repetitions = [" << attr->getRepetitions() << "]"
|
||||
<< ", contigPerThread = [" << attr->getContigPerThread() << "]"
|
||||
<< "}>";
|
||||
}
|
||||
|
||||
Attribute
|
||||
TritonGPUMmaEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
return parseMma(parser, type);
|
||||
}
|
||||
|
||||
void TritonGPUMmaEncodingAttr::print(AsmPrinter &printer) const {
|
||||
printer << "<{"
|
||||
<< "fragmentPerWarp = [" << getFragmentPerWarp() << "]"
|
||||
<< ", shapePerWarp = [" << getShapePerWarp() << "]"
|
||||
<< ", warpPerTile = [" << getWarpPerTile() << "]"
|
||||
<< ", shapePerTile = [" << getShapePerTile() << "]"
|
||||
<< ", repetitions = [" << getRepetitions() << "]"
|
||||
<< ", contigPerThread = [" << getContigPerThread() << "]"
|
||||
<< "}>";
|
||||
printMma(printer, this);
|
||||
}
|
||||
|
||||
Attribute
|
||||
TritonGPUMmaMulticastEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
return parseMma(parser, type);
|
||||
}
|
||||
|
||||
void TritonGPUMmaMulticastEncodingAttr::print(AsmPrinter &printer) const {
|
||||
printMma(printer, this);
|
||||
}
|
||||
|
||||
Attribute
|
||||
|
Reference in New Issue
Block a user