#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h" #include "../PassDetail.h" #include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Transforms/DialectConversion.h" #include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include using namespace mlir; using namespace mlir::triton; namespace mlir { namespace LLVM { static StringRef getStructAttrsAttrName() { return "llvm.struct_attrs"; } } // namespace LLVM } // namespace mlir namespace { // The following code are borrowed from mlir project including the following // functions or classes: // - filterFuncAttributes // - ConvertOpToLLVMPattern // - FuncOpConversion // // The code are hidden in the CPP files in MLIR repo, and we can't call them // directly. I found such code snippets are refactored and added to LLVMCommon // in the latest MLIR code, but the v14.0.0 version currentlly used in Triton // doesn't contain the code. // TODO(Superjomn) Remove the code when mlir v15.0 is included. // // The original code: // https://github.com/llvm/llvm-project/blob/f28c006a5895fc0e329fe15fead81e37457cb1d1/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp#L219 // All the rights are reserved by LLVM community. /// Only retain those attributes that are not constructed by /// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out argument /// attributes. static void filterFuncAttributes(ArrayRef attrs, bool filterArgAttrs, SmallVectorImpl &result) { for (const auto &attr : attrs) { if (attr.getName() == SymbolTable::getSymbolAttrName() || attr.getName() == FunctionOpInterface::getTypeAttrName() || attr.getName() == "std.varargs" || (filterArgAttrs && attr.getName() == FunctionOpInterface::getArgDictAttrName())) continue; result.push_back(attr); } } struct FuncOpConversionBase : public ConvertOpToLLVMPattern { protected: using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; // Convert input FuncOp to LLVMFuncOp by using the LLVMTypeConverter provided // to this legalization pattern. LLVM::LLVMFuncOp convertFuncOpToLLVMFuncOp(FuncOp funcOp, ConversionPatternRewriter &rewriter) const { // Convert the original function arguments. They are converted using the // LLVMTypeConverter provided to this legalization pattern. auto varargsAttr = funcOp->getAttrOfType("std.varargs"); TypeConverter::SignatureConversion result(funcOp.getNumArguments()); auto llvmType = getTypeConverter()->convertFunctionSignature( funcOp.getType(), varargsAttr && varargsAttr.getValue(), result); assert(llvmType); if (!llvmType) return nullptr; // Propagate argument attributes to all converted arguments obtained after // converting a given original argument. SmallVector attributes; filterFuncAttributes(funcOp->getAttrs(), /*filterArgAttrs=*/true, attributes); if (ArrayAttr argAttrDicts = funcOp.getAllArgAttrs()) { SmallVector newArgAttrs( llvmType.cast().getNumParams()); for (unsigned i = 0, e = funcOp.getNumArguments(); i < e; ++i) { auto mapping = result.getInputMapping(i); assert(mapping.hasValue() && "unexpected deletion of function argument"); for (size_t j = 0; j < mapping->size; ++j) newArgAttrs[mapping->inputNo + j] = argAttrDicts[i]; } attributes.push_back( rewriter.getNamedAttr(FunctionOpInterface::getArgDictAttrName(), rewriter.getArrayAttr(newArgAttrs))); } for (const auto &pair : llvm::enumerate(attributes)) { if (pair.value().getName() == "llvm.linkage") { attributes.erase(attributes.begin() + pair.index()); break; } } // Create an LLVM function, use external linkage by default until MLIR // functions have linkage. LLVM::Linkage linkage = LLVM::Linkage::External; if (funcOp->hasAttr("llvm.linkage")) { auto attr = funcOp->getAttr("llvm.linkage").dyn_cast(); if (!attr) { funcOp->emitError() << "Contains llvm.linkage attribute not of type LLVM::LinkageAttr"; return nullptr; } linkage = attr.getLinkage(); } auto newFuncOp = rewriter.create( funcOp.getLoc(), funcOp.getName(), llvmType, linkage, /*dsoLocal*/ false, attributes); rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(), newFuncOp.end()); if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), *typeConverter, &result))) return nullptr; return newFuncOp; } }; /// FuncOp legalization pattern that converts MemRef arguments to pointers to /// MemRef descriptors (LLVM struct data types) containing all the MemRef type /// information. static constexpr StringRef kEmitIfaceAttrName = "llvm.emit_c_interface"; struct FuncOpConversion : public FuncOpConversionBase { FuncOpConversion(LLVMTypeConverter &converter, int numWarps) : FuncOpConversionBase(converter), NumWarps(numWarps) { mlir::ConvertToLLVMPattern::getTypeConverter()->addConversion( [&](triton::PointerType type) { return convertTritonPointerType(type); }); } Type convertTritonPointerType(triton::PointerType type) { return LLVM::LLVMPointerType::get(type.getPointeeType(), type.getAddressSpace()); } LogicalResult matchAndRewrite(FuncOp funcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto newFuncOp = convertFuncOpToLLVMFuncOp(funcOp, rewriter); if (!newFuncOp) return failure(); auto ctx = funcOp->getContext(); auto i32 = IntegerType::get(ctx, 32); // Set an attribute for maxntidx, it could be used in latter LLVM codegen // for `nvvm.annotation` metadata. newFuncOp->setAttr(NVVMMetadataField::MaxNTid, rewriter.getIntegerAttr(i32, 32 * NumWarps)); rewriter.eraseOp(funcOp); return success(); } private: int NumWarps{0}; }; struct ReturnOpConversion : public ConvertOpToLLVMPattern<::mlir::ReturnOp> { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(ReturnOp op, OpAdaptor adapter, ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); unsigned numArguments = op.getNumOperands(); // Currently, Triton kernel function always return nothing. // TODO(Superjomn) add support for non-inline device function if (numArguments > 0) { return rewriter.notifyMatchFailure( op, "Only kernel function with nothing returned is supported."); } rewriter.replaceOpWithNewOp(op, TypeRange(), ValueRange(), op->getAttrs()); return success(); } }; // Extract numWarps information from TritonGPU module, return 0 if failed. // This is a naive implementation, it assumes that all the blocked layout should // have the same numWarps setting in a module, it just find a blocked layout // encoding and return the warpsPerCTA field. int extractNumWarps(mlir::ModuleOp module) { int numWarps{}; if (module->hasAttr(AttrNumWarpsName)) numWarps = module->getAttr(AttrNumWarpsName) .dyn_cast() .getValue() .getZExtValue(); else llvm::report_fatal_error( "TritonGPU module should contain a triton_gpu.num-warps attribute"); return numWarps; } } // namespace void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, int numWarps) { patterns.add<::FuncOpConversion>(typeConverter, numWarps); patterns.add<::ReturnOpConversion>(typeConverter); } class ConvertTritonGPUToLLVM : public ConvertTritonGPUToLLVMBase { public: ConvertTritonGPUToLLVM() = default; void runOnOperation() override { MLIRContext *context = &getContext(); ModuleOp mod = getOperation(); LLVMTypeConverter typeConverter(context); TritonLLVMConversionTarget target(*context, typeConverter); RewritePatternSet patterns(context); // Add arith's patterns to help convert scalar expression to LLVM. mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter, patterns); int numWarps = extractNumWarps(mod); populateTritonToLLVMPatterns(typeConverter, patterns, numWarps); if (failed(applyPartialConversion(mod, target, std::move(patterns)))) return signalPassFailure(); } }; namespace mlir { TritonLLVMConversionTarget::TritonLLVMConversionTarget( MLIRContext &ctx, mlir::LLVMTypeConverter &typeConverter) : ConversionTarget(ctx), typeConverter(typeConverter) { addLegalDialect(); } namespace triton { std::unique_ptr> createConvertTritonGPUToLLVMPass() { return std::make_unique<::ConvertTritonGPUToLLVM>(); } } // namespace triton } // namespace mlir