diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index 70b0f5d7c..82f536582 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -9,7 +9,7 @@ on: jobs: Integration-Tests: - + runs-on: self-hosted steps: @@ -31,7 +31,7 @@ jobs: run: | pip install autopep8 autopep8 -a -r -d --exit-code ./python || ( echo '::error title=Style issues::Please run \"autopep8 -a -r -i ./python\"' ; exit 1 ) - + - name: Check cpp style run: | sudo apt-get install -y clang-format @@ -62,3 +62,9 @@ jobs: run: | cd python/tests pytest + + - name: Run CXX unittests + run: | + cd python/ + cd "build/$(ls build)" + ctest diff --git a/CMakeLists.txt b/CMakeLists.txt index 6b431b37d..2fb182135 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -220,3 +220,5 @@ if(BUILD_PYTHON_MODULE AND NOT WIN32) endif() add_subdirectory(test) + +add_subdirectory(unittest) diff --git a/lib/driver/llvm.cc b/lib/driver/llvm.cc index f78e9c8e9..8f5618d43 100644 --- a/lib/driver/llvm.cc +++ b/lib/driver/llvm.cc @@ -20,9 +20,13 @@ * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ #include + +#if defined __has_include #if __has_include() #include #endif +#endif + #include "triton/driver/dispatch.h" #include "triton/driver/error.h" #include "triton/driver/llvm.h" @@ -87,9 +91,11 @@ static bool find_and_replace(std::string &str, const std::string &begin, const std::string &end, const std::string &target) { size_t start_replace = str.find(begin); - size_t end_replace = str.find(end, start_replace); if (start_replace == std::string::npos) return false; + size_t end_replace = str.find(end, start_replace); + if (end_replace == std::string::npos) + return false; str.replace(start_replace, end_replace + 1 - start_replace, target); return true; } @@ -104,7 +110,7 @@ std::string path_to_ptxas(int &version) { ptxas_prefixes.insert(ptxas_prefixes.begin(), triton_ptxas); // see what path for ptxas are valid std::vector working_ptxas; - for (std::string prefix : ptxas_prefixes) { + for (const std::string &prefix : ptxas_prefixes) { std::string ptxas = prefix + "ptxas"; bool works = tools::exec(ptxas + " --version 2>&1", ret) == 0; if (works) { @@ -124,19 +130,21 @@ std::string path_to_ptxas(int &version) { bool found = false; // currently choosing the first ptxas. Other logics can be implemented in // future - for (std::string ret : rets) { - if (std::regex_search(ret, match, version_regex)) { + size_t i = 0; + while (i < rets.size()) { + if (std::regex_search(rets[i], match, version_regex)) { int major = std::stoi(match[1]); int minor = std::stoi(match[2]); version = major * 1000 + minor * 10; found = true; break; } + ++i; } if (not found) { throw std::runtime_error("Error in parsing version"); } - return ptxas; + return working_ptxas[i]; } int vptx(int version) { diff --git a/unittest/Analysis/CMakeLists.txt b/unittest/Analysis/CMakeLists.txt new file mode 100644 index 000000000..4db4a37af --- /dev/null +++ b/unittest/Analysis/CMakeLists.txt @@ -0,0 +1,5 @@ +add_triton_ut( + NAME TritonAnalysisTests + SRCS UtilityTest.cpp + LIBS TritonAnalysis +) diff --git a/unittest/Analysis/UtilityTest.cpp b/unittest/Analysis/UtilityTest.cpp new file mode 100644 index 000000000..69a7119a4 --- /dev/null +++ b/unittest/Analysis/UtilityTest.cpp @@ -0,0 +1,14 @@ +//===- UtilityTest.cpp - Tests for +// Utility----------------------------------===// +// +//===----------------------------------------------------------------------===// + +#include "triton/Analysis/Utility.h" +#include +#include + +namespace mlir { + +TEST(UtilityTest, DummyTest) { EXPECT_EQ(true, true); } + +} // namespace mlir diff --git a/unittest/CMakeLists.txt b/unittest/CMakeLists.txt new file mode 100644 index 000000000..af705f9bf --- /dev/null +++ b/unittest/CMakeLists.txt @@ -0,0 +1,27 @@ + +include (${CMAKE_CURRENT_SOURCE_DIR}/googletest.cmake) + +include(GoogleTest) +enable_testing() + +function(add_triton_ut) + set(options) + set(oneValueArgs NAME) + set(multiValueArgs SRCS LIBS) + cmake_parse_arguments(_ "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) + add_test(NAME ${__NAME} + COMMAND ${__NAME}) + add_executable( + ${__NAME} + ${__SRCS}) + target_link_libraries( + ${__NAME} + GTest::gtest_main + gmock + ${__LIBS}) + + gtest_discover_tests(${__NAME}) +endfunction() + +add_subdirectory(Analysis) +add_subdirectory(Conversion) diff --git a/unittest/Conversion/CMakeLists.txt b/unittest/Conversion/CMakeLists.txt new file mode 100644 index 000000000..b543b6c62 --- /dev/null +++ b/unittest/Conversion/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(TritonGPUToLLVM) diff --git a/unittest/Conversion/TritonGPUToLLVM/CMakeLists.txt b/unittest/Conversion/TritonGPUToLLVM/CMakeLists.txt new file mode 100644 index 000000000..24d071dd2 --- /dev/null +++ b/unittest/Conversion/TritonGPUToLLVM/CMakeLists.txt @@ -0,0 +1,5 @@ +add_triton_ut( + NAME TritonGPUToLLVMTests + SRCS TritonGPUToLLVMTests.cpp + LIBS TritonGPUToLLVM +) diff --git a/unittest/Conversion/TritonGPUToLLVM/TritonGPUToLLVMTests.cpp b/unittest/Conversion/TritonGPUToLLVM/TritonGPUToLLVMTests.cpp new file mode 100644 index 000000000..a6468e6e3 --- /dev/null +++ b/unittest/Conversion/TritonGPUToLLVM/TritonGPUToLLVMTests.cpp @@ -0,0 +1,14 @@ +//===- TritonGPUToLLVMTests.cpp - Tests for +// TritonGPUToLLVM----------------------------------===// +// +//===----------------------------------------------------------------------===// + +#include "triton/Conversion/TritonGPUToLLVM/PtxAsmFormat.h" +#include +#include + +namespace mlir { + +TEST(PtxAsmFormatTest, BasicTest) { EXPECT_EQ(true, true); } + +} // namespace mlir diff --git a/unittest/googletest.cmake b/unittest/googletest.cmake new file mode 100644 index 000000000..41d3d4fa4 --- /dev/null +++ b/unittest/googletest.cmake @@ -0,0 +1,23 @@ +include(FetchContent) + +set(GOOGLETEST_DIR "" CACHE STRING "Location of local GoogleTest repo to build against") + +if(GOOGLETEST_DIR) + set(FETCHCONTENT_SOURCE_DIR_GOOGLETEST ${GOOGLETEST_DIR} CACHE STRING "GoogleTest source directory override") +endif() + +FetchContent_Declare( + googletest + GIT_REPOSITORY https://github.com/google/googletest.git + GIT_TAG release-1.12.1 + ) + +FetchContent_GetProperties(googletest) + +if(NOT googletest_POPULATED) + FetchContent_Populate(googletest) + if (MSVC) + set(gtest_force_shared_crt ON CACHE BOOL "" FORCE) + endif() + add_subdirectory(${googletest_SOURCE_DIR} ${googletest_BINARY_DIR} EXCLUDE_FROM_ALL) +endif()