[Triton-MLIR][BACKEND] Refine ptxbuilder (#867)
This PR does 1. Add `onlyBindMLIRArgs` argument to `PTXInstrCommon::call` method to support passing in a whole PTX code snippet 2. Refine the APIs and simplify the code usage.
This commit is contained in:
@@ -201,10 +201,17 @@ struct PTXInstrCommon {
|
|||||||
// clang-format on
|
// clang-format on
|
||||||
|
|
||||||
// Set operands of this instruction.
|
// Set operands of this instruction.
|
||||||
PTXInstrExecution &operator()(llvm::ArrayRef<Operand *> oprs);
|
PTXInstrExecution &operator()(llvm::ArrayRef<Operand *> oprs,
|
||||||
|
bool onlyAttachMLIRArgs = false);
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
PTXInstrExecution &call(llvm::ArrayRef<Operand *> oprs);
|
// "Call" the instruction with operands.
|
||||||
|
// \param oprs The operands of this instruction.
|
||||||
|
// \param onlyAttachMLIRArgs Indicate that it simply attach the MLIR Arguments
|
||||||
|
// to the inline Asm without generating the operand ids(such as $0, $1) in PTX
|
||||||
|
// code.
|
||||||
|
PTXInstrExecution &call(llvm::ArrayRef<Operand *> oprs,
|
||||||
|
bool onlyAttachMLIRArgs = false);
|
||||||
|
|
||||||
PTXBuilder *builder{};
|
PTXBuilder *builder{};
|
||||||
llvm::SmallVector<std::string, 4> instrParts;
|
llvm::SmallVector<std::string, 4> instrParts;
|
||||||
@@ -234,70 +241,18 @@ template <class ConcreteT> struct PTXInstrBase : public PTXInstrCommon {
|
|||||||
|
|
||||||
struct PTXInstr : public PTXInstrBase<PTXInstr> {
|
struct PTXInstr : public PTXInstrBase<PTXInstr> {
|
||||||
using PTXInstrBase<PTXInstr>::PTXInstrBase;
|
using PTXInstrBase<PTXInstr>::PTXInstrBase;
|
||||||
};
|
|
||||||
|
|
||||||
// A helper for PTX ld/st instruction.
|
// Append a ".global" to the instruction.
|
||||||
// Usage:
|
PTXInstr &global();
|
||||||
// PtxIOInstr store("st");
|
|
||||||
// store.predicate(pValue).global().v(32).b(1); // @%0 st.global.v32.b1
|
|
||||||
// store.addAddr(addrValue, "l", off);
|
|
||||||
struct PTXIOInstr : public PTXInstrBase<PTXIOInstr> {
|
|
||||||
using PTXInstrBase<PTXIOInstr>::PTXInstrBase;
|
|
||||||
|
|
||||||
// Add ".global" suffix to instruction
|
// Append a ".shared" to the instruction.
|
||||||
PTXIOInstr &global(bool predicate = true) {
|
PTXInstr &shared();
|
||||||
o("global", predicate);
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add ".shared" suffix to instruction
|
// Append a ".v[0-9]+" to the instruction
|
||||||
PTXIOInstr &shared(bool predicate = true) {
|
PTXInstr &v(int vecWidth, bool predicate = true);
|
||||||
o("shared", predicate);
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add ".v" suffix to instruction
|
// Append a".b[0-9]+" to the instruction
|
||||||
PTXIOInstr &v(int vecWidth, bool predicate = true) {
|
PTXInstr &b(int width);
|
||||||
if (vecWidth > 1) {
|
|
||||||
o("v" + std::to_string(vecWidth), predicate);
|
|
||||||
}
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add ".b" suffix to instruction
|
|
||||||
PTXIOInstr &b(int width) {
|
|
||||||
o("b" + std::to_string(width));
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct PTXCpAsyncInstrBase : public PTXInstrBase<PTXCpAsyncInstrBase> {
|
|
||||||
explicit PTXCpAsyncInstrBase(PTXBuilder *builder)
|
|
||||||
: PTXInstrBase(builder, "cp.async") {}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct PTXCpAsyncCommitGroupInstr : public PTXCpAsyncInstrBase {
|
|
||||||
explicit PTXCpAsyncCommitGroupInstr(PTXBuilder *builder)
|
|
||||||
: PTXCpAsyncInstrBase(builder) {
|
|
||||||
o("commit_group");
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct PTXCpAsyncWaitGroupInstr : public PTXCpAsyncInstrBase {
|
|
||||||
explicit PTXCpAsyncWaitGroupInstr(PTXBuilder *builder)
|
|
||||||
: PTXCpAsyncInstrBase(builder) {
|
|
||||||
o("wait_group");
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct PTXCpAsyncLoadInstr : public PTXCpAsyncInstrBase {
|
|
||||||
explicit PTXCpAsyncLoadInstr(PTXBuilder *builder,
|
|
||||||
triton::CacheModifier modifier)
|
|
||||||
: PTXCpAsyncInstrBase(builder) {
|
|
||||||
o(triton::stringifyCacheModifier(modifier).str());
|
|
||||||
o("shared");
|
|
||||||
o("global");
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// Record the operands and context for "launching" a PtxInstr.
|
// Record the operands and context for "launching" a PtxInstr.
|
||||||
@@ -308,8 +263,10 @@ struct PTXInstrExecution {
|
|||||||
|
|
||||||
PTXInstrExecution() = default;
|
PTXInstrExecution() = default;
|
||||||
explicit PTXInstrExecution(PTXInstrCommon *instr,
|
explicit PTXInstrExecution(PTXInstrCommon *instr,
|
||||||
llvm::ArrayRef<Operand *> oprs)
|
llvm::ArrayRef<Operand *> oprs,
|
||||||
: argsInOrder(oprs.begin(), oprs.end()), instr(instr) {}
|
bool onlyAttachMLIRArgs)
|
||||||
|
: argsInOrder(oprs.begin(), oprs.end()), instr(instr),
|
||||||
|
onlyAttachMLIRArgs(onlyAttachMLIRArgs) {}
|
||||||
|
|
||||||
// Prefix a predicate to the instruction.
|
// Prefix a predicate to the instruction.
|
||||||
PTXInstrExecution &predicate(mlir::Value value, StringRef constraint = "b") {
|
PTXInstrExecution &predicate(mlir::Value value, StringRef constraint = "b") {
|
||||||
@@ -330,6 +287,22 @@ struct PTXInstrExecution {
|
|||||||
|
|
||||||
PTXInstrCommon *instr{};
|
PTXInstrCommon *instr{};
|
||||||
Operand *pred{};
|
Operand *pred{};
|
||||||
|
bool onlyAttachMLIRArgs{};
|
||||||
|
};
|
||||||
|
|
||||||
|
//// =============================== Some instruction wrappers
|
||||||
|
///===============================
|
||||||
|
// We add the wrappers to make the usage more intuitive by avoiding mixing the
|
||||||
|
// PTX code with some trivial C++ code.
|
||||||
|
|
||||||
|
struct PTXCpAsyncLoadInstr : PTXInstrBase<PTXCpAsyncLoadInstr> {
|
||||||
|
explicit PTXCpAsyncLoadInstr(PTXBuilder *builder,
|
||||||
|
triton::CacheModifier modifier)
|
||||||
|
: PTXInstrBase(builder, "cp.async") {
|
||||||
|
o(triton::stringifyCacheModifier(modifier).str());
|
||||||
|
o("shared");
|
||||||
|
o("global");
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace triton
|
} // namespace triton
|
||||||
|
@@ -128,19 +128,26 @@ std::string PTXBuilder::dump() const {
|
|||||||
return strJoin(lines, "\n\t");
|
return strJoin(lines, "\n\t");
|
||||||
}
|
}
|
||||||
|
|
||||||
PTXInstrExecution &PTXInstrCommon::call(ArrayRef<Operand *> oprs) {
|
PTXInstrExecution &PTXInstrCommon::call(ArrayRef<Operand *> oprs,
|
||||||
|
bool onlyAttachMLIRArgs) {
|
||||||
builder->executions.emplace_back(
|
builder->executions.emplace_back(
|
||||||
std::make_unique<PTXInstrExecution>(this, oprs));
|
std::make_unique<PTXInstrExecution>(this, oprs, onlyAttachMLIRArgs));
|
||||||
return *builder->executions.back();
|
return *builder->executions.back();
|
||||||
}
|
}
|
||||||
|
|
||||||
PTXInstrExecution &PTXInstrCommon::operator()(ArrayRef<Operand *> oprs) {
|
PTXInstrExecution &PTXInstrCommon::operator()(ArrayRef<Operand *> oprs,
|
||||||
return call(oprs);
|
bool onlyAttachMLIRArgs) {
|
||||||
|
return call(oprs, onlyAttachMLIRArgs);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string PTXInstrExecution::dump() const {
|
std::string PTXInstrExecution::dump() const {
|
||||||
std::string osStr;
|
std::string osStr;
|
||||||
llvm::raw_string_ostream os(osStr);
|
llvm::raw_string_ostream os(osStr);
|
||||||
|
|
||||||
|
std::string instrRepr = strJoin(instr->instrParts, ".");
|
||||||
|
if (onlyAttachMLIRArgs)
|
||||||
|
return instrRepr;
|
||||||
|
|
||||||
if (pred) {
|
if (pred) {
|
||||||
if (!pred->repr)
|
if (!pred->repr)
|
||||||
os << "@" << pred->dump() << " ";
|
os << "@" << pred->dump() << " ";
|
||||||
@@ -148,8 +155,6 @@ std::string PTXInstrExecution::dump() const {
|
|||||||
os << pred->repr(pred->idx) << " ";
|
os << pred->repr(pred->idx) << " ";
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string instrRepr = strJoin(instr->instrParts, ".");
|
|
||||||
|
|
||||||
llvm::SmallVector<std::string, 4> argReprs;
|
llvm::SmallVector<std::string, 4> argReprs;
|
||||||
for (auto *arg : argsInOrder) {
|
for (auto *arg : argsInOrder) {
|
||||||
argReprs.push_back(arg->dump());
|
argReprs.push_back(arg->dump());
|
||||||
@@ -174,5 +179,27 @@ PTXInstrExecution::getArgList() const {
|
|||||||
return args;
|
return args;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
PTXInstr &PTXInstr::global() {
|
||||||
|
o("global");
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
PTXInstr &PTXInstr::shared() {
|
||||||
|
o("shared");
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
PTXInstr &PTXInstr::v(int vecWidth, bool predicate) {
|
||||||
|
if (vecWidth > 1) {
|
||||||
|
o("v" + std::to_string(vecWidth), predicate);
|
||||||
|
}
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
PTXInstr &PTXInstr::b(int width) {
|
||||||
|
o("b" + std::to_string(width));
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace triton
|
} // namespace triton
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
@@ -388,9 +388,9 @@ static Value storeShared(ConversionPatternRewriter &rewriter, Location loc,
|
|||||||
const char *c = bits == 64 ? "l" : (bits == 16 ? "h" : "r");
|
const char *c = bits == 64 ? "l" : (bits == 16 ? "h" : "r");
|
||||||
|
|
||||||
PTXBuilder builder;
|
PTXBuilder builder;
|
||||||
auto &st = builder.create<PTXIOInstr>("st")->shared().b(bits);
|
|
||||||
auto *ptrOpr = builder.newAddrOperand(ptr, "r");
|
auto *ptrOpr = builder.newAddrOperand(ptr, "r");
|
||||||
auto *valOpr = builder.newOperand(val, c);
|
auto *valOpr = builder.newOperand(val, c);
|
||||||
|
auto &st = builder.create<>("st")->shared().b(bits);
|
||||||
st(ptrOpr, valOpr).predicate(pred, "b");
|
st(ptrOpr, valOpr).predicate(pred, "b");
|
||||||
return builder.launch(rewriter, loc, void_ty(ctx));
|
return builder.launch(rewriter, loc, void_ty(ctx));
|
||||||
}
|
}
|
||||||
@@ -1005,7 +1005,6 @@ struct LoadOpConversion
|
|||||||
const bool hasL2EvictPolicy = false;
|
const bool hasL2EvictPolicy = false;
|
||||||
|
|
||||||
PTXBuilder ptxBuilder;
|
PTXBuilder ptxBuilder;
|
||||||
auto &ld = *ptxBuilder.create<PTXIOInstr>("ld");
|
|
||||||
|
|
||||||
Value pred = mask ? maskElems[vecStart] : int_val(1, 1);
|
Value pred = mask ? maskElems[vecStart] : int_val(1, 1);
|
||||||
|
|
||||||
@@ -1025,16 +1024,18 @@ struct LoadOpConversion
|
|||||||
ptxBuilder.newAddrOperand(ptrElems[vecStart], "l", in_off);
|
ptxBuilder.newAddrOperand(ptrElems[vecStart], "l", in_off);
|
||||||
|
|
||||||
// Define the instruction opcode
|
// Define the instruction opcode
|
||||||
ld.o("volatile", op.isVolatile())
|
auto &ld = ptxBuilder.create<>("ld")
|
||||||
.global()
|
->o("volatile", op.isVolatile())
|
||||||
.o("ca", op.cache() == triton::CacheModifier::CA)
|
.global()
|
||||||
.o("cg", op.cache() == triton::CacheModifier::CG)
|
.o("ca", op.cache() == triton::CacheModifier::CA)
|
||||||
.o("L1::evict_first",
|
.o("cg", op.cache() == triton::CacheModifier::CG)
|
||||||
op.evict() == triton::EvictionPolicy::EVICT_FIRST)
|
.o("L1::evict_first",
|
||||||
.o("L1::evict_last", op.evict() == triton::EvictionPolicy::EVICT_LAST)
|
op.evict() == triton::EvictionPolicy::EVICT_FIRST)
|
||||||
.o("L1::cache_hint", hasL2EvictPolicy)
|
.o("L1::evict_last",
|
||||||
.v(nWords)
|
op.evict() == triton::EvictionPolicy::EVICT_LAST)
|
||||||
.b(width);
|
.o("L1::cache_hint", hasL2EvictPolicy)
|
||||||
|
.v(nWords)
|
||||||
|
.b(width);
|
||||||
|
|
||||||
PTXBuilder::Operand *evictOpr{};
|
PTXBuilder::Operand *evictOpr{};
|
||||||
|
|
||||||
@@ -1049,8 +1050,8 @@ struct LoadOpConversion
|
|||||||
|
|
||||||
if (other) {
|
if (other) {
|
||||||
for (size_t ii = 0; ii < nWords; ++ii) {
|
for (size_t ii = 0; ii < nWords; ++ii) {
|
||||||
PTXInstr &mov = *ptxBuilder.create<>("mov");
|
PTXInstr &mov =
|
||||||
mov.o("u" + std::to_string(width));
|
ptxBuilder.create<>("mov")->o("u" + std::to_string(width));
|
||||||
|
|
||||||
size_t size = width / valueElemNbits;
|
size_t size = width / valueElemNbits;
|
||||||
|
|
||||||
@@ -1222,7 +1223,7 @@ struct StoreOpConversion
|
|||||||
ptxBuilder.newAddrOperand(ptrElems[vecStart], "l", in_off);
|
ptxBuilder.newAddrOperand(ptrElems[vecStart], "l", in_off);
|
||||||
|
|
||||||
auto &ptxStoreInstr =
|
auto &ptxStoreInstr =
|
||||||
ptxBuilder.create<PTXIOInstr>("st")->global().v(nWords).b(width);
|
ptxBuilder.create<>("st")->global().v(nWords).b(width);
|
||||||
ptxStoreInstr(asmAddr, asmArgList).predicate(maskVal, "b");
|
ptxStoreInstr(asmAddr, asmArgList).predicate(maskVal, "b");
|
||||||
|
|
||||||
Type boolTy = getTypeConverter()->convertType(rewriter.getIntegerType(1));
|
Type boolTy = getTypeConverter()->convertType(rewriter.getIntegerType(1));
|
||||||
@@ -4802,7 +4803,7 @@ struct AsyncWaitOpConversion
|
|||||||
matchAndRewrite(triton::gpu::AsyncWaitOp op, OpAdaptor adaptor,
|
matchAndRewrite(triton::gpu::AsyncWaitOp op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
PTXBuilder ptxBuilder;
|
PTXBuilder ptxBuilder;
|
||||||
auto &asyncWaitOp = *ptxBuilder.create<PTXCpAsyncWaitGroupInstr>();
|
auto &asyncWaitOp = *ptxBuilder.create<>("cp.async.wait_group");
|
||||||
auto num = op->getAttrOfType<IntegerAttr>("num").getInt();
|
auto num = op->getAttrOfType<IntegerAttr>("num").getInt();
|
||||||
asyncWaitOp(ptxBuilder.newConstantOperand(num));
|
asyncWaitOp(ptxBuilder.newConstantOperand(num));
|
||||||
|
|
||||||
@@ -5025,7 +5026,7 @@ struct InsertSliceAsyncOpConversion
|
|||||||
}
|
}
|
||||||
|
|
||||||
PTXBuilder ptxBuilder;
|
PTXBuilder ptxBuilder;
|
||||||
ptxBuilder.create<PTXCpAsyncCommitGroupInstr>()->operator()();
|
ptxBuilder.create<>("cp.async.commit_group")->operator()();
|
||||||
ptxBuilder.launch(rewriter, loc, void_ty(getContext()));
|
ptxBuilder.launch(rewriter, loc, void_ty(getContext()));
|
||||||
rewriter.replaceOp(op, llDst);
|
rewriter.replaceOp(op, llDst);
|
||||||
return success();
|
return success();
|
||||||
@@ -5178,9 +5179,7 @@ struct AtomicRMWOpConversion
|
|||||||
auto *ptrOpr = ptxBuilder.newAddrOperand(rmwPtr, "r");
|
auto *ptrOpr = ptxBuilder.newAddrOperand(rmwPtr, "r");
|
||||||
auto *valOpr = ptxBuilder.newOperand(rmvVal, "r");
|
auto *valOpr = ptxBuilder.newOperand(rmvVal, "r");
|
||||||
|
|
||||||
auto &atom = *ptxBuilder.create<>("atom");
|
auto &atom = ptxBuilder.create<>("atom")->global().o("gpu");
|
||||||
|
|
||||||
atom.o("global").o("gpu");
|
|
||||||
auto rmwOp = stringifyRMWOp(atomicRmwAttr).str();
|
auto rmwOp = stringifyRMWOp(atomicRmwAttr).str();
|
||||||
auto sBits = std::to_string(valueElemNbits);
|
auto sBits = std::to_string(valueElemNbits);
|
||||||
switch (atomicRmwAttr) {
|
switch (atomicRmwAttr) {
|
||||||
|
@@ -76,7 +76,7 @@ TEST_F(PtxAsmFormatTest, complexInstruction) {
|
|||||||
|
|
||||||
auto &ld =
|
auto &ld =
|
||||||
builder
|
builder
|
||||||
.create<PTXIOInstr>("ld") //
|
.create<>("ld") //
|
||||||
->o("volatile", isVolatile)
|
->o("volatile", isVolatile)
|
||||||
.global()
|
.global()
|
||||||
.o("ca", cache == CacheModifier::CA)
|
.o("ca", cache == CacheModifier::CA)
|
||||||
@@ -121,5 +121,20 @@ TEST_F(PtxAsmFormatTest, MultiLinePTX) {
|
|||||||
EXPECT_EQ(values[1], v[2]); // $1 -> v[2]
|
EXPECT_EQ(values[1], v[2]); // $1 -> v[2]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(PtxAsmFormatTest, onlyAttachMLIRArgs) {
|
||||||
|
PTXBuilder builder;
|
||||||
|
const char *ptxCode =
|
||||||
|
".param .b64 param0;\n" // prepare param0 (format string)
|
||||||
|
"st.param.b64 [param0], %0;\n";
|
||||||
|
|
||||||
|
auto &ptxSnippet = *builder.create(ptxCode);
|
||||||
|
auto *opr = builder.newOperand(v[0], "r");
|
||||||
|
ptxSnippet({opr}, true);
|
||||||
|
|
||||||
|
EXPECT_EQ(builder.dump(), ptxCode);
|
||||||
|
ASSERT_EQ(builder.getAllMLIRArgs()[0], v[0]);
|
||||||
|
ASSERT_EQ(builder.getAllMLIRArgs().size(), 1);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace triton
|
} // namespace triton
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
Reference in New Issue
Block a user