[TritonGPU] Improved documentation and semantics of layout encodings (#30)

This commit is contained in:
Philippe Tillet
2022-07-31 13:59:44 -07:00
committed by GitHub
parent e02c82c765
commit d1593e6ca8
17 changed files with 399 additions and 566 deletions

View File

@@ -30,13 +30,28 @@ static LogicalResult parseIntArrayAttr(AsmParser &parser,
return success();
};
static LogicalResult parseUInt(AsmParser &parser, const NamedAttribute &attr,
unsigned &value, StringRef desc) {
auto intAttr = attr.getValue().dyn_cast<IntegerAttr>();
if (!intAttr) {
parser.emitError(parser.getNameLoc(), "expected an integer ") << desc;
return failure();
}
value = intAttr.getUInt();
return success();
};
//===----------------------------------------------------------------------===//
// Attribute methods
//===----------------------------------------------------------------------===//
#define GET_ATTRDEF_CLASSES
#include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.cpp.inc"
static Attribute parseBlocked(AsmParser &parser, Type type) {
//===----------------------------------------------------------------------===//
// Blocked Encoding
//===----------------------------------------------------------------------===//
Attribute TritonGPUBlockedEncodingAttr::parse(AsmParser &parser, Type type) {
if (parser.parseLess().failed())
return {};
// Parse the data as a dictionary
@@ -46,32 +61,30 @@ static Attribute parseBlocked(AsmParser &parser, Type type) {
if (parser.parseGreater().failed())
return {};
SmallVector<unsigned, 2> threadTileSize;
SmallVector<unsigned, 2> warpTileSize;
SmallVector<unsigned, 2> blockTileSize;
SmallVector<unsigned, 2> sizePerThread;
SmallVector<unsigned, 2> threadsPerWarp;
SmallVector<unsigned, 2> warpsPerCTA;
SmallVector<unsigned, 2> order;
SmallVector<unsigned, 2> broadcastAxis;
for (const NamedAttribute &attr : dict) {
if (attr.getName() == "threadTileSize") {
if (parseIntArrayAttr(parser, attr, threadTileSize, "thread tile size")
if (attr.getName() == "sizePerThread") {
if (parseIntArrayAttr(parser, attr, sizePerThread,
"number of elements per thread")
.failed())
return {};
} else if (attr.getName() == "warpTileSize") {
if (parseIntArrayAttr(parser, attr, warpTileSize, "warp tile size")
} else if (attr.getName() == "threadsPerWarp") {
if (parseIntArrayAttr(parser, attr, threadsPerWarp,
"number of threads per warp")
.failed())
return {};
} else if (attr.getName() == "blockTileSize") {
if (parseIntArrayAttr(parser, attr, blockTileSize, "block tile size")
} else if (attr.getName() == "warpsPerCTA") {
if (parseIntArrayAttr(parser, attr, warpsPerCTA,
"number of warps per CTA")
.failed())
return {};
} else if (attr.getName() == "order") {
if (parseIntArrayAttr(parser, attr, order, "order").failed())
return {};
} else if (attr.getName() == "broadcastAxis") {
if (parseIntArrayAttr(parser, attr, broadcastAxis, "broadcastAxis")
.failed())
return {};
} else {
parser.emitError(parser.getNameLoc(), "unexpected key: ")
<< attr.getName().strref();
@@ -80,39 +93,23 @@ static Attribute parseBlocked(AsmParser &parser, Type type) {
}
return parser.getChecked<TritonGPUBlockedEncodingAttr>(
parser.getContext(), threadTileSize, warpTileSize, blockTileSize, order,
broadcastAxis);
}
template <class T>
static void printBlocked(AsmPrinter &printer, const T *attr) {
printer << "<{"
<< "threadTileSize = [" << attr->getThreadTileSize() << "]"
<< ", warpTileSize = [" << attr->getWarpTileSize() << "]"
<< ", blockTileSize = [" << attr->getBlockTileSize() << "]"
<< ", order = [" << attr->getOrder() << "]"
<< ", broadcastAxis = [" << attr->getBroadcastAxis() << "]"
<< "}>";
}
Attribute TritonGPUBlockedEncodingAttr::parse(AsmParser &parser, Type type) {
return parseBlocked(parser, type);
parser.getContext(), sizePerThread, threadsPerWarp, warpsPerCTA, order);
}
void TritonGPUBlockedEncodingAttr::print(mlir::AsmPrinter &printer) const {
printBlocked(printer, this);
printer << "<{"
<< "sizePerThread = [" << getSizePerThread() << "]"
<< ", threadsPerWarp = [" << getThreadsPerWarp() << "]"
<< ", warpsPerCTA = [" << getWarpsPerCTA() << "]"
<< ", order = [" << getOrder() << "]"
<< "}>";
}
Attribute TritonGPUBlockedMulticastEncodingAttr::parse(AsmParser &parser,
Type type) {
return parseBlocked(parser, type);
}
//===----------------------------------------------------------------------===//
// MMA encoding
//===----------------------------------------------------------------------===//
void TritonGPUBlockedMulticastEncodingAttr::print(AsmPrinter &printer) const {
printBlocked(printer, this);
}
static Attribute parseMma(AsmParser &parser, Type type) {
Attribute TritonGPUMmaEncodingAttr::parse(AsmParser &parser, Type type) {
if (parser.parseLess().failed())
return {};
DictionaryAttr dict;
@@ -121,76 +118,34 @@ static Attribute parseMma(AsmParser &parser, Type type) {
if (parser.parseGreater().failed())
return {};
SmallVector<unsigned, 2> fragmentPerWarp;
SmallVector<unsigned, 2> shapePerWarp;
SmallVector<unsigned, 2> warpPerTile;
SmallVector<unsigned, 2> shapePerTile;
SmallVector<unsigned, 2> repetitions;
SmallVector<unsigned, 2> contigPerThread;
SmallVector<unsigned, 2> broadcastAxis;
unsigned version = 0;
SmallVector<unsigned, 2> warpsPerCTA;
for (const NamedAttribute &attr : dict) {
if (attr.getName() == "fragmentPerWarp") {
if (parseIntArrayAttr(parser, attr, fragmentPerWarp, "fragmentPerWarp")
.failed())
if (attr.getName() == "version") {
if (parseUInt(parser, attr, version, "version").failed())
return {};
} else if (attr.getName() == "shapePerWarp") {
if (parseIntArrayAttr(parser, attr, shapePerWarp, "shapePerWarp")
.failed())
}
if (attr.getName() == "warpsPerCTA") {
if (parseIntArrayAttr(parser, attr, warpsPerCTA, "warpsPerCTA").failed())
return {};
} else if (attr.getName() == "warpPerTile") {
if (parseIntArrayAttr(parser, attr, warpPerTile, "warpPerTile").failed())
return {};
} else if (attr.getName() == "shapePerTile") {
if (parseIntArrayAttr(parser, attr, shapePerTile, "shapePerTile")
.failed())
return {};
} else if (attr.getName() == "repetitions") {
if (parseIntArrayAttr(parser, attr, repetitions, "repetitions").failed())
return {};
} else if (attr.getName() == "contigPerThread") {
if (parseIntArrayAttr(parser, attr, contigPerThread, "contigPerThread")
.failed())
return {};
} else {
parser.emitError(parser.getNameLoc(), "unexpected key: ")
<< attr.getName().strref();
return {};
}
}
return parser.getChecked<TritonGPUMmaEncodingAttr>(
parser.getContext(), fragmentPerWarp, shapePerWarp, warpPerTile,
shapePerTile, repetitions, contigPerThread, broadcastAxis);
}
template <class T> static void printMma(AsmPrinter &printer, T *attr) {
printer << "<{"
<< "fragmentPerWarp = [" << attr->getFragmentPerWarp() << "]"
<< ", shapePerWarp = [" << attr->getShapePerWarp() << "]"
<< ", warpPerTile = [" << attr->getWarpPerTile() << "]"
<< ", shapePerTile = [" << attr->getShapePerTile() << "]"
<< ", repetitions = [" << attr->getRepetitions() << "]"
<< ", contigPerThread = [" << attr->getContigPerThread() << "]"
<< "}>";
}
Attribute TritonGPUMmaEncodingAttr::parse(AsmParser &parser, Type type) {
return parseMma(parser, type);
return parser.getChecked<TritonGPUMmaEncodingAttr>(parser.getContext(),
version, warpsPerCTA);
}
void TritonGPUMmaEncodingAttr::print(AsmPrinter &printer) const {
printMma(printer, this);
printer << "<{"
<< "version = " << getVersion() << ", "
<< "warpsPerCTA = [" << getWarpsPerCTA() << "]"
<< "}>";
}
Attribute TritonGPUMmaMulticastEncodingAttr::parse(AsmParser &parser,
Type type) {
return parseMma(parser, type);
}
void TritonGPUMmaMulticastEncodingAttr::print(AsmPrinter &printer) const {
printMma(printer, this);
}
//===----------------------------------------------------------------------===//
// Shared encoding
//===----------------------------------------------------------------------===//
Attribute TritonGPUSharedEncodingAttr::parse(AsmParser &parser, Type type) {
if (parser.parseLess().failed())
@@ -207,26 +162,15 @@ Attribute TritonGPUSharedEncodingAttr::parse(AsmParser &parser, Type type) {
unsigned maxPhase = 0;
SmallVector<unsigned, 2> order;
auto parseUInt = [&parser](const NamedAttribute &attr, unsigned &value,
StringRef desc) -> LogicalResult {
auto intAttr = attr.getValue().dyn_cast<IntegerAttr>();
if (!intAttr) {
parser.emitError(parser.getNameLoc(), "expected an integer ") << desc;
return failure();
}
value = intAttr.getUInt();
return success();
};
for (const NamedAttribute &attr : dict) {
if (attr.getName() == "vec") {
if (parseUInt(attr, vec, "vec").failed())
if (parseUInt(parser, attr, vec, "vec").failed())
return {};
} else if (attr.getName() == "perPhase") {
if (parseUInt(attr, perPhase, "perPhase").failed())
if (parseUInt(parser, attr, perPhase, "perPhase").failed())
return {};
} else if (attr.getName() == "maxPhase") {
if (parseUInt(attr, maxPhase, "maxPhase").failed())
if (parseUInt(parser, attr, maxPhase, "maxPhase").failed())
return {};
} else if (attr.getName() == "order") {
if (parseIntArrayAttr(parser, attr, order, "order").failed())
@@ -250,6 +194,10 @@ void TritonGPUSharedEncodingAttr::print(AsmPrinter &printer) const {
<< "}>";
}
//===----------------------------------------------------------------------===//
// ASM Interface (i.e.: alias)
//===----------------------------------------------------------------------===//
class TritonGPUOpAsmInterface : public OpAsmDialectInterface {
public:
using OpAsmDialectInterface::OpAsmDialectInterface;
@@ -257,72 +205,18 @@ public:
AliasResult getAlias(Attribute attr, raw_ostream &os) const override {
if (auto mmaAttr = attr.dyn_cast<TritonGPUMmaEncodingAttr>()) {
os << "mma";
TritonGPUOpAsmInterface::printMma(mmaAttr, os);
return AliasResult::FinalAlias;
} else if (auto mmaMulticastAttr =
attr.dyn_cast<TritonGPUMmaMulticastEncodingAttr>()) {
os << "mma_multicast";
TritonGPUOpAsmInterface::printMma(mmaAttr, os);
return AliasResult::FinalAlias;
} else if (auto sharedAttr = attr.dyn_cast<TritonGPUSharedEncodingAttr>()) {
os << "shared";
TritonGPUOpAsmInterface::printShared(sharedAttr, os);
return AliasResult::FinalAlias;
} else if (auto blockedAttr =
attr.dyn_cast<TritonGPUBlockedEncodingAttr>()) {
os << "blocked";
TritonGPUOpAsmInterface::printBlocked(blockedAttr, os);
return AliasResult::FinalAlias;
} else if (auto blockedMulticastAttr =
attr.dyn_cast<TritonGPUBlockedMulticastEncodingAttr>()) {
os << "blocked_multicast";
TritonGPUOpAsmInterface::printBlocked(blockedMulticastAttr, os);
return AliasResult::FinalAlias;
}
OpAsmDialectInterface::getAlias(attr, os);
return AliasResult::FinalAlias;
}
private:
static void printMma(const TritonGPUMmaEncodingAttr &attr, raw_ostream &os) {
TritonGPUOpAsmInterface::printArray(attr.getFragmentPerWarp(), os);
TritonGPUOpAsmInterface::printArray(attr.getShapePerWarp(), os);
TritonGPUOpAsmInterface::printArray(attr.getWarpPerTile(), os);
TritonGPUOpAsmInterface::printArray(attr.getShapePerTile(), os);
TritonGPUOpAsmInterface::printArray(attr.getRepetitions(), os);
TritonGPUOpAsmInterface::printArray(attr.getContigPerThread(), os);
}
static void printShared(const TritonGPUSharedEncodingAttr &attr,
raw_ostream &os) {
os << "_" << attr.getVec();
os << "_" << attr.getPerPhase();
os << "_" << attr.getMaxPhase();
TritonGPUOpAsmInterface::printArray(attr.getOrder(), os);
}
template <class T> static void printBlocked(const T &attr, raw_ostream &os) {
TritonGPUOpAsmInterface::printArray(attr.getThreadTileSize(), os);
TritonGPUOpAsmInterface::printArray(attr.getWarpTileSize(), os);
TritonGPUOpAsmInterface::printArray(attr.getBlockTileSize(), os);
TritonGPUOpAsmInterface::printArray(attr.getOrder(), os);
TritonGPUOpAsmInterface::printArray(attr.getBroadcastAxis(), os);
}
static void printArray(const ArrayRef<unsigned> &array, raw_ostream &os,
const std::string &delimiter = "x") {
os << "_";
if (array.empty()) {
os << "none";
return;
}
for (unsigned i = 0; i < array.size(); i++) {
os << array[i];
if (i != array.size() - 1) {
os << delimiter;
}
}
}
};
void TritonGPUDialect::initialize() {