[BACKEND] Refactoring codegen for LoadOp with PTXFormat (#77)
This PR does following things: Enhance the PTXFormat by Introducing PTXBuilder to enable multiple instructions in a single asm program override PTXInstr's operator() method to enable instr(opr0, opr1) style of setting operands for an instruction Refactor the PTX code used in LoadOpConversion with PTXFormat Authored-by: goostavz <gzhu@nvidia.com>
This commit is contained in:
@@ -13,29 +13,52 @@ namespace mlir {
|
|||||||
namespace triton {
|
namespace triton {
|
||||||
using llvm::StringRef;
|
using llvm::StringRef;
|
||||||
|
|
||||||
// TODO(Superjomn) Move to a global utility file?
|
class PTXInstr;
|
||||||
std::string strJoin(llvm::ArrayRef<std::string> strs,
|
class PTXInstrCommon;
|
||||||
llvm::StringRef delimiter);
|
|
||||||
|
|
||||||
// A helper for building a single inline ASM instruction, the objective of
|
// PTXBuilder helps to manage a PTX asm program consists of one or multiple
|
||||||
// PtxInstr is to give a thin encapsulation and make the ASM code for MLIR LLVM
|
// instructions.
|
||||||
// Dialect more clear. Currently, several factors are introduced to reduce the
|
//
|
||||||
// need for mixing string and C++ if-else code.
|
// A helper for building a ASM program, the objective of PTXBuilder is to give a
|
||||||
|
// thin encapsulation and make the ASM code for MLIR LLVM Dialect more clear.
|
||||||
|
// Currently, several factors are introduced to reduce the need for mixing
|
||||||
|
// string and C++ if-else code.
|
||||||
|
//
|
||||||
// Usage:
|
// Usage:
|
||||||
// To build: asm("add.s32 %0, %1, %2;" : "=r"(i) : "r"(j), "r"(k));
|
// To build: asm("@%3 add.s32 %0, %1, %2;" : "=r"(i) : "r"(j), "r"(k), "b"(p));
|
||||||
//
|
//
|
||||||
// PtxInstr mulr("mul");
|
// PTXBuilder builder;
|
||||||
// mulr.o("lo").o("u32").addOperand(valueI, "=r") // %0 bind to valueI
|
// auto& add = builder.create<>();
|
||||||
// .addOperand(valueJ, "r") // %1 bind to valueJ
|
// add.predicate(pVal).o("lo").o("u32"); // add any suffix
|
||||||
// .addOperand(valueK, "k"); // %2 bind to valueK
|
// // predicate here binds %0 to pVal, pVal is a mlir::Value
|
||||||
//
|
//
|
||||||
// mulr.getConstrains() // get "=r,r,k"
|
// auto* iOpr = builder.newOperand(iVal, "r"); // %1 bind to iVal
|
||||||
// mulr.getAllMlirArgs() // get {valueI, valueJ, valueK}
|
// auto* jOpr = builder.newOperand(jVal, "r"); // %2 bind to jVal
|
||||||
|
// auto* kOpr = builder.newOperand(kVal, "r"); // %3 bind to kVal
|
||||||
|
// add(iOpr, jOpr, kOpr); // set operands
|
||||||
//
|
//
|
||||||
// TODO(Superjomn) Add multi-line ASM code support and register support later.
|
// To get the asm code:
|
||||||
struct PtxInstr {
|
// builder.dump()
|
||||||
explicit PtxInstr(const std::string &name) { o(name); }
|
//
|
||||||
|
// To get all the mlir::Value used in the PTX code,
|
||||||
|
//
|
||||||
|
// builder.getAllMlirArgs() // get {pVal, iVal, jVal, kVal}
|
||||||
|
//
|
||||||
|
// To get the string containing all the contraints with "," seperated,
|
||||||
|
// builder.getConstrains() // get "=r,r,k"
|
||||||
|
//
|
||||||
|
// PTXBuilder can build a PTX asm with multiple instructions, sample code:
|
||||||
|
//
|
||||||
|
// PTXBuilder builder;
|
||||||
|
// auto& instr0 = builder.create<>();
|
||||||
|
// auto& instr1 = builder.create<>();
|
||||||
|
// auto& instr2 = builder.create<>();
|
||||||
|
//
|
||||||
|
// NOTE, the instructions will be serialized in the order of creation.
|
||||||
|
//
|
||||||
|
// There are several derived instruction type for typical instructions, for
|
||||||
|
// example, the PtxIOInstr for ld and st instructions.
|
||||||
|
struct PTXBuilder {
|
||||||
struct Operand {
|
struct Operand {
|
||||||
std::string constraint;
|
std::string constraint;
|
||||||
Value value;
|
Value value;
|
||||||
@@ -48,16 +71,29 @@ struct PtxInstr {
|
|||||||
Operand(Value value, StringRef constraint)
|
Operand(Value value, StringRef constraint)
|
||||||
: value(value), constraint(constraint) {}
|
: value(value), constraint(constraint) {}
|
||||||
|
|
||||||
bool isList() const { return !value; }
|
bool isList() const { return !value && constraint.empty(); }
|
||||||
|
|
||||||
Operand *listAppend(Operand *arg) {
|
Operand *listAppend(Operand *arg) {
|
||||||
list.push_back(arg);
|
list.push_back(arg);
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Operand *listGet(size_t nth) const {
|
||||||
|
assert(nth < list.size());
|
||||||
|
return list[nth];
|
||||||
|
}
|
||||||
|
|
||||||
std::string dump() const;
|
std::string dump() const;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <typename INSTR = PTXInstr> INSTR *create(const std::string &name) {
|
||||||
|
instrs.emplace_back(std::make_unique<INSTR>(this, name));
|
||||||
|
return static_cast<INSTR *>(instrs.back().get());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a list of operands.
|
||||||
|
Operand *newListOperand() { return newOperand(); }
|
||||||
|
|
||||||
// Create a new operand. It will not add to operand list.
|
// Create a new operand. It will not add to operand list.
|
||||||
// @value: the MLIR value bind to this operand.
|
// @value: the MLIR value bind to this operand.
|
||||||
// @constraint: ASM operand constraint, .e.g. "=r"
|
// @constraint: ASM operand constraint, .e.g. "=r"
|
||||||
@@ -66,7 +102,65 @@ struct PtxInstr {
|
|||||||
Operand *newOperand(mlir::Value value, StringRef constraint,
|
Operand *newOperand(mlir::Value value, StringRef constraint,
|
||||||
std::function<std::string(int idx)> formater = nullptr);
|
std::function<std::string(int idx)> formater = nullptr);
|
||||||
|
|
||||||
// Append the operand to the intruction's operand list.
|
// Create a new operand which is written to, that is, the constraint starts
|
||||||
|
// with "=", e.g. "=r".
|
||||||
|
Operand *newOperand(StringRef constraint);
|
||||||
|
|
||||||
|
// Create a constant integer operand.
|
||||||
|
Operand *newConstantOperand(int v);
|
||||||
|
// Create a constant operand with explicit code specified.
|
||||||
|
Operand *newConstantOperand(const std::string &v);
|
||||||
|
|
||||||
|
Operand *newAddrOperand(mlir::Value addr, StringRef constraint, int off = 0);
|
||||||
|
|
||||||
|
llvm::SmallVector<Operand *> getAllArgs() const;
|
||||||
|
|
||||||
|
llvm::SmallVector<Value, 4> getAllMLIRArgs() const;
|
||||||
|
|
||||||
|
std::string getConstrains() const;
|
||||||
|
|
||||||
|
std::string dump() const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
Operand *newOperand() {
|
||||||
|
argArchive.emplace_back(std::make_unique<Operand>());
|
||||||
|
return argArchive.back().get();
|
||||||
|
}
|
||||||
|
|
||||||
|
friend class PTXInstr;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
llvm::SmallVector<std::unique_ptr<Operand>, 6> argArchive;
|
||||||
|
llvm::SmallVector<std::unique_ptr<PTXInstrCommon>, 2> instrs;
|
||||||
|
int oprCounter{};
|
||||||
|
};
|
||||||
|
|
||||||
|
// PTX instruction common interface.
|
||||||
|
// Put the generic logic for all the instructions here.
|
||||||
|
struct PTXInstrCommon {
|
||||||
|
explicit PTXInstrCommon(PTXBuilder *builder) : builder(builder) {}
|
||||||
|
|
||||||
|
using Operand = PTXBuilder::Operand;
|
||||||
|
|
||||||
|
llvm::SmallVector<Operand *> getArgList() const;
|
||||||
|
|
||||||
|
std::string dump() const;
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
|
void operator()(Operand* a) { operator()({a}); }
|
||||||
|
void operator()(Operand* a, Operand* b) { operator()({a, b}); }
|
||||||
|
void operator()(Operand* a, Operand* b, Operand* c) { operator()({a, b, c}); }
|
||||||
|
void operator()(Operand* a, Operand* b, Operand* c, Operand* d) { operator()({a, b, c, d}); }
|
||||||
|
void operator()(Operand* a, Operand* b, Operand* c, Operand* d, Operand * e) { operator()({a, b, c, d, e}); }
|
||||||
|
void operator()(Operand* a, Operand* b, Operand* c, Operand* d, Operand * e, Operand* f) { operator()({a, b, c, d, e, f}); }
|
||||||
|
void operator()(Operand* a, Operand* b, Operand* c, Operand* d, Operand * e, Operand* f, Operand* g) { operator()({a, b, c, d, e, f, g}); }
|
||||||
|
// clang-format on
|
||||||
|
|
||||||
|
// Set operands of this instruction.
|
||||||
|
void operator()(llvm::ArrayRef<Operand *> oprs);
|
||||||
|
|
||||||
|
protected:
|
||||||
|
// Append the operand to the instruction's operand list.
|
||||||
Operand *addOperand(Operand *opr) {
|
Operand *addOperand(Operand *opr) {
|
||||||
assert(std::find(argsInOrder.begin(), argsInOrder.end(), opr) ==
|
assert(std::find(argsInOrder.begin(), argsInOrder.end(), opr) ==
|
||||||
argsInOrder.end());
|
argsInOrder.end());
|
||||||
@@ -74,78 +168,47 @@ struct PtxInstr {
|
|||||||
return opr;
|
return opr;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create and add an operand to the intruction's operand list.
|
PTXBuilder *builder{};
|
||||||
Operand *addOperand(mlir::Value value, StringRef constraint) {
|
Operand *pred{};
|
||||||
auto *opr = newOperand(value, constraint);
|
llvm::SmallVector<std::string, 4> instrParts;
|
||||||
return addOperand(opr);
|
llvm::SmallVector<Operand *> argsInOrder;
|
||||||
}
|
};
|
||||||
|
|
||||||
// Prefix a predicate to the instruction.
|
template <class ConcreteT> struct PTXInstrBase : public PTXInstrCommon {
|
||||||
PtxInstr &predicate(mlir::Value value, StringRef constraint) {
|
using Operand = PTXBuilder::Operand;
|
||||||
pred = newOperand(value, constraint);
|
|
||||||
return *this;
|
explicit PTXInstrBase(PTXBuilder *builder, const std::string &name)
|
||||||
|
: PTXInstrCommon(builder) {
|
||||||
|
o(name);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Append a suffix to the instruction.
|
// Append a suffix to the instruction.
|
||||||
// e.g. PtxInstr("add").o("s32") get a add.s32.
|
// e.g. PTXInstr("add").o("s32") get a add.s32.
|
||||||
// A predicate is used to tell whether to apply the suffix, so that no if-else
|
// A predicate is used to tell whether to apply the suffix, so that no if-else
|
||||||
// code needed. e.g. `PtxInstr("add").o("s32", isS32).o("u32", !isS32);` will
|
// code needed. e.g. `PTXInstr("add").o("s32", isS32).o("u32", !isS32);` will
|
||||||
// get a `add.s32` if isS32 is true.
|
// get a `add.s32` if isS32 is true.
|
||||||
PtxInstr &o(const std::string &suffix, bool predicate = true) {
|
ConcreteT &o(const std::string &suffix, bool predicate = true) {
|
||||||
if (predicate)
|
if (predicate)
|
||||||
instrParts.push_back(suffix);
|
instrParts.push_back(suffix);
|
||||||
return *this;
|
return *static_cast<ConcreteT *>(this);
|
||||||
}
|
}
|
||||||
|
|
||||||
PtxInstr &addListOperation(llvm::ArrayRef<Operand *> list) {
|
// Prefix a predicate to the instruction.
|
||||||
auto *opr = newList();
|
ConcreteT &predicate(mlir::Value value, StringRef constraint) {
|
||||||
for (auto *v : list)
|
pred = builder->newOperand(value, constraint);
|
||||||
opr->listAppend(v);
|
return *static_cast<ConcreteT *>(this);
|
||||||
addOperand(opr);
|
|
||||||
return *this;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create a list of operands.
|
// Prefix a !predicate to the instruction.
|
||||||
Operand *newList() {
|
ConcreteT &predicateNot(mlir::Value value, StringRef constraint) {
|
||||||
argArchive.emplace_back(std::make_unique<Operand>());
|
pred = builder->newOperand(value, constraint);
|
||||||
return argArchive.back().get();
|
pred->repr = [](int idx) { return llvm::formatv("@!%{0}", idx); };
|
||||||
|
return *static_cast<ConcreteT *>(this);
|
||||||
}
|
}
|
||||||
|
};
|
||||||
|
|
||||||
std::string dump() const;
|
struct PTXInstr : public PTXInstrBase<PTXInstr> {
|
||||||
|
using PTXInstrBase<PTXInstr>::PTXInstrBase;
|
||||||
llvm::SmallVector<Operand *, 4> getArgList() const;
|
|
||||||
llvm::SmallVector<Operand *, 4> getAllArgs() const {
|
|
||||||
llvm::SmallVector<Operand *, 4> res;
|
|
||||||
for (auto &x : argArchive)
|
|
||||||
if (!x->isList())
|
|
||||||
res.push_back(x.get());
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string getConstrains() const {
|
|
||||||
auto args = getAllArgs();
|
|
||||||
llvm::SmallVector<std::string, 4> argReprs;
|
|
||||||
for (auto arg : args)
|
|
||||||
argReprs.push_back(arg->constraint);
|
|
||||||
return strJoin(argReprs, ",");
|
|
||||||
}
|
|
||||||
|
|
||||||
llvm::SmallVector<Value, 4> getAllMlirArgs() const {
|
|
||||||
llvm::SmallVector<Value, 4> res;
|
|
||||||
for (auto &arg : argArchive) {
|
|
||||||
if (!arg->isList())
|
|
||||||
res.push_back(arg->value);
|
|
||||||
}
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
|
|
||||||
protected:
|
|
||||||
Operand *pred{};
|
|
||||||
int oprCounter{};
|
|
||||||
llvm::SmallVector<std::string, 4> instrParts;
|
|
||||||
llvm::SmallVector<std::unique_ptr<Operand>, 6> argArchive;
|
|
||||||
llvm::SmallVector<Operand *> argsInOrder;
|
|
||||||
std::string argStr;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// A helper for PTX ld/st instruction.
|
// A helper for PTX ld/st instruction.
|
||||||
@@ -153,8 +216,8 @@ protected:
|
|||||||
// PtxIOInstr store("st");
|
// PtxIOInstr store("st");
|
||||||
// store.predicate(pValue).global().v(32).b(1); // @%0 st.global.v32.b1
|
// store.predicate(pValue).global().v(32).b(1); // @%0 st.global.v32.b1
|
||||||
// store.addAddr(addrValue, "l", off);
|
// store.addAddr(addrValue, "l", off);
|
||||||
struct PtxIOInstr : public PtxInstr {
|
struct PtxIOInstr : public PTXInstrBase<PtxIOInstr> {
|
||||||
PtxIOInstr(const std::string &name) : PtxInstr(name) {}
|
using PTXInstrBase<PtxIOInstr>::PTXInstrBase;
|
||||||
|
|
||||||
// Add ".global" suffix to instruction
|
// Add ".global" suffix to instruction
|
||||||
PtxIOInstr &global(bool predicate = true) {
|
PtxIOInstr &global(bool predicate = true) {
|
||||||
@@ -175,14 +238,6 @@ struct PtxIOInstr : public PtxInstr {
|
|||||||
o(llvm::formatv("b{0}", width));
|
o(llvm::formatv("b{0}", width));
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
PtxIOInstr &addAddr(mlir::Value addr, StringRef constraint, int off = 0) {
|
|
||||||
auto *operand = newAddrOperand(addr, constraint, off);
|
|
||||||
addOperand(operand);
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
|
|
||||||
Operand *newAddrOperand(mlir::Value addr, StringRef constraint, int off = 0);
|
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace triton
|
} // namespace triton
|
||||||
|
@@ -1,9 +1,11 @@
|
|||||||
#include "triton/Conversion/TritonGPUToLLVM/PtxAsmFormat.h"
|
#include "triton/Conversion/TritonGPUToLLVM/PtxAsmFormat.h"
|
||||||
#include "llvm/Support/raw_ostream.h"
|
#include "llvm/Support/raw_ostream.h"
|
||||||
|
#include <sstream> // unify to llvm::raw_string_ostream ?
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
namespace triton {
|
namespace triton {
|
||||||
|
|
||||||
|
// TODO(Superjomn) Move to a global utility file?
|
||||||
std::string strJoin(llvm::ArrayRef<std::string> strs,
|
std::string strJoin(llvm::ArrayRef<std::string> strs,
|
||||||
llvm::StringRef delimiter) {
|
llvm::StringRef delimiter) {
|
||||||
std::string osStr;
|
std::string osStr;
|
||||||
@@ -16,11 +18,101 @@ std::string strJoin(llvm::ArrayRef<std::string> strs,
|
|||||||
return osStr;
|
return osStr;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string PtxInstr::dump() const {
|
PTXInstr::Operand *
|
||||||
|
PTXBuilder::newOperand(mlir::Value value, StringRef constraint,
|
||||||
|
std::function<std::string(int)> formater) {
|
||||||
|
argArchive.emplace_back(std::make_unique<Operand>(value, constraint));
|
||||||
|
auto *opr = argArchive.back().get();
|
||||||
|
opr->repr = formater;
|
||||||
|
opr->idx = oprCounter++;
|
||||||
|
return opr;
|
||||||
|
}
|
||||||
|
|
||||||
|
PTXBuilder::Operand *PTXBuilder::newOperand(StringRef constraint) {
|
||||||
|
// Constraint should be something like "=r"
|
||||||
|
assert(!constraint.empty() && constraint[0] == '=');
|
||||||
|
auto *opr = newOperand();
|
||||||
|
opr->idx = oprCounter++;
|
||||||
|
opr->constraint = constraint;
|
||||||
|
return opr;
|
||||||
|
}
|
||||||
|
|
||||||
|
PTXBuilder::Operand *PTXBuilder::newConstantOperand(const std::string &v) {
|
||||||
|
argArchive.emplace_back(std::make_unique<Operand>());
|
||||||
|
argArchive.back()->repr = [v](int idx) { return v; };
|
||||||
|
return argArchive.back().get();
|
||||||
|
}
|
||||||
|
|
||||||
|
PTXBuilder::Operand *PTXBuilder::newConstantOperand(int v) {
|
||||||
|
std::stringstream ss;
|
||||||
|
ss << "0x" << std::hex << v;
|
||||||
|
return newConstantOperand(ss.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string PTXBuilder::getConstrains() const {
|
||||||
|
auto args = getAllArgs();
|
||||||
|
llvm::SmallVector<std::string, 4> argReprs;
|
||||||
|
for (auto arg : args)
|
||||||
|
argReprs.push_back(arg->constraint);
|
||||||
|
return strJoin(argReprs, ",");
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm::SmallVector<Value, 4> PTXBuilder::getAllMLIRArgs() const {
|
||||||
|
llvm::SmallVector<Value, 4> res;
|
||||||
|
for (auto &arg : argArchive) {
|
||||||
|
if (!arg->isList() && arg->value)
|
||||||
|
res.push_back(arg->value);
|
||||||
|
}
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<PTXBuilder::Operand *> PTXBuilder::getAllArgs() const {
|
||||||
|
llvm::SmallVector<Operand *, 4> res;
|
||||||
|
for (auto &x : argArchive)
|
||||||
|
if (!x->isList())
|
||||||
|
res.push_back(x.get());
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string PTXInstr::Operand::dump() const {
|
||||||
|
if (repr)
|
||||||
|
return repr(idx);
|
||||||
|
if (!isList())
|
||||||
|
return llvm::formatv("${0}", idx);
|
||||||
|
|
||||||
|
llvm::SmallVector<std::string> oprs;
|
||||||
|
for (auto *opr : list)
|
||||||
|
oprs.push_back(opr->dump());
|
||||||
|
return "{ " + strJoin(oprs, ", ") + " }";
|
||||||
|
}
|
||||||
|
|
||||||
|
PTXInstr::Operand *PTXBuilder::newAddrOperand(mlir::Value addr,
|
||||||
|
StringRef constraint, int off) {
|
||||||
|
auto *opr = newOperand(addr, constraint);
|
||||||
|
opr->repr = [off](int idx) -> std::string {
|
||||||
|
return llvm::formatv("[ ${0} + {1} ]", idx, off);
|
||||||
|
};
|
||||||
|
|
||||||
|
return opr;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string PTXBuilder::dump() const {
|
||||||
|
llvm::SmallVector<std::string> lines;
|
||||||
|
for (auto &instr : instrs) {
|
||||||
|
lines.push_back(instr->dump());
|
||||||
|
}
|
||||||
|
|
||||||
|
return strJoin(lines, "\n\t");
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string PTXInstrCommon::dump() const {
|
||||||
std::string osStr;
|
std::string osStr;
|
||||||
llvm::raw_string_ostream os(osStr);
|
llvm::raw_string_ostream os(osStr);
|
||||||
if (pred)
|
if (pred)
|
||||||
os << "@" << pred->dump() << " ";
|
if (!pred->repr)
|
||||||
|
os << "@" << pred->dump() << " ";
|
||||||
|
else
|
||||||
|
os << pred->repr(pred->idx);
|
||||||
|
|
||||||
std::string instrRepr = strJoin(instrParts, ".");
|
std::string instrRepr = strJoin(instrParts, ".");
|
||||||
|
|
||||||
@@ -36,7 +128,7 @@ std::string PtxInstr::dump() const {
|
|||||||
return osStr;
|
return osStr;
|
||||||
}
|
}
|
||||||
|
|
||||||
llvm::SmallVector<PtxInstr::Operand *, 4> PtxInstr::getArgList() const {
|
SmallVector<PTXInstrCommon::Operand *> PTXInstrCommon::getArgList() const {
|
||||||
SmallVector<Operand *> args;
|
SmallVector<Operand *> args;
|
||||||
for (auto *arg : argsInOrder) {
|
for (auto *arg : argsInOrder) {
|
||||||
if (arg->isList())
|
if (arg->isList())
|
||||||
@@ -47,35 +139,10 @@ llvm::SmallVector<PtxInstr::Operand *, 4> PtxInstr::getArgList() const {
|
|||||||
return args;
|
return args;
|
||||||
}
|
}
|
||||||
|
|
||||||
PtxInstr::Operand *
|
void PTXInstrCommon::operator()(ArrayRef<Operand *> oprs) {
|
||||||
PtxInstr::newOperand(mlir::Value value, StringRef constraint,
|
for (auto *opr : oprs) {
|
||||||
std::function<std::string(int)> formater) {
|
addOperand(opr);
|
||||||
argArchive.emplace_back(std::make_unique<Operand>(value, constraint));
|
}
|
||||||
auto *opr = argArchive.back().get();
|
|
||||||
opr->repr = formater;
|
|
||||||
opr->idx = oprCounter++;
|
|
||||||
return opr;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string PtxInstr::Operand::dump() const {
|
|
||||||
if (repr)
|
|
||||||
return repr(idx);
|
|
||||||
if (!isList())
|
|
||||||
return llvm::formatv("${0}", idx);
|
|
||||||
llvm::SmallVector<std::string> oprs;
|
|
||||||
for (auto *opr : list)
|
|
||||||
oprs.push_back(opr->dump());
|
|
||||||
return "{ " + strJoin(oprs, ", ") + " }";
|
|
||||||
}
|
|
||||||
|
|
||||||
PtxInstr::Operand *PtxIOInstr::newAddrOperand(mlir::Value addr,
|
|
||||||
StringRef constraint, int off) {
|
|
||||||
auto *opr = newOperand(addr, constraint);
|
|
||||||
opr->repr = [off](int idx) -> std::string {
|
|
||||||
return llvm::formatv("[ ${0} + {1} ]", idx, off);
|
|
||||||
};
|
|
||||||
|
|
||||||
return opr;
|
|
||||||
}
|
}
|
||||||
} // namespace triton
|
} // namespace triton
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
@@ -708,18 +708,20 @@ struct StoreOpConversion
|
|||||||
// TODO(Superjomn) deal with cache policy.
|
// TODO(Superjomn) deal with cache policy.
|
||||||
const bool hasL2EvictPolicy = false;
|
const bool hasL2EvictPolicy = false;
|
||||||
|
|
||||||
PtxIOInstr asmStoreInstr("st");
|
PTXBuilder ptxBuilder;
|
||||||
asmStoreInstr.predicate(maskElems[vecIdx], "b");
|
auto &ptxStoreInstr = *ptxBuilder.create<PtxIOInstr>("st");
|
||||||
asmStoreInstr.global().b(width).v(nWords);
|
ptxStoreInstr.predicate(maskElems[vecIdx], "b")
|
||||||
|
.global()
|
||||||
|
.b(width)
|
||||||
|
.v(nWords);
|
||||||
|
|
||||||
llvm::SmallVector<std::string> asmArgs;
|
llvm::SmallVector<std::string> asmArgs;
|
||||||
|
|
||||||
Type valArgTy = IntegerType::get(ctx, width);
|
Type valArgTy = IntegerType::get(ctx, width);
|
||||||
auto wordTy = VectorType::get(wordNElems, valueElemTy);
|
auto wordTy = VectorType::get(wordNElems, valueElemTy);
|
||||||
|
|
||||||
auto *asmAddr =
|
auto *asmAddr = ptxBuilder.newAddrOperand(ptrElems[vecIdx], "l", in_off);
|
||||||
asmStoreInstr.newAddrOperand(ptrElems[vecIdx], "l", in_off);
|
auto *asmArgList = ptxBuilder.newListOperand();
|
||||||
auto *asmArgList = asmStoreInstr.newList();
|
|
||||||
for (int wordIdx = 0; wordIdx < nWords; wordIdx++) {
|
for (int wordIdx = 0; wordIdx < nWords; wordIdx++) {
|
||||||
// llWord is a width-len composition
|
// llWord is a width-len composition
|
||||||
Value llWord = rewriter.create<LLVM::UndefOp>(loc, wordTy);
|
Value llWord = rewriter.create<LLVM::UndefOp>(loc, wordTy);
|
||||||
@@ -740,11 +742,10 @@ struct StoreOpConversion
|
|||||||
llWord = rewriter.create<LLVM::BitcastOp>(loc, valArgTy, llWord);
|
llWord = rewriter.create<LLVM::BitcastOp>(loc, valArgTy, llWord);
|
||||||
std::string constraint =
|
std::string constraint =
|
||||||
(width == 64) ? "l" : ((width == 32) ? "r" : "c");
|
(width == 64) ? "l" : ((width == 32) ? "r" : "c");
|
||||||
asmArgList->listAppend(asmStoreInstr.newOperand(llWord, constraint));
|
asmArgList->listAppend(ptxBuilder.newOperand(llWord, constraint));
|
||||||
}
|
}
|
||||||
|
|
||||||
asmStoreInstr.addOperand(asmAddr);
|
ptxStoreInstr(asmAddr, asmArgList);
|
||||||
asmStoreInstr.addOperand(asmArgList);
|
|
||||||
|
|
||||||
llvm::SmallVector<Type, 4> argTys({mask.getType(), ptr.getType()});
|
llvm::SmallVector<Type, 4> argTys({mask.getType(), ptr.getType()});
|
||||||
for (int i = 0; i < nWords; i++)
|
for (int i = 0; i < nWords; i++)
|
||||||
@@ -753,9 +754,9 @@ struct StoreOpConversion
|
|||||||
auto ASMReturnTy = LLVM::LLVMStructType::getLiteral(ctx, /*returnTy*/ {});
|
auto ASMReturnTy = LLVM::LLVMStructType::getLiteral(ctx, /*returnTy*/ {});
|
||||||
|
|
||||||
auto inlineAsm = rewriter.create<LLVM::InlineAsmOp>(
|
auto inlineAsm = rewriter.create<LLVM::InlineAsmOp>(
|
||||||
loc, ASMReturnTy, asmStoreInstr.getAllMlirArgs(), // operands
|
loc, ASMReturnTy, ptxBuilder.getAllMLIRArgs(), // operands
|
||||||
asmStoreInstr.dump(), // asm_string
|
ptxBuilder.dump(), // asm_string
|
||||||
asmStoreInstr.getConstrains(), // constraints
|
ptxBuilder.getConstrains(), // constraints
|
||||||
// TODO(Superjomn) determine the side effect.
|
// TODO(Superjomn) determine the side effect.
|
||||||
true, // has_side_effects
|
true, // has_side_effects
|
||||||
false, // is_align_stack
|
false, // is_align_stack
|
||||||
@@ -1008,43 +1009,54 @@ struct LoadOpConversion
|
|||||||
// ---
|
// ---
|
||||||
// create inline asm string
|
// create inline asm string
|
||||||
// ---
|
// ---
|
||||||
// TODO: (Superjomn) refactor with AsmInstr abstraction
|
|
||||||
std::ostringstream asmOss;
|
const std::string writeConstrait =
|
||||||
asmOss << "@$" << n_words; // predicate
|
(width == 64) ? "=l" : ((width == 32) ? "=r" : "=c");
|
||||||
asmOss << " ld";
|
const std::string readConstrait =
|
||||||
if (op.isVolatile()) {
|
(width == 64) ? "l" : ((width == 32) ? "r" : "c");
|
||||||
asmOss << ".volatile";
|
|
||||||
|
PTXBuilder ptxBuilder;
|
||||||
|
PtxIOInstr &ld = *ptxBuilder.create<PtxIOInstr>("ld");
|
||||||
|
|
||||||
|
// Define the instruction opcode
|
||||||
|
ld.predicate(pred, "b")
|
||||||
|
.o("violatile", op.isVolatile())
|
||||||
|
.global()
|
||||||
|
.o("ca", op.cache() == triton::CacheModifier::CA)
|
||||||
|
.o("cg", op.cache() == triton::CacheModifier::CG)
|
||||||
|
.o("L1::evict_first",
|
||||||
|
op.evict() == triton::EvictionPolicy::EVICT_FIRST)
|
||||||
|
.o("L1::evict_last", op.evict() == triton::EvictionPolicy::EVICT_LAST)
|
||||||
|
.o("L1::cache_hint", has_l2_evict_policy)
|
||||||
|
.v(n_words)
|
||||||
|
.b(width);
|
||||||
|
|
||||||
|
// prepare asm operands
|
||||||
|
auto *dstsOpr = ptxBuilder.newListOperand();
|
||||||
|
for (int i = 0; i < n_words; i++) {
|
||||||
|
auto *opr = ptxBuilder.newOperand(writeConstrait); // =r operations
|
||||||
|
dstsOpr->listAppend(opr);
|
||||||
}
|
}
|
||||||
asmOss << ".global";
|
auto *addrOpr = ptxBuilder.newAddrOperand(ptr, "l", in_off);
|
||||||
if (op.cache() == triton::CacheModifier::CA)
|
|
||||||
asmOss << ".ca";
|
PTXBuilder::Operand *evictOpr{};
|
||||||
if (op.cache() == triton::CacheModifier::CG)
|
// Here lack a mlir::Value to bind to this operation, so disabled.
|
||||||
asmOss << ".cg";
|
// if (has_l2_evict_policy)
|
||||||
if (op.evict() == triton::EvictionPolicy::EVICT_FIRST)
|
// evictOpr = ptxBuilder.newOperand(l2Evict, "l");
|
||||||
asmOss << ".L1::evict_first";
|
|
||||||
if (op.evict() == triton::EvictionPolicy::EVICT_LAST)
|
if (!evictOpr)
|
||||||
asmOss << ".L1::evict_last";
|
ld(dstsOpr, addrOpr);
|
||||||
if (has_l2_evict_policy)
|
else
|
||||||
asmOss << ".L2::cache_hint";
|
ld(dstsOpr, addrOpr, evictOpr);
|
||||||
if (n_words > 1)
|
|
||||||
asmOss << ".v" << n_words; // vector width
|
|
||||||
asmOss << ".b" << width; // word size
|
|
||||||
asmOss << " {";
|
|
||||||
for (int i = 0; i < n_words; i++) { // return values
|
|
||||||
if (i > 0)
|
|
||||||
asmOss << ",";
|
|
||||||
asmOss << "$" << i;
|
|
||||||
}
|
|
||||||
asmOss << "}";
|
|
||||||
asmOss << ", [ $" << n_words + 1; // load
|
|
||||||
asmOss << " + " << in_off << "]"; // constant offset
|
|
||||||
if (has_l2_evict_policy)
|
|
||||||
asmOss << ", $" << n_words + 2;
|
|
||||||
asmOss << ";";
|
|
||||||
SmallVector<Value> others;
|
SmallVector<Value> others;
|
||||||
if (other != nullptr) {
|
if (other) {
|
||||||
for (size_t ii = 0; ii < n_words; ii++) {
|
for (size_t ii = 0; ii < n_words; ii++) {
|
||||||
|
PTXInstr &mov = *ptxBuilder.create<>("mov");
|
||||||
|
mov.predicateNot(pred, "b").o("u", width);
|
||||||
|
|
||||||
size_t size = width / nbits;
|
size_t size = width / nbits;
|
||||||
|
|
||||||
auto vecTy = LLVM::getFixedVectorType(elemTy, size);
|
auto vecTy = LLVM::getFixedVectorType(elemTy, size);
|
||||||
Value v = rewriter.create<LLVM::UndefOp>(loc, vecTy);
|
Value v = rewriter.create<LLVM::UndefOp>(loc, vecTy);
|
||||||
for (size_t s = 0; s < size; s++) {
|
for (size_t s = 0; s < size; s++) {
|
||||||
@@ -1056,20 +1068,19 @@ struct LoadOpConversion
|
|||||||
}
|
}
|
||||||
v = rewriter.create<LLVM::BitcastOp>(
|
v = rewriter.create<LLVM::BitcastOp>(
|
||||||
loc, IntegerType::get(getContext(), width), v);
|
loc, IntegerType::get(getContext(), width), v);
|
||||||
asmOss << "\n ";
|
|
||||||
asmOss << "@!$" << n_words << " mov.u" << width;
|
PTXInstr::Operand *opr{};
|
||||||
asmOss << " $" << ii << ", ";
|
if (otherIsSplatConstInt) {
|
||||||
std::ios_base::fmtflags flags(asmOss.flags());
|
opr = ptxBuilder.newConstantOperand(splatVal);
|
||||||
if (otherIsSplatConstInt)
|
} else {
|
||||||
asmOss << "0x" << std::hex << splatVal;
|
opr = ptxBuilder.newOperand(v, readConstrait);
|
||||||
else {
|
|
||||||
asmOss << "$" << n_words + has_l2_evict_policy + 2 + ii;
|
|
||||||
others.push_back(v);
|
others.push_back(v);
|
||||||
}
|
}
|
||||||
asmOss.flags(flags);
|
|
||||||
asmOss << ";";
|
mov(dstsOpr->listGet(ii), opr);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ---
|
// ---
|
||||||
// create inline ASM signature
|
// create inline ASM signature
|
||||||
// ---
|
// ---
|
||||||
@@ -1077,39 +1088,18 @@ struct LoadOpConversion
|
|||||||
Type retTy = retTys.size() > 1
|
Type retTy = retTys.size() > 1
|
||||||
? LLVM::LLVMStructType::getLiteral(getContext(), retTys)
|
? LLVM::LLVMStructType::getLiteral(getContext(), retTys)
|
||||||
: retTys[0];
|
: retTys[0];
|
||||||
// ---
|
|
||||||
// create inline ASM constraints
|
|
||||||
// ---
|
|
||||||
std::string asmCstrt;
|
|
||||||
for (int ii = 0; ii < n_words; ii++) {
|
|
||||||
if (ii > 0)
|
|
||||||
asmCstrt += ",";
|
|
||||||
asmCstrt += (width == 64) ? "=l" : ((width == 32) ? "=r" : "=c");
|
|
||||||
}
|
|
||||||
asmCstrt += ",b,l";
|
|
||||||
for (size_t ii = 0; ii < others.size(); ii++) {
|
|
||||||
asmCstrt += ",";
|
|
||||||
asmCstrt += (width == 64) ? "l" : ((width == 32) ? "r" : "c");
|
|
||||||
}
|
|
||||||
if (has_l2_evict_policy) {
|
|
||||||
asmCstrt += ",l";
|
|
||||||
}
|
|
||||||
// ---
|
|
||||||
// finally call inline ASM
|
|
||||||
// ---
|
|
||||||
SmallVector<Value> args = {pred, ptr};
|
|
||||||
for (Value v : others) {
|
|
||||||
args.push_back(v);
|
|
||||||
}
|
|
||||||
// TODO: if (has_l2_evict_policy)
|
// TODO: if (has_l2_evict_policy)
|
||||||
auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(),
|
auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(),
|
||||||
LLVM::AsmDialect::AD_ATT);
|
LLVM::AsmDialect::AD_ATT);
|
||||||
auto inlineAsmOp = rewriter.create<LLVM::InlineAsmOp>(
|
auto inlineAsmOp = rewriter.create<LLVM::InlineAsmOp>(
|
||||||
loc, retTy, /*operands=*/args, /*asm_string=*/asmOss.str(),
|
loc, retTy, /*operands=*/ptxBuilder.getAllMLIRArgs(),
|
||||||
/*constraints=*/asmCstrt, /*has_side_effects=*/true,
|
/*asm_string=*/ptxBuilder.dump(),
|
||||||
|
/*constraints=*/ptxBuilder.getConstrains(), /*has_side_effects=*/true,
|
||||||
/*is_align_stack=*/false, /*asm_dialect=*/asmDialectAttr,
|
/*is_align_stack=*/false, /*asm_dialect=*/asmDialectAttr,
|
||||||
/*operand_attrs=*/ArrayAttr());
|
/*operand_attrs=*/ArrayAttr());
|
||||||
Value ret = inlineAsmOp.getResult(0);
|
Value ret = inlineAsmOp.getResult(0);
|
||||||
|
|
||||||
// ---
|
// ---
|
||||||
// extract and store return values
|
// extract and store return values
|
||||||
// ---
|
// ---
|
||||||
@@ -1135,7 +1125,7 @@ struct LoadOpConversion
|
|||||||
loc, elemTy, rets[ii / tmp], vecIdx);
|
loc, elemTy, rets[ii / tmp], vecIdx);
|
||||||
loadedVals.push_back(loaded);
|
loadedVals.push_back(loaded);
|
||||||
}
|
}
|
||||||
}
|
} // end vec
|
||||||
Type llvmResultStructTy = getTypeConverter()->convertType(resultTy);
|
Type llvmResultStructTy = getTypeConverter()->convertType(resultTy);
|
||||||
Value resultStruct =
|
Value resultStruct =
|
||||||
getStructFromElements(loc, loadedVals, rewriter, llvmResultStructTy);
|
getStructFromElements(loc, loadedVals, rewriter, llvmResultStructTy);
|
||||||
@@ -1314,7 +1304,8 @@ public:
|
|||||||
auto axisAnalysis = runAxisAnalysis(mod);
|
auto axisAnalysis = runAxisAnalysis(mod);
|
||||||
|
|
||||||
// We set a higher benefit here to ensure triton's patterns runs before
|
// We set a higher benefit here to ensure triton's patterns runs before
|
||||||
// arith patterns for some encoding not supported by the community patterns.
|
// arith patterns for some encoding not supported by the community
|
||||||
|
// patterns.
|
||||||
populateTritonToLLVMPatterns(typeConverter, patterns, numWarps,
|
populateTritonToLLVMPatterns(typeConverter, patterns, numWarps,
|
||||||
*axisAnalysis, 10 /*benefit*/);
|
*axisAnalysis, 10 /*benefit*/);
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user