[BACKEND] Add C++ tests for PTXFormat and some tiny refinement (#109)

This PR does

1. Add some C++ tests for `PTXFormat`
2. Enhance the functionality of `PTXFormat`, make a `PTXInstr` instance
can be called multiple times similar as a C function.
This commit is contained in:
Yan Chunwei
2022-09-10 00:15:07 +08:00
committed by GitHub
parent a9464f4993
commit 2a852044d9
10 changed files with 231 additions and 92 deletions

View File

@@ -49,7 +49,7 @@ PTXBuilder::Operand *PTXBuilder::newConstantOperand(int v) {
return newConstantOperand(ss.str());
}
std::string PTXBuilder::getConstrains() const {
std::string PTXBuilder::getConstraints() const {
auto args = getAllArgs();
llvm::SmallVector<std::string, 4> argReprs;
for (auto arg : args)
@@ -78,7 +78,7 @@ std::string PTXInstr::Operand::dump() const {
if (repr)
return repr(idx);
if (!isList())
return llvm::formatv("${0}", idx);
return "$" + std::to_string(idx);
llvm::SmallVector<std::string> oprs;
for (auto *opr : list)
@@ -90,7 +90,9 @@ PTXInstr::Operand *PTXBuilder::newAddrOperand(mlir::Value addr,
StringRef constraint, int off) {
auto *opr = newOperand(addr, constraint);
opr->repr = [off](int idx) -> std::string {
return llvm::formatv("[ ${0} + {1} ]", idx, off);
std::stringstream ss;
ss << "[ $" << idx << " + " << off << " ]";
return ss.str();
};
return opr;
@@ -98,14 +100,24 @@ PTXInstr::Operand *PTXBuilder::newAddrOperand(mlir::Value addr,
std::string PTXBuilder::dump() const {
llvm::SmallVector<std::string> lines;
for (auto &instr : instrs) {
lines.push_back(instr->dump());
for (auto &exec : executions) {
lines.push_back(exec->dump());
}
return strJoin(lines, "\n\t");
return strJoin(lines, "\r\n");
}
std::string PTXInstrCommon::dump() const {
PTXInstrExecution &PTXInstrCommon::call(ArrayRef<Operand *> oprs) {
builder->executions.emplace_back(
std::make_unique<PTXInstrExecution>(this, oprs));
return *builder->executions.back();
}
PTXInstrExecution &PTXInstrCommon::operator()(ArrayRef<Operand *> oprs) {
return call(oprs);
}
std::string PTXInstrExecution::dump() const {
std::string osStr;
llvm::raw_string_ostream os(osStr);
if (pred)
@@ -114,7 +126,7 @@ std::string PTXInstrCommon::dump() const {
else
os << pred->repr(pred->idx);
std::string instrRepr = strJoin(instrParts, ".");
std::string instrRepr = strJoin(instr->instrParts, ".");
llvm::SmallVector<std::string, 4> argReprs;
for (auto *arg : argsInOrder) {
@@ -128,7 +140,8 @@ std::string PTXInstrCommon::dump() const {
return osStr;
}
SmallVector<PTXInstrCommon::Operand *> PTXInstrCommon::getArgList() const {
SmallVector<PTXInstrExecution::Operand *>
PTXInstrExecution::getArgList() const {
SmallVector<Operand *> args;
for (auto *arg : argsInOrder) {
if (arg->isList())
@@ -138,11 +151,5 @@ SmallVector<PTXInstrCommon::Operand *> PTXInstrCommon::getArgList() const {
}
return args;
}
void PTXInstrCommon::operator()(ArrayRef<Operand *> oprs) {
for (auto *opr : oprs) {
addOperand(opr);
}
}
} // namespace triton
} // namespace mlir