diff --git a/.ci/azure-pipelines.yml b/.ci/azure-pipelines.yml index fd56ee836..77cb8a607 100644 --- a/.ci/azure-pipelines.yml +++ b/.ci/azure-pipelines.yml @@ -1,4 +1,8 @@ name: Triton CI + +workspace: + clean: all + pool: name: default @@ -16,27 +20,18 @@ pr: # Pipeline steps: - script: | - mkdir $(venv) - python -m virtualenv --python=python3 $(venv) - source $(venv)/bin/activate - python -m pip install --upgrade pip - pip install torch==1.7.1+cu110 torchvision==0.8.2+cu110 torchaudio===0.7.2 \ - -f https://download.pytorch.org/whl/torch_stable.html + alias python='python3' cd python - python setup.py install + pip3 install -e . displayName: Setup python environment - script: | - source $(venv)/bin/activate - pip install matplotlib pandas cd python/bench - python -m run + python3 -m run - publish: python/bench/results artifact: Benchmarks - script: | - source $(venv)/bin/activate - pip install pytest pytest . displayName: 'Run Python tests' \ No newline at end of file diff --git a/.ci/build-wheels.yml b/.ci/build-wheels.yml index 7c56ce186..2ec71820e 100644 --- a/.ci/build-wheels.yml +++ b/.ci/build-wheels.yml @@ -4,6 +4,9 @@ pr: none jobs: - job: linux + workspace: + clean: all + timeoutInMinutes: 180 pool: default @@ -33,4 +36,5 @@ jobs: inputs: {pathtoPublish: 'wheelhouse'} - bash: | python3 -m twine upload wheelhouse/* --skip-existing -u $(PYPI_USERNAME) -p $(PYPI_PASSWORD) - displayName: Upload wheels to PyPI \ No newline at end of file + displayName: Upload wheels to PyPI + \ No newline at end of file diff --git a/CMakeLists.txt b/CMakeLists.txt index 23887c91e..d82a282a6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -22,48 +22,24 @@ endif() # Compiler flags include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include) -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__STDC_FORMAT_MACROS -fvisibility=default -std=gnu++17") +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__STDC_FORMAT_MACROS -std=gnu++17") +# if(APPLE) +# set(CMAKE_OSX_SYSROOT "/") +# set(CMAKE_OSX_DEPLOYMENT_TARGET "") +# endif() + ########## # LLVM ########## -get_cmake_property(_variableNames VARIABLES) -set(__variableNames ${_variableNames}) - -configure_file(cmake/DownloadLLVM.in ${TRITON_LLVM_BUILD_DIR}/llvm-download/CMakeLists.txt) -execute_process(COMMAND "${CMAKE_COMMAND}" -G "${CMAKE_GENERATOR}" . - WORKING_DIRECTORY "${TRITON_LLVM_BUILD_DIR}/llvm-download" -) -execute_process(COMMAND "${CMAKE_COMMAND}" --build . - WORKING_DIRECTORY "${TRITON_LLVM_BUILD_DIR}/llvm-download" -) -set(LLVM_TARGETS_TO_BUILD "NVPTX" CACHE INTERNAL "") -set(LLVM_BUILD_RUNTIME "OFF" CACHE INTERNAL "") -set(LLVM_BUILD_RUNTIMES "OFF" CACHE INTERNAL "") -set(LLVM_BUILD_TOOLS "OFF" CACHE INTERNAL "") -set(LLVM_BUILD_UTILS "OFF" CACHE INTERNAL "") -set(LLVM_INCLUDE_BENCHMARKS "OFF" CACHE INTERNAL "") -set(LLVM_INCLUDE_DOCS "OFF" CACHE INTERNAL "") -set(LLVM_INCLUDE_EXAMPLES "OFF" CACHE INTERNAL "") -set(LLVM_INCLUDE_GO_TESTS "OFF" CACHE INTERNAL "") -set(LLVM_INCLUDE_RUNTIME "OFF" CACHE INTERNAL "") -set(LLVM_INCLUDE_TESTS "OFF" CACHE INTERNAL "") -set(LLVM_INCLUDE_TOOLS "OFF" CACHE INTERNAL "") -set(LLVM_INCLUDE_UTILS "OFF" CACHE INTERNAL "") -add_subdirectory(${TRITON_LLVM_BUILD_DIR}/llvm-src - ${TRITON_LLVM_BUILD_DIR}/llvm-build) -get_property(LLVM_LIBRARIES GLOBAL PROPERTY LLVM_COMPONENT_LIBS) -# remove LLVM-specific variables so we don't pollute GUI -get_cmake_property(_variableNames VARIABLES) -list(REMOVE_ITEM _variableNames ${__variableNames}) -list(REMOVE_ITEM _variableNames ${LLVM_LIBRARIES}) -foreach (_variableName ${_variableNames}) - unset(${_variableName} CACHE) -endforeach() -include_directories("${TRITON_LLVM_BUILD_DIR}/llvm-build/include/" - "${TRITON_LLVM_BUILD_DIR}/llvm-src/include/") +find_package(LLVM 11 REQUIRED COMPONENTS "nvptx") +message(STATUS "Found LLVM ${LLVM_PACKAGE_VERSION}") +include_directories("${LLVM_INCLUDE_DIRS}") +if(APPLE) + set(CMAKE_OSX_DEPLOYMENT_TARGET "10.14") +endif() # Python module if(BUILD_PYTHON_MODULE) @@ -87,8 +63,15 @@ endif() # Triton file(GLOB_RECURSE LIBTRITON_SRC lib/*.cc) add_library(triton SHARED ${LIBTRITON_SRC} ${PYTHON_SRC}) -target_link_libraries(triton ${LLVM_LIBRARIES}) +target_link_options(triton PRIVATE ${LLVM_LDFLAGS}) +target_link_libraries(triton ${LLVM_LIBRARIES} ${LLVM_SYSTEM_LIBS}) +message(STATUS ${LLVM_LDFLAGS}) if(BUILD_PYTHON_MODULE) - target_link_libraries(triton ${TORCH_LIBRARIES} ${CUTLASS_LIBRARIES}) + set(CMAKE_SHARED_LIBRARY_SUFFIX ".so") + # Check if the platform is MacOS + if(APPLE) + set(PYTHON_LDFLAGS "-undefined dynamic_lookup -flto") + endif() + target_link_libraries(triton ${CUTLASS_LIBRARIES} ${PYTHON_LDFLAGS}) endif() diff --git a/cmake/DownloadLLVM.in b/cmake/DownloadLLVM.in deleted file mode 100644 index afe3d8362..000000000 --- a/cmake/DownloadLLVM.in +++ /dev/null @@ -1,15 +0,0 @@ -cmake_minimum_required(VERSION 3.6) - -project(llvm-download NONE) -include(ExternalProject) - - -ExternalProject_Add(llvm - URL "https://github.com/llvm/llvm-project/releases/download/llvmorg-11.0.0/llvm-11.0.0.src.tar.xz" - SOURCE_DIR "${TRITON_LLVM_BUILD_DIR}/llvm-src" - BINARY_DIR "${TRITON_LLVM_BUILD_DIR}/llvm-build" - CONFIGURE_COMMAND "" - BUILD_COMMAND "" - INSTALL_COMMAND "" - TEST_COMMAND "" -) diff --git a/cmake/FindLLVM.cmake b/cmake/FindLLVM.cmake index 3de161c64..2bdf22c28 100644 --- a/cmake/FindLLVM.cmake +++ b/cmake/FindLLVM.cmake @@ -1,3 +1,4 @@ + # - Find LLVM headers and libraries. # This module locates LLVM and adapts the llvm-config output for use with # CMake. @@ -7,14 +8,18 @@ # The following variables are defined: # LLVM_FOUND - true if LLVM was found # LLVM_CXXFLAGS - C++ compiler flags for files that include LLVM headers. -# LLVM_HOST_TARGET - Target triple used to configure LLVM. +# LLVM_ENABLE_ASSERTIONS - Whether LLVM was built with enabled assertions (ON/OFF). # LLVM_INCLUDE_DIRS - Directory containing LLVM include files. +# LLVM_IS_SHARED - Whether LLVM is going to be linked dynamically (ON) or statically (OFF). # LLVM_LDFLAGS - Linker flags to add when linking against LLVM # (includes -LLLVM_LIBRARY_DIRS). # LLVM_LIBRARIES - Full paths to the library files to link against. # LLVM_LIBRARY_DIRS - Directory containing LLVM libraries. +# LLVM_NATIVE_ARCH - Backend corresponding to LLVM_HOST_TARGET, e.g., +# X86 for x86_64 and i686 hosts. # LLVM_ROOT_DIR - The root directory of the LLVM installation. # llvm-config is searched for in ${LLVM_ROOT_DIR}/bin. +# LLVM_TARGETS_TO_BUILD - List of built LLVM targets. # LLVM_VERSION_MAJOR - Major version of LLVM. # LLVM_VERSION_MINOR - Minor version of LLVM. # LLVM_VERSION_STRING - Full LLVM version string (e.g. 6.0.0svn). @@ -28,16 +33,31 @@ # We also want an user-specified LLVM_ROOT_DIR to take precedence over the # system default locations such as /usr/local/bin. Executing find_program() # multiples times is the approach recommended in the docs. -set(llvm_config_names llvm-config-11 llvm-config-11.0 - llvm-config-10 llvm-config-10.0 llvm-config100 - llvm-config-9 llvm-config-9.0 llvm-config90 - llvm-config-8 llvm-config-8.0 llvm-config80 +set(llvm_config_names llvm-config-12.0 llvm-config120 llvm-config-12 + llvm-config-11.0 llvm-config110 llvm-config-11 + llvm-config-10.0 llvm-config100 llvm-config-10 + llvm-config-9.0 llvm-config90 llvm-config-9 + llvm-config-8.0 llvm-config80 llvm-config-8 + llvm-config-7.0 llvm-config70 llvm-config-7 + llvm-config-6.0 llvm-config60 llvm-config) find_program(LLVM_CONFIG NAMES ${llvm_config_names} PATHS ${LLVM_ROOT_DIR}/bin NO_DEFAULT_PATH DOC "Path to llvm-config tool.") find_program(LLVM_CONFIG NAMES ${llvm_config_names}) +if(APPLE) + # extra fallbacks for MacPorts & Homebrew + find_program(LLVM_CONFIG + NAMES ${llvm_config_names} + PATHS /opt/local/libexec/llvm-11/bin /opt/local/libexec/llvm-10/bin /opt/local/libexec/llvm-9.0/bin + /opt/local/libexec/llvm-8.0/bin /opt/local/libexec/llvm-7.0/bin /opt/local/libexec/llvm-6.0/bin + /opt/local/libexec/llvm/bin + /usr/local/opt/llvm@11/bin /usr/local/opt/llvm@10/bin /usr/local/opt/llvm@9/bin + /usr/local/opt/llvm@8/bin /usr/local/opt/llvm@7/bin /usr/local/opt/llvm@6/bin + /usr/local/opt/llvm/bin + NO_DEFAULT_PATH) +endif() # Prints a warning/failure message depending on the required/quiet flags. Copied # from FindPackageHandleStandardArgs.cmake because it doesn't seem to be exposed. @@ -46,7 +66,7 @@ macro(_LLVM_FAIL _msg) message(FATAL_ERROR "${_msg}") else() if(NOT LLVM_FIND_QUIETLY) - message(STATUS "${_msg}") + message(WARNING "${_msg}") endif() endif() endmacro() @@ -54,7 +74,7 @@ endmacro() if(NOT LLVM_CONFIG) if(NOT LLVM_FIND_QUIETLY) - message(WARNING "Could not find llvm-config (LLVM >= ${LLVM_FIND_VERSION}). Try manually setting LLVM_CONFIG to the llvm-config executable of the installation to use.") + _LLVM_FAIL("No LLVM installation (>= ${LLVM_FIND_VERSION}) found. Try manually setting the 'LLVM_ROOT_DIR' or 'LLVM_CONFIG' variables.") endif() else() macro(llvm_set var flag) @@ -63,7 +83,7 @@ else() endif() set(result_code) execute_process( - COMMAND ${LLVM_CONFIG} --${flag} + COMMAND ${LLVM_CONFIG} --link-static --${flag} RESULT_VARIABLE result_code OUTPUT_VARIABLE LLVM_${var} OUTPUT_STRIP_TRAILING_WHITESPACE @@ -77,13 +97,13 @@ else() endif() endif() endmacro() - macro(llvm_set_libs var flag) + macro(llvm_set_libs var flag components) if(LLVM_FIND_QUIETLY) set(_quiet_arg ERROR_QUIET) endif() set(result_code) execute_process( - COMMAND ${LLVM_CONFIG} --${flag} ${LLVM_FIND_COMPONENTS} + COMMAND ${LLVM_CONFIG} --link-static --${flag} ${components} RESULT_VARIABLE result_code OUTPUT_VARIABLE tmplibs OUTPUT_STRIP_TRAILING_WHITESPACE @@ -91,7 +111,7 @@ else() ) if(result_code) _LLVM_FAIL("Failed to execute llvm-config ('${LLVM_CONFIG}', result code: '${result_code})'") - else() + else() file(TO_CMAKE_PATH "${tmplibs}" tmplibs) string(REGEX MATCHALL "${pattern}[^ ]+" LLVM_${var} ${tmplibs}) endif() @@ -99,31 +119,29 @@ else() llvm_set(VERSION_STRING version) llvm_set(CXXFLAGS cxxflags) - llvm_set(HOST_TARGET host-target) llvm_set(INCLUDE_DIRS includedir true) llvm_set(ROOT_DIR prefix true) llvm_set(ENABLE_ASSERTIONS assertion-mode) - # The LLVM version string _may_ contain a git/svn suffix, so cut that off - string(SUBSTRING "${LLVM_VERSION_STRING}" 0 5 LLVM_VERSION_BASE_STRING) + # The LLVM version string _may_ contain a git/svn suffix, so match only the x.y.z part + string(REGEX MATCH "^[0-9]+[.][0-9]+[.][0-9]+" LLVM_VERSION_BASE_STRING "${LLVM_VERSION_STRING}") - # Versions below 4.0 do not support components debuginfomsf and demangle - if(${LLVM_VERSION_STRING} MATCHES "^3\\..*") - list(REMOVE_ITEM LLVM_FIND_COMPONENTS "debuginfomsf" index) - list(REMOVE_ITEM LLVM_FIND_COMPONENTS "demangle" index) - endif() - # Versions below 8.0 not supported - if(${LLVM_VERSION_STRING} MATCHES "^[3-7]\\..*") - message(FATAL_ERROR "LLVM version below 8.0 not supported") + llvm_set(SHARED_MODE shared-mode) + if(LLVM_SHARED_MODE STREQUAL "shared") + set(LLVM_IS_SHARED ON) + else() + set(LLVM_IS_SHARED OFF) endif() llvm_set(LDFLAGS ldflags) - # In LLVM 3.5+, the system library dependencies (e.g. "-lz") are accessed - # using the separate "--system-libs" flag. llvm_set(SYSTEM_LIBS system-libs) string(REPLACE "\n" " " LLVM_LDFLAGS "${LLVM_LDFLAGS} ${LLVM_SYSTEM_LIBS}") + if(APPLE) # unclear why/how this happens + string(REPLACE "-llibxml2.tbd" "-lxml2" LLVM_LDFLAGS ${LLVM_LDFLAGS}) + endif() + llvm_set(LIBRARY_DIRS libdir true) - llvm_set_libs(LIBRARIES libs) + llvm_set_libs(LIBRARIES libfiles "${LLVM_FIND_COMPONENTS}") # LLVM bug: llvm-config --libs tablegen returns -lLLVM-3.8.0 # but code for it is not in shared library if("${LLVM_FIND_COMPONENTS}" MATCHES "tablegen") @@ -132,37 +150,50 @@ else() endif() endif() - # Versions below 4.0 do not support llvm-config --cmakedir - if(${LLVM_VERSION_STRING} MATCHES "^3\\..*") - set(LLVM_CMAKEDIR ${LLVM_LIBRARY_DIRS}/cmake/llvm) - else() - llvm_set(CMAKEDIR cmakedir) - endif() - + llvm_set(CMAKEDIR cmakedir) llvm_set(TARGETS_TO_BUILD targets-built) string(REGEX MATCHALL "${pattern}[^ ]+" LLVM_TARGETS_TO_BUILD ${LLVM_TARGETS_TO_BUILD}) + + # Parse LLVM_NATIVE_ARCH manually from LLVMConfig.cmake; including it leads to issues like + # https://github.com/ldc-developers/ldc/issues/3079. + file(STRINGS "${LLVM_CMAKEDIR}/LLVMConfig.cmake" LLVM_NATIVE_ARCH LIMIT_COUNT 1 REGEX "^set\\(LLVM_NATIVE_ARCH (.+)\\)$") + string(REGEX MATCH "set\\(LLVM_NATIVE_ARCH (.+)\\)" LLVM_NATIVE_ARCH "${LLVM_NATIVE_ARCH}") + set(LLVM_NATIVE_ARCH ${CMAKE_MATCH_1}) + message(STATUS "LLVM_NATIVE_ARCH: ${LLVM_NATIVE_ARCH}") + + # On CMake builds of LLVM, the output of llvm-config --cxxflags does not + # include -fno-rtti, leading to linker errors. Be sure to add it. + if(NOT MSVC AND (CMAKE_COMPILER_IS_GNUCXX OR (${CMAKE_CXX_COMPILER_ID} STREQUAL "Clang"))) + if(NOT ${LLVM_CXXFLAGS} MATCHES "-fno-rtti") + set(LLVM_CXXFLAGS "${LLVM_CXXFLAGS} -fno-rtti") + endif() + endif() + + # Remove some clang-specific flags for gcc. + if(CMAKE_COMPILER_IS_GNUCXX) + string(REPLACE "-Wcovered-switch-default " "" LLVM_CXXFLAGS ${LLVM_CXXFLAGS}) + string(REPLACE "-Wstring-conversion " "" LLVM_CXXFLAGS ${LLVM_CXXFLAGS}) + string(REPLACE "-fcolor-diagnostics " "" LLVM_CXXFLAGS ${LLVM_CXXFLAGS}) + # this requires more recent gcc versions (not supported by 4.9) + string(REPLACE "-Werror=unguarded-availability-new " "" LLVM_CXXFLAGS ${LLVM_CXXFLAGS}) + endif() + + # Remove gcc-specific flags for clang. + if(${CMAKE_CXX_COMPILER_ID} MATCHES "Clang") + string(REPLACE "-Wno-maybe-uninitialized " "" LLVM_CXXFLAGS ${LLVM_CXXFLAGS}) + endif() + + string(REGEX REPLACE "([0-9]+).*" "\\1" LLVM_VERSION_MAJOR "${LLVM_VERSION_STRING}" ) + string(REGEX REPLACE "[0-9]+\\.([0-9]+).*[A-Za-z]*" "\\1" LLVM_VERSION_MINOR "${LLVM_VERSION_STRING}" ) + + if (${LLVM_VERSION_STRING} VERSION_LESS ${LLVM_FIND_VERSION}) + _LLVM_FAIL("Unsupported LLVM version ${LLVM_VERSION_STRING} found (${LLVM_CONFIG}). At least version ${LLVM_FIND_VERSION} is required. You can also set variables 'LLVM_ROOT_DIR' or 'LLVM_CONFIG' to use a different LLVM installation.") + endif() endif() -# Remove some clang-specific flags for gcc. -if(CMAKE_COMPILER_IS_GNUCXX) - string(REPLACE "-Wcovered-switch-default " "" LLVM_CXXFLAGS ${LLVM_CXXFLAGS}) - string(REPLACE "-Wstring-conversion " "" LLVM_CXXFLAGS ${LLVM_CXXFLAGS}) - string(REPLACE "-fcolor-diagnostics " "" LLVM_CXXFLAGS ${LLVM_CXXFLAGS}) - string(REPLACE "-Werror=unguarded-availability-new " "" LLVM_CXXFLAGS ${LLVM_CXXFLAGS}) -endif() - -# Remove gcc-specific flags for clang. -if(${CMAKE_CXX_COMPILER_ID} MATCHES "Clang") - string(REPLACE "-Wno-maybe-uninitialized " "" LLVM_CXXFLAGS ${LLVM_CXXFLAGS}) -endif() - -string(REGEX REPLACE "([0-9]+).*" "\\1" LLVM_VERSION_MAJOR "${LLVM_VERSION_STRING}" ) -string(REGEX REPLACE "[0-9]+\\.([0-9]+).*[A-Za-z]*" "\\1" LLVM_VERSION_MINOR "${LLVM_VERSION_STRING}" ) - - # Use the default CMake facilities for handling QUIET/REQUIRED. include(FindPackageHandleStandardArgs) find_package_handle_standard_args(LLVM - REQUIRED_VARS LLVM_ROOT_DIR LLVM_HOST_TARGET - VERSION_VAR LLVM_VERSION_STRING) + REQUIRED_VARS LLVM_ROOT_DIR + VERSION_VAR LLVM_VERSION_STRING) \ No newline at end of file diff --git a/lib/driver/module.cc b/lib/driver/module.cc index c31e3cca4..3f5b1d953 100755 --- a/lib/driver/module.cc +++ b/lib/driver/module.cc @@ -59,6 +59,13 @@ std::string exec(const char* cmd) { return result; } + void LLVMInitializeNVPTXTargetInfo(); + void LLVMInitializeNVPTXTarget(); + void LLVMInitializeNVPTXTargetMC(); + void LLVMInitializeNVPTXAsmPrinter(); + void LLVMInitializeNVPTXAsmParser(); + + namespace triton { namespace driver @@ -68,14 +75,14 @@ namespace driver // Base // /* ------------------------ */ + void module::init_llvm() { static bool init = false; if(!init){ - llvm::InitializeAllTargetInfos(); - llvm::InitializeAllTargets(); - llvm::InitializeAllTargetMCs(); - llvm::InitializeAllAsmParsers(); - llvm::InitializeAllAsmPrinters(); + LLVMInitializeNVPTXTargetInfo(); + LLVMInitializeNVPTXTarget(); + LLVMInitializeNVPTXTargetMC(); + LLVMInitializeNVPTXAsmPrinter(); init = true; } } @@ -111,80 +118,81 @@ void module::compile_llvm_module(std::unique_ptr module, const std /* ------------------------ */ host_module::host_module(std::unique_ptr src): module(host_module_t(), true) { - init_llvm(); - // create kernel wrapper - llvm::LLVMContext &ctx = src->getContext(); - llvm::Type *void_ty = llvm::Type::getVoidTy(ctx); - llvm::Type *args_ty = llvm::Type::getInt8PtrTy(ctx)->getPointerTo(); - llvm::Type *int32_ty = llvm::Type::getInt32Ty(ctx); - std::vector tys = {args_ty, int32_ty, int32_ty, int32_ty}; - llvm::FunctionType *main_ty = llvm::FunctionType::get(void_ty, tys, false); - llvm::Function* main = llvm::Function::Create(main_ty, llvm::Function::ExternalLinkage, "_main", &*src); - llvm::Function* fn = &*src->getFunctionList().begin(); - llvm::FunctionType *fn_ty = fn->getFunctionType(); - std::vector fn_args(fn_ty->getNumParams()); - std::vector ptrs(fn_args.size() - 3); - llvm::BasicBlock* entry = llvm::BasicBlock::Create(ctx, "entry", main); - llvm::IRBuilder<> ir_builder(ctx); - ir_builder.SetInsertPoint(entry); - auto get_size = [](llvm::Type* ty) { return ty->isPointerTy() ? sizeof(char*) : ty->getPrimitiveSizeInBits() / 8; }; - llvm::Value* base = main->arg_begin(); - llvm::Value* args_base = ir_builder.CreateBitCast(base, base->getType()->getPointerElementType()); + throw std::runtime_error("CPU unsupported"); +// init_llvm(); +// // create kernel wrapper +// llvm::LLVMContext &ctx = src->getContext(); +// llvm::Type *void_ty = llvm::Type::getVoidTy(ctx); +// llvm::Type *args_ty = llvm::Type::getInt8PtrTy(ctx)->getPointerTo(); +// llvm::Type *int32_ty = llvm::Type::getInt32Ty(ctx); +// std::vector tys = {args_ty, int32_ty, int32_ty, int32_ty}; +// llvm::FunctionType *main_ty = llvm::FunctionType::get(void_ty, tys, false); +// llvm::Function* main = llvm::Function::Create(main_ty, llvm::Function::ExternalLinkage, "_main", &*src); +// llvm::Function* fn = &*src->getFunctionList().begin(); +// llvm::FunctionType *fn_ty = fn->getFunctionType(); +// std::vector fn_args(fn_ty->getNumParams()); +// std::vector ptrs(fn_args.size() - 3); +// llvm::BasicBlock* entry = llvm::BasicBlock::Create(ctx, "entry", main); +// llvm::IRBuilder<> ir_builder(ctx); +// ir_builder.SetInsertPoint(entry); +// auto get_size = [](llvm::Type* ty) { return ty->isPointerTy() ? sizeof(char*) : ty->getPrimitiveSizeInBits() / 8; }; +// llvm::Value* base = main->arg_begin(); +// llvm::Value* args_base = ir_builder.CreateBitCast(base, base->getType()->getPointerElementType()); - size_t offset = 0; - for(unsigned i = 0; i < ptrs.size(); i++){ - ptrs[i] = ir_builder.CreateGEP(args_base, ir_builder.getInt32(offset)); - size_t nbytes = get_size(fn_ty->getParamType(i)); - offset += nbytes; - if(i < ptrs.size() - 1){ - size_t np1bytes = get_size(fn_ty->getParamType(i+1)); - offset = (offset + np1bytes - 1) / np1bytes * np1bytes; - } - } - for(unsigned i = 0; i < ptrs.size(); i++) - ptrs[i] = ir_builder.CreateBitCast(ptrs[i], fn_ty->getParamType(i)->getPointerTo()); - for(unsigned i = 0; i < ptrs.size(); i++) - fn_args[i] = ir_builder.CreateLoad(ptrs[i]); +// size_t offset = 0; +// for(unsigned i = 0; i < ptrs.size(); i++){ +// ptrs[i] = ir_builder.CreateGEP(args_base, ir_builder.getInt32(offset)); +// size_t nbytes = get_size(fn_ty->getParamType(i)); +// offset += nbytes; +// if(i < ptrs.size() - 1){ +// size_t np1bytes = get_size(fn_ty->getParamType(i+1)); +// offset = (offset + np1bytes - 1) / np1bytes * np1bytes; +// } +// } +// for(unsigned i = 0; i < ptrs.size(); i++) +// ptrs[i] = ir_builder.CreateBitCast(ptrs[i], fn_ty->getParamType(i)->getPointerTo()); +// for(unsigned i = 0; i < ptrs.size(); i++) +// fn_args[i] = ir_builder.CreateLoad(ptrs[i]); - fn_args[fn_args.size() - 3] = main->arg_begin() + 1; - fn_args[fn_args.size() - 2] = main->arg_begin() + 2; - fn_args[fn_args.size() - 1] = main->arg_begin() + 3; - ir_builder.CreateCall(fn, fn_args); - ir_builder.CreateRetVoid(); +// fn_args[fn_args.size() - 3] = main->arg_begin() + 1; +// fn_args[fn_args.size() - 2] = main->arg_begin() + 2; +// fn_args[fn_args.size() - 1] = main->arg_begin() + 3; +// ir_builder.CreateCall(fn, fn_args); +// ir_builder.CreateRetVoid(); -// llvm::legacy::PassManager pm; -// pm.add(llvm::createPrintModulePass(llvm::outs())); -// pm.add(llvm::createVerifierPass()); -// pm.run(*src); +//// llvm::legacy::PassManager pm; +//// pm.add(llvm::createPrintModulePass(llvm::outs())); +//// pm.add(llvm::createVerifierPass()); +//// pm.run(*src); -// create execution engine - for(llvm::Function& fn: src->functions()) - hst_->functions[fn.getName().str()] = &fn; +//// create execution engine +// for(llvm::Function& fn: src->functions()) +// hst_->functions[fn.getName().str()] = &fn; -// llvm::orc::JITTargetMachineBuilder JTMB = *llvm::orc::JITTargetMachineBuilder::detectHost(); -// auto DL = JTMB.getDefaultDataLayoutForTarget(); -// auto CIRC = std::unique_ptr(new llvm::orc::ConcurrentIRCompiler(JTMB)); -// hst_->ES = new llvm::orc::ExecutionSession(); -// hst_->ObjectLayer = new llvm::orc::RTDyldObjectLinkingLayer(*hst_->ES, []() { return std::unique_ptr(new llvm::SectionMemoryManager()); }); -// hst_->CompileLayer = new llvm::orc::IRCompileLayer(*hst_->ES, *hst_->ObjectLayer, *CIRC); -// hst_->DL = new llvm::DataLayout(std::move(*DL)); -// hst_->Mangle = new llvm::orc::MangleAndInterner(*hst_->ES, *hst_->DL); -// hst_->Ctx = new llvm::orc::ThreadSafeContext(std::unique_ptr(new llvm::LLVMContext())); -// hst_->MainJD = &hst_->ES->createJITDylib("
"); -// hst_->MainJD->setGenerator(llvm::cantFail(llvm::orc::DynamicLibrarySearchGenerator::GetForCurrentProcess( -// hst_->DL->getGlobalPrefix()))); -// llvm::cantFail(hst_->CompileLayer->add(*hst_->MainJD, llvm::orc::ThreadSafeModule(std::move(src), *hst_->Ctx))); -// hst_->fn = (void(*)(char**, int32_t, int32_t, int32_t))(hst_->ES->lookup({hst_->MainJD}, (*hst_->Mangle)("_main"))->getAddress()); +//// llvm::orc::JITTargetMachineBuilder JTMB = *llvm::orc::JITTargetMachineBuilder::detectHost(); +//// auto DL = JTMB.getDefaultDataLayoutForTarget(); +//// auto CIRC = std::unique_ptr(new llvm::orc::ConcurrentIRCompiler(JTMB)); +//// hst_->ES = new llvm::orc::ExecutionSession(); +//// hst_->ObjectLayer = new llvm::orc::RTDyldObjectLinkingLayer(*hst_->ES, []() { return std::unique_ptr(new llvm::SectionMemoryManager()); }); +//// hst_->CompileLayer = new llvm::orc::IRCompileLayer(*hst_->ES, *hst_->ObjectLayer, *CIRC); +//// hst_->DL = new llvm::DataLayout(std::move(*DL)); +//// hst_->Mangle = new llvm::orc::MangleAndInterner(*hst_->ES, *hst_->DL); +//// hst_->Ctx = new llvm::orc::ThreadSafeContext(std::unique_ptr(new llvm::LLVMContext())); +//// hst_->MainJD = &hst_->ES->createJITDylib("
"); +//// hst_->MainJD->setGenerator(llvm::cantFail(llvm::orc::DynamicLibrarySearchGenerator::GetForCurrentProcess( +//// hst_->DL->getGlobalPrefix()))); +//// llvm::cantFail(hst_->CompileLayer->add(*hst_->MainJD, llvm::orc::ThreadSafeModule(std::move(src), *hst_->Ctx))); +//// hst_->fn = (void(*)(char**, int32_t, int32_t, int32_t))(hst_->ES->lookup({hst_->MainJD}, (*hst_->Mangle)("_main"))->getAddress()); - llvm::EngineBuilder builder(std::move(src)); - builder.setErrorStr(&hst_->error); - builder.setMCJITMemoryManager(std::make_unique()); - builder.setOptLevel(llvm::CodeGenOpt::Aggressive); - builder.setEngineKind(llvm::EngineKind::JIT); - hst_->engine = builder.create(); - hst_->fn = (void(*)(char**, int32_t, int32_t, int32_t))(hst_->engine->getFunctionAddress("_main")); +// llvm::EngineBuilder builder(std::move(src)); +// builder.setErrorStr(&hst_->error); +// builder.setMCJITMemoryManager(std::make_unique()); +// builder.setOptLevel(llvm::CodeGenOpt::Aggressive); +// builder.setEngineKind(llvm::EngineKind::JIT); +// hst_->engine = builder.create(); +// hst_->fn = (void(*)(char**, int32_t, int32_t, int32_t))(hst_->engine->getFunctionAddress("_main")); } std::unique_ptr host_module::symbol(const char *name) const { @@ -211,7 +219,7 @@ static std::map vptx = { {11010, 71}, {11020, 72}, {11030, 73}, - {11040, 74} + {11040, 73} }; std::string cu_module::compile_llvm_module(llvm::Module* module, driver::device* device) { diff --git a/python/setup.py b/python/setup.py index bf7675f32..3e07122af 100644 --- a/python/setup.py +++ b/python/setup.py @@ -66,7 +66,7 @@ class CMakeBuild(build_ext): "-DBUILD_TUTORIALS=OFF", "-DBUILD_PYTHON_MODULE=ON", #'-DPYTHON_EXECUTABLE=' + sys.executable, - #'-DCMAKE_VERBOSE_MAKEFILE:BOOL=ON, + '-DCMAKE_VERBOSE_MAKEFILE:BOOL=ON', "-DTRITON_LLVM_BUILD_DIR=" + llvm_build_dir, "-DPYTHON_INCLUDE_DIRS=" + ";".join(python_include_dirs) ]