[BACKEND] Refactoring codegen for LoadOp with PTXFormat (#77)
This PR does following things: Enhance the PTXFormat by Introducing PTXBuilder to enable multiple instructions in a single asm program override PTXInstr's operator() method to enable instr(opr0, opr1) style of setting operands for an instruction Refactor the PTX code used in LoadOpConversion with PTXFormat Authored-by: goostavz <gzhu@nvidia.com>
This commit is contained in:
@@ -1,9 +1,11 @@
|
||||
#include "triton/Conversion/TritonGPUToLLVM/PtxAsmFormat.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include <sstream> // unify to llvm::raw_string_ostream ?
|
||||
|
||||
namespace mlir {
|
||||
namespace triton {
|
||||
|
||||
// TODO(Superjomn) Move to a global utility file?
|
||||
std::string strJoin(llvm::ArrayRef<std::string> strs,
|
||||
llvm::StringRef delimiter) {
|
||||
std::string osStr;
|
||||
@@ -16,11 +18,101 @@ std::string strJoin(llvm::ArrayRef<std::string> strs,
|
||||
return osStr;
|
||||
}
|
||||
|
||||
std::string PtxInstr::dump() const {
|
||||
PTXInstr::Operand *
|
||||
PTXBuilder::newOperand(mlir::Value value, StringRef constraint,
|
||||
std::function<std::string(int)> formater) {
|
||||
argArchive.emplace_back(std::make_unique<Operand>(value, constraint));
|
||||
auto *opr = argArchive.back().get();
|
||||
opr->repr = formater;
|
||||
opr->idx = oprCounter++;
|
||||
return opr;
|
||||
}
|
||||
|
||||
PTXBuilder::Operand *PTXBuilder::newOperand(StringRef constraint) {
|
||||
// Constraint should be something like "=r"
|
||||
assert(!constraint.empty() && constraint[0] == '=');
|
||||
auto *opr = newOperand();
|
||||
opr->idx = oprCounter++;
|
||||
opr->constraint = constraint;
|
||||
return opr;
|
||||
}
|
||||
|
||||
PTXBuilder::Operand *PTXBuilder::newConstantOperand(const std::string &v) {
|
||||
argArchive.emplace_back(std::make_unique<Operand>());
|
||||
argArchive.back()->repr = [v](int idx) { return v; };
|
||||
return argArchive.back().get();
|
||||
}
|
||||
|
||||
PTXBuilder::Operand *PTXBuilder::newConstantOperand(int v) {
|
||||
std::stringstream ss;
|
||||
ss << "0x" << std::hex << v;
|
||||
return newConstantOperand(ss.str());
|
||||
}
|
||||
|
||||
std::string PTXBuilder::getConstrains() const {
|
||||
auto args = getAllArgs();
|
||||
llvm::SmallVector<std::string, 4> argReprs;
|
||||
for (auto arg : args)
|
||||
argReprs.push_back(arg->constraint);
|
||||
return strJoin(argReprs, ",");
|
||||
}
|
||||
|
||||
llvm::SmallVector<Value, 4> PTXBuilder::getAllMLIRArgs() const {
|
||||
llvm::SmallVector<Value, 4> res;
|
||||
for (auto &arg : argArchive) {
|
||||
if (!arg->isList() && arg->value)
|
||||
res.push_back(arg->value);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
SmallVector<PTXBuilder::Operand *> PTXBuilder::getAllArgs() const {
|
||||
llvm::SmallVector<Operand *, 4> res;
|
||||
for (auto &x : argArchive)
|
||||
if (!x->isList())
|
||||
res.push_back(x.get());
|
||||
return res;
|
||||
}
|
||||
|
||||
std::string PTXInstr::Operand::dump() const {
|
||||
if (repr)
|
||||
return repr(idx);
|
||||
if (!isList())
|
||||
return llvm::formatv("${0}", idx);
|
||||
|
||||
llvm::SmallVector<std::string> oprs;
|
||||
for (auto *opr : list)
|
||||
oprs.push_back(opr->dump());
|
||||
return "{ " + strJoin(oprs, ", ") + " }";
|
||||
}
|
||||
|
||||
PTXInstr::Operand *PTXBuilder::newAddrOperand(mlir::Value addr,
|
||||
StringRef constraint, int off) {
|
||||
auto *opr = newOperand(addr, constraint);
|
||||
opr->repr = [off](int idx) -> std::string {
|
||||
return llvm::formatv("[ ${0} + {1} ]", idx, off);
|
||||
};
|
||||
|
||||
return opr;
|
||||
}
|
||||
|
||||
std::string PTXBuilder::dump() const {
|
||||
llvm::SmallVector<std::string> lines;
|
||||
for (auto &instr : instrs) {
|
||||
lines.push_back(instr->dump());
|
||||
}
|
||||
|
||||
return strJoin(lines, "\n\t");
|
||||
}
|
||||
|
||||
std::string PTXInstrCommon::dump() const {
|
||||
std::string osStr;
|
||||
llvm::raw_string_ostream os(osStr);
|
||||
if (pred)
|
||||
os << "@" << pred->dump() << " ";
|
||||
if (!pred->repr)
|
||||
os << "@" << pred->dump() << " ";
|
||||
else
|
||||
os << pred->repr(pred->idx);
|
||||
|
||||
std::string instrRepr = strJoin(instrParts, ".");
|
||||
|
||||
@@ -36,7 +128,7 @@ std::string PtxInstr::dump() const {
|
||||
return osStr;
|
||||
}
|
||||
|
||||
llvm::SmallVector<PtxInstr::Operand *, 4> PtxInstr::getArgList() const {
|
||||
SmallVector<PTXInstrCommon::Operand *> PTXInstrCommon::getArgList() const {
|
||||
SmallVector<Operand *> args;
|
||||
for (auto *arg : argsInOrder) {
|
||||
if (arg->isList())
|
||||
@@ -47,35 +139,10 @@ llvm::SmallVector<PtxInstr::Operand *, 4> PtxInstr::getArgList() const {
|
||||
return args;
|
||||
}
|
||||
|
||||
PtxInstr::Operand *
|
||||
PtxInstr::newOperand(mlir::Value value, StringRef constraint,
|
||||
std::function<std::string(int)> formater) {
|
||||
argArchive.emplace_back(std::make_unique<Operand>(value, constraint));
|
||||
auto *opr = argArchive.back().get();
|
||||
opr->repr = formater;
|
||||
opr->idx = oprCounter++;
|
||||
return opr;
|
||||
}
|
||||
|
||||
std::string PtxInstr::Operand::dump() const {
|
||||
if (repr)
|
||||
return repr(idx);
|
||||
if (!isList())
|
||||
return llvm::formatv("${0}", idx);
|
||||
llvm::SmallVector<std::string> oprs;
|
||||
for (auto *opr : list)
|
||||
oprs.push_back(opr->dump());
|
||||
return "{ " + strJoin(oprs, ", ") + " }";
|
||||
}
|
||||
|
||||
PtxInstr::Operand *PtxIOInstr::newAddrOperand(mlir::Value addr,
|
||||
StringRef constraint, int off) {
|
||||
auto *opr = newOperand(addr, constraint);
|
||||
opr->repr = [off](int idx) -> std::string {
|
||||
return llvm::formatv("[ ${0} + {1} ]", idx, off);
|
||||
};
|
||||
|
||||
return opr;
|
||||
void PTXInstrCommon::operator()(ArrayRef<Operand *> oprs) {
|
||||
for (auto *opr : oprs) {
|
||||
addOperand(opr);
|
||||
}
|
||||
}
|
||||
} // namespace triton
|
||||
} // namespace mlir
|
||||
|
Reference in New Issue
Block a user