From 8832e326833ecf03ede2d3e96c161d4c26bf023f Mon Sep 17 00:00:00 2001 From: Yan Chunwei Date: Thu, 10 Nov 2022 13:41:52 +0800 Subject: [PATCH] [Triton-MLIR][BACKEND] Refine ptxbuilder (#867) This PR does 1. Add `onlyBindMLIRArgs` argument to `PTXInstrCommon::call` method to support passing in a whole PTX code snippet 2. Refine the APIs and simplify the code usage. --- .../Conversion/TritonGPUToLLVM/PtxAsmFormat.h | 101 +++++++----------- .../TritonGPUToLLVM/PtxAsmFormat.cpp | 39 +++++-- .../TritonGPUToLLVM/TritonGPUToLLVM.cpp | 39 ++++--- .../TritonGPUToLLVM/PtxAsmFormatTest.cpp | 17 ++- 4 files changed, 105 insertions(+), 91 deletions(-) diff --git a/include/triton/Conversion/TritonGPUToLLVM/PtxAsmFormat.h b/include/triton/Conversion/TritonGPUToLLVM/PtxAsmFormat.h index 966da36c2..2765f611d 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/PtxAsmFormat.h +++ b/include/triton/Conversion/TritonGPUToLLVM/PtxAsmFormat.h @@ -201,10 +201,17 @@ struct PTXInstrCommon { // clang-format on // Set operands of this instruction. - PTXInstrExecution &operator()(llvm::ArrayRef oprs); + PTXInstrExecution &operator()(llvm::ArrayRef oprs, + bool onlyAttachMLIRArgs = false); protected: - PTXInstrExecution &call(llvm::ArrayRef oprs); + // "Call" the instruction with operands. + // \param oprs The operands of this instruction. + // \param onlyAttachMLIRArgs Indicate that it simply attach the MLIR Arguments + // to the inline Asm without generating the operand ids(such as $0, $1) in PTX + // code. + PTXInstrExecution &call(llvm::ArrayRef oprs, + bool onlyAttachMLIRArgs = false); PTXBuilder *builder{}; llvm::SmallVector instrParts; @@ -234,70 +241,18 @@ template struct PTXInstrBase : public PTXInstrCommon { struct PTXInstr : public PTXInstrBase { using PTXInstrBase::PTXInstrBase; -}; -// A helper for PTX ld/st instruction. -// Usage: -// PtxIOInstr store("st"); -// store.predicate(pValue).global().v(32).b(1); // @%0 st.global.v32.b1 -// store.addAddr(addrValue, "l", off); -struct PTXIOInstr : public PTXInstrBase { - using PTXInstrBase::PTXInstrBase; + // Append a ".global" to the instruction. + PTXInstr &global(); - // Add ".global" suffix to instruction - PTXIOInstr &global(bool predicate = true) { - o("global", predicate); - return *this; - } + // Append a ".shared" to the instruction. + PTXInstr &shared(); - // Add ".shared" suffix to instruction - PTXIOInstr &shared(bool predicate = true) { - o("shared", predicate); - return *this; - } + // Append a ".v[0-9]+" to the instruction + PTXInstr &v(int vecWidth, bool predicate = true); - // Add ".v" suffix to instruction - PTXIOInstr &v(int vecWidth, bool predicate = true) { - if (vecWidth > 1) { - o("v" + std::to_string(vecWidth), predicate); - } - return *this; - } - - // Add ".b" suffix to instruction - PTXIOInstr &b(int width) { - o("b" + std::to_string(width)); - return *this; - } -}; - -struct PTXCpAsyncInstrBase : public PTXInstrBase { - explicit PTXCpAsyncInstrBase(PTXBuilder *builder) - : PTXInstrBase(builder, "cp.async") {} -}; - -struct PTXCpAsyncCommitGroupInstr : public PTXCpAsyncInstrBase { - explicit PTXCpAsyncCommitGroupInstr(PTXBuilder *builder) - : PTXCpAsyncInstrBase(builder) { - o("commit_group"); - } -}; - -struct PTXCpAsyncWaitGroupInstr : public PTXCpAsyncInstrBase { - explicit PTXCpAsyncWaitGroupInstr(PTXBuilder *builder) - : PTXCpAsyncInstrBase(builder) { - o("wait_group"); - } -}; - -struct PTXCpAsyncLoadInstr : public PTXCpAsyncInstrBase { - explicit PTXCpAsyncLoadInstr(PTXBuilder *builder, - triton::CacheModifier modifier) - : PTXCpAsyncInstrBase(builder) { - o(triton::stringifyCacheModifier(modifier).str()); - o("shared"); - o("global"); - } + // Append a".b[0-9]+" to the instruction + PTXInstr &b(int width); }; // Record the operands and context for "launching" a PtxInstr. @@ -308,8 +263,10 @@ struct PTXInstrExecution { PTXInstrExecution() = default; explicit PTXInstrExecution(PTXInstrCommon *instr, - llvm::ArrayRef oprs) - : argsInOrder(oprs.begin(), oprs.end()), instr(instr) {} + llvm::ArrayRef oprs, + bool onlyAttachMLIRArgs) + : argsInOrder(oprs.begin(), oprs.end()), instr(instr), + onlyAttachMLIRArgs(onlyAttachMLIRArgs) {} // Prefix a predicate to the instruction. PTXInstrExecution &predicate(mlir::Value value, StringRef constraint = "b") { @@ -330,6 +287,22 @@ struct PTXInstrExecution { PTXInstrCommon *instr{}; Operand *pred{}; + bool onlyAttachMLIRArgs{}; +}; + +//// =============================== Some instruction wrappers +///=============================== +// We add the wrappers to make the usage more intuitive by avoiding mixing the +// PTX code with some trivial C++ code. + +struct PTXCpAsyncLoadInstr : PTXInstrBase { + explicit PTXCpAsyncLoadInstr(PTXBuilder *builder, + triton::CacheModifier modifier) + : PTXInstrBase(builder, "cp.async") { + o(triton::stringifyCacheModifier(modifier).str()); + o("shared"); + o("global"); + } }; } // namespace triton diff --git a/lib/Conversion/TritonGPUToLLVM/PtxAsmFormat.cpp b/lib/Conversion/TritonGPUToLLVM/PtxAsmFormat.cpp index 7b257fd3c..55ac2b73a 100644 --- a/lib/Conversion/TritonGPUToLLVM/PtxAsmFormat.cpp +++ b/lib/Conversion/TritonGPUToLLVM/PtxAsmFormat.cpp @@ -128,19 +128,26 @@ std::string PTXBuilder::dump() const { return strJoin(lines, "\n\t"); } -PTXInstrExecution &PTXInstrCommon::call(ArrayRef oprs) { +PTXInstrExecution &PTXInstrCommon::call(ArrayRef oprs, + bool onlyAttachMLIRArgs) { builder->executions.emplace_back( - std::make_unique(this, oprs)); + std::make_unique(this, oprs, onlyAttachMLIRArgs)); return *builder->executions.back(); } -PTXInstrExecution &PTXInstrCommon::operator()(ArrayRef oprs) { - return call(oprs); +PTXInstrExecution &PTXInstrCommon::operator()(ArrayRef oprs, + bool onlyAttachMLIRArgs) { + return call(oprs, onlyAttachMLIRArgs); } std::string PTXInstrExecution::dump() const { std::string osStr; llvm::raw_string_ostream os(osStr); + + std::string instrRepr = strJoin(instr->instrParts, "."); + if (onlyAttachMLIRArgs) + return instrRepr; + if (pred) { if (!pred->repr) os << "@" << pred->dump() << " "; @@ -148,8 +155,6 @@ std::string PTXInstrExecution::dump() const { os << pred->repr(pred->idx) << " "; } - std::string instrRepr = strJoin(instr->instrParts, "."); - llvm::SmallVector argReprs; for (auto *arg : argsInOrder) { argReprs.push_back(arg->dump()); @@ -174,5 +179,27 @@ PTXInstrExecution::getArgList() const { return args; } +PTXInstr &PTXInstr::global() { + o("global"); + return *this; +} + +PTXInstr &PTXInstr::shared() { + o("shared"); + return *this; +} + +PTXInstr &PTXInstr::v(int vecWidth, bool predicate) { + if (vecWidth > 1) { + o("v" + std::to_string(vecWidth), predicate); + } + return *this; +} + +PTXInstr &PTXInstr::b(int width) { + o("b" + std::to_string(width)); + return *this; +} + } // namespace triton } // namespace mlir diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index e9d0942df..30711eb5c 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -388,9 +388,9 @@ static Value storeShared(ConversionPatternRewriter &rewriter, Location loc, const char *c = bits == 64 ? "l" : (bits == 16 ? "h" : "r"); PTXBuilder builder; - auto &st = builder.create("st")->shared().b(bits); auto *ptrOpr = builder.newAddrOperand(ptr, "r"); auto *valOpr = builder.newOperand(val, c); + auto &st = builder.create<>("st")->shared().b(bits); st(ptrOpr, valOpr).predicate(pred, "b"); return builder.launch(rewriter, loc, void_ty(ctx)); } @@ -1005,7 +1005,6 @@ struct LoadOpConversion const bool hasL2EvictPolicy = false; PTXBuilder ptxBuilder; - auto &ld = *ptxBuilder.create("ld"); Value pred = mask ? maskElems[vecStart] : int_val(1, 1); @@ -1025,16 +1024,18 @@ struct LoadOpConversion ptxBuilder.newAddrOperand(ptrElems[vecStart], "l", in_off); // Define the instruction opcode - ld.o("volatile", op.isVolatile()) - .global() - .o("ca", op.cache() == triton::CacheModifier::CA) - .o("cg", op.cache() == triton::CacheModifier::CG) - .o("L1::evict_first", - op.evict() == triton::EvictionPolicy::EVICT_FIRST) - .o("L1::evict_last", op.evict() == triton::EvictionPolicy::EVICT_LAST) - .o("L1::cache_hint", hasL2EvictPolicy) - .v(nWords) - .b(width); + auto &ld = ptxBuilder.create<>("ld") + ->o("volatile", op.isVolatile()) + .global() + .o("ca", op.cache() == triton::CacheModifier::CA) + .o("cg", op.cache() == triton::CacheModifier::CG) + .o("L1::evict_first", + op.evict() == triton::EvictionPolicy::EVICT_FIRST) + .o("L1::evict_last", + op.evict() == triton::EvictionPolicy::EVICT_LAST) + .o("L1::cache_hint", hasL2EvictPolicy) + .v(nWords) + .b(width); PTXBuilder::Operand *evictOpr{}; @@ -1049,8 +1050,8 @@ struct LoadOpConversion if (other) { for (size_t ii = 0; ii < nWords; ++ii) { - PTXInstr &mov = *ptxBuilder.create<>("mov"); - mov.o("u" + std::to_string(width)); + PTXInstr &mov = + ptxBuilder.create<>("mov")->o("u" + std::to_string(width)); size_t size = width / valueElemNbits; @@ -1222,7 +1223,7 @@ struct StoreOpConversion ptxBuilder.newAddrOperand(ptrElems[vecStart], "l", in_off); auto &ptxStoreInstr = - ptxBuilder.create("st")->global().v(nWords).b(width); + ptxBuilder.create<>("st")->global().v(nWords).b(width); ptxStoreInstr(asmAddr, asmArgList).predicate(maskVal, "b"); Type boolTy = getTypeConverter()->convertType(rewriter.getIntegerType(1)); @@ -4802,7 +4803,7 @@ struct AsyncWaitOpConversion matchAndRewrite(triton::gpu::AsyncWaitOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { PTXBuilder ptxBuilder; - auto &asyncWaitOp = *ptxBuilder.create(); + auto &asyncWaitOp = *ptxBuilder.create<>("cp.async.wait_group"); auto num = op->getAttrOfType("num").getInt(); asyncWaitOp(ptxBuilder.newConstantOperand(num)); @@ -5025,7 +5026,7 @@ struct InsertSliceAsyncOpConversion } PTXBuilder ptxBuilder; - ptxBuilder.create()->operator()(); + ptxBuilder.create<>("cp.async.commit_group")->operator()(); ptxBuilder.launch(rewriter, loc, void_ty(getContext())); rewriter.replaceOp(op, llDst); return success(); @@ -5178,9 +5179,7 @@ struct AtomicRMWOpConversion auto *ptrOpr = ptxBuilder.newAddrOperand(rmwPtr, "r"); auto *valOpr = ptxBuilder.newOperand(rmvVal, "r"); - auto &atom = *ptxBuilder.create<>("atom"); - - atom.o("global").o("gpu"); + auto &atom = ptxBuilder.create<>("atom")->global().o("gpu"); auto rmwOp = stringifyRMWOp(atomicRmwAttr).str(); auto sBits = std::to_string(valueElemNbits); switch (atomicRmwAttr) { diff --git a/unittest/Conversion/TritonGPUToLLVM/PtxAsmFormatTest.cpp b/unittest/Conversion/TritonGPUToLLVM/PtxAsmFormatTest.cpp index d4024b5c1..1c3f3fb27 100644 --- a/unittest/Conversion/TritonGPUToLLVM/PtxAsmFormatTest.cpp +++ b/unittest/Conversion/TritonGPUToLLVM/PtxAsmFormatTest.cpp @@ -76,7 +76,7 @@ TEST_F(PtxAsmFormatTest, complexInstruction) { auto &ld = builder - .create("ld") // + .create<>("ld") // ->o("volatile", isVolatile) .global() .o("ca", cache == CacheModifier::CA) @@ -121,5 +121,20 @@ TEST_F(PtxAsmFormatTest, MultiLinePTX) { EXPECT_EQ(values[1], v[2]); // $1 -> v[2] } +TEST_F(PtxAsmFormatTest, onlyAttachMLIRArgs) { + PTXBuilder builder; + const char *ptxCode = + ".param .b64 param0;\n" // prepare param0 (format string) + "st.param.b64 [param0], %0;\n"; + + auto &ptxSnippet = *builder.create(ptxCode); + auto *opr = builder.newOperand(v[0], "r"); + ptxSnippet({opr}, true); + + EXPECT_EQ(builder.dump(), ptxCode); + ASSERT_EQ(builder.getAllMLIRArgs()[0], v[0]); + ASSERT_EQ(builder.getAllMLIRArgs().size(), 1); +} + } // namespace triton } // namespace mlir