[TritonGPU] Pretty printer for layouts (#21)

This commit is contained in:
Keren Zhou
2022-07-26 10:50:11 -07:00
committed by GitHub
parent 27c9f3d8cb
commit 96cc6fb563

View File

@@ -1,5 +1,6 @@
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/OpImplementation.h"
#include "llvm/ADT/TypeSwitch.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.cpp.inc"
@@ -252,10 +253,83 @@ void TritonGPUSharedEncodingAttr::print(AsmPrinter &printer) const {
printer << "<{"
<< "vec = " << getVec()
<< ", perPhase = " << getPerPhase()
<< ", maxPhase = " << getMaxPhase()
<< ", order = [" << getOrder() << "]"
<< "}>";
}
class TritonGPUOpAsmInterface : public OpAsmDialectInterface {
public:
using OpAsmDialectInterface::OpAsmDialectInterface;
AliasResult getAlias(Attribute attr, raw_ostream &os) const override {
if (auto mmaAttr = attr.dyn_cast<TritonGPUMmaEncodingAttr>()) {
os << "mma";
TritonGPUOpAsmInterface::printMma(mmaAttr, os);
return AliasResult::FinalAlias;
} else if (auto mmaMulticastAttr =
attr.dyn_cast<TritonGPUMmaMulticastEncodingAttr>()) {
os << "mma_multicast";
TritonGPUOpAsmInterface::printMma(mmaAttr, os);
return AliasResult::FinalAlias;
} else if (auto sharedAttr = attr.dyn_cast<TritonGPUSharedEncodingAttr>()) {
os << "shared";
TritonGPUOpAsmInterface::printShared(sharedAttr, os);
return AliasResult::FinalAlias;
} else if (auto blockedAttr =
attr.dyn_cast<TritonGPUBlockedEncodingAttr>()) {
os << "blocked";
TritonGPUOpAsmInterface::printBlocked(blockedAttr, os);
return AliasResult::FinalAlias;
} else if (auto blockedMulticastAttr =
attr.dyn_cast<TritonGPUBlockedMulticastEncodingAttr>()) {
os << "blocked_multicast";
TritonGPUOpAsmInterface::printBlocked(blockedMulticastAttr, os);
}
OpAsmDialectInterface::getAlias(attr, os);
}
private:
static void printMma(const auto &attr, raw_ostream &os) {
TritonGPUOpAsmInterface::printArray(attr.getFragmentPerWarp(), os);
TritonGPUOpAsmInterface::printArray(attr.getShapePerWarp(), os);
TritonGPUOpAsmInterface::printArray(attr.getWarpPerTile(), os);
TritonGPUOpAsmInterface::printArray(attr.getShapePerTile(), os);
TritonGPUOpAsmInterface::printArray(attr.getRepetitions(), os);
TritonGPUOpAsmInterface::printArray(attr.getContigPerThread(), os);
}
static void printShared(const auto &attr, raw_ostream &os) {
os << "_" << attr.getVec();
os << "_" << attr.getPerPhase();
os << "_" << attr.getMaxPhase();
TritonGPUOpAsmInterface::printArray(attr.getOrder(), os);
}
static void printBlocked(const auto &attr, raw_ostream &os) {
TritonGPUOpAsmInterface::printArray(attr.getThreadTileSize(), os);
TritonGPUOpAsmInterface::printArray(attr.getWarpTileSize(), os);
TritonGPUOpAsmInterface::printArray(attr.getBlockTileSize(), os);
TritonGPUOpAsmInterface::printArray(attr.getOrder(), os);
TritonGPUOpAsmInterface::printArray(attr.getBroadcastAxis(), os);
}
static void printArray(const auto &array, raw_ostream &os,
const std::string &delimiter = "x") {
os << "_";
if (array.empty()) {
os << "none";
return;
}
for (unsigned i = 0; i < array.size(); i++) {
os << array[i];
if (i != array.size() - 1) {
os << delimiter;
}
}
}
};
void TritonGPUDialect::initialize() {
addAttributes<
#define GET_ATTRDEF_LIST
@@ -265,6 +339,7 @@ void TritonGPUDialect::initialize() {
#define GET_OP_LIST
#include "triton/Dialect/TritonGPU/IR/Ops.cpp.inc"
>();
addInterfaces<TritonGPUOpAsmInterface>();
}
namespace mlir {