[BACKEND] Refactoring codegen for LoadOp with PTXFormat (#77)
This PR does following things: Enhance the PTXFormat by Introducing PTXBuilder to enable multiple instructions in a single asm program override PTXInstr's operator() method to enable instr(opr0, opr1) style of setting operands for an instruction Refactor the PTX code used in LoadOpConversion with PTXFormat Authored-by: goostavz <gzhu@nvidia.com>
This commit is contained in:
@@ -13,29 +13,52 @@ namespace mlir {
|
||||
namespace triton {
|
||||
using llvm::StringRef;
|
||||
|
||||
// TODO(Superjomn) Move to a global utility file?
|
||||
std::string strJoin(llvm::ArrayRef<std::string> strs,
|
||||
llvm::StringRef delimiter);
|
||||
class PTXInstr;
|
||||
class PTXInstrCommon;
|
||||
|
||||
// A helper for building a single inline ASM instruction, the objective of
|
||||
// PtxInstr is to give a thin encapsulation and make the ASM code for MLIR LLVM
|
||||
// Dialect more clear. Currently, several factors are introduced to reduce the
|
||||
// need for mixing string and C++ if-else code.
|
||||
// PTXBuilder helps to manage a PTX asm program consists of one or multiple
|
||||
// instructions.
|
||||
//
|
||||
// A helper for building a ASM program, the objective of PTXBuilder is to give a
|
||||
// thin encapsulation and make the ASM code for MLIR LLVM Dialect more clear.
|
||||
// Currently, several factors are introduced to reduce the need for mixing
|
||||
// string and C++ if-else code.
|
||||
//
|
||||
// Usage:
|
||||
// To build: asm("add.s32 %0, %1, %2;" : "=r"(i) : "r"(j), "r"(k));
|
||||
// To build: asm("@%3 add.s32 %0, %1, %2;" : "=r"(i) : "r"(j), "r"(k), "b"(p));
|
||||
//
|
||||
// PtxInstr mulr("mul");
|
||||
// mulr.o("lo").o("u32").addOperand(valueI, "=r") // %0 bind to valueI
|
||||
// .addOperand(valueJ, "r") // %1 bind to valueJ
|
||||
// .addOperand(valueK, "k"); // %2 bind to valueK
|
||||
// PTXBuilder builder;
|
||||
// auto& add = builder.create<>();
|
||||
// add.predicate(pVal).o("lo").o("u32"); // add any suffix
|
||||
// // predicate here binds %0 to pVal, pVal is a mlir::Value
|
||||
//
|
||||
// mulr.getConstrains() // get "=r,r,k"
|
||||
// mulr.getAllMlirArgs() // get {valueI, valueJ, valueK}
|
||||
// auto* iOpr = builder.newOperand(iVal, "r"); // %1 bind to iVal
|
||||
// auto* jOpr = builder.newOperand(jVal, "r"); // %2 bind to jVal
|
||||
// auto* kOpr = builder.newOperand(kVal, "r"); // %3 bind to kVal
|
||||
// add(iOpr, jOpr, kOpr); // set operands
|
||||
//
|
||||
// TODO(Superjomn) Add multi-line ASM code support and register support later.
|
||||
struct PtxInstr {
|
||||
explicit PtxInstr(const std::string &name) { o(name); }
|
||||
|
||||
// To get the asm code:
|
||||
// builder.dump()
|
||||
//
|
||||
// To get all the mlir::Value used in the PTX code,
|
||||
//
|
||||
// builder.getAllMlirArgs() // get {pVal, iVal, jVal, kVal}
|
||||
//
|
||||
// To get the string containing all the contraints with "," seperated,
|
||||
// builder.getConstrains() // get "=r,r,k"
|
||||
//
|
||||
// PTXBuilder can build a PTX asm with multiple instructions, sample code:
|
||||
//
|
||||
// PTXBuilder builder;
|
||||
// auto& instr0 = builder.create<>();
|
||||
// auto& instr1 = builder.create<>();
|
||||
// auto& instr2 = builder.create<>();
|
||||
//
|
||||
// NOTE, the instructions will be serialized in the order of creation.
|
||||
//
|
||||
// There are several derived instruction type for typical instructions, for
|
||||
// example, the PtxIOInstr for ld and st instructions.
|
||||
struct PTXBuilder {
|
||||
struct Operand {
|
||||
std::string constraint;
|
||||
Value value;
|
||||
@@ -48,16 +71,29 @@ struct PtxInstr {
|
||||
Operand(Value value, StringRef constraint)
|
||||
: value(value), constraint(constraint) {}
|
||||
|
||||
bool isList() const { return !value; }
|
||||
bool isList() const { return !value && constraint.empty(); }
|
||||
|
||||
Operand *listAppend(Operand *arg) {
|
||||
list.push_back(arg);
|
||||
return this;
|
||||
}
|
||||
|
||||
Operand *listGet(size_t nth) const {
|
||||
assert(nth < list.size());
|
||||
return list[nth];
|
||||
}
|
||||
|
||||
std::string dump() const;
|
||||
};
|
||||
|
||||
template <typename INSTR = PTXInstr> INSTR *create(const std::string &name) {
|
||||
instrs.emplace_back(std::make_unique<INSTR>(this, name));
|
||||
return static_cast<INSTR *>(instrs.back().get());
|
||||
}
|
||||
|
||||
// Create a list of operands.
|
||||
Operand *newListOperand() { return newOperand(); }
|
||||
|
||||
// Create a new operand. It will not add to operand list.
|
||||
// @value: the MLIR value bind to this operand.
|
||||
// @constraint: ASM operand constraint, .e.g. "=r"
|
||||
@@ -66,7 +102,65 @@ struct PtxInstr {
|
||||
Operand *newOperand(mlir::Value value, StringRef constraint,
|
||||
std::function<std::string(int idx)> formater = nullptr);
|
||||
|
||||
// Append the operand to the intruction's operand list.
|
||||
// Create a new operand which is written to, that is, the constraint starts
|
||||
// with "=", e.g. "=r".
|
||||
Operand *newOperand(StringRef constraint);
|
||||
|
||||
// Create a constant integer operand.
|
||||
Operand *newConstantOperand(int v);
|
||||
// Create a constant operand with explicit code specified.
|
||||
Operand *newConstantOperand(const std::string &v);
|
||||
|
||||
Operand *newAddrOperand(mlir::Value addr, StringRef constraint, int off = 0);
|
||||
|
||||
llvm::SmallVector<Operand *> getAllArgs() const;
|
||||
|
||||
llvm::SmallVector<Value, 4> getAllMLIRArgs() const;
|
||||
|
||||
std::string getConstrains() const;
|
||||
|
||||
std::string dump() const;
|
||||
|
||||
private:
|
||||
Operand *newOperand() {
|
||||
argArchive.emplace_back(std::make_unique<Operand>());
|
||||
return argArchive.back().get();
|
||||
}
|
||||
|
||||
friend class PTXInstr;
|
||||
|
||||
protected:
|
||||
llvm::SmallVector<std::unique_ptr<Operand>, 6> argArchive;
|
||||
llvm::SmallVector<std::unique_ptr<PTXInstrCommon>, 2> instrs;
|
||||
int oprCounter{};
|
||||
};
|
||||
|
||||
// PTX instruction common interface.
|
||||
// Put the generic logic for all the instructions here.
|
||||
struct PTXInstrCommon {
|
||||
explicit PTXInstrCommon(PTXBuilder *builder) : builder(builder) {}
|
||||
|
||||
using Operand = PTXBuilder::Operand;
|
||||
|
||||
llvm::SmallVector<Operand *> getArgList() const;
|
||||
|
||||
std::string dump() const;
|
||||
|
||||
// clang-format off
|
||||
void operator()(Operand* a) { operator()({a}); }
|
||||
void operator()(Operand* a, Operand* b) { operator()({a, b}); }
|
||||
void operator()(Operand* a, Operand* b, Operand* c) { operator()({a, b, c}); }
|
||||
void operator()(Operand* a, Operand* b, Operand* c, Operand* d) { operator()({a, b, c, d}); }
|
||||
void operator()(Operand* a, Operand* b, Operand* c, Operand* d, Operand * e) { operator()({a, b, c, d, e}); }
|
||||
void operator()(Operand* a, Operand* b, Operand* c, Operand* d, Operand * e, Operand* f) { operator()({a, b, c, d, e, f}); }
|
||||
void operator()(Operand* a, Operand* b, Operand* c, Operand* d, Operand * e, Operand* f, Operand* g) { operator()({a, b, c, d, e, f, g}); }
|
||||
// clang-format on
|
||||
|
||||
// Set operands of this instruction.
|
||||
void operator()(llvm::ArrayRef<Operand *> oprs);
|
||||
|
||||
protected:
|
||||
// Append the operand to the instruction's operand list.
|
||||
Operand *addOperand(Operand *opr) {
|
||||
assert(std::find(argsInOrder.begin(), argsInOrder.end(), opr) ==
|
||||
argsInOrder.end());
|
||||
@@ -74,78 +168,47 @@ struct PtxInstr {
|
||||
return opr;
|
||||
}
|
||||
|
||||
// Create and add an operand to the intruction's operand list.
|
||||
Operand *addOperand(mlir::Value value, StringRef constraint) {
|
||||
auto *opr = newOperand(value, constraint);
|
||||
return addOperand(opr);
|
||||
}
|
||||
PTXBuilder *builder{};
|
||||
Operand *pred{};
|
||||
llvm::SmallVector<std::string, 4> instrParts;
|
||||
llvm::SmallVector<Operand *> argsInOrder;
|
||||
};
|
||||
|
||||
// Prefix a predicate to the instruction.
|
||||
PtxInstr &predicate(mlir::Value value, StringRef constraint) {
|
||||
pred = newOperand(value, constraint);
|
||||
return *this;
|
||||
template <class ConcreteT> struct PTXInstrBase : public PTXInstrCommon {
|
||||
using Operand = PTXBuilder::Operand;
|
||||
|
||||
explicit PTXInstrBase(PTXBuilder *builder, const std::string &name)
|
||||
: PTXInstrCommon(builder) {
|
||||
o(name);
|
||||
}
|
||||
|
||||
// Append a suffix to the instruction.
|
||||
// e.g. PtxInstr("add").o("s32") get a add.s32.
|
||||
// e.g. PTXInstr("add").o("s32") get a add.s32.
|
||||
// A predicate is used to tell whether to apply the suffix, so that no if-else
|
||||
// code needed. e.g. `PtxInstr("add").o("s32", isS32).o("u32", !isS32);` will
|
||||
// code needed. e.g. `PTXInstr("add").o("s32", isS32).o("u32", !isS32);` will
|
||||
// get a `add.s32` if isS32 is true.
|
||||
PtxInstr &o(const std::string &suffix, bool predicate = true) {
|
||||
ConcreteT &o(const std::string &suffix, bool predicate = true) {
|
||||
if (predicate)
|
||||
instrParts.push_back(suffix);
|
||||
return *this;
|
||||
return *static_cast<ConcreteT *>(this);
|
||||
}
|
||||
|
||||
PtxInstr &addListOperation(llvm::ArrayRef<Operand *> list) {
|
||||
auto *opr = newList();
|
||||
for (auto *v : list)
|
||||
opr->listAppend(v);
|
||||
addOperand(opr);
|
||||
return *this;
|
||||
// Prefix a predicate to the instruction.
|
||||
ConcreteT &predicate(mlir::Value value, StringRef constraint) {
|
||||
pred = builder->newOperand(value, constraint);
|
||||
return *static_cast<ConcreteT *>(this);
|
||||
}
|
||||
|
||||
// Create a list of operands.
|
||||
Operand *newList() {
|
||||
argArchive.emplace_back(std::make_unique<Operand>());
|
||||
return argArchive.back().get();
|
||||
// Prefix a !predicate to the instruction.
|
||||
ConcreteT &predicateNot(mlir::Value value, StringRef constraint) {
|
||||
pred = builder->newOperand(value, constraint);
|
||||
pred->repr = [](int idx) { return llvm::formatv("@!%{0}", idx); };
|
||||
return *static_cast<ConcreteT *>(this);
|
||||
}
|
||||
};
|
||||
|
||||
std::string dump() const;
|
||||
|
||||
llvm::SmallVector<Operand *, 4> getArgList() const;
|
||||
llvm::SmallVector<Operand *, 4> getAllArgs() const {
|
||||
llvm::SmallVector<Operand *, 4> res;
|
||||
for (auto &x : argArchive)
|
||||
if (!x->isList())
|
||||
res.push_back(x.get());
|
||||
return res;
|
||||
}
|
||||
|
||||
std::string getConstrains() const {
|
||||
auto args = getAllArgs();
|
||||
llvm::SmallVector<std::string, 4> argReprs;
|
||||
for (auto arg : args)
|
||||
argReprs.push_back(arg->constraint);
|
||||
return strJoin(argReprs, ",");
|
||||
}
|
||||
|
||||
llvm::SmallVector<Value, 4> getAllMlirArgs() const {
|
||||
llvm::SmallVector<Value, 4> res;
|
||||
for (auto &arg : argArchive) {
|
||||
if (!arg->isList())
|
||||
res.push_back(arg->value);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
protected:
|
||||
Operand *pred{};
|
||||
int oprCounter{};
|
||||
llvm::SmallVector<std::string, 4> instrParts;
|
||||
llvm::SmallVector<std::unique_ptr<Operand>, 6> argArchive;
|
||||
llvm::SmallVector<Operand *> argsInOrder;
|
||||
std::string argStr;
|
||||
struct PTXInstr : public PTXInstrBase<PTXInstr> {
|
||||
using PTXInstrBase<PTXInstr>::PTXInstrBase;
|
||||
};
|
||||
|
||||
// A helper for PTX ld/st instruction.
|
||||
@@ -153,8 +216,8 @@ protected:
|
||||
// PtxIOInstr store("st");
|
||||
// store.predicate(pValue).global().v(32).b(1); // @%0 st.global.v32.b1
|
||||
// store.addAddr(addrValue, "l", off);
|
||||
struct PtxIOInstr : public PtxInstr {
|
||||
PtxIOInstr(const std::string &name) : PtxInstr(name) {}
|
||||
struct PtxIOInstr : public PTXInstrBase<PtxIOInstr> {
|
||||
using PTXInstrBase<PtxIOInstr>::PTXInstrBase;
|
||||
|
||||
// Add ".global" suffix to instruction
|
||||
PtxIOInstr &global(bool predicate = true) {
|
||||
@@ -175,14 +238,6 @@ struct PtxIOInstr : public PtxInstr {
|
||||
o(llvm::formatv("b{0}", width));
|
||||
return *this;
|
||||
}
|
||||
|
||||
PtxIOInstr &addAddr(mlir::Value addr, StringRef constraint, int off = 0) {
|
||||
auto *operand = newAddrOperand(addr, constraint, off);
|
||||
addOperand(operand);
|
||||
return *this;
|
||||
}
|
||||
|
||||
Operand *newAddrOperand(mlir::Value addr, StringRef constraint, int off = 0);
|
||||
};
|
||||
|
||||
} // namespace triton
|
||||
|
Reference in New Issue
Block a user