|
|
|
@@ -2,6 +2,7 @@
|
|
|
|
|
#define TRITON_CONVERSION_TRITON_GPU_TO_LLVM_ASM_FORMAT_H_
|
|
|
|
|
|
|
|
|
|
#include "mlir/IR/Value.h"
|
|
|
|
|
#include "triton/Dialect/Triton/IR/Dialect.h"
|
|
|
|
|
#include "llvm/ADT/SmallVector.h"
|
|
|
|
|
#include "llvm/ADT/StringRef.h"
|
|
|
|
|
#include <memory>
|
|
|
|
@@ -99,8 +100,9 @@ struct PTXBuilder {
|
|
|
|
|
std::string dump() const;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename INSTR = PTXInstr> INSTR *create(const std::string &name) {
|
|
|
|
|
instrs.emplace_back(std::make_unique<INSTR>(this, name));
|
|
|
|
|
template <typename INSTR = PTXInstr, typename... Args>
|
|
|
|
|
INSTR *create(Args &&...args) {
|
|
|
|
|
instrs.emplace_back(std::make_unique<INSTR>(this, args...));
|
|
|
|
|
return static_cast<INSTR *>(instrs.back().get());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@@ -188,6 +190,7 @@ struct PTXInstrCommon {
|
|
|
|
|
using Operand = PTXBuilder::Operand;
|
|
|
|
|
|
|
|
|
|
// clang-format off
|
|
|
|
|
PTXInstrExecution& operator()() { return call({}); }
|
|
|
|
|
PTXInstrExecution& operator()(Operand* a) { return call({a}); }
|
|
|
|
|
PTXInstrExecution& operator()(Operand* a, Operand* b) { return call({a, b}); }
|
|
|
|
|
PTXInstrExecution& operator()(Operand* a, Operand* b, Operand* c) { return call({a, b, c}); }
|
|
|
|
@@ -238,17 +241,17 @@ struct PTXInstr : public PTXInstrBase<PTXInstr> {
|
|
|
|
|
// PtxIOInstr store("st");
|
|
|
|
|
// store.predicate(pValue).global().v(32).b(1); // @%0 st.global.v32.b1
|
|
|
|
|
// store.addAddr(addrValue, "l", off);
|
|
|
|
|
struct PtxIOInstr : public PTXInstrBase<PtxIOInstr> {
|
|
|
|
|
using PTXInstrBase<PtxIOInstr>::PTXInstrBase;
|
|
|
|
|
struct PTXIOInstr : public PTXInstrBase<PTXIOInstr> {
|
|
|
|
|
using PTXInstrBase<PTXIOInstr>::PTXInstrBase;
|
|
|
|
|
|
|
|
|
|
// Add ".global" suffix to instruction
|
|
|
|
|
PtxIOInstr &global(bool predicate = true) {
|
|
|
|
|
PTXIOInstr &global(bool predicate = true) {
|
|
|
|
|
o("global", predicate);
|
|
|
|
|
return *this;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Add ".v" suffix to instruction
|
|
|
|
|
PtxIOInstr &v(int vecWidth, bool predicate = true) {
|
|
|
|
|
PTXIOInstr &v(int vecWidth, bool predicate = true) {
|
|
|
|
|
if (vecWidth > 1) {
|
|
|
|
|
o("v" + std::to_string(vecWidth), predicate);
|
|
|
|
|
}
|
|
|
|
@@ -256,12 +259,43 @@ struct PtxIOInstr : public PTXInstrBase<PtxIOInstr> {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Add ".b" suffix to instruction
|
|
|
|
|
PtxIOInstr &b(int width) {
|
|
|
|
|
PTXIOInstr &b(int width) {
|
|
|
|
|
o("b" + std::to_string(width));
|
|
|
|
|
return *this;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
struct PTXCpAsyncInstrBase : public PTXInstrBase<PTXCpAsyncInstrBase> {
|
|
|
|
|
explicit PTXCpAsyncInstrBase(PTXBuilder *builder)
|
|
|
|
|
: PTXInstrBase(builder, "cp.async") {}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
struct PTXCpAsyncCommitGroupInstr : public PTXCpAsyncInstrBase {
|
|
|
|
|
explicit PTXCpAsyncCommitGroupInstr(PTXBuilder *builder)
|
|
|
|
|
: PTXCpAsyncInstrBase(builder) {
|
|
|
|
|
o("commit_group");
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
struct PTXCpAsyncWaitGroupInstr : public PTXCpAsyncInstrBase {
|
|
|
|
|
explicit PTXCpAsyncWaitGroupInstr(PTXBuilder *builder)
|
|
|
|
|
: PTXCpAsyncInstrBase(builder) {
|
|
|
|
|
o("wait_group");
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
struct PTXCpAsyncLoadInstr : public PTXCpAsyncInstrBase {
|
|
|
|
|
explicit PTXCpAsyncLoadInstr(PTXBuilder *builder,
|
|
|
|
|
triton::CacheModifier modifier,
|
|
|
|
|
triton::EvictionPolicy policy)
|
|
|
|
|
: PTXCpAsyncInstrBase(builder) {
|
|
|
|
|
o(triton::stringifyCacheModifier(modifier).str());
|
|
|
|
|
o("shared");
|
|
|
|
|
o("global");
|
|
|
|
|
o("L2::" + triton::stringifyEvictionPolicy(policy).str());
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// Record the operands and context for "launching" a PtxInstr.
|
|
|
|
|
struct PTXInstrExecution {
|
|
|
|
|
using Operand = PTXBuilder::Operand;
|
|
|
|
|