[Triton-MLIR][BACKEND] Refine ptxbuilder (#867)
This PR does 1. Add `onlyBindMLIRArgs` argument to `PTXInstrCommon::call` method to support passing in a whole PTX code snippet 2. Refine the APIs and simplify the code usage.
This commit is contained in:
@@ -128,19 +128,26 @@ std::string PTXBuilder::dump() const {
|
||||
return strJoin(lines, "\n\t");
|
||||
}
|
||||
|
||||
PTXInstrExecution &PTXInstrCommon::call(ArrayRef<Operand *> oprs) {
|
||||
PTXInstrExecution &PTXInstrCommon::call(ArrayRef<Operand *> oprs,
|
||||
bool onlyAttachMLIRArgs) {
|
||||
builder->executions.emplace_back(
|
||||
std::make_unique<PTXInstrExecution>(this, oprs));
|
||||
std::make_unique<PTXInstrExecution>(this, oprs, onlyAttachMLIRArgs));
|
||||
return *builder->executions.back();
|
||||
}
|
||||
|
||||
PTXInstrExecution &PTXInstrCommon::operator()(ArrayRef<Operand *> oprs) {
|
||||
return call(oprs);
|
||||
PTXInstrExecution &PTXInstrCommon::operator()(ArrayRef<Operand *> oprs,
|
||||
bool onlyAttachMLIRArgs) {
|
||||
return call(oprs, onlyAttachMLIRArgs);
|
||||
}
|
||||
|
||||
std::string PTXInstrExecution::dump() const {
|
||||
std::string osStr;
|
||||
llvm::raw_string_ostream os(osStr);
|
||||
|
||||
std::string instrRepr = strJoin(instr->instrParts, ".");
|
||||
if (onlyAttachMLIRArgs)
|
||||
return instrRepr;
|
||||
|
||||
if (pred) {
|
||||
if (!pred->repr)
|
||||
os << "@" << pred->dump() << " ";
|
||||
@@ -148,8 +155,6 @@ std::string PTXInstrExecution::dump() const {
|
||||
os << pred->repr(pred->idx) << " ";
|
||||
}
|
||||
|
||||
std::string instrRepr = strJoin(instr->instrParts, ".");
|
||||
|
||||
llvm::SmallVector<std::string, 4> argReprs;
|
||||
for (auto *arg : argsInOrder) {
|
||||
argReprs.push_back(arg->dump());
|
||||
@@ -174,5 +179,27 @@ PTXInstrExecution::getArgList() const {
|
||||
return args;
|
||||
}
|
||||
|
||||
PTXInstr &PTXInstr::global() {
|
||||
o("global");
|
||||
return *this;
|
||||
}
|
||||
|
||||
PTXInstr &PTXInstr::shared() {
|
||||
o("shared");
|
||||
return *this;
|
||||
}
|
||||
|
||||
PTXInstr &PTXInstr::v(int vecWidth, bool predicate) {
|
||||
if (vecWidth > 1) {
|
||||
o("v" + std::to_string(vecWidth), predicate);
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
PTXInstr &PTXInstr::b(int width) {
|
||||
o("b" + std::to_string(width));
|
||||
return *this;
|
||||
}
|
||||
|
||||
} // namespace triton
|
||||
} // namespace mlir
|
||||
|
Reference in New Issue
Block a user