diff --git a/bin/CMakeLists.txt b/bin/CMakeLists.txt index 92635eca3..6d5673a02 100644 --- a/bin/CMakeLists.txt +++ b/bin/CMakeLists.txt @@ -5,10 +5,10 @@ add_subdirectory(FileCheck) get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) -add_llvm_executable(triton-opt triton-opt.cpp) +add_llvm_executable(triton-opt triton-opt.cpp PARTIAL_SOURCES_INTENDED) # TODO: what's this? -# llvm_update_compile_flags(triton-opt) +llvm_update_compile_flags(triton-opt) target_link_libraries(triton-opt PRIVATE TritonAnalysis TritonTransforms @@ -23,4 +23,36 @@ target_link_libraries(triton-opt PRIVATE MLIRTransforms ) -mlir_check_all_link_libraries(triton-opt) \ No newline at end of file +mlir_check_all_link_libraries(triton-opt) + + +add_llvm_executable(triton-translate triton-translate.cpp PARTIAL_SOURCES_INTENDED) +llvm_update_compile_flags(triton-translate) +target_link_libraries(triton-translate PRIVATE + TritonAnalysis + TritonTransforms + TritonGPUTransforms + TritonLLVMIR + TritonDriver + ${dialect_libs} + ${conversion_libs} + # tests + TritonTestAnalysis + + LLVMCore + LLVMSupport + LLVMOption + LLVMCodeGen + LLVMAsmParser + + # MLIR core + MLIROptLib + MLIRIR + MLIRPass + MLIRSupport + MLIRTransforms + MLIRExecutionEngine + MLIRTransformUtils + MLIRLLVMToLLVMIRTranslation + ) +mlir_check_all_link_libraries(triton-translate) diff --git a/bin/triton-translate.cpp b/bin/triton-translate.cpp new file mode 100644 index 000000000..ad5ec5f65 --- /dev/null +++ b/bin/triton-translate.cpp @@ -0,0 +1,141 @@ +#include "mlir/ExecutionEngine/ExecutionEngine.h" +#include "mlir/ExecutionEngine/OptUtils.h" +#include "mlir/IR/AsmState.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Dialect.h" +#include "mlir/Parser.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/FileUtilities.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Export.h" +#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h" +#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Target/LLVMIR/LLVMIRTranslation.h" +#include "triton/driver/llvm.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/InitLLVM.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/ToolOutputFile.h" +#include + +namespace mlir { +namespace triton { + +OwningOpRef loadMLIRModule(llvm::StringRef inputFilename, + MLIRContext &context) { + std::string errorMessage; + auto input = openInputFile(inputFilename, &errorMessage); + if (!input) { + llvm::errs() << errorMessage << "\n"; + return nullptr; + } + + mlir::DialectRegistry registry; + registry + .insert(); + + context.appendDialectRegistry(registry); + + auto processBuffer = [&](std::unique_ptr ownedBuffer) + -> OwningOpRef { + llvm::SourceMgr sourceMgr; + sourceMgr.AddNewSourceBuffer(std::move(ownedBuffer), SMLoc()); + + context.loadAllAvailableDialects(); + context.allowUnregisteredDialects(); + + OwningOpRef module(parseSourceFile(sourceMgr, &context)); + if (!module) { + llvm::errs() << "Parse MLIR file failed."; + return nullptr; + } + + return module; + }; + + auto module = processBuffer(std::move(input)); + if (!module) { + return nullptr; + } + + mlir::PassManager pm(module->getContext()); + applyPassManagerCLOptions(pm); + + pm.addPass(createConvertTritonGPUToLLVMPass()); + + if (failed(pm.run(module->getOperation()))) { + llvm::errs() << "Pass execution failed"; + return nullptr; + } + + return module; +} + +LogicalResult tritonTranslateMain(int argc, char **argv, + llvm::StringRef toolName) { + static llvm::cl::opt inputFilename( + llvm::cl::Positional, llvm::cl::desc(""), + llvm::cl::init("-")); + + static llvm::cl::opt outputFilename( + "o", llvm::cl::desc("Output filename"), llvm::cl::value_desc("filename"), + llvm::cl::init("-")); + + static llvm::cl::opt targetKind( + "target", llvm::cl::desc(""), + llvm::cl::value_desc("target"), llvm::cl::init("llvmir")); + + static llvm::cl::opt SMArch("sm", llvm::cl::desc("sm arch"), + llvm::cl::init(80)); + + static llvm::cl::opt ptxVersion( + "ptx-version", llvm::cl::desc("PTX version"), llvm::cl::init(10000)); + + llvm::InitLLVM y(argc, argv); + + registerAsmPrinterCLOptions(); + + registerMLIRContextCLOptions(); + llvm::cl::ParseCommandLineOptions(argc, argv, toolName); + + mlir::MLIRContext context; + auto module = loadMLIRModule(inputFilename, context); + if (!module) { + return failure(); + } + + std::string errorMessage; + auto output = openOutputFile(outputFilename, &errorMessage); + if (!output) { + llvm::errs() << errorMessage << "\n"; + return failure(); + } + + llvm::LLVMContext llvmContext; + auto llvmir = TranslateLLVMToLLVMIR(&llvmContext, *module); + if (!llvmir) { + llvm::errs() << "Translate to LLVM IR failed"; + } + + if (targetKind == "llvmir") + llvm::outs() << *llvmir << '\n'; + else if (targetKind == "ptx") + llvm::outs() << ::triton::driver::llir_to_ptx( + llvmir.get(), SMArch.getValue(), ptxVersion.getValue()); + + return success(); +} + +} // namespace triton +} // namespace mlir + +int main(int argc, char **argv) { + return failed(mlir::triton::tritonTranslateMain( + argc, argv, "Triton Translate Testing Tool.")); +} diff --git a/include/triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h b/include/triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h index adbd2ef52..85ffc1944 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h +++ b/include/triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h @@ -20,6 +20,15 @@ public: namespace triton { +// Names for identifying different NVVM annotations. It is used as attribute +// names in MLIR modules. Refer to +// https://docs.nvidia.com/cuda/nvvm-ir-spec/index.html#supported-properties for +// the full list. +struct NVVMMetadataField { + static constexpr char MaxNTid[] = "nvvm.maxntid"; + static constexpr char Kernel[] = "nvvm.kernel"; +}; + std::unique_ptr> createConvertTritonGPUToLLVMPass(); } // namespace triton diff --git a/include/triton/Target/LLVMIR/LLVMIRTranslation.h b/include/triton/Target/LLVMIR/LLVMIRTranslation.h new file mode 100644 index 000000000..0ec11d524 --- /dev/null +++ b/include/triton/Target/LLVMIR/LLVMIRTranslation.h @@ -0,0 +1,24 @@ +#ifndef TRITON_TARGET_LLVMIRTRANSLATION_H +#define TRITON_TARGET_LLVMIRTRANSLATION_H +#include + +namespace llvm { +class Module; +class LLVMContext; +} // namespace llvm + +namespace mlir { +class ModuleOp; +} // namespace mlir + +namespace mlir { +namespace triton { + +// Translate mlir LLVM dialect to LLVMIR, return null if failed. +std::unique_ptr +TranslateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module); + +} // namespace triton +} // namespace mlir + +#endif // TRITON_TARGET_LLVMIRTRANSLATION_H diff --git a/lib/CMakeLists.txt b/lib/CMakeLists.txt index d19e99c0b..5a6ba8951 100644 --- a/lib/CMakeLists.txt +++ b/lib/CMakeLists.txt @@ -3,3 +3,4 @@ add_subdirectory(driver) add_subdirectory(Analysis) add_subdirectory(Conversion) add_subdirectory(Dialect) +add_subdirectory(Target) diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index 7ccdf7bc0..78e51b962 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -155,7 +155,7 @@ struct FuncOpConversion : public FuncOpConversionBase { auto i32 = IntegerType::get(ctx, 32); // Set an attribute for maxntidx, it could be used in latter LLVM codegen // for `nvvm.annotation` metadata. - newFuncOp->setAttr("nvvm.maxntidx", + newFuncOp->setAttr(NVVMMetadataField::MaxNTid, rewriter.getIntegerAttr(i32, 32 * NumWarps)); rewriter.eraseOp(funcOp); diff --git a/lib/Target/CMakeLists.txt b/lib/Target/CMakeLists.txt new file mode 100644 index 000000000..88b49be6b --- /dev/null +++ b/lib/Target/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(LLVMIR) \ No newline at end of file diff --git a/lib/Target/LLVMIR/CMakeLists.txt b/lib/Target/LLVMIR/CMakeLists.txt new file mode 100644 index 000000000..73a89676a --- /dev/null +++ b/lib/Target/LLVMIR/CMakeLists.txt @@ -0,0 +1,12 @@ +add_mlir_translation_library(TritonLLVMIR + LLVMIRTranslation.cpp + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRIR + MLIRLLVMIR + MLIRSupport + MLIRTargetLLVMIRExport + ) diff --git a/lib/Target/LLVMIR/LLVMIRTranslation.cpp b/lib/Target/LLVMIR/LLVMIRTranslation.cpp new file mode 100644 index 000000000..ed4931a8d --- /dev/null +++ b/lib/Target/LLVMIR/LLVMIRTranslation.cpp @@ -0,0 +1,118 @@ +#include "triton/Target/LLVMIR/LLVMIRTranslation.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/ExecutionEngine/ExecutionEngine.h" +#include "mlir/ExecutionEngine/OptUtils.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Dialect.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Export.h" +#include "mlir/Target/LLVMIR/LLVMTranslationInterface.h" +#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h" +#include "triton/driver/llvm.h" +#include "llvm/IR/Constants.h" + +namespace mlir { +namespace triton { + +// Describes NVVM Metadata. It is used to record the nvvm related meta +// information from mlir module. +struct NVVMMetadata { + int maxntidx{-1}; + bool is_kernel{}; + // Free to extend with other information. +}; + +// Add the nvvm related metadata to LLVM IR. +void amendLLVMFunc(llvm::Function *func, const NVVMMetadata &metadata) { + auto *module = func->getParent(); + auto &ctx = func->getContext(); + + if (metadata.maxntidx > 0) { + auto i32_ty = llvm::IntegerType::get(ctx, 32); + auto warps = + llvm::ConstantInt::get(i32_ty, llvm::APInt(32, metadata.maxntidx)); + + llvm::Metadata *md_args[] = {llvm::ValueAsMetadata::get(func), + llvm::MDString::get(ctx, "maxntidx"), + llvm::ValueAsMetadata::get(warps)}; + + module->getOrInsertNamedMetadata("nvvm.annotations") + ->addOperand(llvm::MDNode::get(ctx, md_args)); + } + + if (metadata.is_kernel) { + llvm::Metadata *md_args[] = { + llvm::ValueAsMetadata::get(func), llvm::MDString::get(ctx, "kernel"), + llvm::ValueAsMetadata::get( + llvm::ConstantInt::get(llvm::Type::getInt32Ty(ctx), 1))}; + module->getOrInsertNamedMetadata("nvvm.annotations") + ->addOperand(llvm::MDNode::get(ctx, md_args)); + } +} + +void extractNVVMMetadata(mlir::ModuleOp module, + llvm::DenseMap *dic) { + for (auto op : module.getOps()) { + NVVMMetadata meta; + + bool hasMetadata{}; + + // maxntid + if (op->hasAttr(NVVMMetadataField::MaxNTid)) { + auto attr = op->getAttr(NVVMMetadataField::MaxNTid); + meta.maxntidx = attr.dyn_cast().getInt(); + hasMetadata = true; + } + + // kernel + if (op->hasAttr(NVVMMetadataField::Kernel)) { + meta.is_kernel = true; + hasMetadata = true; + } + + if (hasMetadata) + dic->try_emplace(op.getNameAttr().strref(), std::move(meta)); + } +} + +std::unique_ptr +TranslateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module) { + auto context = module->getContext(); + DialectRegistry registry; + registerLLVMDialectTranslation(registry); + context->appendDialectRegistry(registry); + + llvm::DenseMap nvvmMetadata; + extractNVVMMetadata(module, &nvvmMetadata); + + auto llvmModule = mlir::translateModuleToLLVMIR(module, *llvmContext); + if (!llvmModule) { + llvm::errs() << "Failed to emit LLVM IR\n"; + return nullptr; + } + + // Initialize LLVM targets. + ::triton::driver::init_llvm(); + mlir::ExecutionEngine::setupTargetTriple(llvmModule.get()); + + auto optPipeline = mlir::makeOptimizingTransformer( + /*optLevel=*/3, /*sizeLevel=*/0, + /*targetMachine=*/nullptr); + + if (auto err = optPipeline(llvmModule.get())) { + llvm::errs() << "Failed to optimize LLVM IR " << err << "\n"; + return nullptr; + } + + for (auto &func : llvmModule->functions()) { + auto it = nvvmMetadata.find(func.getName()); + if (it != nvvmMetadata.end()) + amendLLVMFunc(&func, it->second); + } + + return llvmModule; +} + +} // namespace triton +} // namespace mlir diff --git a/test/Target/tritongpu_to_llvmir.mlir b/test/Target/tritongpu_to_llvmir.mlir new file mode 100644 index 000000000..70291ec2e --- /dev/null +++ b/test/Target/tritongpu_to_llvmir.mlir @@ -0,0 +1,13 @@ +// RUN: triton-translate %s --target=llvmir | FileCheck %s + +// == LLVM IR check begin == +// CHECK-LABEL: ; ModuleID = 'LLVMDialectModule' +// CHECK: define void @test_empty_kernel +// CHECK: !nvvm.annotations +// CHECK: !{void (i64, half addrspace(1)*)* @test_empty_kernel, !"maxntidx", i32 128} + +func @test_empty_kernel(%lb : index, %A : !tt.ptr) { + + return +} + diff --git a/test/Target/tritongpu_to_ptx.mlir b/test/Target/tritongpu_to_ptx.mlir new file mode 100644 index 000000000..b09600065 --- /dev/null +++ b/test/Target/tritongpu_to_ptx.mlir @@ -0,0 +1,11 @@ +// RUN: triton-translate %s --target=ptx --sm=80 --ptx-version=10000 | FileCheck %s + +// CHECK-LABEL: // Generated by LLVM NVPTX Back-End +// CHECK: .version 6.3 +// CHECK: .target sm_80 +// CHECK: .address_size 64 + +func @test_empty_kernel(%lb : index, %A : !tt.ptr) { + + return +}