#include "mlir/IR/AsmState.h" #include "mlir/Pass/Pass.h" #include "triton/Analysis/Alias.h" #include "triton/Analysis/Utility.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" using namespace mlir; namespace { struct TestAliasPass : public PassWrapper> { // LLVM15+ // MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAliasPass); static void print(StringRef name, SmallVector &vals, raw_ostream &os) { if (vals.empty()) return; os << name << " -> "; size_t i = 0; for (auto val : vals) { if (i != 0) os << ","; os << val; ++i; } os << "\n"; } StringRef getArgument() const final { return "test-print-alias"; } StringRef getDescription() const final { return "print the result of the alias analysis pass"; } void runOnOperation() override { Operation *operation = getOperation(); auto &os = llvm::errs(); auto op_name = SymbolTable::getSymbolName(operation).getValue().str(); os << op_name << "\n"; SharedMemoryAliasAnalysis analysis(&getContext()); analysis.run(operation); AsmState state(operation->getParentOfType()); // Get operation ids of value's aliases auto getAllocOpNames = [&](Value value) { LatticeElement *latticeElement = analysis.lookupLatticeElement(value); SmallVector opNames; if (latticeElement) { auto &info = latticeElement->getValue(); if (!info.getAllocs().empty()) { for (auto &alias : info.getAllocs()) { auto opName = getValueOperandName(alias.getDefiningOp()->getResult(0), state); opNames.push_back(std::move(opName)); } } } // Ensure deterministic output std::sort(opNames.begin(), opNames.end()); return opNames; }; operation->walk([&](Operation *op) { if (op->getNumResults() < 1) return; if (auto forOp = dyn_cast(op)) { for (auto arg : llvm::enumerate(forOp.getRegionIterArgs())) { auto operand = forOp.getOpOperandForRegionIterArg(arg.value()).get(); auto opNames = getAllocOpNames(operand); auto argName = getValueOperandName(arg.value(), state); print(argName, opNames, os); } } for (auto result : llvm::enumerate(op->getResults())) { auto opNames = getAllocOpNames(result.value()); auto resultName = getValueOperandName(result.value(), state); print(resultName, opNames, os); } }); } }; } // namespace namespace mlir { namespace test { void registerTestAliasPass() { PassRegistration(); } } // namespace test } // namespace mlir