55 lines
1.8 KiB
C++
55 lines
1.8 KiB
C++
![]() |
#include "mlir/Pass/Pass.h"
|
||
|
#include "triton/Analysis/Allocation.h"
|
||
|
|
||
|
using namespace mlir;
|
||
|
|
||
|
namespace {
|
||
|
|
||
|
struct TestAllocationPass
|
||
|
: public PassWrapper<TestAllocationPass, OperationPass<FuncOp>> {
|
||
|
|
||
|
// LLVM15+
|
||
|
// MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAllocationPass);
|
||
|
|
||
|
StringRef getArgument() const final { return "test-print-allocation"; }
|
||
|
StringRef getDescription() const final {
|
||
|
return "print the result of the allocation pass";
|
||
|
}
|
||
|
|
||
|
void runOnOperation() override {
|
||
|
Operation *operation = getOperation();
|
||
|
auto &os = llvm::errs();
|
||
|
// Convert to std::string can remove quotes from op_name
|
||
|
auto op_name = SymbolTable::getSymbolName(operation).getValue().str();
|
||
|
os << op_name << "\n";
|
||
|
Allocation allocation(operation);
|
||
|
operation->walk([&](Operation *op) {
|
||
|
auto scratchBufferId = allocation.getBufferId(op);
|
||
|
if (scratchBufferId != Allocation::InvalidBufferId) {
|
||
|
size_t offset = allocation.getOffset(scratchBufferId);
|
||
|
size_t size = allocation.getAllocatedSize(scratchBufferId);
|
||
|
os << "scratch offset = " << offset << ", size = " << size << "\n";
|
||
|
}
|
||
|
if (op->getNumResults() < 1)
|
||
|
return;
|
||
|
for (Value result : op->getResults()) {
|
||
|
auto bufferId = allocation.getBufferId(result);
|
||
|
if (bufferId != Allocation::InvalidBufferId) {
|
||
|
size_t offset = allocation.getOffset(bufferId);
|
||
|
size_t size = allocation.getAllocatedSize(bufferId);
|
||
|
os << "offset = " << offset << ", size = " << size << "\n";
|
||
|
}
|
||
|
}
|
||
|
});
|
||
|
os << "size = " << allocation.getSharedMemorySize() << "\n";
|
||
|
}
|
||
|
};
|
||
|
|
||
|
} // namespace
|
||
|
|
||
|
namespace mlir {
|
||
|
namespace test {
|
||
|
void registerTestAllocationPass() { PassRegistration<TestAllocationPass>(); }
|
||
|
} // namespace test
|
||
|
} // namespace mlir
|