[TritonGPU] Pretty printer for layouts (#21)
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include "mlir/IR/DialectImplementation.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "llvm/ADT/TypeSwitch.h"
|
||||
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.cpp.inc"
|
||||
@@ -252,10 +253,83 @@ void TritonGPUSharedEncodingAttr::print(AsmPrinter &printer) const {
|
||||
printer << "<{"
|
||||
<< "vec = " << getVec()
|
||||
<< ", perPhase = " << getPerPhase()
|
||||
<< ", maxPhase = " << getMaxPhase()
|
||||
<< ", order = [" << getOrder() << "]"
|
||||
<< "}>";
|
||||
}
|
||||
|
||||
class TritonGPUOpAsmInterface : public OpAsmDialectInterface {
|
||||
public:
|
||||
using OpAsmDialectInterface::OpAsmDialectInterface;
|
||||
|
||||
AliasResult getAlias(Attribute attr, raw_ostream &os) const override {
|
||||
if (auto mmaAttr = attr.dyn_cast<TritonGPUMmaEncodingAttr>()) {
|
||||
os << "mma";
|
||||
TritonGPUOpAsmInterface::printMma(mmaAttr, os);
|
||||
return AliasResult::FinalAlias;
|
||||
} else if (auto mmaMulticastAttr =
|
||||
attr.dyn_cast<TritonGPUMmaMulticastEncodingAttr>()) {
|
||||
os << "mma_multicast";
|
||||
TritonGPUOpAsmInterface::printMma(mmaAttr, os);
|
||||
return AliasResult::FinalAlias;
|
||||
} else if (auto sharedAttr = attr.dyn_cast<TritonGPUSharedEncodingAttr>()) {
|
||||
os << "shared";
|
||||
TritonGPUOpAsmInterface::printShared(sharedAttr, os);
|
||||
return AliasResult::FinalAlias;
|
||||
} else if (auto blockedAttr =
|
||||
attr.dyn_cast<TritonGPUBlockedEncodingAttr>()) {
|
||||
os << "blocked";
|
||||
TritonGPUOpAsmInterface::printBlocked(blockedAttr, os);
|
||||
return AliasResult::FinalAlias;
|
||||
} else if (auto blockedMulticastAttr =
|
||||
attr.dyn_cast<TritonGPUBlockedMulticastEncodingAttr>()) {
|
||||
os << "blocked_multicast";
|
||||
TritonGPUOpAsmInterface::printBlocked(blockedMulticastAttr, os);
|
||||
}
|
||||
OpAsmDialectInterface::getAlias(attr, os);
|
||||
}
|
||||
|
||||
private:
|
||||
static void printMma(const auto &attr, raw_ostream &os) {
|
||||
TritonGPUOpAsmInterface::printArray(attr.getFragmentPerWarp(), os);
|
||||
TritonGPUOpAsmInterface::printArray(attr.getShapePerWarp(), os);
|
||||
TritonGPUOpAsmInterface::printArray(attr.getWarpPerTile(), os);
|
||||
TritonGPUOpAsmInterface::printArray(attr.getShapePerTile(), os);
|
||||
TritonGPUOpAsmInterface::printArray(attr.getRepetitions(), os);
|
||||
TritonGPUOpAsmInterface::printArray(attr.getContigPerThread(), os);
|
||||
}
|
||||
|
||||
static void printShared(const auto &attr, raw_ostream &os) {
|
||||
os << "_" << attr.getVec();
|
||||
os << "_" << attr.getPerPhase();
|
||||
os << "_" << attr.getMaxPhase();
|
||||
TritonGPUOpAsmInterface::printArray(attr.getOrder(), os);
|
||||
}
|
||||
|
||||
static void printBlocked(const auto &attr, raw_ostream &os) {
|
||||
TritonGPUOpAsmInterface::printArray(attr.getThreadTileSize(), os);
|
||||
TritonGPUOpAsmInterface::printArray(attr.getWarpTileSize(), os);
|
||||
TritonGPUOpAsmInterface::printArray(attr.getBlockTileSize(), os);
|
||||
TritonGPUOpAsmInterface::printArray(attr.getOrder(), os);
|
||||
TritonGPUOpAsmInterface::printArray(attr.getBroadcastAxis(), os);
|
||||
}
|
||||
|
||||
static void printArray(const auto &array, raw_ostream &os,
|
||||
const std::string &delimiter = "x") {
|
||||
os << "_";
|
||||
if (array.empty()) {
|
||||
os << "none";
|
||||
return;
|
||||
}
|
||||
for (unsigned i = 0; i < array.size(); i++) {
|
||||
os << array[i];
|
||||
if (i != array.size() - 1) {
|
||||
os << delimiter;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
void TritonGPUDialect::initialize() {
|
||||
addAttributes<
|
||||
#define GET_ATTRDEF_LIST
|
||||
@@ -265,6 +339,7 @@ void TritonGPUDialect::initialize() {
|
||||
#define GET_OP_LIST
|
||||
#include "triton/Dialect/TritonGPU/IR/Ops.cpp.inc"
|
||||
>();
|
||||
addInterfaces<TritonGPUOpAsmInterface>();
|
||||
}
|
||||
|
||||
namespace mlir {
|
||||
|
Reference in New Issue
Block a user