[CI] Various improvements to CI (#137)

Add clean-up before CI runs. Now using static LLVM-11 libraries from system rather than recompilation. Still no run-time LLVM dependencies
This commit is contained in:
Philippe Tillet
2021-07-22 11:41:51 -07:00
committed by Philippe Tillet
parent 298aead378
commit 8eb63bcb01
7 changed files with 196 additions and 190 deletions

View File

@@ -1,4 +1,8 @@
name: Triton CI name: Triton CI
workspace:
clean: all
pool: pool:
name: default name: default
@@ -16,27 +20,18 @@ pr:
# Pipeline # Pipeline
steps: steps:
- script: | - script: |
mkdir $(venv) alias python='python3'
python -m virtualenv --python=python3 $(venv)
source $(venv)/bin/activate
python -m pip install --upgrade pip
pip install torch==1.7.1+cu110 torchvision==0.8.2+cu110 torchaudio===0.7.2 \
-f https://download.pytorch.org/whl/torch_stable.html
cd python cd python
python setup.py install pip3 install -e .
displayName: Setup python environment displayName: Setup python environment
- script: | - script: |
source $(venv)/bin/activate
pip install matplotlib pandas
cd python/bench cd python/bench
python -m run python3 -m run
- publish: python/bench/results - publish: python/bench/results
artifact: Benchmarks artifact: Benchmarks
- script: | - script: |
source $(venv)/bin/activate
pip install pytest
pytest . pytest .
displayName: 'Run Python tests' displayName: 'Run Python tests'

View File

@@ -4,6 +4,9 @@ pr: none
jobs: jobs:
- job: linux - job: linux
workspace:
clean: all
timeoutInMinutes: 180 timeoutInMinutes: 180
pool: default pool: default
@@ -33,4 +36,5 @@ jobs:
inputs: {pathtoPublish: 'wheelhouse'} inputs: {pathtoPublish: 'wheelhouse'}
- bash: | - bash: |
python3 -m twine upload wheelhouse/* --skip-existing -u $(PYPI_USERNAME) -p $(PYPI_PASSWORD) python3 -m twine upload wheelhouse/* --skip-existing -u $(PYPI_USERNAME) -p $(PYPI_PASSWORD)
displayName: Upload wheels to PyPI displayName: Upload wheels to PyPI

View File

@@ -22,48 +22,24 @@ endif()
# Compiler flags # Compiler flags
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include) include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__STDC_FORMAT_MACROS -fvisibility=default -std=gnu++17") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__STDC_FORMAT_MACROS -std=gnu++17")
# if(APPLE)
# set(CMAKE_OSX_SYSROOT "/")
# set(CMAKE_OSX_DEPLOYMENT_TARGET "")
# endif()
########## ##########
# LLVM # LLVM
########## ##########
get_cmake_property(_variableNames VARIABLES) find_package(LLVM 11 REQUIRED COMPONENTS "nvptx")
set(__variableNames ${_variableNames}) message(STATUS "Found LLVM ${LLVM_PACKAGE_VERSION}")
include_directories("${LLVM_INCLUDE_DIRS}")
configure_file(cmake/DownloadLLVM.in ${TRITON_LLVM_BUILD_DIR}/llvm-download/CMakeLists.txt) if(APPLE)
execute_process(COMMAND "${CMAKE_COMMAND}" -G "${CMAKE_GENERATOR}" . set(CMAKE_OSX_DEPLOYMENT_TARGET "10.14")
WORKING_DIRECTORY "${TRITON_LLVM_BUILD_DIR}/llvm-download" endif()
)
execute_process(COMMAND "${CMAKE_COMMAND}" --build .
WORKING_DIRECTORY "${TRITON_LLVM_BUILD_DIR}/llvm-download"
)
set(LLVM_TARGETS_TO_BUILD "NVPTX" CACHE INTERNAL "")
set(LLVM_BUILD_RUNTIME "OFF" CACHE INTERNAL "")
set(LLVM_BUILD_RUNTIMES "OFF" CACHE INTERNAL "")
set(LLVM_BUILD_TOOLS "OFF" CACHE INTERNAL "")
set(LLVM_BUILD_UTILS "OFF" CACHE INTERNAL "")
set(LLVM_INCLUDE_BENCHMARKS "OFF" CACHE INTERNAL "")
set(LLVM_INCLUDE_DOCS "OFF" CACHE INTERNAL "")
set(LLVM_INCLUDE_EXAMPLES "OFF" CACHE INTERNAL "")
set(LLVM_INCLUDE_GO_TESTS "OFF" CACHE INTERNAL "")
set(LLVM_INCLUDE_RUNTIME "OFF" CACHE INTERNAL "")
set(LLVM_INCLUDE_TESTS "OFF" CACHE INTERNAL "")
set(LLVM_INCLUDE_TOOLS "OFF" CACHE INTERNAL "")
set(LLVM_INCLUDE_UTILS "OFF" CACHE INTERNAL "")
add_subdirectory(${TRITON_LLVM_BUILD_DIR}/llvm-src
${TRITON_LLVM_BUILD_DIR}/llvm-build)
get_property(LLVM_LIBRARIES GLOBAL PROPERTY LLVM_COMPONENT_LIBS)
# remove LLVM-specific variables so we don't pollute GUI
get_cmake_property(_variableNames VARIABLES)
list(REMOVE_ITEM _variableNames ${__variableNames})
list(REMOVE_ITEM _variableNames ${LLVM_LIBRARIES})
foreach (_variableName ${_variableNames})
unset(${_variableName} CACHE)
endforeach()
include_directories("${TRITON_LLVM_BUILD_DIR}/llvm-build/include/"
"${TRITON_LLVM_BUILD_DIR}/llvm-src/include/")
# Python module # Python module
if(BUILD_PYTHON_MODULE) if(BUILD_PYTHON_MODULE)
@@ -87,8 +63,15 @@ endif()
# Triton # Triton
file(GLOB_RECURSE LIBTRITON_SRC lib/*.cc) file(GLOB_RECURSE LIBTRITON_SRC lib/*.cc)
add_library(triton SHARED ${LIBTRITON_SRC} ${PYTHON_SRC}) add_library(triton SHARED ${LIBTRITON_SRC} ${PYTHON_SRC})
target_link_libraries(triton ${LLVM_LIBRARIES}) target_link_options(triton PRIVATE ${LLVM_LDFLAGS})
target_link_libraries(triton ${LLVM_LIBRARIES} ${LLVM_SYSTEM_LIBS})
message(STATUS ${LLVM_LDFLAGS})
if(BUILD_PYTHON_MODULE) if(BUILD_PYTHON_MODULE)
target_link_libraries(triton ${TORCH_LIBRARIES} ${CUTLASS_LIBRARIES}) set(CMAKE_SHARED_LIBRARY_SUFFIX ".so")
# Check if the platform is MacOS
if(APPLE)
set(PYTHON_LDFLAGS "-undefined dynamic_lookup -flto")
endif()
target_link_libraries(triton ${CUTLASS_LIBRARIES} ${PYTHON_LDFLAGS})
endif() endif()

View File

@@ -1,15 +0,0 @@
cmake_minimum_required(VERSION 3.6)
project(llvm-download NONE)
include(ExternalProject)
ExternalProject_Add(llvm
URL "https://github.com/llvm/llvm-project/releases/download/llvmorg-11.0.0/llvm-11.0.0.src.tar.xz"
SOURCE_DIR "${TRITON_LLVM_BUILD_DIR}/llvm-src"
BINARY_DIR "${TRITON_LLVM_BUILD_DIR}/llvm-build"
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
INSTALL_COMMAND ""
TEST_COMMAND ""
)

View File

@@ -1,3 +1,4 @@
# - Find LLVM headers and libraries. # - Find LLVM headers and libraries.
# This module locates LLVM and adapts the llvm-config output for use with # This module locates LLVM and adapts the llvm-config output for use with
# CMake. # CMake.
@@ -7,14 +8,18 @@
# The following variables are defined: # The following variables are defined:
# LLVM_FOUND - true if LLVM was found # LLVM_FOUND - true if LLVM was found
# LLVM_CXXFLAGS - C++ compiler flags for files that include LLVM headers. # LLVM_CXXFLAGS - C++ compiler flags for files that include LLVM headers.
# LLVM_HOST_TARGET - Target triple used to configure LLVM. # LLVM_ENABLE_ASSERTIONS - Whether LLVM was built with enabled assertions (ON/OFF).
# LLVM_INCLUDE_DIRS - Directory containing LLVM include files. # LLVM_INCLUDE_DIRS - Directory containing LLVM include files.
# LLVM_IS_SHARED - Whether LLVM is going to be linked dynamically (ON) or statically (OFF).
# LLVM_LDFLAGS - Linker flags to add when linking against LLVM # LLVM_LDFLAGS - Linker flags to add when linking against LLVM
# (includes -LLLVM_LIBRARY_DIRS). # (includes -LLLVM_LIBRARY_DIRS).
# LLVM_LIBRARIES - Full paths to the library files to link against. # LLVM_LIBRARIES - Full paths to the library files to link against.
# LLVM_LIBRARY_DIRS - Directory containing LLVM libraries. # LLVM_LIBRARY_DIRS - Directory containing LLVM libraries.
# LLVM_NATIVE_ARCH - Backend corresponding to LLVM_HOST_TARGET, e.g.,
# X86 for x86_64 and i686 hosts.
# LLVM_ROOT_DIR - The root directory of the LLVM installation. # LLVM_ROOT_DIR - The root directory of the LLVM installation.
# llvm-config is searched for in ${LLVM_ROOT_DIR}/bin. # llvm-config is searched for in ${LLVM_ROOT_DIR}/bin.
# LLVM_TARGETS_TO_BUILD - List of built LLVM targets.
# LLVM_VERSION_MAJOR - Major version of LLVM. # LLVM_VERSION_MAJOR - Major version of LLVM.
# LLVM_VERSION_MINOR - Minor version of LLVM. # LLVM_VERSION_MINOR - Minor version of LLVM.
# LLVM_VERSION_STRING - Full LLVM version string (e.g. 6.0.0svn). # LLVM_VERSION_STRING - Full LLVM version string (e.g. 6.0.0svn).
@@ -28,16 +33,31 @@
# We also want an user-specified LLVM_ROOT_DIR to take precedence over the # We also want an user-specified LLVM_ROOT_DIR to take precedence over the
# system default locations such as /usr/local/bin. Executing find_program() # system default locations such as /usr/local/bin. Executing find_program()
# multiples times is the approach recommended in the docs. # multiples times is the approach recommended in the docs.
set(llvm_config_names llvm-config-11 llvm-config-11.0 set(llvm_config_names llvm-config-12.0 llvm-config120 llvm-config-12
llvm-config-10 llvm-config-10.0 llvm-config100 llvm-config-11.0 llvm-config110 llvm-config-11
llvm-config-9 llvm-config-9.0 llvm-config90 llvm-config-10.0 llvm-config100 llvm-config-10
llvm-config-8 llvm-config-8.0 llvm-config80 llvm-config-9.0 llvm-config90 llvm-config-9
llvm-config-8.0 llvm-config80 llvm-config-8
llvm-config-7.0 llvm-config70 llvm-config-7
llvm-config-6.0 llvm-config60
llvm-config) llvm-config)
find_program(LLVM_CONFIG find_program(LLVM_CONFIG
NAMES ${llvm_config_names} NAMES ${llvm_config_names}
PATHS ${LLVM_ROOT_DIR}/bin NO_DEFAULT_PATH PATHS ${LLVM_ROOT_DIR}/bin NO_DEFAULT_PATH
DOC "Path to llvm-config tool.") DOC "Path to llvm-config tool.")
find_program(LLVM_CONFIG NAMES ${llvm_config_names}) find_program(LLVM_CONFIG NAMES ${llvm_config_names})
if(APPLE)
# extra fallbacks for MacPorts & Homebrew
find_program(LLVM_CONFIG
NAMES ${llvm_config_names}
PATHS /opt/local/libexec/llvm-11/bin /opt/local/libexec/llvm-10/bin /opt/local/libexec/llvm-9.0/bin
/opt/local/libexec/llvm-8.0/bin /opt/local/libexec/llvm-7.0/bin /opt/local/libexec/llvm-6.0/bin
/opt/local/libexec/llvm/bin
/usr/local/opt/llvm@11/bin /usr/local/opt/llvm@10/bin /usr/local/opt/llvm@9/bin
/usr/local/opt/llvm@8/bin /usr/local/opt/llvm@7/bin /usr/local/opt/llvm@6/bin
/usr/local/opt/llvm/bin
NO_DEFAULT_PATH)
endif()
# Prints a warning/failure message depending on the required/quiet flags. Copied # Prints a warning/failure message depending on the required/quiet flags. Copied
# from FindPackageHandleStandardArgs.cmake because it doesn't seem to be exposed. # from FindPackageHandleStandardArgs.cmake because it doesn't seem to be exposed.
@@ -46,7 +66,7 @@ macro(_LLVM_FAIL _msg)
message(FATAL_ERROR "${_msg}") message(FATAL_ERROR "${_msg}")
else() else()
if(NOT LLVM_FIND_QUIETLY) if(NOT LLVM_FIND_QUIETLY)
message(STATUS "${_msg}") message(WARNING "${_msg}")
endif() endif()
endif() endif()
endmacro() endmacro()
@@ -54,7 +74,7 @@ endmacro()
if(NOT LLVM_CONFIG) if(NOT LLVM_CONFIG)
if(NOT LLVM_FIND_QUIETLY) if(NOT LLVM_FIND_QUIETLY)
message(WARNING "Could not find llvm-config (LLVM >= ${LLVM_FIND_VERSION}). Try manually setting LLVM_CONFIG to the llvm-config executable of the installation to use.") _LLVM_FAIL("No LLVM installation (>= ${LLVM_FIND_VERSION}) found. Try manually setting the 'LLVM_ROOT_DIR' or 'LLVM_CONFIG' variables.")
endif() endif()
else() else()
macro(llvm_set var flag) macro(llvm_set var flag)
@@ -63,7 +83,7 @@ else()
endif() endif()
set(result_code) set(result_code)
execute_process( execute_process(
COMMAND ${LLVM_CONFIG} --${flag} COMMAND ${LLVM_CONFIG} --link-static --${flag}
RESULT_VARIABLE result_code RESULT_VARIABLE result_code
OUTPUT_VARIABLE LLVM_${var} OUTPUT_VARIABLE LLVM_${var}
OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_STRIP_TRAILING_WHITESPACE
@@ -77,13 +97,13 @@ else()
endif() endif()
endif() endif()
endmacro() endmacro()
macro(llvm_set_libs var flag) macro(llvm_set_libs var flag components)
if(LLVM_FIND_QUIETLY) if(LLVM_FIND_QUIETLY)
set(_quiet_arg ERROR_QUIET) set(_quiet_arg ERROR_QUIET)
endif() endif()
set(result_code) set(result_code)
execute_process( execute_process(
COMMAND ${LLVM_CONFIG} --${flag} ${LLVM_FIND_COMPONENTS} COMMAND ${LLVM_CONFIG} --link-static --${flag} ${components}
RESULT_VARIABLE result_code RESULT_VARIABLE result_code
OUTPUT_VARIABLE tmplibs OUTPUT_VARIABLE tmplibs
OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_STRIP_TRAILING_WHITESPACE
@@ -91,7 +111,7 @@ else()
) )
if(result_code) if(result_code)
_LLVM_FAIL("Failed to execute llvm-config ('${LLVM_CONFIG}', result code: '${result_code})'") _LLVM_FAIL("Failed to execute llvm-config ('${LLVM_CONFIG}', result code: '${result_code})'")
else() else()
file(TO_CMAKE_PATH "${tmplibs}" tmplibs) file(TO_CMAKE_PATH "${tmplibs}" tmplibs)
string(REGEX MATCHALL "${pattern}[^ ]+" LLVM_${var} ${tmplibs}) string(REGEX MATCHALL "${pattern}[^ ]+" LLVM_${var} ${tmplibs})
endif() endif()
@@ -99,31 +119,29 @@ else()
llvm_set(VERSION_STRING version) llvm_set(VERSION_STRING version)
llvm_set(CXXFLAGS cxxflags) llvm_set(CXXFLAGS cxxflags)
llvm_set(HOST_TARGET host-target)
llvm_set(INCLUDE_DIRS includedir true) llvm_set(INCLUDE_DIRS includedir true)
llvm_set(ROOT_DIR prefix true) llvm_set(ROOT_DIR prefix true)
llvm_set(ENABLE_ASSERTIONS assertion-mode) llvm_set(ENABLE_ASSERTIONS assertion-mode)
# The LLVM version string _may_ contain a git/svn suffix, so cut that off # The LLVM version string _may_ contain a git/svn suffix, so match only the x.y.z part
string(SUBSTRING "${LLVM_VERSION_STRING}" 0 5 LLVM_VERSION_BASE_STRING) string(REGEX MATCH "^[0-9]+[.][0-9]+[.][0-9]+" LLVM_VERSION_BASE_STRING "${LLVM_VERSION_STRING}")
# Versions below 4.0 do not support components debuginfomsf and demangle llvm_set(SHARED_MODE shared-mode)
if(${LLVM_VERSION_STRING} MATCHES "^3\\..*") if(LLVM_SHARED_MODE STREQUAL "shared")
list(REMOVE_ITEM LLVM_FIND_COMPONENTS "debuginfomsf" index) set(LLVM_IS_SHARED ON)
list(REMOVE_ITEM LLVM_FIND_COMPONENTS "demangle" index) else()
endif() set(LLVM_IS_SHARED OFF)
# Versions below 8.0 not supported
if(${LLVM_VERSION_STRING} MATCHES "^[3-7]\\..*")
message(FATAL_ERROR "LLVM version below 8.0 not supported")
endif() endif()
llvm_set(LDFLAGS ldflags) llvm_set(LDFLAGS ldflags)
# In LLVM 3.5+, the system library dependencies (e.g. "-lz") are accessed
# using the separate "--system-libs" flag.
llvm_set(SYSTEM_LIBS system-libs) llvm_set(SYSTEM_LIBS system-libs)
string(REPLACE "\n" " " LLVM_LDFLAGS "${LLVM_LDFLAGS} ${LLVM_SYSTEM_LIBS}") string(REPLACE "\n" " " LLVM_LDFLAGS "${LLVM_LDFLAGS} ${LLVM_SYSTEM_LIBS}")
if(APPLE) # unclear why/how this happens
string(REPLACE "-llibxml2.tbd" "-lxml2" LLVM_LDFLAGS ${LLVM_LDFLAGS})
endif()
llvm_set(LIBRARY_DIRS libdir true) llvm_set(LIBRARY_DIRS libdir true)
llvm_set_libs(LIBRARIES libs) llvm_set_libs(LIBRARIES libfiles "${LLVM_FIND_COMPONENTS}")
# LLVM bug: llvm-config --libs tablegen returns -lLLVM-3.8.0 # LLVM bug: llvm-config --libs tablegen returns -lLLVM-3.8.0
# but code for it is not in shared library # but code for it is not in shared library
if("${LLVM_FIND_COMPONENTS}" MATCHES "tablegen") if("${LLVM_FIND_COMPONENTS}" MATCHES "tablegen")
@@ -132,37 +150,50 @@ else()
endif() endif()
endif() endif()
# Versions below 4.0 do not support llvm-config --cmakedir llvm_set(CMAKEDIR cmakedir)
if(${LLVM_VERSION_STRING} MATCHES "^3\\..*")
set(LLVM_CMAKEDIR ${LLVM_LIBRARY_DIRS}/cmake/llvm)
else()
llvm_set(CMAKEDIR cmakedir)
endif()
llvm_set(TARGETS_TO_BUILD targets-built) llvm_set(TARGETS_TO_BUILD targets-built)
string(REGEX MATCHALL "${pattern}[^ ]+" LLVM_TARGETS_TO_BUILD ${LLVM_TARGETS_TO_BUILD}) string(REGEX MATCHALL "${pattern}[^ ]+" LLVM_TARGETS_TO_BUILD ${LLVM_TARGETS_TO_BUILD})
# Parse LLVM_NATIVE_ARCH manually from LLVMConfig.cmake; including it leads to issues like
# https://github.com/ldc-developers/ldc/issues/3079.
file(STRINGS "${LLVM_CMAKEDIR}/LLVMConfig.cmake" LLVM_NATIVE_ARCH LIMIT_COUNT 1 REGEX "^set\\(LLVM_NATIVE_ARCH (.+)\\)$")
string(REGEX MATCH "set\\(LLVM_NATIVE_ARCH (.+)\\)" LLVM_NATIVE_ARCH "${LLVM_NATIVE_ARCH}")
set(LLVM_NATIVE_ARCH ${CMAKE_MATCH_1})
message(STATUS "LLVM_NATIVE_ARCH: ${LLVM_NATIVE_ARCH}")
# On CMake builds of LLVM, the output of llvm-config --cxxflags does not
# include -fno-rtti, leading to linker errors. Be sure to add it.
if(NOT MSVC AND (CMAKE_COMPILER_IS_GNUCXX OR (${CMAKE_CXX_COMPILER_ID} STREQUAL "Clang")))
if(NOT ${LLVM_CXXFLAGS} MATCHES "-fno-rtti")
set(LLVM_CXXFLAGS "${LLVM_CXXFLAGS} -fno-rtti")
endif()
endif()
# Remove some clang-specific flags for gcc.
if(CMAKE_COMPILER_IS_GNUCXX)
string(REPLACE "-Wcovered-switch-default " "" LLVM_CXXFLAGS ${LLVM_CXXFLAGS})
string(REPLACE "-Wstring-conversion " "" LLVM_CXXFLAGS ${LLVM_CXXFLAGS})
string(REPLACE "-fcolor-diagnostics " "" LLVM_CXXFLAGS ${LLVM_CXXFLAGS})
# this requires more recent gcc versions (not supported by 4.9)
string(REPLACE "-Werror=unguarded-availability-new " "" LLVM_CXXFLAGS ${LLVM_CXXFLAGS})
endif()
# Remove gcc-specific flags for clang.
if(${CMAKE_CXX_COMPILER_ID} MATCHES "Clang")
string(REPLACE "-Wno-maybe-uninitialized " "" LLVM_CXXFLAGS ${LLVM_CXXFLAGS})
endif()
string(REGEX REPLACE "([0-9]+).*" "\\1" LLVM_VERSION_MAJOR "${LLVM_VERSION_STRING}" )
string(REGEX REPLACE "[0-9]+\\.([0-9]+).*[A-Za-z]*" "\\1" LLVM_VERSION_MINOR "${LLVM_VERSION_STRING}" )
if (${LLVM_VERSION_STRING} VERSION_LESS ${LLVM_FIND_VERSION})
_LLVM_FAIL("Unsupported LLVM version ${LLVM_VERSION_STRING} found (${LLVM_CONFIG}). At least version ${LLVM_FIND_VERSION} is required. You can also set variables 'LLVM_ROOT_DIR' or 'LLVM_CONFIG' to use a different LLVM installation.")
endif()
endif() endif()
# Remove some clang-specific flags for gcc.
if(CMAKE_COMPILER_IS_GNUCXX)
string(REPLACE "-Wcovered-switch-default " "" LLVM_CXXFLAGS ${LLVM_CXXFLAGS})
string(REPLACE "-Wstring-conversion " "" LLVM_CXXFLAGS ${LLVM_CXXFLAGS})
string(REPLACE "-fcolor-diagnostics " "" LLVM_CXXFLAGS ${LLVM_CXXFLAGS})
string(REPLACE "-Werror=unguarded-availability-new " "" LLVM_CXXFLAGS ${LLVM_CXXFLAGS})
endif()
# Remove gcc-specific flags for clang.
if(${CMAKE_CXX_COMPILER_ID} MATCHES "Clang")
string(REPLACE "-Wno-maybe-uninitialized " "" LLVM_CXXFLAGS ${LLVM_CXXFLAGS})
endif()
string(REGEX REPLACE "([0-9]+).*" "\\1" LLVM_VERSION_MAJOR "${LLVM_VERSION_STRING}" )
string(REGEX REPLACE "[0-9]+\\.([0-9]+).*[A-Za-z]*" "\\1" LLVM_VERSION_MINOR "${LLVM_VERSION_STRING}" )
# Use the default CMake facilities for handling QUIET/REQUIRED. # Use the default CMake facilities for handling QUIET/REQUIRED.
include(FindPackageHandleStandardArgs) include(FindPackageHandleStandardArgs)
find_package_handle_standard_args(LLVM find_package_handle_standard_args(LLVM
REQUIRED_VARS LLVM_ROOT_DIR LLVM_HOST_TARGET REQUIRED_VARS LLVM_ROOT_DIR
VERSION_VAR LLVM_VERSION_STRING) VERSION_VAR LLVM_VERSION_STRING)

View File

@@ -59,6 +59,13 @@ std::string exec(const char* cmd) {
return result; return result;
} }
void LLVMInitializeNVPTXTargetInfo();
void LLVMInitializeNVPTXTarget();
void LLVMInitializeNVPTXTargetMC();
void LLVMInitializeNVPTXAsmPrinter();
void LLVMInitializeNVPTXAsmParser();
namespace triton namespace triton
{ {
namespace driver namespace driver
@@ -68,14 +75,14 @@ namespace driver
// Base // // Base //
/* ------------------------ */ /* ------------------------ */
void module::init_llvm() { void module::init_llvm() {
static bool init = false; static bool init = false;
if(!init){ if(!init){
llvm::InitializeAllTargetInfos(); LLVMInitializeNVPTXTargetInfo();
llvm::InitializeAllTargets(); LLVMInitializeNVPTXTarget();
llvm::InitializeAllTargetMCs(); LLVMInitializeNVPTXTargetMC();
llvm::InitializeAllAsmParsers(); LLVMInitializeNVPTXAsmPrinter();
llvm::InitializeAllAsmPrinters();
init = true; init = true;
} }
} }
@@ -111,80 +118,81 @@ void module::compile_llvm_module(std::unique_ptr<llvm::Module> module, const std
/* ------------------------ */ /* ------------------------ */
host_module::host_module(std::unique_ptr<llvm::Module> src): module(host_module_t(), true) { host_module::host_module(std::unique_ptr<llvm::Module> src): module(host_module_t(), true) {
init_llvm(); throw std::runtime_error("CPU unsupported");
// create kernel wrapper // init_llvm();
llvm::LLVMContext &ctx = src->getContext(); // // create kernel wrapper
llvm::Type *void_ty = llvm::Type::getVoidTy(ctx); // llvm::LLVMContext &ctx = src->getContext();
llvm::Type *args_ty = llvm::Type::getInt8PtrTy(ctx)->getPointerTo(); // llvm::Type *void_ty = llvm::Type::getVoidTy(ctx);
llvm::Type *int32_ty = llvm::Type::getInt32Ty(ctx); // llvm::Type *args_ty = llvm::Type::getInt8PtrTy(ctx)->getPointerTo();
std::vector<llvm::Type*> tys = {args_ty, int32_ty, int32_ty, int32_ty}; // llvm::Type *int32_ty = llvm::Type::getInt32Ty(ctx);
llvm::FunctionType *main_ty = llvm::FunctionType::get(void_ty, tys, false); // std::vector<llvm::Type*> tys = {args_ty, int32_ty, int32_ty, int32_ty};
llvm::Function* main = llvm::Function::Create(main_ty, llvm::Function::ExternalLinkage, "_main", &*src); // llvm::FunctionType *main_ty = llvm::FunctionType::get(void_ty, tys, false);
llvm::Function* fn = &*src->getFunctionList().begin(); // llvm::Function* main = llvm::Function::Create(main_ty, llvm::Function::ExternalLinkage, "_main", &*src);
llvm::FunctionType *fn_ty = fn->getFunctionType(); // llvm::Function* fn = &*src->getFunctionList().begin();
std::vector<llvm::Value*> fn_args(fn_ty->getNumParams()); // llvm::FunctionType *fn_ty = fn->getFunctionType();
std::vector<llvm::Value*> ptrs(fn_args.size() - 3); // std::vector<llvm::Value*> fn_args(fn_ty->getNumParams());
llvm::BasicBlock* entry = llvm::BasicBlock::Create(ctx, "entry", main); // std::vector<llvm::Value*> ptrs(fn_args.size() - 3);
llvm::IRBuilder<> ir_builder(ctx); // llvm::BasicBlock* entry = llvm::BasicBlock::Create(ctx, "entry", main);
ir_builder.SetInsertPoint(entry); // llvm::IRBuilder<> ir_builder(ctx);
auto get_size = [](llvm::Type* ty) { return ty->isPointerTy() ? sizeof(char*) : ty->getPrimitiveSizeInBits() / 8; }; // ir_builder.SetInsertPoint(entry);
llvm::Value* base = main->arg_begin(); // auto get_size = [](llvm::Type* ty) { return ty->isPointerTy() ? sizeof(char*) : ty->getPrimitiveSizeInBits() / 8; };
llvm::Value* args_base = ir_builder.CreateBitCast(base, base->getType()->getPointerElementType()); // llvm::Value* base = main->arg_begin();
// llvm::Value* args_base = ir_builder.CreateBitCast(base, base->getType()->getPointerElementType());
size_t offset = 0; // size_t offset = 0;
for(unsigned i = 0; i < ptrs.size(); i++){ // for(unsigned i = 0; i < ptrs.size(); i++){
ptrs[i] = ir_builder.CreateGEP(args_base, ir_builder.getInt32(offset)); // ptrs[i] = ir_builder.CreateGEP(args_base, ir_builder.getInt32(offset));
size_t nbytes = get_size(fn_ty->getParamType(i)); // size_t nbytes = get_size(fn_ty->getParamType(i));
offset += nbytes; // offset += nbytes;
if(i < ptrs.size() - 1){ // if(i < ptrs.size() - 1){
size_t np1bytes = get_size(fn_ty->getParamType(i+1)); // size_t np1bytes = get_size(fn_ty->getParamType(i+1));
offset = (offset + np1bytes - 1) / np1bytes * np1bytes; // offset = (offset + np1bytes - 1) / np1bytes * np1bytes;
} // }
} // }
for(unsigned i = 0; i < ptrs.size(); i++) // for(unsigned i = 0; i < ptrs.size(); i++)
ptrs[i] = ir_builder.CreateBitCast(ptrs[i], fn_ty->getParamType(i)->getPointerTo()); // ptrs[i] = ir_builder.CreateBitCast(ptrs[i], fn_ty->getParamType(i)->getPointerTo());
for(unsigned i = 0; i < ptrs.size(); i++) // for(unsigned i = 0; i < ptrs.size(); i++)
fn_args[i] = ir_builder.CreateLoad(ptrs[i]); // fn_args[i] = ir_builder.CreateLoad(ptrs[i]);
fn_args[fn_args.size() - 3] = main->arg_begin() + 1; // fn_args[fn_args.size() - 3] = main->arg_begin() + 1;
fn_args[fn_args.size() - 2] = main->arg_begin() + 2; // fn_args[fn_args.size() - 2] = main->arg_begin() + 2;
fn_args[fn_args.size() - 1] = main->arg_begin() + 3; // fn_args[fn_args.size() - 1] = main->arg_begin() + 3;
ir_builder.CreateCall(fn, fn_args); // ir_builder.CreateCall(fn, fn_args);
ir_builder.CreateRetVoid(); // ir_builder.CreateRetVoid();
// llvm::legacy::PassManager pm; //// llvm::legacy::PassManager pm;
// pm.add(llvm::createPrintModulePass(llvm::outs())); //// pm.add(llvm::createPrintModulePass(llvm::outs()));
// pm.add(llvm::createVerifierPass()); //// pm.add(llvm::createVerifierPass());
// pm.run(*src); //// pm.run(*src);
// create execution engine //// create execution engine
for(llvm::Function& fn: src->functions()) // for(llvm::Function& fn: src->functions())
hst_->functions[fn.getName().str()] = &fn; // hst_->functions[fn.getName().str()] = &fn;
// llvm::orc::JITTargetMachineBuilder JTMB = *llvm::orc::JITTargetMachineBuilder::detectHost(); //// llvm::orc::JITTargetMachineBuilder JTMB = *llvm::orc::JITTargetMachineBuilder::detectHost();
// auto DL = JTMB.getDefaultDataLayoutForTarget(); //// auto DL = JTMB.getDefaultDataLayoutForTarget();
// auto CIRC = std::unique_ptr<llvm::orc::ConcurrentIRCompiler>(new llvm::orc::ConcurrentIRCompiler(JTMB)); //// auto CIRC = std::unique_ptr<llvm::orc::ConcurrentIRCompiler>(new llvm::orc::ConcurrentIRCompiler(JTMB));
// hst_->ES = new llvm::orc::ExecutionSession(); //// hst_->ES = new llvm::orc::ExecutionSession();
// hst_->ObjectLayer = new llvm::orc::RTDyldObjectLinkingLayer(*hst_->ES, []() { return std::unique_ptr<llvm::SectionMemoryManager>(new llvm::SectionMemoryManager()); }); //// hst_->ObjectLayer = new llvm::orc::RTDyldObjectLinkingLayer(*hst_->ES, []() { return std::unique_ptr<llvm::SectionMemoryManager>(new llvm::SectionMemoryManager()); });
// hst_->CompileLayer = new llvm::orc::IRCompileLayer(*hst_->ES, *hst_->ObjectLayer, *CIRC); //// hst_->CompileLayer = new llvm::orc::IRCompileLayer(*hst_->ES, *hst_->ObjectLayer, *CIRC);
// hst_->DL = new llvm::DataLayout(std::move(*DL)); //// hst_->DL = new llvm::DataLayout(std::move(*DL));
// hst_->Mangle = new llvm::orc::MangleAndInterner(*hst_->ES, *hst_->DL); //// hst_->Mangle = new llvm::orc::MangleAndInterner(*hst_->ES, *hst_->DL);
// hst_->Ctx = new llvm::orc::ThreadSafeContext(std::unique_ptr<llvm::LLVMContext>(new llvm::LLVMContext())); //// hst_->Ctx = new llvm::orc::ThreadSafeContext(std::unique_ptr<llvm::LLVMContext>(new llvm::LLVMContext()));
// hst_->MainJD = &hst_->ES->createJITDylib("<main>"); //// hst_->MainJD = &hst_->ES->createJITDylib("<main>");
// hst_->MainJD->setGenerator(llvm::cantFail(llvm::orc::DynamicLibrarySearchGenerator::GetForCurrentProcess( //// hst_->MainJD->setGenerator(llvm::cantFail(llvm::orc::DynamicLibrarySearchGenerator::GetForCurrentProcess(
// hst_->DL->getGlobalPrefix()))); //// hst_->DL->getGlobalPrefix())));
// llvm::cantFail(hst_->CompileLayer->add(*hst_->MainJD, llvm::orc::ThreadSafeModule(std::move(src), *hst_->Ctx))); //// llvm::cantFail(hst_->CompileLayer->add(*hst_->MainJD, llvm::orc::ThreadSafeModule(std::move(src), *hst_->Ctx)));
// hst_->fn = (void(*)(char**, int32_t, int32_t, int32_t))(hst_->ES->lookup({hst_->MainJD}, (*hst_->Mangle)("_main"))->getAddress()); //// hst_->fn = (void(*)(char**, int32_t, int32_t, int32_t))(hst_->ES->lookup({hst_->MainJD}, (*hst_->Mangle)("_main"))->getAddress());
llvm::EngineBuilder builder(std::move(src)); // llvm::EngineBuilder builder(std::move(src));
builder.setErrorStr(&hst_->error); // builder.setErrorStr(&hst_->error);
builder.setMCJITMemoryManager(std::make_unique<llvm::SectionMemoryManager>()); // builder.setMCJITMemoryManager(std::make_unique<llvm::SectionMemoryManager>());
builder.setOptLevel(llvm::CodeGenOpt::Aggressive); // builder.setOptLevel(llvm::CodeGenOpt::Aggressive);
builder.setEngineKind(llvm::EngineKind::JIT); // builder.setEngineKind(llvm::EngineKind::JIT);
hst_->engine = builder.create(); // hst_->engine = builder.create();
hst_->fn = (void(*)(char**, int32_t, int32_t, int32_t))(hst_->engine->getFunctionAddress("_main")); // hst_->fn = (void(*)(char**, int32_t, int32_t, int32_t))(hst_->engine->getFunctionAddress("_main"));
} }
std::unique_ptr<buffer> host_module::symbol(const char *name) const { std::unique_ptr<buffer> host_module::symbol(const char *name) const {
@@ -211,7 +219,7 @@ static std::map<int, int> vptx = {
{11010, 71}, {11010, 71},
{11020, 72}, {11020, 72},
{11030, 73}, {11030, 73},
{11040, 74} {11040, 73}
}; };
std::string cu_module::compile_llvm_module(llvm::Module* module, driver::device* device) { std::string cu_module::compile_llvm_module(llvm::Module* module, driver::device* device) {

View File

@@ -66,7 +66,7 @@ class CMakeBuild(build_ext):
"-DBUILD_TUTORIALS=OFF", "-DBUILD_TUTORIALS=OFF",
"-DBUILD_PYTHON_MODULE=ON", "-DBUILD_PYTHON_MODULE=ON",
#'-DPYTHON_EXECUTABLE=' + sys.executable, #'-DPYTHON_EXECUTABLE=' + sys.executable,
#'-DCMAKE_VERBOSE_MAKEFILE:BOOL=ON, '-DCMAKE_VERBOSE_MAKEFILE:BOOL=ON',
"-DTRITON_LLVM_BUILD_DIR=" + llvm_build_dir, "-DTRITON_LLVM_BUILD_DIR=" + llvm_build_dir,
"-DPYTHON_INCLUDE_DIRS=" + ";".join(python_include_dirs) "-DPYTHON_INCLUDE_DIRS=" + ";".join(python_include_dirs)
] ]