[Triton-MLIR][BACKEND] Refine ptxbuilder (#867)

This PR does

1. Add `onlyBindMLIRArgs` argument to `PTXInstrCommon::call` method to
support passing in a whole PTX code snippet
2. Refine the APIs and simplify the code usage.
This commit is contained in:
Yan Chunwei
2022-11-10 13:41:52 +08:00
committed by GitHub
parent 4640023d9b
commit 8832e32683
4 changed files with 105 additions and 91 deletions

View File

@@ -201,10 +201,17 @@ struct PTXInstrCommon {
// clang-format on
// Set operands of this instruction.
PTXInstrExecution &operator()(llvm::ArrayRef<Operand *> oprs);
PTXInstrExecution &operator()(llvm::ArrayRef<Operand *> oprs,
bool onlyAttachMLIRArgs = false);
protected:
PTXInstrExecution &call(llvm::ArrayRef<Operand *> oprs);
// "Call" the instruction with operands.
// \param oprs The operands of this instruction.
// \param onlyAttachMLIRArgs Indicate that it simply attach the MLIR Arguments
// to the inline Asm without generating the operand ids(such as $0, $1) in PTX
// code.
PTXInstrExecution &call(llvm::ArrayRef<Operand *> oprs,
bool onlyAttachMLIRArgs = false);
PTXBuilder *builder{};
llvm::SmallVector<std::string, 4> instrParts;
@@ -234,70 +241,18 @@ template <class ConcreteT> struct PTXInstrBase : public PTXInstrCommon {
struct PTXInstr : public PTXInstrBase<PTXInstr> {
using PTXInstrBase<PTXInstr>::PTXInstrBase;
};
// A helper for PTX ld/st instruction.
// Usage:
// PtxIOInstr store("st");
// store.predicate(pValue).global().v(32).b(1); // @%0 st.global.v32.b1
// store.addAddr(addrValue, "l", off);
struct PTXIOInstr : public PTXInstrBase<PTXIOInstr> {
using PTXInstrBase<PTXIOInstr>::PTXInstrBase;
// Append a ".global" to the instruction.
PTXInstr &global();
// Add ".global" suffix to instruction
PTXIOInstr &global(bool predicate = true) {
o("global", predicate);
return *this;
}
// Append a ".shared" to the instruction.
PTXInstr &shared();
// Add ".shared" suffix to instruction
PTXIOInstr &shared(bool predicate = true) {
o("shared", predicate);
return *this;
}
// Append a ".v[0-9]+" to the instruction
PTXInstr &v(int vecWidth, bool predicate = true);
// Add ".v" suffix to instruction
PTXIOInstr &v(int vecWidth, bool predicate = true) {
if (vecWidth > 1) {
o("v" + std::to_string(vecWidth), predicate);
}
return *this;
}
// Add ".b" suffix to instruction
PTXIOInstr &b(int width) {
o("b" + std::to_string(width));
return *this;
}
};
struct PTXCpAsyncInstrBase : public PTXInstrBase<PTXCpAsyncInstrBase> {
explicit PTXCpAsyncInstrBase(PTXBuilder *builder)
: PTXInstrBase(builder, "cp.async") {}
};
struct PTXCpAsyncCommitGroupInstr : public PTXCpAsyncInstrBase {
explicit PTXCpAsyncCommitGroupInstr(PTXBuilder *builder)
: PTXCpAsyncInstrBase(builder) {
o("commit_group");
}
};
struct PTXCpAsyncWaitGroupInstr : public PTXCpAsyncInstrBase {
explicit PTXCpAsyncWaitGroupInstr(PTXBuilder *builder)
: PTXCpAsyncInstrBase(builder) {
o("wait_group");
}
};
struct PTXCpAsyncLoadInstr : public PTXCpAsyncInstrBase {
explicit PTXCpAsyncLoadInstr(PTXBuilder *builder,
triton::CacheModifier modifier)
: PTXCpAsyncInstrBase(builder) {
o(triton::stringifyCacheModifier(modifier).str());
o("shared");
o("global");
}
// Append a".b[0-9]+" to the instruction
PTXInstr &b(int width);
};
// Record the operands and context for "launching" a PtxInstr.
@@ -308,8 +263,10 @@ struct PTXInstrExecution {
PTXInstrExecution() = default;
explicit PTXInstrExecution(PTXInstrCommon *instr,
llvm::ArrayRef<Operand *> oprs)
: argsInOrder(oprs.begin(), oprs.end()), instr(instr) {}
llvm::ArrayRef<Operand *> oprs,
bool onlyAttachMLIRArgs)
: argsInOrder(oprs.begin(), oprs.end()), instr(instr),
onlyAttachMLIRArgs(onlyAttachMLIRArgs) {}
// Prefix a predicate to the instruction.
PTXInstrExecution &predicate(mlir::Value value, StringRef constraint = "b") {
@@ -330,6 +287,22 @@ struct PTXInstrExecution {
PTXInstrCommon *instr{};
Operand *pred{};
bool onlyAttachMLIRArgs{};
};
//// =============================== Some instruction wrappers
///===============================
// We add the wrappers to make the usage more intuitive by avoiding mixing the
// PTX code with some trivial C++ code.
struct PTXCpAsyncLoadInstr : PTXInstrBase<PTXCpAsyncLoadInstr> {
explicit PTXCpAsyncLoadInstr(PTXBuilder *builder,
triton::CacheModifier modifier)
: PTXInstrBase(builder, "cp.async") {
o(triton::stringifyCacheModifier(modifier).str());
o("shared");
o("global");
}
};
} // namespace triton