[TritonGPU] Pretty printer for layouts (#21)
This commit is contained in:
@@ -1,5 +1,6 @@
|
|||||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||||
#include "mlir/IR/DialectImplementation.h"
|
#include "mlir/IR/DialectImplementation.h"
|
||||||
|
#include "mlir/IR/OpImplementation.h"
|
||||||
#include "llvm/ADT/TypeSwitch.h"
|
#include "llvm/ADT/TypeSwitch.h"
|
||||||
|
|
||||||
#include "triton/Dialect/TritonGPU/IR/Dialect.cpp.inc"
|
#include "triton/Dialect/TritonGPU/IR/Dialect.cpp.inc"
|
||||||
@@ -252,10 +253,83 @@ void TritonGPUSharedEncodingAttr::print(AsmPrinter &printer) const {
|
|||||||
printer << "<{"
|
printer << "<{"
|
||||||
<< "vec = " << getVec()
|
<< "vec = " << getVec()
|
||||||
<< ", perPhase = " << getPerPhase()
|
<< ", perPhase = " << getPerPhase()
|
||||||
|
<< ", maxPhase = " << getMaxPhase()
|
||||||
<< ", order = [" << getOrder() << "]"
|
<< ", order = [" << getOrder() << "]"
|
||||||
<< "}>";
|
<< "}>";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class TritonGPUOpAsmInterface : public OpAsmDialectInterface {
|
||||||
|
public:
|
||||||
|
using OpAsmDialectInterface::OpAsmDialectInterface;
|
||||||
|
|
||||||
|
AliasResult getAlias(Attribute attr, raw_ostream &os) const override {
|
||||||
|
if (auto mmaAttr = attr.dyn_cast<TritonGPUMmaEncodingAttr>()) {
|
||||||
|
os << "mma";
|
||||||
|
TritonGPUOpAsmInterface::printMma(mmaAttr, os);
|
||||||
|
return AliasResult::FinalAlias;
|
||||||
|
} else if (auto mmaMulticastAttr =
|
||||||
|
attr.dyn_cast<TritonGPUMmaMulticastEncodingAttr>()) {
|
||||||
|
os << "mma_multicast";
|
||||||
|
TritonGPUOpAsmInterface::printMma(mmaAttr, os);
|
||||||
|
return AliasResult::FinalAlias;
|
||||||
|
} else if (auto sharedAttr = attr.dyn_cast<TritonGPUSharedEncodingAttr>()) {
|
||||||
|
os << "shared";
|
||||||
|
TritonGPUOpAsmInterface::printShared(sharedAttr, os);
|
||||||
|
return AliasResult::FinalAlias;
|
||||||
|
} else if (auto blockedAttr =
|
||||||
|
attr.dyn_cast<TritonGPUBlockedEncodingAttr>()) {
|
||||||
|
os << "blocked";
|
||||||
|
TritonGPUOpAsmInterface::printBlocked(blockedAttr, os);
|
||||||
|
return AliasResult::FinalAlias;
|
||||||
|
} else if (auto blockedMulticastAttr =
|
||||||
|
attr.dyn_cast<TritonGPUBlockedMulticastEncodingAttr>()) {
|
||||||
|
os << "blocked_multicast";
|
||||||
|
TritonGPUOpAsmInterface::printBlocked(blockedMulticastAttr, os);
|
||||||
|
}
|
||||||
|
OpAsmDialectInterface::getAlias(attr, os);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
static void printMma(const auto &attr, raw_ostream &os) {
|
||||||
|
TritonGPUOpAsmInterface::printArray(attr.getFragmentPerWarp(), os);
|
||||||
|
TritonGPUOpAsmInterface::printArray(attr.getShapePerWarp(), os);
|
||||||
|
TritonGPUOpAsmInterface::printArray(attr.getWarpPerTile(), os);
|
||||||
|
TritonGPUOpAsmInterface::printArray(attr.getShapePerTile(), os);
|
||||||
|
TritonGPUOpAsmInterface::printArray(attr.getRepetitions(), os);
|
||||||
|
TritonGPUOpAsmInterface::printArray(attr.getContigPerThread(), os);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void printShared(const auto &attr, raw_ostream &os) {
|
||||||
|
os << "_" << attr.getVec();
|
||||||
|
os << "_" << attr.getPerPhase();
|
||||||
|
os << "_" << attr.getMaxPhase();
|
||||||
|
TritonGPUOpAsmInterface::printArray(attr.getOrder(), os);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void printBlocked(const auto &attr, raw_ostream &os) {
|
||||||
|
TritonGPUOpAsmInterface::printArray(attr.getThreadTileSize(), os);
|
||||||
|
TritonGPUOpAsmInterface::printArray(attr.getWarpTileSize(), os);
|
||||||
|
TritonGPUOpAsmInterface::printArray(attr.getBlockTileSize(), os);
|
||||||
|
TritonGPUOpAsmInterface::printArray(attr.getOrder(), os);
|
||||||
|
TritonGPUOpAsmInterface::printArray(attr.getBroadcastAxis(), os);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void printArray(const auto &array, raw_ostream &os,
|
||||||
|
const std::string &delimiter = "x") {
|
||||||
|
os << "_";
|
||||||
|
if (array.empty()) {
|
||||||
|
os << "none";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
for (unsigned i = 0; i < array.size(); i++) {
|
||||||
|
os << array[i];
|
||||||
|
if (i != array.size() - 1) {
|
||||||
|
os << delimiter;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
void TritonGPUDialect::initialize() {
|
void TritonGPUDialect::initialize() {
|
||||||
addAttributes<
|
addAttributes<
|
||||||
#define GET_ATTRDEF_LIST
|
#define GET_ATTRDEF_LIST
|
||||||
@@ -265,6 +339,7 @@ void TritonGPUDialect::initialize() {
|
|||||||
#define GET_OP_LIST
|
#define GET_OP_LIST
|
||||||
#include "triton/Dialect/TritonGPU/IR/Ops.cpp.inc"
|
#include "triton/Dialect/TritonGPU/IR/Ops.cpp.inc"
|
||||||
>();
|
>();
|
||||||
|
addInterfaces<TritonGPUOpAsmInterface>();
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
|
Reference in New Issue
Block a user