diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 95c8c7b30..d66a08892 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -1,5 +1,6 @@ #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/OpImplementation.h" #include "llvm/ADT/TypeSwitch.h" #include "triton/Dialect/TritonGPU/IR/Dialect.cpp.inc" @@ -252,10 +253,83 @@ void TritonGPUSharedEncodingAttr::print(AsmPrinter &printer) const { printer << "<{" << "vec = " << getVec() << ", perPhase = " << getPerPhase() + << ", maxPhase = " << getMaxPhase() << ", order = [" << getOrder() << "]" << "}>"; } +class TritonGPUOpAsmInterface : public OpAsmDialectInterface { + public: + using OpAsmDialectInterface::OpAsmDialectInterface; + + AliasResult getAlias(Attribute attr, raw_ostream &os) const override { + if (auto mmaAttr = attr.dyn_cast()) { + os << "mma"; + TritonGPUOpAsmInterface::printMma(mmaAttr, os); + return AliasResult::FinalAlias; + } else if (auto mmaMulticastAttr = + attr.dyn_cast()) { + os << "mma_multicast"; + TritonGPUOpAsmInterface::printMma(mmaAttr, os); + return AliasResult::FinalAlias; + } else if (auto sharedAttr = attr.dyn_cast()) { + os << "shared"; + TritonGPUOpAsmInterface::printShared(sharedAttr, os); + return AliasResult::FinalAlias; + } else if (auto blockedAttr = + attr.dyn_cast()) { + os << "blocked"; + TritonGPUOpAsmInterface::printBlocked(blockedAttr, os); + return AliasResult::FinalAlias; + } else if (auto blockedMulticastAttr = + attr.dyn_cast()) { + os << "blocked_multicast"; + TritonGPUOpAsmInterface::printBlocked(blockedMulticastAttr, os); + } + OpAsmDialectInterface::getAlias(attr, os); + } + + private: + static void printMma(const auto &attr, raw_ostream &os) { + TritonGPUOpAsmInterface::printArray(attr.getFragmentPerWarp(), os); + TritonGPUOpAsmInterface::printArray(attr.getShapePerWarp(), os); + TritonGPUOpAsmInterface::printArray(attr.getWarpPerTile(), os); + TritonGPUOpAsmInterface::printArray(attr.getShapePerTile(), os); + TritonGPUOpAsmInterface::printArray(attr.getRepetitions(), os); + TritonGPUOpAsmInterface::printArray(attr.getContigPerThread(), os); + } + + static void printShared(const auto &attr, raw_ostream &os) { + os << "_" << attr.getVec(); + os << "_" << attr.getPerPhase(); + os << "_" << attr.getMaxPhase(); + TritonGPUOpAsmInterface::printArray(attr.getOrder(), os); + } + + static void printBlocked(const auto &attr, raw_ostream &os) { + TritonGPUOpAsmInterface::printArray(attr.getThreadTileSize(), os); + TritonGPUOpAsmInterface::printArray(attr.getWarpTileSize(), os); + TritonGPUOpAsmInterface::printArray(attr.getBlockTileSize(), os); + TritonGPUOpAsmInterface::printArray(attr.getOrder(), os); + TritonGPUOpAsmInterface::printArray(attr.getBroadcastAxis(), os); + } + + static void printArray(const auto &array, raw_ostream &os, + const std::string &delimiter = "x") { + os << "_"; + if (array.empty()) { + os << "none"; + return; + } + for (unsigned i = 0; i < array.size(); i++) { + os << array[i]; + if (i != array.size() - 1) { + os << delimiter; + } + } + } +}; + void TritonGPUDialect::initialize() { addAttributes< #define GET_ATTRDEF_LIST @@ -265,6 +339,7 @@ void TritonGPUDialect::initialize() { #define GET_OP_LIST #include "triton/Dialect/TritonGPU/IR/Ops.cpp.inc" >(); + addInterfaces(); } namespace mlir {