[Triton-MLIR][BACKEND] Refine ptxbuilder (#867)
This PR does 1. Add `onlyBindMLIRArgs` argument to `PTXInstrCommon::call` method to support passing in a whole PTX code snippet 2. Refine the APIs and simplify the code usage.
This commit is contained in:
@@ -128,19 +128,26 @@ std::string PTXBuilder::dump() const {
|
||||
return strJoin(lines, "\n\t");
|
||||
}
|
||||
|
||||
PTXInstrExecution &PTXInstrCommon::call(ArrayRef<Operand *> oprs) {
|
||||
PTXInstrExecution &PTXInstrCommon::call(ArrayRef<Operand *> oprs,
|
||||
bool onlyAttachMLIRArgs) {
|
||||
builder->executions.emplace_back(
|
||||
std::make_unique<PTXInstrExecution>(this, oprs));
|
||||
std::make_unique<PTXInstrExecution>(this, oprs, onlyAttachMLIRArgs));
|
||||
return *builder->executions.back();
|
||||
}
|
||||
|
||||
PTXInstrExecution &PTXInstrCommon::operator()(ArrayRef<Operand *> oprs) {
|
||||
return call(oprs);
|
||||
PTXInstrExecution &PTXInstrCommon::operator()(ArrayRef<Operand *> oprs,
|
||||
bool onlyAttachMLIRArgs) {
|
||||
return call(oprs, onlyAttachMLIRArgs);
|
||||
}
|
||||
|
||||
std::string PTXInstrExecution::dump() const {
|
||||
std::string osStr;
|
||||
llvm::raw_string_ostream os(osStr);
|
||||
|
||||
std::string instrRepr = strJoin(instr->instrParts, ".");
|
||||
if (onlyAttachMLIRArgs)
|
||||
return instrRepr;
|
||||
|
||||
if (pred) {
|
||||
if (!pred->repr)
|
||||
os << "@" << pred->dump() << " ";
|
||||
@@ -148,8 +155,6 @@ std::string PTXInstrExecution::dump() const {
|
||||
os << pred->repr(pred->idx) << " ";
|
||||
}
|
||||
|
||||
std::string instrRepr = strJoin(instr->instrParts, ".");
|
||||
|
||||
llvm::SmallVector<std::string, 4> argReprs;
|
||||
for (auto *arg : argsInOrder) {
|
||||
argReprs.push_back(arg->dump());
|
||||
@@ -174,5 +179,27 @@ PTXInstrExecution::getArgList() const {
|
||||
return args;
|
||||
}
|
||||
|
||||
PTXInstr &PTXInstr::global() {
|
||||
o("global");
|
||||
return *this;
|
||||
}
|
||||
|
||||
PTXInstr &PTXInstr::shared() {
|
||||
o("shared");
|
||||
return *this;
|
||||
}
|
||||
|
||||
PTXInstr &PTXInstr::v(int vecWidth, bool predicate) {
|
||||
if (vecWidth > 1) {
|
||||
o("v" + std::to_string(vecWidth), predicate);
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
PTXInstr &PTXInstr::b(int width) {
|
||||
o("b" + std::to_string(width));
|
||||
return *this;
|
||||
}
|
||||
|
||||
} // namespace triton
|
||||
} // namespace mlir
|
||||
|
@@ -388,9 +388,9 @@ static Value storeShared(ConversionPatternRewriter &rewriter, Location loc,
|
||||
const char *c = bits == 64 ? "l" : (bits == 16 ? "h" : "r");
|
||||
|
||||
PTXBuilder builder;
|
||||
auto &st = builder.create<PTXIOInstr>("st")->shared().b(bits);
|
||||
auto *ptrOpr = builder.newAddrOperand(ptr, "r");
|
||||
auto *valOpr = builder.newOperand(val, c);
|
||||
auto &st = builder.create<>("st")->shared().b(bits);
|
||||
st(ptrOpr, valOpr).predicate(pred, "b");
|
||||
return builder.launch(rewriter, loc, void_ty(ctx));
|
||||
}
|
||||
@@ -1005,7 +1005,6 @@ struct LoadOpConversion
|
||||
const bool hasL2EvictPolicy = false;
|
||||
|
||||
PTXBuilder ptxBuilder;
|
||||
auto &ld = *ptxBuilder.create<PTXIOInstr>("ld");
|
||||
|
||||
Value pred = mask ? maskElems[vecStart] : int_val(1, 1);
|
||||
|
||||
@@ -1025,16 +1024,18 @@ struct LoadOpConversion
|
||||
ptxBuilder.newAddrOperand(ptrElems[vecStart], "l", in_off);
|
||||
|
||||
// Define the instruction opcode
|
||||
ld.o("volatile", op.isVolatile())
|
||||
.global()
|
||||
.o("ca", op.cache() == triton::CacheModifier::CA)
|
||||
.o("cg", op.cache() == triton::CacheModifier::CG)
|
||||
.o("L1::evict_first",
|
||||
op.evict() == triton::EvictionPolicy::EVICT_FIRST)
|
||||
.o("L1::evict_last", op.evict() == triton::EvictionPolicy::EVICT_LAST)
|
||||
.o("L1::cache_hint", hasL2EvictPolicy)
|
||||
.v(nWords)
|
||||
.b(width);
|
||||
auto &ld = ptxBuilder.create<>("ld")
|
||||
->o("volatile", op.isVolatile())
|
||||
.global()
|
||||
.o("ca", op.cache() == triton::CacheModifier::CA)
|
||||
.o("cg", op.cache() == triton::CacheModifier::CG)
|
||||
.o("L1::evict_first",
|
||||
op.evict() == triton::EvictionPolicy::EVICT_FIRST)
|
||||
.o("L1::evict_last",
|
||||
op.evict() == triton::EvictionPolicy::EVICT_LAST)
|
||||
.o("L1::cache_hint", hasL2EvictPolicy)
|
||||
.v(nWords)
|
||||
.b(width);
|
||||
|
||||
PTXBuilder::Operand *evictOpr{};
|
||||
|
||||
@@ -1049,8 +1050,8 @@ struct LoadOpConversion
|
||||
|
||||
if (other) {
|
||||
for (size_t ii = 0; ii < nWords; ++ii) {
|
||||
PTXInstr &mov = *ptxBuilder.create<>("mov");
|
||||
mov.o("u" + std::to_string(width));
|
||||
PTXInstr &mov =
|
||||
ptxBuilder.create<>("mov")->o("u" + std::to_string(width));
|
||||
|
||||
size_t size = width / valueElemNbits;
|
||||
|
||||
@@ -1222,7 +1223,7 @@ struct StoreOpConversion
|
||||
ptxBuilder.newAddrOperand(ptrElems[vecStart], "l", in_off);
|
||||
|
||||
auto &ptxStoreInstr =
|
||||
ptxBuilder.create<PTXIOInstr>("st")->global().v(nWords).b(width);
|
||||
ptxBuilder.create<>("st")->global().v(nWords).b(width);
|
||||
ptxStoreInstr(asmAddr, asmArgList).predicate(maskVal, "b");
|
||||
|
||||
Type boolTy = getTypeConverter()->convertType(rewriter.getIntegerType(1));
|
||||
@@ -4802,7 +4803,7 @@ struct AsyncWaitOpConversion
|
||||
matchAndRewrite(triton::gpu::AsyncWaitOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
PTXBuilder ptxBuilder;
|
||||
auto &asyncWaitOp = *ptxBuilder.create<PTXCpAsyncWaitGroupInstr>();
|
||||
auto &asyncWaitOp = *ptxBuilder.create<>("cp.async.wait_group");
|
||||
auto num = op->getAttrOfType<IntegerAttr>("num").getInt();
|
||||
asyncWaitOp(ptxBuilder.newConstantOperand(num));
|
||||
|
||||
@@ -5025,7 +5026,7 @@ struct InsertSliceAsyncOpConversion
|
||||
}
|
||||
|
||||
PTXBuilder ptxBuilder;
|
||||
ptxBuilder.create<PTXCpAsyncCommitGroupInstr>()->operator()();
|
||||
ptxBuilder.create<>("cp.async.commit_group")->operator()();
|
||||
ptxBuilder.launch(rewriter, loc, void_ty(getContext()));
|
||||
rewriter.replaceOp(op, llDst);
|
||||
return success();
|
||||
@@ -5178,9 +5179,7 @@ struct AtomicRMWOpConversion
|
||||
auto *ptrOpr = ptxBuilder.newAddrOperand(rmwPtr, "r");
|
||||
auto *valOpr = ptxBuilder.newOperand(rmvVal, "r");
|
||||
|
||||
auto &atom = *ptxBuilder.create<>("atom");
|
||||
|
||||
atom.o("global").o("gpu");
|
||||
auto &atom = ptxBuilder.create<>("atom")->global().o("gpu");
|
||||
auto rmwOp = stringifyRMWOp(atomicRmwAttr).str();
|
||||
auto sBits = std::to_string(valueElemNbits);
|
||||
switch (atomicRmwAttr) {
|
||||
|
Reference in New Issue
Block a user