[Triton-MLIR][BACKEND] Refine ptxbuilder (#867)
This PR does 1. Add `onlyBindMLIRArgs` argument to `PTXInstrCommon::call` method to support passing in a whole PTX code snippet 2. Refine the APIs and simplify the code usage.
This commit is contained in:
@@ -201,10 +201,17 @@ struct PTXInstrCommon {
|
||||
// clang-format on
|
||||
|
||||
// Set operands of this instruction.
|
||||
PTXInstrExecution &operator()(llvm::ArrayRef<Operand *> oprs);
|
||||
PTXInstrExecution &operator()(llvm::ArrayRef<Operand *> oprs,
|
||||
bool onlyAttachMLIRArgs = false);
|
||||
|
||||
protected:
|
||||
PTXInstrExecution &call(llvm::ArrayRef<Operand *> oprs);
|
||||
// "Call" the instruction with operands.
|
||||
// \param oprs The operands of this instruction.
|
||||
// \param onlyAttachMLIRArgs Indicate that it simply attach the MLIR Arguments
|
||||
// to the inline Asm without generating the operand ids(such as $0, $1) in PTX
|
||||
// code.
|
||||
PTXInstrExecution &call(llvm::ArrayRef<Operand *> oprs,
|
||||
bool onlyAttachMLIRArgs = false);
|
||||
|
||||
PTXBuilder *builder{};
|
||||
llvm::SmallVector<std::string, 4> instrParts;
|
||||
@@ -234,70 +241,18 @@ template <class ConcreteT> struct PTXInstrBase : public PTXInstrCommon {
|
||||
|
||||
struct PTXInstr : public PTXInstrBase<PTXInstr> {
|
||||
using PTXInstrBase<PTXInstr>::PTXInstrBase;
|
||||
};
|
||||
|
||||
// A helper for PTX ld/st instruction.
|
||||
// Usage:
|
||||
// PtxIOInstr store("st");
|
||||
// store.predicate(pValue).global().v(32).b(1); // @%0 st.global.v32.b1
|
||||
// store.addAddr(addrValue, "l", off);
|
||||
struct PTXIOInstr : public PTXInstrBase<PTXIOInstr> {
|
||||
using PTXInstrBase<PTXIOInstr>::PTXInstrBase;
|
||||
// Append a ".global" to the instruction.
|
||||
PTXInstr &global();
|
||||
|
||||
// Add ".global" suffix to instruction
|
||||
PTXIOInstr &global(bool predicate = true) {
|
||||
o("global", predicate);
|
||||
return *this;
|
||||
}
|
||||
// Append a ".shared" to the instruction.
|
||||
PTXInstr &shared();
|
||||
|
||||
// Add ".shared" suffix to instruction
|
||||
PTXIOInstr &shared(bool predicate = true) {
|
||||
o("shared", predicate);
|
||||
return *this;
|
||||
}
|
||||
// Append a ".v[0-9]+" to the instruction
|
||||
PTXInstr &v(int vecWidth, bool predicate = true);
|
||||
|
||||
// Add ".v" suffix to instruction
|
||||
PTXIOInstr &v(int vecWidth, bool predicate = true) {
|
||||
if (vecWidth > 1) {
|
||||
o("v" + std::to_string(vecWidth), predicate);
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
// Add ".b" suffix to instruction
|
||||
PTXIOInstr &b(int width) {
|
||||
o("b" + std::to_string(width));
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
|
||||
struct PTXCpAsyncInstrBase : public PTXInstrBase<PTXCpAsyncInstrBase> {
|
||||
explicit PTXCpAsyncInstrBase(PTXBuilder *builder)
|
||||
: PTXInstrBase(builder, "cp.async") {}
|
||||
};
|
||||
|
||||
struct PTXCpAsyncCommitGroupInstr : public PTXCpAsyncInstrBase {
|
||||
explicit PTXCpAsyncCommitGroupInstr(PTXBuilder *builder)
|
||||
: PTXCpAsyncInstrBase(builder) {
|
||||
o("commit_group");
|
||||
}
|
||||
};
|
||||
|
||||
struct PTXCpAsyncWaitGroupInstr : public PTXCpAsyncInstrBase {
|
||||
explicit PTXCpAsyncWaitGroupInstr(PTXBuilder *builder)
|
||||
: PTXCpAsyncInstrBase(builder) {
|
||||
o("wait_group");
|
||||
}
|
||||
};
|
||||
|
||||
struct PTXCpAsyncLoadInstr : public PTXCpAsyncInstrBase {
|
||||
explicit PTXCpAsyncLoadInstr(PTXBuilder *builder,
|
||||
triton::CacheModifier modifier)
|
||||
: PTXCpAsyncInstrBase(builder) {
|
||||
o(triton::stringifyCacheModifier(modifier).str());
|
||||
o("shared");
|
||||
o("global");
|
||||
}
|
||||
// Append a".b[0-9]+" to the instruction
|
||||
PTXInstr &b(int width);
|
||||
};
|
||||
|
||||
// Record the operands and context for "launching" a PtxInstr.
|
||||
@@ -308,8 +263,10 @@ struct PTXInstrExecution {
|
||||
|
||||
PTXInstrExecution() = default;
|
||||
explicit PTXInstrExecution(PTXInstrCommon *instr,
|
||||
llvm::ArrayRef<Operand *> oprs)
|
||||
: argsInOrder(oprs.begin(), oprs.end()), instr(instr) {}
|
||||
llvm::ArrayRef<Operand *> oprs,
|
||||
bool onlyAttachMLIRArgs)
|
||||
: argsInOrder(oprs.begin(), oprs.end()), instr(instr),
|
||||
onlyAttachMLIRArgs(onlyAttachMLIRArgs) {}
|
||||
|
||||
// Prefix a predicate to the instruction.
|
||||
PTXInstrExecution &predicate(mlir::Value value, StringRef constraint = "b") {
|
||||
@@ -330,6 +287,22 @@ struct PTXInstrExecution {
|
||||
|
||||
PTXInstrCommon *instr{};
|
||||
Operand *pred{};
|
||||
bool onlyAttachMLIRArgs{};
|
||||
};
|
||||
|
||||
//// =============================== Some instruction wrappers
|
||||
///===============================
|
||||
// We add the wrappers to make the usage more intuitive by avoiding mixing the
|
||||
// PTX code with some trivial C++ code.
|
||||
|
||||
struct PTXCpAsyncLoadInstr : PTXInstrBase<PTXCpAsyncLoadInstr> {
|
||||
explicit PTXCpAsyncLoadInstr(PTXBuilder *builder,
|
||||
triton::CacheModifier modifier)
|
||||
: PTXInstrBase(builder, "cp.async") {
|
||||
o(triton::stringifyCacheModifier(modifier).str());
|
||||
o("shared");
|
||||
o("global");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace triton
|
||||
|
Reference in New Issue
Block a user