[Triton-MLIR] Generate LLVM/PTX code for async ops (#735)

This commit is contained in:
Keren Zhou
2022-10-04 09:37:00 -07:00
committed by GitHub
parent f9d7f2f126
commit 289ff293cc
9 changed files with 412 additions and 57 deletions

View File

@@ -2,6 +2,7 @@
#define TRITON_CONVERSION_TRITON_GPU_TO_LLVM_ASM_FORMAT_H_
#include "mlir/IR/Value.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include <memory>
@@ -99,8 +100,9 @@ struct PTXBuilder {
std::string dump() const;
};
template <typename INSTR = PTXInstr> INSTR *create(const std::string &name) {
instrs.emplace_back(std::make_unique<INSTR>(this, name));
template <typename INSTR = PTXInstr, typename... Args>
INSTR *create(Args &&...args) {
instrs.emplace_back(std::make_unique<INSTR>(this, args...));
return static_cast<INSTR *>(instrs.back().get());
}
@@ -188,6 +190,7 @@ struct PTXInstrCommon {
using Operand = PTXBuilder::Operand;
// clang-format off
PTXInstrExecution& operator()() { return call({}); }
PTXInstrExecution& operator()(Operand* a) { return call({a}); }
PTXInstrExecution& operator()(Operand* a, Operand* b) { return call({a, b}); }
PTXInstrExecution& operator()(Operand* a, Operand* b, Operand* c) { return call({a, b, c}); }
@@ -238,17 +241,17 @@ struct PTXInstr : public PTXInstrBase<PTXInstr> {
// PtxIOInstr store("st");
// store.predicate(pValue).global().v(32).b(1); // @%0 st.global.v32.b1
// store.addAddr(addrValue, "l", off);
struct PtxIOInstr : public PTXInstrBase<PtxIOInstr> {
using PTXInstrBase<PtxIOInstr>::PTXInstrBase;
struct PTXIOInstr : public PTXInstrBase<PTXIOInstr> {
using PTXInstrBase<PTXIOInstr>::PTXInstrBase;
// Add ".global" suffix to instruction
PtxIOInstr &global(bool predicate = true) {
PTXIOInstr &global(bool predicate = true) {
o("global", predicate);
return *this;
}
// Add ".v" suffix to instruction
PtxIOInstr &v(int vecWidth, bool predicate = true) {
PTXIOInstr &v(int vecWidth, bool predicate = true) {
if (vecWidth > 1) {
o("v" + std::to_string(vecWidth), predicate);
}
@@ -256,12 +259,43 @@ struct PtxIOInstr : public PTXInstrBase<PtxIOInstr> {
}
// Add ".b" suffix to instruction
PtxIOInstr &b(int width) {
PTXIOInstr &b(int width) {
o("b" + std::to_string(width));
return *this;
}
};
struct PTXCpAsyncInstrBase : public PTXInstrBase<PTXCpAsyncInstrBase> {
explicit PTXCpAsyncInstrBase(PTXBuilder *builder)
: PTXInstrBase(builder, "cp.async") {}
};
struct PTXCpAsyncCommitGroupInstr : public PTXCpAsyncInstrBase {
explicit PTXCpAsyncCommitGroupInstr(PTXBuilder *builder)
: PTXCpAsyncInstrBase(builder) {
o("commit_group");
}
};
struct PTXCpAsyncWaitGroupInstr : public PTXCpAsyncInstrBase {
explicit PTXCpAsyncWaitGroupInstr(PTXBuilder *builder)
: PTXCpAsyncInstrBase(builder) {
o("wait_group");
}
};
struct PTXCpAsyncLoadInstr : public PTXCpAsyncInstrBase {
explicit PTXCpAsyncLoadInstr(PTXBuilder *builder,
triton::CacheModifier modifier,
triton::EvictionPolicy policy)
: PTXCpAsyncInstrBase(builder) {
o(triton::stringifyCacheModifier(modifier).str());
o("shared");
o("global");
o("L2::" + triton::stringifyEvictionPolicy(policy).str());
}
};
// Record the operands and context for "launching" a PtxInstr.
struct PTXInstrExecution {
using Operand = PTXBuilder::Operand;

View File

@@ -16,7 +16,7 @@ def TT_CacheModifierAttr : I32EnumAttr<
def TT_EvictionPolicyAttr : I32EnumAttr<
"EvictionPolicy", "",
[
I32EnumAttrCase<"NORMAL", 1, "normal">,
I32EnumAttrCase<"NORMAL", 1, "evict_normal">,
I32EnumAttrCase<"EVICT_FIRST", 2, "evict_first">,
I32EnumAttrCase<"EVICT_LAST", 3, "evict_last">
]> {

View File

@@ -24,6 +24,8 @@ unsigned getElemsPerThread(Attribute layout, ArrayRef<int64_t> shape);
SmallVector<unsigned> getSizePerThread(Attribute layout);
SmallVector<unsigned> getThreadsPerCTA(const Attribute &layout);
SmallVector<unsigned> getShapePerCTA(const Attribute &layout);
SmallVector<unsigned> getOrder(const Attribute &layout);

View File

@@ -54,7 +54,7 @@ in memory. For example, a swizzled row-major layout could store its data
as follows:
A_{0, 0} A_{0, 1} A_{0, 2} A_{0, 3} ... [phase 0] \ per_phase = 2
A_{1, 0} A_{0, 1} A_{1, 2} A_{1, 3} ... [phase 0] /
A_{1, 0} A_{1, 1} A_{1, 2} A_{1, 3} ... [phase 0] /
groups of vec=2 elements
are stored contiguously
_ _ _ _ /\_ _ _ _