Analyze shared memory alias (#81)
The purpose of this PR is analyzing shared memory aliases so that we can fix memory allocation bugs and save memory allocations in triton code involving complex control flows. Changes to memory bar and allocation are on the way. Co-authored-by: Philippe Tillet <phil@openai.com>
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
add_mlir_library(TritonTestAnalysis
|
||||
TestAlias.cpp
|
||||
TestAxisInfo.cpp
|
||||
TestAllocation.cpp
|
||||
TestMembar.cpp
|
||||
|
92
test/lib/Analysis/TestAlias.cpp
Normal file
92
test/lib/Analysis/TestAlias.cpp
Normal file
@@ -0,0 +1,92 @@
|
||||
#include "mlir/IR/AsmState.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "triton/Analysis/Alias.h"
|
||||
#include "triton/Analysis/Utility.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace {
|
||||
|
||||
struct TestAliasPass
|
||||
: public PassWrapper<TestAliasPass, OperationPass<FuncOp>> {
|
||||
|
||||
// LLVM15+
|
||||
// MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAliasPass);
|
||||
static void print(StringRef name, SmallVector<std::string, 4> &vals,
|
||||
raw_ostream &os) {
|
||||
if (vals.empty())
|
||||
return;
|
||||
os << name << " -> ";
|
||||
size_t i = 0;
|
||||
for (auto val : vals) {
|
||||
if (i != 0)
|
||||
os << ",";
|
||||
os << val;
|
||||
++i;
|
||||
}
|
||||
os << "\n";
|
||||
}
|
||||
|
||||
StringRef getArgument() const final { return "test-print-alias"; }
|
||||
StringRef getDescription() const final {
|
||||
return "print the result of the alias analysis pass";
|
||||
}
|
||||
|
||||
void runOnOperation() override {
|
||||
Operation *operation = getOperation();
|
||||
auto &os = llvm::errs();
|
||||
auto op_name = SymbolTable::getSymbolName(operation).getValue().str();
|
||||
os << op_name << "\n";
|
||||
|
||||
SharedMemoryAliasAnalysis analysis(&getContext());
|
||||
analysis.run(operation);
|
||||
|
||||
AsmState state(operation->getParentOfType<ModuleOp>());
|
||||
// Get operation ids of value's aliases
|
||||
auto getAllocOpNames = [&](Value value) {
|
||||
LatticeElement<AliasInfo> *latticeElement =
|
||||
analysis.lookupLatticeElement(value);
|
||||
SmallVector<std::string, 4> opNames;
|
||||
if (latticeElement) {
|
||||
auto &info = latticeElement->getValue();
|
||||
if (!info.getAllocs().empty()) {
|
||||
for (auto &alias : info.getAllocs()) {
|
||||
auto opName =
|
||||
getValueOperandName(alias.getDefiningOp()->getResult(0), state);
|
||||
opNames.push_back(std::move(opName));
|
||||
}
|
||||
}
|
||||
}
|
||||
// Ensure deterministic output
|
||||
std::sort(opNames.begin(), opNames.end());
|
||||
return opNames;
|
||||
};
|
||||
|
||||
operation->walk<WalkOrder::PreOrder>([&](Operation *op) {
|
||||
if (op->getNumResults() < 1)
|
||||
return;
|
||||
if (auto forOp = dyn_cast<scf::ForOp>(op)) {
|
||||
for (auto arg : llvm::enumerate(forOp.getRegionIterArgs())) {
|
||||
auto operand = forOp.getOpOperandForRegionIterArg(arg.value()).get();
|
||||
auto opNames = getAllocOpNames(operand);
|
||||
auto argName = getValueOperandName(arg.value(), state);
|
||||
print(argName, opNames, os);
|
||||
}
|
||||
}
|
||||
for (auto result : llvm::enumerate(op->getResults())) {
|
||||
auto opNames = getAllocOpNames(result.value());
|
||||
auto resultName = getValueOperandName(result.value(), state);
|
||||
print(resultName, opNames, os);
|
||||
}
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace mlir {
|
||||
namespace test {
|
||||
void registerTestAliasPass() { PassRegistration<TestAliasPass>(); }
|
||||
} // namespace test
|
||||
} // namespace mlir
|
Reference in New Issue
Block a user