From 1b513c9866d36e0a3ec6c404cb08985fec4dd78f Mon Sep 17 00:00:00 2001 From: Yan Chunwei Date: Wed, 24 Aug 2022 06:51:13 +0800 Subject: [PATCH] [BACKEND] Refactoring codegen for LoadOp with PTXFormat (#77) This PR does following things: Enhance the PTXFormat by Introducing PTXBuilder to enable multiple instructions in a single asm program override PTXInstr's operator() method to enable instr(opr0, opr1) style of setting operands for an instruction Refactor the PTX code used in LoadOpConversion with PTXFormat Authored-by: goostavz --- .../Conversion/TritonGPUToLLVM/PtxAsmFormat.h | 231 +++++++++++------- .../TritonGPUToLLVM/PtxAsmFormat.cpp | 131 +++++++--- .../TritonGPUToLLVM/TritonGPUToLLVM.cpp | 157 ++++++------ 3 files changed, 316 insertions(+), 203 deletions(-) diff --git a/include/triton/Conversion/TritonGPUToLLVM/PtxAsmFormat.h b/include/triton/Conversion/TritonGPUToLLVM/PtxAsmFormat.h index a5eaff617..fee07b46f 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/PtxAsmFormat.h +++ b/include/triton/Conversion/TritonGPUToLLVM/PtxAsmFormat.h @@ -13,29 +13,52 @@ namespace mlir { namespace triton { using llvm::StringRef; -// TODO(Superjomn) Move to a global utility file? -std::string strJoin(llvm::ArrayRef strs, - llvm::StringRef delimiter); +class PTXInstr; +class PTXInstrCommon; -// A helper for building a single inline ASM instruction, the objective of -// PtxInstr is to give a thin encapsulation and make the ASM code for MLIR LLVM -// Dialect more clear. Currently, several factors are introduced to reduce the -// need for mixing string and C++ if-else code. +// PTXBuilder helps to manage a PTX asm program consists of one or multiple +// instructions. +// +// A helper for building a ASM program, the objective of PTXBuilder is to give a +// thin encapsulation and make the ASM code for MLIR LLVM Dialect more clear. +// Currently, several factors are introduced to reduce the need for mixing +// string and C++ if-else code. +// // Usage: -// To build: asm("add.s32 %0, %1, %2;" : "=r"(i) : "r"(j), "r"(k)); +// To build: asm("@%3 add.s32 %0, %1, %2;" : "=r"(i) : "r"(j), "r"(k), "b"(p)); // -// PtxInstr mulr("mul"); -// mulr.o("lo").o("u32").addOperand(valueI, "=r") // %0 bind to valueI -// .addOperand(valueJ, "r") // %1 bind to valueJ -// .addOperand(valueK, "k"); // %2 bind to valueK +// PTXBuilder builder; +// auto& add = builder.create<>(); +// add.predicate(pVal).o("lo").o("u32"); // add any suffix +// // predicate here binds %0 to pVal, pVal is a mlir::Value // -// mulr.getConstrains() // get "=r,r,k" -// mulr.getAllMlirArgs() // get {valueI, valueJ, valueK} +// auto* iOpr = builder.newOperand(iVal, "r"); // %1 bind to iVal +// auto* jOpr = builder.newOperand(jVal, "r"); // %2 bind to jVal +// auto* kOpr = builder.newOperand(kVal, "r"); // %3 bind to kVal +// add(iOpr, jOpr, kOpr); // set operands // -// TODO(Superjomn) Add multi-line ASM code support and register support later. -struct PtxInstr { - explicit PtxInstr(const std::string &name) { o(name); } - +// To get the asm code: +// builder.dump() +// +// To get all the mlir::Value used in the PTX code, +// +// builder.getAllMlirArgs() // get {pVal, iVal, jVal, kVal} +// +// To get the string containing all the contraints with "," seperated, +// builder.getConstrains() // get "=r,r,k" +// +// PTXBuilder can build a PTX asm with multiple instructions, sample code: +// +// PTXBuilder builder; +// auto& instr0 = builder.create<>(); +// auto& instr1 = builder.create<>(); +// auto& instr2 = builder.create<>(); +// +// NOTE, the instructions will be serialized in the order of creation. +// +// There are several derived instruction type for typical instructions, for +// example, the PtxIOInstr for ld and st instructions. +struct PTXBuilder { struct Operand { std::string constraint; Value value; @@ -48,16 +71,29 @@ struct PtxInstr { Operand(Value value, StringRef constraint) : value(value), constraint(constraint) {} - bool isList() const { return !value; } + bool isList() const { return !value && constraint.empty(); } Operand *listAppend(Operand *arg) { list.push_back(arg); return this; } + Operand *listGet(size_t nth) const { + assert(nth < list.size()); + return list[nth]; + } + std::string dump() const; }; + template INSTR *create(const std::string &name) { + instrs.emplace_back(std::make_unique(this, name)); + return static_cast(instrs.back().get()); + } + + // Create a list of operands. + Operand *newListOperand() { return newOperand(); } + // Create a new operand. It will not add to operand list. // @value: the MLIR value bind to this operand. // @constraint: ASM operand constraint, .e.g. "=r" @@ -66,7 +102,65 @@ struct PtxInstr { Operand *newOperand(mlir::Value value, StringRef constraint, std::function formater = nullptr); - // Append the operand to the intruction's operand list. + // Create a new operand which is written to, that is, the constraint starts + // with "=", e.g. "=r". + Operand *newOperand(StringRef constraint); + + // Create a constant integer operand. + Operand *newConstantOperand(int v); + // Create a constant operand with explicit code specified. + Operand *newConstantOperand(const std::string &v); + + Operand *newAddrOperand(mlir::Value addr, StringRef constraint, int off = 0); + + llvm::SmallVector getAllArgs() const; + + llvm::SmallVector getAllMLIRArgs() const; + + std::string getConstrains() const; + + std::string dump() const; + +private: + Operand *newOperand() { + argArchive.emplace_back(std::make_unique()); + return argArchive.back().get(); + } + + friend class PTXInstr; + +protected: + llvm::SmallVector, 6> argArchive; + llvm::SmallVector, 2> instrs; + int oprCounter{}; +}; + +// PTX instruction common interface. +// Put the generic logic for all the instructions here. +struct PTXInstrCommon { + explicit PTXInstrCommon(PTXBuilder *builder) : builder(builder) {} + + using Operand = PTXBuilder::Operand; + + llvm::SmallVector getArgList() const; + + std::string dump() const; + + // clang-format off + void operator()(Operand* a) { operator()({a}); } + void operator()(Operand* a, Operand* b) { operator()({a, b}); } + void operator()(Operand* a, Operand* b, Operand* c) { operator()({a, b, c}); } + void operator()(Operand* a, Operand* b, Operand* c, Operand* d) { operator()({a, b, c, d}); } + void operator()(Operand* a, Operand* b, Operand* c, Operand* d, Operand * e) { operator()({a, b, c, d, e}); } + void operator()(Operand* a, Operand* b, Operand* c, Operand* d, Operand * e, Operand* f) { operator()({a, b, c, d, e, f}); } + void operator()(Operand* a, Operand* b, Operand* c, Operand* d, Operand * e, Operand* f, Operand* g) { operator()({a, b, c, d, e, f, g}); } + // clang-format on + + // Set operands of this instruction. + void operator()(llvm::ArrayRef oprs); + +protected: + // Append the operand to the instruction's operand list. Operand *addOperand(Operand *opr) { assert(std::find(argsInOrder.begin(), argsInOrder.end(), opr) == argsInOrder.end()); @@ -74,78 +168,47 @@ struct PtxInstr { return opr; } - // Create and add an operand to the intruction's operand list. - Operand *addOperand(mlir::Value value, StringRef constraint) { - auto *opr = newOperand(value, constraint); - return addOperand(opr); - } + PTXBuilder *builder{}; + Operand *pred{}; + llvm::SmallVector instrParts; + llvm::SmallVector argsInOrder; +}; - // Prefix a predicate to the instruction. - PtxInstr &predicate(mlir::Value value, StringRef constraint) { - pred = newOperand(value, constraint); - return *this; +template struct PTXInstrBase : public PTXInstrCommon { + using Operand = PTXBuilder::Operand; + + explicit PTXInstrBase(PTXBuilder *builder, const std::string &name) + : PTXInstrCommon(builder) { + o(name); } // Append a suffix to the instruction. - // e.g. PtxInstr("add").o("s32") get a add.s32. + // e.g. PTXInstr("add").o("s32") get a add.s32. // A predicate is used to tell whether to apply the suffix, so that no if-else - // code needed. e.g. `PtxInstr("add").o("s32", isS32).o("u32", !isS32);` will + // code needed. e.g. `PTXInstr("add").o("s32", isS32).o("u32", !isS32);` will // get a `add.s32` if isS32 is true. - PtxInstr &o(const std::string &suffix, bool predicate = true) { + ConcreteT &o(const std::string &suffix, bool predicate = true) { if (predicate) instrParts.push_back(suffix); - return *this; + return *static_cast(this); } - PtxInstr &addListOperation(llvm::ArrayRef list) { - auto *opr = newList(); - for (auto *v : list) - opr->listAppend(v); - addOperand(opr); - return *this; + // Prefix a predicate to the instruction. + ConcreteT &predicate(mlir::Value value, StringRef constraint) { + pred = builder->newOperand(value, constraint); + return *static_cast(this); } - // Create a list of operands. - Operand *newList() { - argArchive.emplace_back(std::make_unique()); - return argArchive.back().get(); + // Prefix a !predicate to the instruction. + ConcreteT &predicateNot(mlir::Value value, StringRef constraint) { + pred = builder->newOperand(value, constraint); + pred->repr = [](int idx) { return llvm::formatv("@!%{0}", idx); }; + return *static_cast(this); } +}; - std::string dump() const; - - llvm::SmallVector getArgList() const; - llvm::SmallVector getAllArgs() const { - llvm::SmallVector res; - for (auto &x : argArchive) - if (!x->isList()) - res.push_back(x.get()); - return res; - } - - std::string getConstrains() const { - auto args = getAllArgs(); - llvm::SmallVector argReprs; - for (auto arg : args) - argReprs.push_back(arg->constraint); - return strJoin(argReprs, ","); - } - - llvm::SmallVector getAllMlirArgs() const { - llvm::SmallVector res; - for (auto &arg : argArchive) { - if (!arg->isList()) - res.push_back(arg->value); - } - return res; - } - -protected: - Operand *pred{}; - int oprCounter{}; - llvm::SmallVector instrParts; - llvm::SmallVector, 6> argArchive; - llvm::SmallVector argsInOrder; - std::string argStr; +struct PTXInstr : public PTXInstrBase { + using PTXInstrBase::PTXInstrBase; }; // A helper for PTX ld/st instruction. @@ -153,8 +216,8 @@ protected: // PtxIOInstr store("st"); // store.predicate(pValue).global().v(32).b(1); // @%0 st.global.v32.b1 // store.addAddr(addrValue, "l", off); -struct PtxIOInstr : public PtxInstr { - PtxIOInstr(const std::string &name) : PtxInstr(name) {} +struct PtxIOInstr : public PTXInstrBase { + using PTXInstrBase::PTXInstrBase; // Add ".global" suffix to instruction PtxIOInstr &global(bool predicate = true) { @@ -175,14 +238,6 @@ struct PtxIOInstr : public PtxInstr { o(llvm::formatv("b{0}", width)); return *this; } - - PtxIOInstr &addAddr(mlir::Value addr, StringRef constraint, int off = 0) { - auto *operand = newAddrOperand(addr, constraint, off); - addOperand(operand); - return *this; - } - - Operand *newAddrOperand(mlir::Value addr, StringRef constraint, int off = 0); }; } // namespace triton diff --git a/lib/Conversion/TritonGPUToLLVM/PtxAsmFormat.cpp b/lib/Conversion/TritonGPUToLLVM/PtxAsmFormat.cpp index ee126df18..00a9c0d03 100644 --- a/lib/Conversion/TritonGPUToLLVM/PtxAsmFormat.cpp +++ b/lib/Conversion/TritonGPUToLLVM/PtxAsmFormat.cpp @@ -1,9 +1,11 @@ #include "triton/Conversion/TritonGPUToLLVM/PtxAsmFormat.h" #include "llvm/Support/raw_ostream.h" +#include // unify to llvm::raw_string_ostream ? namespace mlir { namespace triton { +// TODO(Superjomn) Move to a global utility file? std::string strJoin(llvm::ArrayRef strs, llvm::StringRef delimiter) { std::string osStr; @@ -16,11 +18,101 @@ std::string strJoin(llvm::ArrayRef strs, return osStr; } -std::string PtxInstr::dump() const { +PTXInstr::Operand * +PTXBuilder::newOperand(mlir::Value value, StringRef constraint, + std::function formater) { + argArchive.emplace_back(std::make_unique(value, constraint)); + auto *opr = argArchive.back().get(); + opr->repr = formater; + opr->idx = oprCounter++; + return opr; +} + +PTXBuilder::Operand *PTXBuilder::newOperand(StringRef constraint) { + // Constraint should be something like "=r" + assert(!constraint.empty() && constraint[0] == '='); + auto *opr = newOperand(); + opr->idx = oprCounter++; + opr->constraint = constraint; + return opr; +} + +PTXBuilder::Operand *PTXBuilder::newConstantOperand(const std::string &v) { + argArchive.emplace_back(std::make_unique()); + argArchive.back()->repr = [v](int idx) { return v; }; + return argArchive.back().get(); +} + +PTXBuilder::Operand *PTXBuilder::newConstantOperand(int v) { + std::stringstream ss; + ss << "0x" << std::hex << v; + return newConstantOperand(ss.str()); +} + +std::string PTXBuilder::getConstrains() const { + auto args = getAllArgs(); + llvm::SmallVector argReprs; + for (auto arg : args) + argReprs.push_back(arg->constraint); + return strJoin(argReprs, ","); +} + +llvm::SmallVector PTXBuilder::getAllMLIRArgs() const { + llvm::SmallVector res; + for (auto &arg : argArchive) { + if (!arg->isList() && arg->value) + res.push_back(arg->value); + } + return res; +} + +SmallVector PTXBuilder::getAllArgs() const { + llvm::SmallVector res; + for (auto &x : argArchive) + if (!x->isList()) + res.push_back(x.get()); + return res; +} + +std::string PTXInstr::Operand::dump() const { + if (repr) + return repr(idx); + if (!isList()) + return llvm::formatv("${0}", idx); + + llvm::SmallVector oprs; + for (auto *opr : list) + oprs.push_back(opr->dump()); + return "{ " + strJoin(oprs, ", ") + " }"; +} + +PTXInstr::Operand *PTXBuilder::newAddrOperand(mlir::Value addr, + StringRef constraint, int off) { + auto *opr = newOperand(addr, constraint); + opr->repr = [off](int idx) -> std::string { + return llvm::formatv("[ ${0} + {1} ]", idx, off); + }; + + return opr; +} + +std::string PTXBuilder::dump() const { + llvm::SmallVector lines; + for (auto &instr : instrs) { + lines.push_back(instr->dump()); + } + + return strJoin(lines, "\n\t"); +} + +std::string PTXInstrCommon::dump() const { std::string osStr; llvm::raw_string_ostream os(osStr); if (pred) - os << "@" << pred->dump() << " "; + if (!pred->repr) + os << "@" << pred->dump() << " "; + else + os << pred->repr(pred->idx); std::string instrRepr = strJoin(instrParts, "."); @@ -36,7 +128,7 @@ std::string PtxInstr::dump() const { return osStr; } -llvm::SmallVector PtxInstr::getArgList() const { +SmallVector PTXInstrCommon::getArgList() const { SmallVector args; for (auto *arg : argsInOrder) { if (arg->isList()) @@ -47,35 +139,10 @@ llvm::SmallVector PtxInstr::getArgList() const { return args; } -PtxInstr::Operand * -PtxInstr::newOperand(mlir::Value value, StringRef constraint, - std::function formater) { - argArchive.emplace_back(std::make_unique(value, constraint)); - auto *opr = argArchive.back().get(); - opr->repr = formater; - opr->idx = oprCounter++; - return opr; -} - -std::string PtxInstr::Operand::dump() const { - if (repr) - return repr(idx); - if (!isList()) - return llvm::formatv("${0}", idx); - llvm::SmallVector oprs; - for (auto *opr : list) - oprs.push_back(opr->dump()); - return "{ " + strJoin(oprs, ", ") + " }"; -} - -PtxInstr::Operand *PtxIOInstr::newAddrOperand(mlir::Value addr, - StringRef constraint, int off) { - auto *opr = newOperand(addr, constraint); - opr->repr = [off](int idx) -> std::string { - return llvm::formatv("[ ${0} + {1} ]", idx, off); - }; - - return opr; +void PTXInstrCommon::operator()(ArrayRef oprs) { + for (auto *opr : oprs) { + addOperand(opr); + } } } // namespace triton } // namespace mlir diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index 96b24f1ef..48e1a7818 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -708,18 +708,20 @@ struct StoreOpConversion // TODO(Superjomn) deal with cache policy. const bool hasL2EvictPolicy = false; - PtxIOInstr asmStoreInstr("st"); - asmStoreInstr.predicate(maskElems[vecIdx], "b"); - asmStoreInstr.global().b(width).v(nWords); + PTXBuilder ptxBuilder; + auto &ptxStoreInstr = *ptxBuilder.create("st"); + ptxStoreInstr.predicate(maskElems[vecIdx], "b") + .global() + .b(width) + .v(nWords); llvm::SmallVector asmArgs; Type valArgTy = IntegerType::get(ctx, width); auto wordTy = VectorType::get(wordNElems, valueElemTy); - auto *asmAddr = - asmStoreInstr.newAddrOperand(ptrElems[vecIdx], "l", in_off); - auto *asmArgList = asmStoreInstr.newList(); + auto *asmAddr = ptxBuilder.newAddrOperand(ptrElems[vecIdx], "l", in_off); + auto *asmArgList = ptxBuilder.newListOperand(); for (int wordIdx = 0; wordIdx < nWords; wordIdx++) { // llWord is a width-len composition Value llWord = rewriter.create(loc, wordTy); @@ -740,11 +742,10 @@ struct StoreOpConversion llWord = rewriter.create(loc, valArgTy, llWord); std::string constraint = (width == 64) ? "l" : ((width == 32) ? "r" : "c"); - asmArgList->listAppend(asmStoreInstr.newOperand(llWord, constraint)); + asmArgList->listAppend(ptxBuilder.newOperand(llWord, constraint)); } - asmStoreInstr.addOperand(asmAddr); - asmStoreInstr.addOperand(asmArgList); + ptxStoreInstr(asmAddr, asmArgList); llvm::SmallVector argTys({mask.getType(), ptr.getType()}); for (int i = 0; i < nWords; i++) @@ -753,9 +754,9 @@ struct StoreOpConversion auto ASMReturnTy = LLVM::LLVMStructType::getLiteral(ctx, /*returnTy*/ {}); auto inlineAsm = rewriter.create( - loc, ASMReturnTy, asmStoreInstr.getAllMlirArgs(), // operands - asmStoreInstr.dump(), // asm_string - asmStoreInstr.getConstrains(), // constraints + loc, ASMReturnTy, ptxBuilder.getAllMLIRArgs(), // operands + ptxBuilder.dump(), // asm_string + ptxBuilder.getConstrains(), // constraints // TODO(Superjomn) determine the side effect. true, // has_side_effects false, // is_align_stack @@ -1008,43 +1009,54 @@ struct LoadOpConversion // --- // create inline asm string // --- - // TODO: (Superjomn) refactor with AsmInstr abstraction - std::ostringstream asmOss; - asmOss << "@$" << n_words; // predicate - asmOss << " ld"; - if (op.isVolatile()) { - asmOss << ".volatile"; + + const std::string writeConstrait = + (width == 64) ? "=l" : ((width == 32) ? "=r" : "=c"); + const std::string readConstrait = + (width == 64) ? "l" : ((width == 32) ? "r" : "c"); + + PTXBuilder ptxBuilder; + PtxIOInstr &ld = *ptxBuilder.create("ld"); + + // Define the instruction opcode + ld.predicate(pred, "b") + .o("violatile", op.isVolatile()) + .global() + .o("ca", op.cache() == triton::CacheModifier::CA) + .o("cg", op.cache() == triton::CacheModifier::CG) + .o("L1::evict_first", + op.evict() == triton::EvictionPolicy::EVICT_FIRST) + .o("L1::evict_last", op.evict() == triton::EvictionPolicy::EVICT_LAST) + .o("L1::cache_hint", has_l2_evict_policy) + .v(n_words) + .b(width); + + // prepare asm operands + auto *dstsOpr = ptxBuilder.newListOperand(); + for (int i = 0; i < n_words; i++) { + auto *opr = ptxBuilder.newOperand(writeConstrait); // =r operations + dstsOpr->listAppend(opr); } - asmOss << ".global"; - if (op.cache() == triton::CacheModifier::CA) - asmOss << ".ca"; - if (op.cache() == triton::CacheModifier::CG) - asmOss << ".cg"; - if (op.evict() == triton::EvictionPolicy::EVICT_FIRST) - asmOss << ".L1::evict_first"; - if (op.evict() == triton::EvictionPolicy::EVICT_LAST) - asmOss << ".L1::evict_last"; - if (has_l2_evict_policy) - asmOss << ".L2::cache_hint"; - if (n_words > 1) - asmOss << ".v" << n_words; // vector width - asmOss << ".b" << width; // word size - asmOss << " {"; - for (int i = 0; i < n_words; i++) { // return values - if (i > 0) - asmOss << ","; - asmOss << "$" << i; - } - asmOss << "}"; - asmOss << ", [ $" << n_words + 1; // load - asmOss << " + " << in_off << "]"; // constant offset - if (has_l2_evict_policy) - asmOss << ", $" << n_words + 2; - asmOss << ";"; + auto *addrOpr = ptxBuilder.newAddrOperand(ptr, "l", in_off); + + PTXBuilder::Operand *evictOpr{}; + // Here lack a mlir::Value to bind to this operation, so disabled. + // if (has_l2_evict_policy) + // evictOpr = ptxBuilder.newOperand(l2Evict, "l"); + + if (!evictOpr) + ld(dstsOpr, addrOpr); + else + ld(dstsOpr, addrOpr, evictOpr); + SmallVector others; - if (other != nullptr) { + if (other) { for (size_t ii = 0; ii < n_words; ii++) { + PTXInstr &mov = *ptxBuilder.create<>("mov"); + mov.predicateNot(pred, "b").o("u", width); + size_t size = width / nbits; + auto vecTy = LLVM::getFixedVectorType(elemTy, size); Value v = rewriter.create(loc, vecTy); for (size_t s = 0; s < size; s++) { @@ -1056,20 +1068,19 @@ struct LoadOpConversion } v = rewriter.create( loc, IntegerType::get(getContext(), width), v); - asmOss << "\n "; - asmOss << "@!$" << n_words << " mov.u" << width; - asmOss << " $" << ii << ", "; - std::ios_base::fmtflags flags(asmOss.flags()); - if (otherIsSplatConstInt) - asmOss << "0x" << std::hex << splatVal; - else { - asmOss << "$" << n_words + has_l2_evict_policy + 2 + ii; + + PTXInstr::Operand *opr{}; + if (otherIsSplatConstInt) { + opr = ptxBuilder.newConstantOperand(splatVal); + } else { + opr = ptxBuilder.newOperand(v, readConstrait); others.push_back(v); } - asmOss.flags(flags); - asmOss << ";"; + + mov(dstsOpr->listGet(ii), opr); } } + // --- // create inline ASM signature // --- @@ -1077,39 +1088,18 @@ struct LoadOpConversion Type retTy = retTys.size() > 1 ? LLVM::LLVMStructType::getLiteral(getContext(), retTys) : retTys[0]; - // --- - // create inline ASM constraints - // --- - std::string asmCstrt; - for (int ii = 0; ii < n_words; ii++) { - if (ii > 0) - asmCstrt += ","; - asmCstrt += (width == 64) ? "=l" : ((width == 32) ? "=r" : "=c"); - } - asmCstrt += ",b,l"; - for (size_t ii = 0; ii < others.size(); ii++) { - asmCstrt += ","; - asmCstrt += (width == 64) ? "l" : ((width == 32) ? "r" : "c"); - } - if (has_l2_evict_policy) { - asmCstrt += ",l"; - } - // --- - // finally call inline ASM - // --- - SmallVector args = {pred, ptr}; - for (Value v : others) { - args.push_back(v); - } + // TODO: if (has_l2_evict_policy) auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(), LLVM::AsmDialect::AD_ATT); auto inlineAsmOp = rewriter.create( - loc, retTy, /*operands=*/args, /*asm_string=*/asmOss.str(), - /*constraints=*/asmCstrt, /*has_side_effects=*/true, + loc, retTy, /*operands=*/ptxBuilder.getAllMLIRArgs(), + /*asm_string=*/ptxBuilder.dump(), + /*constraints=*/ptxBuilder.getConstrains(), /*has_side_effects=*/true, /*is_align_stack=*/false, /*asm_dialect=*/asmDialectAttr, /*operand_attrs=*/ArrayAttr()); Value ret = inlineAsmOp.getResult(0); + // --- // extract and store return values // --- @@ -1135,7 +1125,7 @@ struct LoadOpConversion loc, elemTy, rets[ii / tmp], vecIdx); loadedVals.push_back(loaded); } - } + } // end vec Type llvmResultStructTy = getTypeConverter()->convertType(resultTy); Value resultStruct = getStructFromElements(loc, loadedVals, rewriter, llvmResultStructTy); @@ -1314,7 +1304,8 @@ public: auto axisAnalysis = runAxisAnalysis(mod); // We set a higher benefit here to ensure triton's patterns runs before - // arith patterns for some encoding not supported by the community patterns. + // arith patterns for some encoding not supported by the community + // patterns. populateTritonToLLVMPatterns(typeConverter, patterns, numWarps, *axisAnalysis, 10 /*benefit*/);