diff --git a/include/triton/Conversion/MLIRTypes.h b/include/triton/Conversion/MLIRTypes.h new file mode 100644 index 000000000..5f33ced4b --- /dev/null +++ b/include/triton/Conversion/MLIRTypes.h @@ -0,0 +1,39 @@ +#ifndef TRITON_CONVERSION_MLIR_TYPES_H_ +#define TRITON_CONVERSION_MLIR_TYPES_H_ + +#include "mlir/Transforms/DialectConversion.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +// This file redefines some common MLIR types for easy usage. +namespace mlir { +namespace triton { +namespace type { + +// Integer types +Type i32Ty(MLIRContext *ctx) { + return IntegerType::get(ctx, 32, IntegerType::Signed); +} +Type i8Ty(MLIRContext *ctx) { + return IntegerType::get(ctx, 8, IntegerType::Signed); +} +Type u32Ty(MLIRContext *ctx) { + return IntegerType::get(ctx, 32, IntegerType::Signless); +} +Type u1Ty(MLIRContext *ctx) { return IntegerType::get(ctx, 1); } + +// Float types +Type f16Ty(MLIRContext *ctx) { return FloatType::getF16(ctx); } +Type f32Ty(MLIRContext *ctx) { return FloatType::getF32(ctx); } +Type f64Ty(MLIRContext *ctx) { return FloatType::getF64(ctx); } + +static bool isFloat(Type type) { + return type.isF32() || type.isF64() || type.isF16() || type.isF128(); +} + +static bool isInt(Type type) { return type.isIntOrFloat() && !isFloat(type); } + +} // namespace type +} // namespace triton +} // namespace mlir + +#endif // TRITON_CONVERSION_MLIR_TYPES_H_ diff --git a/include/triton/Conversion/TritonGPUToLLVM/PtxAsmFormat.h b/include/triton/Conversion/TritonGPUToLLVM/PtxAsmFormat.h new file mode 100644 index 000000000..a5eaff617 --- /dev/null +++ b/include/triton/Conversion/TritonGPUToLLVM/PtxAsmFormat.h @@ -0,0 +1,191 @@ +#ifndef TRITON_CONVERSION_TRITON_GPU_TO_LLVM_ASM_FORMAT_H_ +#define TRITON_CONVERSION_TRITON_GPU_TO_LLVM_ASM_FORMAT_H_ + +#include "mlir/IR/Value.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Format.h" +#include "llvm/Support/FormatVariadic.h" +#include +#include + +namespace mlir { +namespace triton { +using llvm::StringRef; + +// TODO(Superjomn) Move to a global utility file? +std::string strJoin(llvm::ArrayRef strs, + llvm::StringRef delimiter); + +// A helper for building a single inline ASM instruction, the objective of +// PtxInstr is to give a thin encapsulation and make the ASM code for MLIR LLVM +// Dialect more clear. Currently, several factors are introduced to reduce the +// need for mixing string and C++ if-else code. +// Usage: +// To build: asm("add.s32 %0, %1, %2;" : "=r"(i) : "r"(j), "r"(k)); +// +// PtxInstr mulr("mul"); +// mulr.o("lo").o("u32").addOperand(valueI, "=r") // %0 bind to valueI +// .addOperand(valueJ, "r") // %1 bind to valueJ +// .addOperand(valueK, "k"); // %2 bind to valueK +// +// mulr.getConstrains() // get "=r,r,k" +// mulr.getAllMlirArgs() // get {valueI, valueJ, valueK} +// +// TODO(Superjomn) Add multi-line ASM code support and register support later. +struct PtxInstr { + explicit PtxInstr(const std::string &name) { o(name); } + + struct Operand { + std::string constraint; + Value value; + int idx{-1}; + llvm::SmallVector list; + std::function repr; + + // for list + Operand() = default; + Operand(Value value, StringRef constraint) + : value(value), constraint(constraint) {} + + bool isList() const { return !value; } + + Operand *listAppend(Operand *arg) { + list.push_back(arg); + return this; + } + + std::string dump() const; + }; + + // Create a new operand. It will not add to operand list. + // @value: the MLIR value bind to this operand. + // @constraint: ASM operand constraint, .e.g. "=r" + // @formater: extra format to represent this operand in ASM code, default is + // "%{0}".format(operand.idx). + Operand *newOperand(mlir::Value value, StringRef constraint, + std::function formater = nullptr); + + // Append the operand to the intruction's operand list. + Operand *addOperand(Operand *opr) { + assert(std::find(argsInOrder.begin(), argsInOrder.end(), opr) == + argsInOrder.end()); + argsInOrder.push_back(opr); + return opr; + } + + // Create and add an operand to the intruction's operand list. + Operand *addOperand(mlir::Value value, StringRef constraint) { + auto *opr = newOperand(value, constraint); + return addOperand(opr); + } + + // Prefix a predicate to the instruction. + PtxInstr &predicate(mlir::Value value, StringRef constraint) { + pred = newOperand(value, constraint); + return *this; + } + + // Append a suffix to the instruction. + // e.g. PtxInstr("add").o("s32") get a add.s32. + // A predicate is used to tell whether to apply the suffix, so that no if-else + // code needed. e.g. `PtxInstr("add").o("s32", isS32).o("u32", !isS32);` will + // get a `add.s32` if isS32 is true. + PtxInstr &o(const std::string &suffix, bool predicate = true) { + if (predicate) + instrParts.push_back(suffix); + return *this; + } + + PtxInstr &addListOperation(llvm::ArrayRef list) { + auto *opr = newList(); + for (auto *v : list) + opr->listAppend(v); + addOperand(opr); + return *this; + } + + // Create a list of operands. + Operand *newList() { + argArchive.emplace_back(std::make_unique()); + return argArchive.back().get(); + } + + std::string dump() const; + + llvm::SmallVector getArgList() const; + llvm::SmallVector getAllArgs() const { + llvm::SmallVector res; + for (auto &x : argArchive) + if (!x->isList()) + res.push_back(x.get()); + return res; + } + + std::string getConstrains() const { + auto args = getAllArgs(); + llvm::SmallVector argReprs; + for (auto arg : args) + argReprs.push_back(arg->constraint); + return strJoin(argReprs, ","); + } + + llvm::SmallVector getAllMlirArgs() const { + llvm::SmallVector res; + for (auto &arg : argArchive) { + if (!arg->isList()) + res.push_back(arg->value); + } + return res; + } + +protected: + Operand *pred{}; + int oprCounter{}; + llvm::SmallVector instrParts; + llvm::SmallVector, 6> argArchive; + llvm::SmallVector argsInOrder; + std::string argStr; +}; + +// A helper for PTX ld/st instruction. +// Usage: +// PtxIOInstr store("st"); +// store.predicate(pValue).global().v(32).b(1); // @%0 st.global.v32.b1 +// store.addAddr(addrValue, "l", off); +struct PtxIOInstr : public PtxInstr { + PtxIOInstr(const std::string &name) : PtxInstr(name) {} + + // Add ".global" suffix to instruction + PtxIOInstr &global(bool predicate = true) { + o("global", predicate); + return *this; + } + + // Add ".v" suffix to instruction + PtxIOInstr &v(int vecWidth, bool predicate = true) { + if (vecWidth > 1) { + o(llvm::formatv("v{0}", vecWidth), predicate); + } + return *this; + } + + // Add ".b" suffix to instruction + PtxIOInstr &b(int width) { + o(llvm::formatv("b{0}", width)); + return *this; + } + + PtxIOInstr &addAddr(mlir::Value addr, StringRef constraint, int off = 0) { + auto *operand = newAddrOperand(addr, constraint, off); + addOperand(operand); + return *this; + } + + Operand *newAddrOperand(mlir::Value addr, StringRef constraint, int off = 0); +}; + +} // namespace triton +} // namespace mlir + +#endif // TRITON_CONVERSION_TRITON_GPU_TO_LLVM_ASM_FORMAT_H_ diff --git a/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt b/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt index ec464b99c..ad971c9e6 100644 --- a/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt +++ b/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_conversion_library(TritonGPUToLLVM TritonGPUToLLVM.cpp + PtxAsmFormat.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include/triton/Conversion/TritonGPUToLLVM diff --git a/lib/Conversion/TritonGPUToLLVM/PtxAsmFormat.cpp b/lib/Conversion/TritonGPUToLLVM/PtxAsmFormat.cpp new file mode 100644 index 000000000..579b91c39 --- /dev/null +++ b/lib/Conversion/TritonGPUToLLVM/PtxAsmFormat.cpp @@ -0,0 +1,81 @@ +#include "triton/Conversion/TritonGPUToLLVM/PtxAsmFormat.h" +#include "llvm/Support/raw_ostream.h" + +namespace mlir { +namespace triton { + +std::string strJoin(llvm::ArrayRef strs, + llvm::StringRef delimiter) { + std::string osStr; + llvm::raw_string_ostream os(osStr); + for (size_t i = 0; !strs.empty() && i < strs.size() - 1; i++) + os << strs[i] << delimiter; + if (!strs.empty()) + os << strs.back(); + os.flush(); + return osStr; +} + +std::string PtxInstr::dump() const { + std::string osStr; + llvm::raw_string_ostream os(osStr); + if (pred) + os << "@" << pred->dump() << " "; + + std::string instrRepr = strJoin(instrParts, "."); + + llvm::SmallVector argReprs; + for (auto *arg : argsInOrder) { + argReprs.push_back(arg->dump()); + } + + std::string argsRepr = strJoin(argReprs, ", "); + + os << instrRepr << " " << argsRepr << ";"; + os.flush(); + return osStr; +} + +llvm::SmallVector PtxInstr::getArgList() const { + SmallVector args; + for (auto *arg : argsInOrder) { + if (arg->isList()) + args.insert(args.end(), arg->list.begin(), arg->list.end()); + else + args.push_back(arg); + } + return args; +} + +PtxInstr::Operand * +PtxInstr::newOperand(mlir::Value value, StringRef constraint, + std::function formater) { + argArchive.emplace_back(std::make_unique(value, constraint)); + auto *opr = argArchive.back().get(); + opr->repr = formater; + opr->idx = oprCounter++; + return opr; +} + +std::string PtxInstr::Operand::dump() const { + if (repr) + return repr(idx); + if (!isList()) + return llvm::formatv("%{0}", idx); + llvm::SmallVector oprs; + for (auto *opr : list) + oprs.push_back(opr->dump()); + return "{ " + strJoin(oprs, ", ") + " }"; +} + +PtxInstr::Operand *PtxIOInstr::newAddrOperand(mlir::Value addr, + StringRef constraint, int off) { + auto *opr = newOperand(addr, constraint); + opr->repr = [off](int idx) -> std::string { + return llvm::formatv("[ %{0} + {1} ]", idx, off); + }; + + return opr; +} +} // namespace triton +} // namespace mlir diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index 9de3f8a1e..74dcf8220 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -1,5 +1,4 @@ #include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h" - #include "../PassDetail.h" #include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h" #include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" @@ -10,10 +9,17 @@ #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/IR/Matchers.h" #include "mlir/Transforms/DialectConversion.h" +#include "triton/Analysis/AxisInfo.h" +#include "triton/Conversion/MLIRTypes.h" +#include "triton/Conversion/TritonGPUToLLVM/PtxAsmFormat.h" #include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "llvm/Support/Format.h" +#include "llvm/Support/FormatVariadic.h" +#include #include +#include using namespace mlir; using namespace mlir::triton; @@ -31,8 +37,15 @@ static StringRef getStructAttrsAttrName() { return "llvm.struct_attrs"; } namespace { +namespace type = mlir::triton::type; + class TritonGPUToLLVMTypeConverter; +// TODO(Superjomn) Move to somewhere general utilities locates. +template size_t product(llvm::ArrayRef arr) { + return std::accumulate(arr.begin(), arr.end(), 1, std::multiplies{}); +} + // The following code are borrowed from mlir project including the following // functions or classes: // - filterFuncAttributes @@ -81,7 +94,6 @@ protected: TypeConverter::SignatureConversion result(funcOp.getNumArguments()); auto llvmType = getTypeConverter()->convertFunctionSignature( funcOp.getType(), varargsAttr && varargsAttr.getValue(), result); - assert(llvmType); if (!llvmType) return nullptr; @@ -124,6 +136,8 @@ protected: } linkage = attr.getLinkage(); } + + auto oldArgs = funcOp.getArguments(); auto newFuncOp = rewriter.create( funcOp.getLoc(), funcOp.getName(), llvmType, linkage, /*dsoLocal*/ false, attributes); @@ -134,6 +148,25 @@ protected: &result))) return nullptr; + // Convert argument + llvm::DenseMap argMap; + for (int i = 0, n = funcOp.getNumArguments(); i < n; i++) { + Value oldArg = oldArgs[i]; + Value newArg = newFuncOp.getArgument(i); + argMap.try_emplace(oldArg, newArg); + } + + newFuncOp.getBody().walk([&](Operation *op) { + // Convert the function argument types, e.g, from !tt.ptr to + // ptr + for (int i = 0; i < op->getNumOperands(); i++) { + auto arg = op->getOperand(i); + auto it = argMap.find(arg); + if (it != argMap.end()) + op->setOperand(i, it->second); + } + }); + return newFuncOp; } }; @@ -143,8 +176,9 @@ protected: /// information. static constexpr StringRef kEmitIfaceAttrName = "llvm.emit_c_interface"; struct FuncOpConversion : public FuncOpConversionBase { - FuncOpConversion(LLVMTypeConverter &converter, int numWarps) - : FuncOpConversionBase(converter), NumWarps(numWarps) {} + FuncOpConversion(LLVMTypeConverter &converter, int numWarps, + PatternBenefit benefit) + : FuncOpConversionBase(converter, benefit), NumWarps(numWarps) {} LogicalResult matchAndRewrite(FuncOp funcOp, OpAdaptor adaptor, @@ -154,11 +188,11 @@ struct FuncOpConversion : public FuncOpConversionBase { return failure(); auto ctx = funcOp->getContext(); - auto i32 = IntegerType::get(ctx, 32); // Set an attribute for maxntidx, it could be used in latter LLVM codegen // for `nvvm.annotation` metadata. - newFuncOp->setAttr(NVVMMetadataField::MaxNTid, - rewriter.getIntegerAttr(i32, 32 * NumWarps)); + newFuncOp->setAttr( + NVVMMetadataField::MaxNTid, + rewriter.getIntegerAttr(type::i32Ty(ctx), 32 * NumWarps)); rewriter.eraseOp(funcOp); return success(); @@ -190,22 +224,47 @@ struct ReturnOpConversion : public ConvertOpToLLVMPattern<::mlir::ReturnOp> { } }; -// Extract numWarps information from TritonGPU module, return 0 if failed. -// This is a naive implementation, it assumes that all the blocked layout should -// have the same numWarps setting in a module, it just find a blocked layout -// encoding and return the warpsPerCTA field. -int extractNumWarps(mlir::ModuleOp module) { - int numWarps{}; - if (module->hasAttr(AttrNumWarpsName)) - numWarps = module->getAttr(AttrNumWarpsName) - .dyn_cast() - .getValue() - .getZExtValue(); - else - llvm::report_fatal_error( - "TritonGPU module should contain a triton_gpu.num-warps attribute"); +static int64_t getLinearIndex(std::vector multidim_index, + ArrayRef shape) { + assert(multidim_index.size() == shape.size()); + // sizes {a, b, c, d} -> acc_mul {b*c*d, c*d, d, 1} + int64_t rank = shape.size(); + int64_t acc_mul = 1; + for (int64_t i = 1; i < rank; ++i) { + acc_mul *= shape[i]; + } + int64_t linear_index = 0; + for (int64_t i = 0; i < rank; ++i) { + linear_index += multidim_index[i] * acc_mul; + if (i != (rank - 1)) { + acc_mul = acc_mul / shape[i + 1]; + } + } + return linear_index; +} - return numWarps; +static unsigned getElemsPerThread(TritonGPUBlockedEncodingAttr layout, + ArrayRef shape) { + return product(shape) / (product(layout.getThreadsPerWarp()) * + product(layout.getWarpsPerCTA())); +} + +static Value createIndexAttrConstant(OpBuilder &builder, Location loc, + Type resultType, int64_t value) { + return builder.create( + loc, resultType, builder.getIntegerAttr(builder.getIndexType(), value)); +} + +Value getStructFromElements(Location loc, ValueRange resultVals, + ConversionPatternRewriter &rewriter, + Type structType, Type elemPtrPtrType) { + Value llvmStruct = rewriter.create(loc, structType); + for (auto v : llvm::enumerate(resultVals)) { + llvmStruct = rewriter.create( + loc, structType, llvmStruct, v.value(), + rewriter.getI64ArrayAttr(v.index())); + } + return llvmStruct; } template @@ -247,24 +306,6 @@ static T getLinearIndex(ArrayRef multidim_index, ArrayRef shape) { return linear_index; } -static unsigned getElemsPerThread(const TritonGPUBlockedEncodingAttr &layout, - ArrayRef shape) { - unsigned elems = 1; - size_t rank = shape.size(); - assert(rank == layout.getThreadsPerWarp().size()); - for (size_t d = 0; d < rank; ++d) { - elems *= - shape[d] / (layout.getThreadsPerWarp()[d] * layout.getWarpsPerCTA()[d]); - } - return elems; -} - -static Value createIndexAttrConstant(OpBuilder &builder, Location loc, - Type resultType, int64_t value) { - return builder.create( - loc, resultType, builder.getIntegerAttr(resultType, value)); -} - template class ConvertTritonGPUOpToLLVMPattern : public ConvertOpToLLVMPattern { @@ -443,6 +484,358 @@ public: } }; +// Convert SplatOp or arith::ConstantOp with SplatElementsAttr to a +// LLVM::StructType value. +// +// @elemType: the element type in operand. +// @resType: the return type of the Splat-like op. +// @constVal: a LLVM::ConstantOp or other scalar value. +Value convertSplatLikeOp(Type elemType, Type resType, Value constVal, + TypeConverter *typeConverter, + ConversionPatternRewriter &rewriter, Location loc) { + + auto tensorTy = resType.cast(); + auto layout = tensorTy.getEncoding().cast(); + auto srcType = typeConverter->convertType(elemType); + auto llSrc = rewriter.create(loc, srcType, constVal); + + auto numElems = layout.getSizePerThread(); + size_t totalElems = + std::accumulate(tensorTy.getShape().begin(), tensorTy.getShape().end(), 1, + std::multiplies<>{}); + size_t numThreads = + product(layout.getWarpsPerCTA()) * product(layout.getThreadsPerWarp()); + // TODO(Superjomn) add numElemsPerThread to the layout encodings. + size_t numElemsPerThread = totalElems / numThreads; + + llvm::SmallVector elems(numElemsPerThread, llSrc); + llvm::SmallVector elemTypes(elems.size(), srcType); + auto structTy = + LLVM::LLVMStructType::getLiteral(rewriter.getContext(), elemTypes); + + auto llElemPtrPtrTy = + LLVM::LLVMPointerType::get(LLVM::LLVMPointerType::get(srcType)); + auto llStruct = + getStructFromElements(loc, elems, rewriter, structTy, llElemPtrPtrTy); + return llStruct; +} + +struct SplatOpConversion + : public ConvertTritonGPUOpToLLVMPattern { + using ConvertTritonGPUOpToLLVMPattern< + triton::SplatOp>::ConvertTritonGPUOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::SplatOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + auto src = op->getOperand(0); + + LLVM::ConstantOp arithConstantOp; + if (src.getDefiningOp() && + (arithConstantOp = + llvm::dyn_cast(src.getDefiningOp()))) { + Value constant; + auto values = arithConstantOp.getValue().dyn_cast(); + + assert(values.size() == 1); + Attribute val; + if (type::isInt(src.getType())) { + val = values.getValues()[0]; + } else if (type::isFloat(src.getType())) { + val = values.getValues()[0]; + } else { + llvm::errs() << "Constant op type not supported"; + return failure(); + } + + src = rewriter.create(loc, val.getType(), val); + } + + auto llStruct = convertSplatLikeOp(src.getType(), op.getType(), src, + getTypeConverter(), rewriter, loc); + rewriter.replaceOp(op, {llStruct}); + return success(); + } +}; + +// This pattern helps to convert arith::ConstantOp(with SplatElementsAttr), the +// logic is the same as triton::SplatOp, so the underlying implementation is +// reused. +struct ArithConstantSplatOpConversion + : public ConvertTritonGPUOpToLLVMPattern { + using ConvertTritonGPUOpToLLVMPattern< + arith::ConstantOp>::ConvertTritonGPUOpToLLVMPattern; + + LogicalResult + matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto value = op.getValue(); + if (!value.dyn_cast()) + return failure(); + + auto loc = op->getLoc(); + + LLVM::ConstantOp arithConstantOp; + auto values = op.getValue().dyn_cast(); + auto elemType = values.getElementType(); + + Attribute val; + if (type::isInt(elemType)) { + val = values.getValues()[0]; + } else if (type::isFloat(elemType)) { + val = values.getValues()[0]; + } else { + llvm::errs() << "ArithConstantSplatOpConversion get unsupported type: " + << value.getType() << "\n"; + return failure(); + } + + auto constOp = rewriter.create(loc, elemType, val); + auto llStruct = convertSplatLikeOp(elemType, op.getType(), constOp, + getTypeConverter(), rewriter, loc); + rewriter.replaceOp(op, llStruct); + + return success(); + } +}; + +struct StoreOpConversion + : public ConvertTritonGPUOpToLLVMPattern { + using ConvertTritonGPUOpToLLVMPattern< + triton::StoreOp>::ConvertTritonGPUOpToLLVMPattern; + + StoreOpConversion(LLVMTypeConverter &converter, + AxisInfoAnalysis &axisAnalysisPass, PatternBenefit benefit) + : ConvertTritonGPUOpToLLVMPattern(converter, benefit), + AxisAnalysisPass(axisAnalysisPass) {} + + LogicalResult + matchAndRewrite(triton::StoreOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + Value ptr = op.ptr(); + Value mask = op.mask(); + Value value = op.value(); + + Value llPtr = adaptor.ptr(); // should be LLVM ops + Value llMask = adaptor.mask(); + Value llValue = adaptor.value(); + + Type valueElemTy = getTypeConverter()->convertType( + value.getType().cast().getElementType()); + + MLIRContext *ctx = rewriter.getContext(); + auto loc = op->getLoc(); + + auto getLLVMElems = [&](Value value, Value llValue, + const TritonGPUBlockedEncodingAttr &layout) + -> SmallVector { + auto ty = value.getType().cast(); + auto shape = ty.getShape(); + // Here, we assume that all inputs should have a blockedLayout + + unsigned valueElems = getElemsPerThread(layout, shape); + + auto llvmElemTy = getTypeConverter()->convertType(ty.getElementType()); + auto llvmElemPtrPtrTy = + LLVM::LLVMPointerType::get(LLVM::LLVMPointerType::get(llvmElemTy)); + + auto valueVals = + getElementsFromStruct(loc, llValue, valueElems, rewriter); + return valueVals; + }; + + auto getLayout = + [&](Value val) -> std::tuple { + auto ty = val.getType().cast(); + auto shape = ty.getShape(); + // Here, we assume that all inputs should have a blockedLayout + auto layout = ty.getEncoding().dyn_cast(); + + unsigned valueElems = getElemsPerThread(layout, shape); + + return std::make_tuple(layout, valueElems); + }; + + auto [ptrLayout, ptrNumElems] = getLayout(ptr); + auto [maskLayout, maskNumElems] = getLayout(mask); + auto [valueLayout, valueNumElems] = getLayout(value); + + auto valueElems = getLLVMElems(value, llValue, valueLayout); + auto maskElems = getLLVMElems(mask, llMask, maskLayout); + assert(valueElems.size() == maskElems.size()); + + auto getAlign = + [this](Value val, + const TritonGPUBlockedEncodingAttr &layout) -> unsigned { + auto axisInfo = getAxisInfo(val); + assert(axisInfo.hasValue()); + + auto order = layout.getOrder(); + + unsigned maxMultiple = axisInfo->getDivisibility(order[0]); + unsigned maxContig = axisInfo->getContiguity(order[0]); + unsigned alignment = std::min(maxMultiple, maxContig); + return alignment; + }; + + // get align + auto getVec = [this, &getAlign]( + Value val, + const TritonGPUBlockedEncodingAttr &layout) -> unsigned { + auto axisInfo = getAxisInfo(val); + auto contig = axisInfo->getContiguity(); + // Here order should be ordered by contiguous first, so the first element + // should have the largest contiguous. + auto order = layout.getOrder(); + unsigned align = getAlign(val, layout); + + assert(!order.empty()); + // Is this right? + unsigned contigPerThread = layout.getSizePerThread()[order[0]]; + unsigned vec = std::min(align, contigPerThread); + + // TODO(Superjomn) Consider the is_mma_first_row in the legacy code + bool isMMAFirstRow = false; + + if (isMMAFirstRow) + vec = std::min(2, align); + + return vec; + }; + + // Determine the vectorization size + size_t vec = getVec(ptr, ptrLayout); + + const size_t dtsize = value.getType() + .cast() + .getElementType() + .getIntOrFloatBitWidth() / + 8; + const size_t valueElemNbits = dtsize * 8; + + const int numVecs = ptrNumElems / vec; + for (size_t vecIdx = 0; vecIdx < ptrNumElems; vecIdx += vec) { + + size_t in_off{}; + auto ptrProducer = llPtr.getDefiningOp(); + auto in_gep = llvm::dyn_cast(ptrProducer); + + if (in_gep) { + auto indices = in_gep.getIndices(); + auto cst = dyn_cast(indices.front().getDefiningOp()); + in_off = + cst ? cst.getValue().dyn_cast().getInt() * dtsize : 0; + ptr = cst ? in_gep.getBase() : in_gep; + } + + // pack sub-words (< 32/64bits) into words + // each load has width min(nbits*vec, 32/64) + // and there are (nbits * vec)/width of them + const int maxWordWidth = std::max(32, valueElemNbits); + const int totalWidth = valueElemNbits * vec; + const int width = std::min(totalWidth, maxWordWidth); + const int nWords = std::max(1, totalWidth / width); + const int wordNElems = width / valueElemNbits; + const int vecNElems = totalWidth / valueElemNbits; + + assert(wordNElems * nWords * numVecs == valueElems.size()); + + // TODO(Superjomn) Add cache policy to store. + // TODO(Superjomn) deal with cache policy. + const bool hasL2EvictPolicy = false; + + PtxIOInstr asmStoreInstr("st"); + asmStoreInstr.predicate(llMask, "b"); + asmStoreInstr.global().v(width).b(nWords); + + llvm::SmallVector asmArgs; + + Type valArgTy = IntegerType::get(ctx, width); + auto wordTy = VectorType::get(wordNElems, valueElemTy); + + auto *asmAddr = asmStoreInstr.newAddrOperand(llPtr, "l", in_off); + auto *asmArgList = asmStoreInstr.newList(); + for (int wordIdx = 0; wordIdx < nWords; wordIdx++) { + // llWord is a width-len composition + Value llWord = rewriter.create(loc, wordTy); + // Insert each value element to the composition + for (int elemIdx = 0; elemIdx < wordNElems; elemIdx++) { + Value elem = + valueElems[vecIdx * vecNElems + wordIdx * wordNElems + elemIdx]; + if (elem.getType().isInteger(1)) + elem = rewriter.create(loc, type::i8Ty(ctx), elem); + elem = rewriter.create(loc, valueElemTy, elem); + + llWord = rewriter.create( + loc, wordTy, llWord, elem, + rewriter.create( + loc, type::u32Ty(ctx), + IntegerAttr::get(type::u32Ty(ctx), elemIdx))); + } + llWord = rewriter.create(loc, valArgTy, llWord); + std::string constraint = + (width == 64) ? "l" : ((width == 32) ? "r" : "c"); + asmArgList->listAppend(asmStoreInstr.newOperand(llWord, constraint)); + } + + asmStoreInstr.addOperand(asmAddr); + asmStoreInstr.addOperand(asmArgList); + + llvm::SmallVector argTys({mask.getType(), ptr.getType()}); + for (int i = 0; i < nWords; i++) + argTys.push_back(valArgTy); + + auto ASMReturnTy = LLVM::LLVMStructType::getLiteral(ctx, /*returnTy*/ {}); + + auto inlineAsm = rewriter.create( + loc, ASMReturnTy, asmStoreInstr.getAllMlirArgs(), // operands + asmStoreInstr.dump(), // asm_string + asmStoreInstr.getConstrains(), // constraints + // TODO(Superjomn) determine the side effect. + true, // has_side_effects + false, // is_align_stack + LLVM::AsmDialectAttr::get(ctx, + LLVM::AsmDialect::AD_ATT), // asm_dialect + ArrayAttr::get(ctx, {}) // operand_attrs + ); + + rewriter.replaceOp(op, inlineAsm.getRes()); + } + return success(); + } + + llvm::Optional getAxisInfo(Value val) const { + if (auto it = AxisAnalysisPass.lookupLatticeElement(val)) { + return it->getValue(); + } + + return llvm::Optional{}; + } + +private: + AxisInfoAnalysis &AxisAnalysisPass; +}; + +// Extract numWarps information from TritonGPU module, return 0 if failed. +// This is a naive implementation, it assumes that all the blocked layout should +// have the same numWarps setting in a module, it just find a blocked layout +// encoding and return the warpsPerCTA field. +int extractNumWarps(mlir::ModuleOp module) { + int numWarps{}; + if (module->hasAttr(AttrNumWarpsName)) + numWarps = module->getAttr(AttrNumWarpsName) + .dyn_cast() + .getValue() + .getZExtValue(); + else + llvm::report_fatal_error( + "TritonGPU module should contain a triton_gpu.num-warps attribute"); + + return numWarps; +} + struct BroadcastOpConversion : public ConvertTritonGPUOpToLLVMPattern { using ConvertTritonGPUOpToLLVMPattern< @@ -647,8 +1040,7 @@ struct LoadOpConversion for (size_t i = 0; i < numElems; i += vecWidth) { Value ptr = ptrVals[i]; // TODO: Handle the optimization if ptr is from GEP and the idx is - // constant - // This should be a canonicalization pattern in LLVM Dialect + // constant. This should be a canonicalization pattern in LLVM Dialect unsigned in_off = 0; Value pred = maskVals[i]; @@ -826,13 +1218,18 @@ public: }; void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter, - RewritePatternSet &patterns, int numWarps) { - patterns.add(typeConverter); - patterns.add(typeConverter, numWarps); - patterns.add(typeConverter); - patterns.add(typeConverter); - patterns.add(typeConverter); - patterns.add(typeConverter); + RewritePatternSet &patterns, int numWarps, + AxisInfoAnalysis &analysis, + PatternBenefit benefit = 1) { + patterns.add(typeConverter, numWarps, benefit); + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, analysis, benefit); + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); } class ConvertTritonGPUToLLVM @@ -851,20 +1248,33 @@ public: TritonLLVMConversionTarget target(*context, typeConverter); RewritePatternSet patterns(context); - // TODO: (goostavz) Temporarily disable this, since the lowering of - // arithmetic ops in tensor format is not complete yet. - // Add arith's patterns to help convert scalar expression to LLVM. - // mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter, - // patterns); int numWarps = extractNumWarps(mod); - populateTritonToLLVMPatterns(typeConverter, patterns, numWarps); + auto axisAnalysis = runAxisAnalysis(mod); + + // We set a higher benefit here to ensure triton's patterns runs before + // arith patterns for some encoding not supported by the community patterns. + populateTritonToLLVMPatterns(typeConverter, patterns, numWarps, + *axisAnalysis, 10 /*benefit*/); + + // Add arith's patterns to help convert scalar expression to LLVM. + mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter, + patterns); + mlir::populateGpuToNVVMConversionPatterns(typeConverter, patterns); if (failed(applyPartialConversion(mod, target, std::move(patterns)))) return signalPassFailure(); } + +protected: + std::unique_ptr runAxisAnalysis(ModuleOp module) { + auto axisAnalysisPass = + std::make_unique(module->getContext()); + axisAnalysisPass->run(module); + return axisAnalysisPass; + } }; } // namespace diff --git a/lib/Dialect/Triton/IR/Ops.cpp b/lib/Dialect/Triton/IR/Ops.cpp index 77d74aa8a..132d00c44 100644 --- a/lib/Dialect/Triton/IR/Ops.cpp +++ b/lib/Dialect/Triton/IR/Ops.cpp @@ -102,6 +102,7 @@ OpFoldResult SplatOp::fold(ArrayRef operands) { auto constOperand = src().getDefiningOp(); if (!constOperand) return {}; + auto shapedType = getType().cast(); auto ret = SplatElementsAttr::get(shapedType, {constOperand.getValue()}); return ret; diff --git a/test/Conversion/triton_to_llvm.mlir b/test/Conversion/triton_to_llvm.mlir new file mode 100644 index 000000000..0d9aea81d --- /dev/null +++ b/test/Conversion/triton_to_llvm.mlir @@ -0,0 +1,36 @@ +// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu=num-warps=2 -convert-triton-gpu-to-llvm | FileCheck %s + +func @test_splat(%ptr: !tt.ptr) { + // Here, 128 elements, 64(2*32) threads, so each need to process 2 elements + // + // CHECK: %0 = llvm.bitcast %arg0 : !llvm.ptr to !llvm.ptr + // CHECK: %1 = llvm.mlir.undef : !llvm.struct<(ptr, ptr)> + // CHECK: %2 = llvm.insertvalue %0, %1[0] : !llvm.struct<(ptr, ptr)> + // CHECK: %3 = llvm.insertvalue %0, %2[1] : !llvm.struct<(ptr, ptr)> + %ptrs = tt.splat %ptr : (!tt.ptr) -> tensor<128x!tt.ptr> + %a = arith.constant 1.0 : f32 + %true = arith.constant 1 : i1 + %b = tt.splat %a : (f32) -> tensor<128xf32> + + // Here, each thread process only 1 element + // CHECK: %{{.*}} = llvm.mlir.undef : !llvm.struct<(i1)> + %mask = tt.splat %true : (i1) -> tensor<64xi1> + + return +} + +func @test_store_splat(%ptr: !tt.ptr) { + %ptrs = tt.splat %ptr : (!tt.ptr) -> tensor<128x!tt.ptr> + %a = arith.constant 1.0 : f32 + %true = arith.constant 1 : i1 + + %vs = tt.splat %a : (f32) -> tensor<128xf32> + %mask = tt.splat %true : (i1) -> tensor<128xi1> + + // CHECK: %{{.*}} = llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "@%0 st.global.v32.b1 [ %1 + 0 ], { %2 };", + // CHECK: "b,l,r" %{{.*}}, %{{.*}}, %{{.*}} : (!llvm.struct<(i1, i1)>, !llvm.struct<(ptr, ptr)>, i32) -> !llvm.struct<()> + + tt.store %ptrs, %vs, %mask, {} : tensor<128xf32> + + return +} diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index 51a0e4a11..4628e326d 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -1,11 +1,10 @@ // RUN: triton-opt %s -split-input-file --convert-triton-gpu-to-llvm | FileCheck %s - module attributes {"triton_gpu.num-warps" = 4 : i32} { // CHECK: llvm.func @test_empty_kernel(%arg0: i32, %arg1: !llvm.ptr) // Here the 128 comes from the 4 in module attribute multiples 32 -// CHECK: attributes {nvvm.maxntid = 128 : i32} {{.*}} +// CHECK: attributes {nvvm.maxntid = 128 : si32} {{.*}} func @test_empty_kernel(%lb : index, %A : !tt.ptr) { // CHECK: llvm.return