[Triton-MLIR][BACKEND] Refine ptxbuilder (#867)

This PR does

1. Add `onlyBindMLIRArgs` argument to `PTXInstrCommon::call` method to
support passing in a whole PTX code snippet
2. Refine the APIs and simplify the code usage.
This commit is contained in:
Yan Chunwei
2022-11-10 13:41:52 +08:00
committed by GitHub
parent 4640023d9b
commit 8832e32683
4 changed files with 105 additions and 91 deletions

View File

@@ -388,9 +388,9 @@ static Value storeShared(ConversionPatternRewriter &rewriter, Location loc,
const char *c = bits == 64 ? "l" : (bits == 16 ? "h" : "r");
PTXBuilder builder;
auto &st = builder.create<PTXIOInstr>("st")->shared().b(bits);
auto *ptrOpr = builder.newAddrOperand(ptr, "r");
auto *valOpr = builder.newOperand(val, c);
auto &st = builder.create<>("st")->shared().b(bits);
st(ptrOpr, valOpr).predicate(pred, "b");
return builder.launch(rewriter, loc, void_ty(ctx));
}
@@ -1005,7 +1005,6 @@ struct LoadOpConversion
const bool hasL2EvictPolicy = false;
PTXBuilder ptxBuilder;
auto &ld = *ptxBuilder.create<PTXIOInstr>("ld");
Value pred = mask ? maskElems[vecStart] : int_val(1, 1);
@@ -1025,16 +1024,18 @@ struct LoadOpConversion
ptxBuilder.newAddrOperand(ptrElems[vecStart], "l", in_off);
// Define the instruction opcode
ld.o("volatile", op.isVolatile())
.global()
.o("ca", op.cache() == triton::CacheModifier::CA)
.o("cg", op.cache() == triton::CacheModifier::CG)
.o("L1::evict_first",
op.evict() == triton::EvictionPolicy::EVICT_FIRST)
.o("L1::evict_last", op.evict() == triton::EvictionPolicy::EVICT_LAST)
.o("L1::cache_hint", hasL2EvictPolicy)
.v(nWords)
.b(width);
auto &ld = ptxBuilder.create<>("ld")
->o("volatile", op.isVolatile())
.global()
.o("ca", op.cache() == triton::CacheModifier::CA)
.o("cg", op.cache() == triton::CacheModifier::CG)
.o("L1::evict_first",
op.evict() == triton::EvictionPolicy::EVICT_FIRST)
.o("L1::evict_last",
op.evict() == triton::EvictionPolicy::EVICT_LAST)
.o("L1::cache_hint", hasL2EvictPolicy)
.v(nWords)
.b(width);
PTXBuilder::Operand *evictOpr{};
@@ -1049,8 +1050,8 @@ struct LoadOpConversion
if (other) {
for (size_t ii = 0; ii < nWords; ++ii) {
PTXInstr &mov = *ptxBuilder.create<>("mov");
mov.o("u" + std::to_string(width));
PTXInstr &mov =
ptxBuilder.create<>("mov")->o("u" + std::to_string(width));
size_t size = width / valueElemNbits;
@@ -1222,7 +1223,7 @@ struct StoreOpConversion
ptxBuilder.newAddrOperand(ptrElems[vecStart], "l", in_off);
auto &ptxStoreInstr =
ptxBuilder.create<PTXIOInstr>("st")->global().v(nWords).b(width);
ptxBuilder.create<>("st")->global().v(nWords).b(width);
ptxStoreInstr(asmAddr, asmArgList).predicate(maskVal, "b");
Type boolTy = getTypeConverter()->convertType(rewriter.getIntegerType(1));
@@ -4802,7 +4803,7 @@ struct AsyncWaitOpConversion
matchAndRewrite(triton::gpu::AsyncWaitOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
PTXBuilder ptxBuilder;
auto &asyncWaitOp = *ptxBuilder.create<PTXCpAsyncWaitGroupInstr>();
auto &asyncWaitOp = *ptxBuilder.create<>("cp.async.wait_group");
auto num = op->getAttrOfType<IntegerAttr>("num").getInt();
asyncWaitOp(ptxBuilder.newConstantOperand(num));
@@ -5025,7 +5026,7 @@ struct InsertSliceAsyncOpConversion
}
PTXBuilder ptxBuilder;
ptxBuilder.create<PTXCpAsyncCommitGroupInstr>()->operator()();
ptxBuilder.create<>("cp.async.commit_group")->operator()();
ptxBuilder.launch(rewriter, loc, void_ty(getContext()));
rewriter.replaceOp(op, llDst);
return success();
@@ -5178,9 +5179,7 @@ struct AtomicRMWOpConversion
auto *ptrOpr = ptxBuilder.newAddrOperand(rmwPtr, "r");
auto *valOpr = ptxBuilder.newOperand(rmvVal, "r");
auto &atom = *ptxBuilder.create<>("atom");
atom.o("global").o("gpu");
auto &atom = ptxBuilder.create<>("atom")->global().o("gpu");
auto rmwOp = stringifyRMWOp(atomicRmwAttr).str();
auto sBits = std::to_string(valueElemNbits);
switch (atomicRmwAttr) {