[Triton-MLIR][BACKEND] Refine ptxbuilder (#867)

This PR does

1. Add `onlyBindMLIRArgs` argument to `PTXInstrCommon::call` method to
support passing in a whole PTX code snippet
2. Refine the APIs and simplify the code usage.
This commit is contained in:
Yan Chunwei
2022-11-10 13:41:52 +08:00
committed by GitHub
parent 4640023d9b
commit 8832e32683
4 changed files with 105 additions and 91 deletions

View File

@@ -128,19 +128,26 @@ std::string PTXBuilder::dump() const {
return strJoin(lines, "\n\t");
}
PTXInstrExecution &PTXInstrCommon::call(ArrayRef<Operand *> oprs) {
PTXInstrExecution &PTXInstrCommon::call(ArrayRef<Operand *> oprs,
bool onlyAttachMLIRArgs) {
builder->executions.emplace_back(
std::make_unique<PTXInstrExecution>(this, oprs));
std::make_unique<PTXInstrExecution>(this, oprs, onlyAttachMLIRArgs));
return *builder->executions.back();
}
PTXInstrExecution &PTXInstrCommon::operator()(ArrayRef<Operand *> oprs) {
return call(oprs);
PTXInstrExecution &PTXInstrCommon::operator()(ArrayRef<Operand *> oprs,
bool onlyAttachMLIRArgs) {
return call(oprs, onlyAttachMLIRArgs);
}
std::string PTXInstrExecution::dump() const {
std::string osStr;
llvm::raw_string_ostream os(osStr);
std::string instrRepr = strJoin(instr->instrParts, ".");
if (onlyAttachMLIRArgs)
return instrRepr;
if (pred) {
if (!pred->repr)
os << "@" << pred->dump() << " ";
@@ -148,8 +155,6 @@ std::string PTXInstrExecution::dump() const {
os << pred->repr(pred->idx) << " ";
}
std::string instrRepr = strJoin(instr->instrParts, ".");
llvm::SmallVector<std::string, 4> argReprs;
for (auto *arg : argsInOrder) {
argReprs.push_back(arg->dump());
@@ -174,5 +179,27 @@ PTXInstrExecution::getArgList() const {
return args;
}
PTXInstr &PTXInstr::global() {
o("global");
return *this;
}
PTXInstr &PTXInstr::shared() {
o("shared");
return *this;
}
PTXInstr &PTXInstr::v(int vecWidth, bool predicate) {
if (vecWidth > 1) {
o("v" + std::to_string(vecWidth), predicate);
}
return *this;
}
PTXInstr &PTXInstr::b(int width) {
o("b" + std::to_string(width));
return *this;
}
} // namespace triton
} // namespace mlir