Analyze shared memory alias (#81)

The purpose of this PR is analyzing shared memory aliases so that we can
fix memory allocation bugs and save memory allocations in triton code
involving complex control flows.

Changes to memory bar and allocation are on the way.

Co-authored-by: Philippe Tillet <phil@openai.com>
This commit is contained in:
Keren Zhou
2022-08-29 10:43:20 -07:00
committed by GitHub
parent 83287d7193
commit 02ebf24d35
15 changed files with 761 additions and 61 deletions

View File

@@ -1,4 +1,5 @@
add_mlir_library(TritonTestAnalysis
TestAlias.cpp
TestAxisInfo.cpp
TestAllocation.cpp
TestMembar.cpp

View File

@@ -0,0 +1,92 @@
#include "mlir/IR/AsmState.h"
#include "mlir/Pass/Pass.h"
#include "triton/Analysis/Alias.h"
#include "triton/Analysis/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
using namespace mlir;
namespace {
struct TestAliasPass
: public PassWrapper<TestAliasPass, OperationPass<FuncOp>> {
// LLVM15+
// MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAliasPass);
static void print(StringRef name, SmallVector<std::string, 4> &vals,
raw_ostream &os) {
if (vals.empty())
return;
os << name << " -> ";
size_t i = 0;
for (auto val : vals) {
if (i != 0)
os << ",";
os << val;
++i;
}
os << "\n";
}
StringRef getArgument() const final { return "test-print-alias"; }
StringRef getDescription() const final {
return "print the result of the alias analysis pass";
}
void runOnOperation() override {
Operation *operation = getOperation();
auto &os = llvm::errs();
auto op_name = SymbolTable::getSymbolName(operation).getValue().str();
os << op_name << "\n";
SharedMemoryAliasAnalysis analysis(&getContext());
analysis.run(operation);
AsmState state(operation->getParentOfType<ModuleOp>());
// Get operation ids of value's aliases
auto getAllocOpNames = [&](Value value) {
LatticeElement<AliasInfo> *latticeElement =
analysis.lookupLatticeElement(value);
SmallVector<std::string, 4> opNames;
if (latticeElement) {
auto &info = latticeElement->getValue();
if (!info.getAllocs().empty()) {
for (auto &alias : info.getAllocs()) {
auto opName =
getValueOperandName(alias.getDefiningOp()->getResult(0), state);
opNames.push_back(std::move(opName));
}
}
}
// Ensure deterministic output
std::sort(opNames.begin(), opNames.end());
return opNames;
};
operation->walk<WalkOrder::PreOrder>([&](Operation *op) {
if (op->getNumResults() < 1)
return;
if (auto forOp = dyn_cast<scf::ForOp>(op)) {
for (auto arg : llvm::enumerate(forOp.getRegionIterArgs())) {
auto operand = forOp.getOpOperandForRegionIterArg(arg.value()).get();
auto opNames = getAllocOpNames(operand);
auto argName = getValueOperandName(arg.value(), state);
print(argName, opNames, os);
}
}
for (auto result : llvm::enumerate(op->getResults())) {
auto opNames = getAllocOpNames(result.value());
auto resultName = getValueOperandName(result.value(), state);
print(resultName, opNames, os);
}
});
}
};
} // namespace
namespace mlir {
namespace test {
void registerTestAliasPass() { PassRegistration<TestAliasPass>(); }
} // namespace test
} // namespace mlir