Files
triton/test/lib/Analysis/TestAllocation.cpp

55 lines
1.8 KiB
C++
Raw Permalink Normal View History

#include "mlir/Pass/Pass.h"
#include "triton/Analysis/Allocation.h"
using namespace mlir;
namespace {
struct TestAllocationPass
: public PassWrapper<TestAllocationPass, OperationPass<FuncOp>> {
// LLVM15+
// MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAllocationPass);
StringRef getArgument() const final { return "test-print-allocation"; }
StringRef getDescription() const final {
return "print the result of the allocation pass";
}
void runOnOperation() override {
Operation *operation = getOperation();
auto &os = llvm::errs();
// Convert to std::string can remove quotes from op_name
auto op_name = SymbolTable::getSymbolName(operation).getValue().str();
os << op_name << "\n";
Allocation allocation(operation);
operation->walk([&](Operation *op) {
auto scratchBufferId = allocation.getBufferId(op);
if (scratchBufferId != Allocation::InvalidBufferId) {
size_t offset = allocation.getOffset(scratchBufferId);
size_t size = allocation.getAllocatedSize(scratchBufferId);
os << "scratch offset = " << offset << ", size = " << size << "\n";
}
if (op->getNumResults() < 1)
return;
for (Value result : op->getResults()) {
auto bufferId = allocation.getBufferId(result);
if (bufferId != Allocation::InvalidBufferId) {
size_t offset = allocation.getOffset(bufferId);
size_t size = allocation.getAllocatedSize(bufferId);
os << "offset = " << offset << ", size = " << size << "\n";
}
}
});
os << "size = " << allocation.getSharedMemorySize() << "\n";
}
};
} // namespace
namespace mlir {
namespace test {
void registerTestAllocationPass() { PassRegistration<TestAllocationPass>(); }
} // namespace test
} // namespace mlir