From 59281f579454de74d73d45c4bd84044c9d15c35d Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Tue, 27 Aug 2019 20:33:38 -0700 Subject: [PATCH] [structure] better directory structure for tests --- CMakeLists.txt | 10 +- cmake/FindTensorFlow.cmake | 20 --- cmake/FindTorch.cmake | 14 -- examples/CMakeLists.txt | 1 - examples/cpp/CMakeLists.txt | 6 - examples/cpp/cuda.h | 160 -------------------- include/triton/driver/dispatch.h | 2 +- tests/CMakeLists.txt | 3 + tests/bench/CMakeLists.txt | 6 + tests/bench/dot.cc | 98 ++++++++++++ tests/common/cuda/cublas.h | 221 ++++++++++++++++++++++++++++ tests/common/cuda/forward.h | 105 +++++++++++++ tests/common/src/dot.h | 77 ++++++++++ tests/unit/CMakeLists.txt | 6 + {examples/cpp => tests/unit}/dot.cc | 103 +++---------- 15 files changed, 539 insertions(+), 293 deletions(-) delete mode 100644 cmake/FindTensorFlow.cmake delete mode 100644 cmake/FindTorch.cmake delete mode 100644 examples/CMakeLists.txt delete mode 100644 examples/cpp/CMakeLists.txt delete mode 100644 examples/cpp/cuda.h create mode 100644 tests/CMakeLists.txt create mode 100644 tests/bench/CMakeLists.txt create mode 100644 tests/bench/dot.cc create mode 100644 tests/common/cuda/cublas.h create mode 100644 tests/common/cuda/forward.h create mode 100644 tests/common/src/dot.h create mode 100644 tests/unit/CMakeLists.txt rename {examples/cpp => tests/unit}/dot.cc (67%) diff --git a/CMakeLists.txt b/CMakeLists.txt index 637718fa6..9e05aca5d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -4,7 +4,7 @@ include(CTest) list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake") # Options -option(BUILD_EXAMPLES "Build C++ Triton examples" ON) +option(BUILD_TESTS "Build C++ Triton tests" ON) option(BUILD_PYTHON_MODULE "Build Python Triton bindings" OFF) # LLVM @@ -23,10 +23,10 @@ endif() include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${LLVM_CXXFLAGS} -std=c++11") -# Examples -if(BUILD_EXAMPLES) - message(STATUS "Adding C++ examples") - add_subdirectory(examples) +# Tests +if(BUILD_TESTS) + message(STATUS "Adding C++ tests") + add_subdirectory(tests) endif() # Python module diff --git a/cmake/FindTensorFlow.cmake b/cmake/FindTensorFlow.cmake deleted file mode 100644 index 405febbeb..000000000 --- a/cmake/FindTensorFlow.cmake +++ /dev/null @@ -1,20 +0,0 @@ -include(FindPackageHandleStandardArgs) -unset(TENSORFLOW_FOUND) - -execute_process(COMMAND python -c "from os.path import dirname; import tensorflow as tf; print(dirname(dirname(tf.sysconfig.get_include())))" - OUTPUT_VARIABLE TF_INC OUTPUT_STRIP_TRAILING_WHITESPACE ERROR_QUIET) -execute_process(COMMAND python -c "import tensorflow as tf; print(tf.sysconfig.get_lib())" - OUTPUT_VARIABLE TF_LIB OUTPUT_STRIP_TRAILING_WHITESPACE ERROR_QUIET) -execute_process(COMMAND python -c "import tensorflow as tf; print(tf.__cxx11_abi_flag__ if \"__cxx11_abi_flag__\" in tf.__dict__ else 0)" - OUTPUT_VARIABLE TF_ABI OUTPUT_STRIP_TRAILING_WHITESPACE ERROR_QUIET) - -find_package_handle_standard_args(TensorFlow DEFAULT_MSG TF_INC TF_LIB) - -# set external variables for usage in CMakeLists.txt -if(TensorFlow_FOUND) - set(TensorFlow_LIBRARIES ${TF_LIB}) - set(TensorFlow_INCLUDE_DIRS ${TF_INC}) - set(TensorFlow_ABI ${TF_ABI}) -endif() - -mark_as_advanced(TF_INC TF_LIB TF_ABI) diff --git a/cmake/FindTorch.cmake b/cmake/FindTorch.cmake deleted file mode 100644 index 79a814d03..000000000 --- a/cmake/FindTorch.cmake +++ /dev/null @@ -1,14 +0,0 @@ -include(FindPackageHandleStandardArgs) -execute_process(COMMAND python -c "import torch; import os; print(os.path.dirname(torch.__file__))" - OUTPUT_VARIABLE TORCH_INSTALL_PREFIX OUTPUT_STRIP_TRAILING_WHITESPACE ERROR_QUIET) - -find_package_handle_standard_args(TORCH DEFAULT_MSG TORCH_INSTALL_PREFIX) -if(TORCH_INSTALL_PREFIX) - set(TORCH_INCLUDE_DIRS ${TORCH_INSTALL_PREFIX}/lib/include/ - ${TORCH_INSTALL_PREFIX}/lib/include/torch/csrc/api/include - ${TORCH_INSTALL_PREFIX}/include/ - ${TORCH_INSTALL_PREFIX}/include/torch/csrc/api/include/) - set(TORCH_LIBRARY_DIRS ${TORCH_INSTALL_PREFIX}/lib/) -endif() - -mark_as_advanced(TORCH_INCLUDE_DIRS TORCH_LIBRARY_DIRS) diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt deleted file mode 100644 index 2322a85f7..000000000 --- a/examples/CMakeLists.txt +++ /dev/null @@ -1 +0,0 @@ -add_subdirectory(cpp) diff --git a/examples/cpp/CMakeLists.txt b/examples/cpp/CMakeLists.txt deleted file mode 100644 index cea728c8e..000000000 --- a/examples/cpp/CMakeLists.txt +++ /dev/null @@ -1,6 +0,0 @@ -foreach(PROG dot) - add_executable(${PROG} ${PROG}.cc) - set_target_properties(${PROG} PROPERTIES OUTPUT_NAME ${PROG}) - include_directories(/usr/local/cuda/include/) - target_link_libraries(${PROG} triton cublas) -endforeach(PROG) diff --git a/examples/cpp/cuda.h b/examples/cpp/cuda.h deleted file mode 100644 index fef17dc55..000000000 --- a/examples/cpp/cuda.h +++ /dev/null @@ -1,160 +0,0 @@ -/* Copyright 2015-2017 Philippe Tillet -* -* Permission is hereby granted, free of charge, to any person obtaining -* a copy of this software and associated documentation files -* (the "Software"), to deal in the Software without restriction, -* including without limitation the rights to use, copy, modify, merge, -* publish, distribute, sublicense, and/or sell copies of the Software, -* and to permit persons to whom the Software is furnished to do so, -* subject to the following conditions: -* -* The above copyright notice and this permission notice shall be -* included in all copies or substantial portions of the Software. -* -* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, -* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF -* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. -* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY -* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, -* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE -* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -*/ - -#include -#include -#include -#include "cublas_v2.h" -#include "triton/driver/buffer.h" -#include "triton/driver/stream.h" -#include "triton/driver/context.h" -#include "triton/tools/bench.hpp" - -enum cublasStrategy_t{ - CUBLAS_PREFER_FASTEST, - CUBLAS_HEURISTICS -}; - -enum DType{ - HALF_TYPE, - FLOAT_TYPE, - DOUBLE_TYPE, -}; - -inline size_t size_of(DType dtype){ - switch (dtype) { - case HALF_TYPE: return 2; - case FLOAT_TYPE: return 4; - case DOUBLE_TYPE: return 8; - default: throw; - } -} - -inline std::vector gather_all_algos() { - std::vector result; - // non-tensor ops - for(int i = -1; i < 24; i++) - result.push_back((cublasGemmAlgo_t)i); - // tensor ops - for(int i = 99; i < 116; i++) - result.push_back((cublasGemmAlgo_t)i); - return result; -} - -static const std::vector algorithms = gather_all_algos(); - -static const std::map cu_dtype = { - {HALF_TYPE, CUDA_R_16F}, - {FLOAT_TYPE, CUDA_R_32F}, - {DOUBLE_TYPE, CUDA_R_64F} -}; - -static const std::map cu_op = { - {false, CUBLAS_OP_N}, - {true, CUBLAS_OP_T} -}; - -inline cublasGemmAlgo_t cublasGemmFastest( - triton::driver::stream* stream, - cublasHandle_t handle, cudaDataType cudt, - cublasOperation_t AT, cublasOperation_t BT, - int32_t M, int32_t N, int32_t K, - void* alpha, CUdeviceptr A, int32_t lda, CUdeviceptr B, int32_t ldb, - void* beta, CUdeviceptr C, int32_t ldc) { - - // cache to avoid re-benchmarking - typedef std::tuple key_t; - static std::map cache; - key_t key(cudt, AT, BT, M, N, K); - // benchmark algorithms if necessary - if(cache.find(key) == cache.end()){ - std::vector times; - for(cublasGemmAlgo_t a: algorithms) { - cublasStatus_t status; - double nanosec = triton::tools::bench([&](){ status = cublasGemmEx(handle, AT, BT, - M, N, K, - alpha, (const void*)A, cudt, lda, - (const void*)B, cudt, ldb, - beta, (void*)C, cudt, ldc, cudt, - a); }, stream); - if(status != CUBLAS_STATUS_SUCCESS) - nanosec = INFINITY; - } - size_t argmin = std::min_element(times.begin(), times.end()) - times.begin(); - assert(times[argmin] != INFINITY); - cache.insert({key, algorithms[argmin]}); - } - - // return best algorithm - return cache.at(key); -} - -/* Wrapper for cublasGemmEx */ -inline cublasStatus_t cublasGemmEx(cublasHandle_t handle, cudaDataType cudt, cublasOperation_t AT, cublasOperation_t BT, int32_t M, int32_t N, int32_t K, - void* alpha, CUdeviceptr A, int32_t lda, CUdeviceptr B, int32_t ldb, - void* beta, CUdeviceptr C, int32_t ldc, cublasGemmAlgo_t algo) -{ - cublasStatus_t status = cublasGemmEx(handle, AT, BT, M, N, K, alpha, (const void*)A, cudt, lda, (const void*)B, cudt, ldb, beta, (void*)C, cudt, ldc, cudt, algo); - if(status != CUBLAS_STATUS_SUCCESS){ - std::cout << status; - exit(EXIT_FAILURE); - } -} - - -/* Get cuBLAS handle */ -inline cublasHandle_t cublasGetHandle(triton::driver::stream* stream) { - static std::map cache; - CUstream key = *stream->cu(); - - // create handle if necessary - if(cache.find(key) == cache.end()) { - cublasHandle_t handle; - if(cublasCreate_v2(&handle) != CUBLAS_STATUS_SUCCESS) - throw std::runtime_error("Error: could not create cuBLAS handle"); - cublasSetStream_v2(handle, key); - cache.insert({key, handle}); - } - - // return handle for the stream - return cache.at(key); -} - -/* Simplified API for default GEMM */ -inline void cublasGemm(DType dtype, triton::driver::stream* stream, bool AT, bool BT, - int32_t M, int32_t N, int32_t K, - void* alpha, triton::driver::buffer* A, int32_t lda, - triton::driver::buffer* B, int32_t ldb, - void* beta, triton::driver::buffer* C, int32_t ldc, - cublasGemmAlgo_t* fastest = NULL, cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT) { - triton::driver::cu_context::context_switcher scope(*stream->context()); - static cublasHandle_t handle = cublasGetHandle(stream); - if(dtype == HALF_TYPE) - cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH); - cublasStatus_t status; - if(fastest) - *fastest = cublasGemmFastest(stream, handle, cu_dtype.at(dtype), cu_op.at(AT), cu_op.at(BT), M, N, K, alpha, *A->cu(), lda, *B->cu(), ldb, beta, *C->cu(), ldc); - else - status = cublasGemmEx(handle, cu_dtype.at(dtype), cu_op.at(AT), cu_op.at(BT), M, N, K, alpha, *A->cu(), lda, *B->cu(), ldb, beta, *C->cu(), ldc, algo); -} diff --git a/include/triton/driver/dispatch.h b/include/triton/driver/dispatch.h index 7f6fdf7e0..ed717a7fb 100755 --- a/include/triton/driver/dispatch.h +++ b/include/triton/driver/dispatch.h @@ -34,7 +34,7 @@ void check(cl_int err); class dispatch { -private: +protected: template struct return_type; diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt new file mode 100644 index 000000000..8c80ee070 --- /dev/null +++ b/tests/CMakeLists.txt @@ -0,0 +1,3 @@ +include_directories("${CMAKE_CURRENT_SOURCE_DIR}/common") +add_subdirectory(bench) +add_subdirectory(unit) diff --git a/tests/bench/CMakeLists.txt b/tests/bench/CMakeLists.txt new file mode 100644 index 000000000..1f3cc3341 --- /dev/null +++ b/tests/bench/CMakeLists.txt @@ -0,0 +1,6 @@ +foreach(PROG dot) + set(TARGET bench_${PROG}) + add_executable(${TARGET} ${PROG}.cc) + set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME ${TARGET}) + target_link_libraries(${TARGET} triton dl) +endforeach(PROG) diff --git a/tests/bench/dot.cc b/tests/bench/dot.cc new file mode 100644 index 000000000..63e5e877d --- /dev/null +++ b/tests/bench/dot.cc @@ -0,0 +1,98 @@ +#include +#include +#include +#include "triton/driver/backend.h" +#include "triton/driver/stream.h" +#include "triton/tools/bench.hpp" +#include "triton/external/half.hpp" +#include "triton/runtime/function.h" +#include "src/dot.h" +#include "cuda/cublas.h" + + +struct perf_t { + double triton; + double cublas; +}; + +namespace drv = triton::driver; +namespace rt = triton::runtime; + +inline size_t ceil(size_t x, size_t y) { + return (x + y - 1) / y; +}; + + +std::vector do_bench(drv::stream* stream, bool AT, bool BT, int32_t M, int32_t N, int32_t K){ + typedef half_float::half NumericT; + std::string ty = "half"; + size_t dt_nbytes = sizeof(NumericT); + drv::context* context = stream->context(); + // leading dimensions + int32_t lda = AT ? K : M; + int32_t ldb = BT ? N : K; + int32_t ldc = M; + // create inputs + auto dc = std::unique_ptr(drv::buffer::create(context, M*N*dt_nbytes)); + auto da = std::unique_ptr(drv::buffer::create(context, M*K*dt_nbytes)); + auto db = std::unique_ptr(drv::buffer::create(context, K*N*dt_nbytes)); + // create options + rt::function::options_space_t opt; + opt.defines.push_back({"TYPE", {ty}}); + if(AT) + opt.defines.push_back({"AT", {""}}); + if(BT) + opt.defines.push_back({"BT", {""}}); + opt.defines.push_back({"TM", {"16", "32", "64", "128"}}); + opt.defines.push_back({"TN", {"16", "32", "64", "128"}}); + opt.defines.push_back({"TK", {"32"}}); + opt.num_warps = {1, 2, 4, 8}; + // create grid + auto grid = [&](const rt::function::options_t& x) { + return rt::grid_t{ceil(M, x.D("TM")), + ceil(N, x.D("TN"))}; + }; + // create function + rt::function function(src::dot, opt); + // benchmark available libraries + std::vector result; + auto tflops = [&](double nanosec) { return 2.*M*N*K / nanosec * 1e-3; }; + // cublas + if(cublas::cublasinit()){ + NumericT alpha(static_cast(1)); + NumericT beta(static_cast(0)); + cublasGemmAlgo_t fastest; + cublasGemm(CUDA_R_16F, stream, AT, BT, M, N, K, &alpha, &*da, lda, &*db, ldb, &beta, &*dc, ldc, &fastest); + double cublas_ms = triton::tools::bench([&]() { cublasGemm(CUDA_R_16F, stream, AT, BT, M, N, K, + &alpha, &*da, lda, &*db, ldb, &beta, &*dc, ldc, nullptr, fastest); }, stream); + result.push_back(tflops(cublas_ms)); + } + // triton + double triton_ms = triton::tools::bench([&]() { function({&*da, &*db, &*dc, M, N, K, lda, ldb, ldc}, grid, stream);}, stream); + result.push_back(tflops(triton_ms)); + // done + return result; +} + +int main() { + // initialize default compute device + auto context = triton::driver::backend::contexts::get_default(); + triton::driver::stream* stream = triton::driver::stream::create(context); + // shapes to benchmark + typedef std::tuple config_t; + std::vector configs = { + config_t{false, true, 512, 512, 512}, + config_t{false, true, 2048, 2048, 2048}, + config_t{false, true, 8192, 8192, 8192} + }; + // does the work + bool AT, BT; + int32_t M, N, K; + for(const auto& c: configs){ + std::tie(AT, BT, M, N, K) = c; + std::cout << "// " << AT << " " << BT << " " << M << " " << N << " " << K << std::flush; + for(auto perf: do_bench(stream, AT, BT, M, N, K)) + std::cout << ", " << perf << std::flush; + std::cout << std::endl; + } +} diff --git a/tests/common/cuda/cublas.h b/tests/common/cuda/cublas.h new file mode 100644 index 000000000..db1f2a360 --- /dev/null +++ b/tests/common/cuda/cublas.h @@ -0,0 +1,221 @@ +/* Copyright 2019 Philippe Tillet +* +* Permission is hereby granted, free of charge, to any person obtaining +* a copy of this software and associated documentation files +* (the "Software"), to deal in the Software without restriction, +* including without limitation the rights to use, copy, modify, merge, +* publish, distribute, sublicense, and/or sell copies of the Software, +* and to permit persons to whom the Software is furnished to do so, +* subject to the following conditions: +* +* The above copyright notice and this permission notice shall be +* included in all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +*/ + +#include +#include +#include +#include "forward.h" +#include "triton/driver/buffer.h" +#include "triton/driver/stream.h" +#include "triton/driver/context.h" +#include "triton/driver/error.h" +#include "triton/tools/bench.hpp" + + +class cublas { +private: + template + struct return_type; + + template + struct return_type + { typedef R type; }; + + typedef bool (*f_init_t)(); + + template + static typename return_type::type f_impl(void*& lib_h, FunPtrT, void*& cache, const char * name, Args... args) + { + initializer(); + if(cache == nullptr){ + cache = dlsym(lib_h, name); + if(cache == 0) + throw std::runtime_error("dlsym unable to load function"); + } + FunPtrT fptr; + *reinterpret_cast(&fptr) = cache; + typename return_type::type res = (*fptr)(args...); + triton::driver::check(res); + return res; + } + +public: + static bool cublasinit(); + static cublasStatus_t cublasSetMathMode(cublasHandle_t h, cublasMath_t m); + static cublasStatus_t cublasCreate_v2(cublasHandle_t* h); + static cublasStatus_t cublasGetStream_v2(cublasHandle_t h, cudaStream_t *streamId); + static cublasStatus_t cublasSetStream_v2(cublasHandle_t h, cudaStream_t streamId); + static cublasStatus_t cublasGemmEx(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, + int m, int n, int k, + const void *alpha, const void *A, cudaDataType Atype, int lda, + const void *B, cudaDataType Btype, int ldb, const void *beta, + void *C, cudaDataType Ctype, int ldc, + cudaDataType computeType, cublasGemmAlgo_t algo); + +private: + static void* so_; + static void* cublasGetStream_v2_; + static void* cublasSetStream_v2_; + static void* cublasCreate_v2_; + static void* cublasGemmEx_; + static void* cublasSetMathMode_; +}; + +void* cublas::so_; +void* cublas::cublasGetStream_v2_; +void* cublas::cublasSetStream_v2_; +void* cublas::cublasCreate_v2_; +void* cublas::cublasGemmEx_; +void* cublas::cublasSetMathMode_; + + +bool cublas::cublasinit() { + if(so_==nullptr) + so_ = dlopen("libcublas.so", RTLD_LAZY); + return so_ != nullptr; +} + +cublasStatus_t cublas::cublasGetStream_v2(cublasHandle_t h, cudaStream_t *a) +{ return f_impl(so_, cublasGetStream_v2, cublasGetStream_v2_, "cublasGetStream_v2", h, a); } +cublasStatus_t cublas::cublasSetStream_v2(cublasHandle_t h, cudaStream_t a) +{ return f_impl(so_, cublasSetStream_v2, cublasSetStream_v2_, "cublasSetStream_v2", h, a); } +cublasStatus_t cublas::cublasGemmEx(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, + const void *alpha, const void *A, cudaDataType Atype, int lda, + const void *B, cudaDataType Btype, int ldb, const void *beta, + void *C, cudaDataType Ctype, int ldc, cudaDataType computeType, cublasGemmAlgo_t algo) { + return f_impl(so_, cublasGemmEx, cublasGemmEx_, "cublasGemmEx", handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, Btype, ldb, beta, C, Ctype, ldc, computeType, algo); +} +cublasStatus_t cublas::cublasCreate_v2(cublasHandle_t *h) { + return f_impl(so_, cublasCreate_v2, cublasCreate_v2_, "cublasCreate_v2", h); +} +cublasStatus_t cublas::cublasSetMathMode(cublasHandle_t h, cublasMath_t m) { + return f_impl(so_, cublasSetMathMode, cublasSetMathMode_, "cublasSetMathMode", h, m); +} + + + +inline cublasGemmAlgo_t cublasGemmFastest( + triton::driver::stream* stream, + cublasHandle_t handle, cudaDataType cudt, + cublasOperation_t AT, cublasOperation_t BT, + int32_t M, int32_t N, int32_t K, + void* alpha, CUdeviceptr A, int32_t lda, CUdeviceptr B, int32_t ldb, + void* beta, CUdeviceptr C, int32_t ldc) { + + // initialize list of cublas algorithms + static std::vector algorithms; + if(algorithms.empty()) { + // non-tensor ops + for(int i = -1; i < 24; i++) + algorithms.push_back((cublasGemmAlgo_t)i); + // tensor ops + for(int i = 99; i < 116; i++) + algorithms.push_back((cublasGemmAlgo_t)i); + } + + // cache to avoid re-benchmarking + typedef std::tuple key_t; + static std::map cache; + key_t key(cudt, AT, BT, M, N, K); + // benchmark algorithms if necessary + if(cache.find(key) == cache.end()){ + std::vector times; + for(cublasGemmAlgo_t a: algorithms) { + cublasStatus_t status; + double nanosec = triton::tools::bench([&](){ status = cublas::cublasGemmEx(handle, AT, BT, + M, N, K, + alpha, (const void*)A, cudt, lda, + (const void*)B, cudt, ldb, + beta, (void*)C, cudt, ldc, cudt, + a); }, stream); + if(status != CUBLAS_STATUS_SUCCESS) + nanosec = INFINITY; + } + size_t argmin = std::min_element(times.begin(), times.end()) - times.begin(); + assert(times[argmin] != INFINITY); + cache.insert({key, algorithms[argmin]}); + } + + // return best algorithm + return cache.at(key); +} + + + + +/* Get cuBLAS handle */ +inline cublasHandle_t cublasGetHandle(triton::driver::stream* stream) { + static std::map cache; + CUstream key = *stream->cu(); + + // create handle if necessary + if(cache.find(key) == cache.end()) { + cublasHandle_t handle; + if(cublas::cublasCreate_v2(&handle) != CUBLAS_STATUS_SUCCESS) + throw std::runtime_error("Error: could not create cuBLAS handle"); + cublas::cublasSetStream_v2(handle, key); + cache.insert({key, handle}); + } + + // return handle for the stream + return cache.at(key); +} + + + +/* Simplified API for default GEMM */ +inline void cublasGemm(cublasDataType_t dtype, + triton::driver::stream* stream, + bool AT, bool BT, + int32_t M, int32_t N, int32_t K, + void* alpha, triton::driver::buffer* A, int32_t lda, + triton::driver::buffer* B, int32_t ldb, + void* beta, triton::driver::buffer* C, int32_t ldc, + cublasGemmAlgo_t* fastest = NULL, cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT) { + + // switch triton context + triton::driver::cu_context::context_switcher scope(*stream->context()); + // get handle + static cublasHandle_t handle = cublasGetHandle(stream); + // set math mode + if(dtype == CUDA_R_16F) + cublas::cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH); + // cuda types + static const std::map cu_op = { + {false, CUBLAS_OP_N}, + {true, CUBLAS_OP_T} + }; + cublasOperation_t opa = cu_op.at(AT); + cublasOperation_t opb = cu_op.at(BT); + // benchmark fastest + if(fastest) + *fastest = cublasGemmFastest(stream, handle, dtype, opa, opb, M, N, K, alpha, *A->cu(), lda, *B->cu(), ldb, beta, *C->cu(), ldc); + else { + // execute supplied algo + cublasStatus_t status = cublas::cublasGemmEx(handle, opa, opb, M, N, K, + alpha, (const void*)*A->cu(), dtype, lda, + (const void*)*B->cu(), dtype, ldb, + beta, (void*)*C->cu(), dtype, ldc, dtype, algo); + } +} diff --git a/tests/common/cuda/forward.h b/tests/common/cuda/forward.h new file mode 100644 index 000000000..1c12c4247 --- /dev/null +++ b/tests/common/cuda/forward.h @@ -0,0 +1,105 @@ +#ifndef _COMMON_CUDA_FORWARDS_H_ +#define _COMMON_CUDA_FORwARDS_H_ + +struct cublasContext; +typedef struct cublasContext *cublasHandle_t; +struct CUstream_st; +typedef struct CUstream_st *cudaStream_t; + +/* CUBLAS status type returns */ +typedef enum{ + CUBLAS_STATUS_SUCCESS =0, + CUBLAS_STATUS_NOT_INITIALIZED =1, + CUBLAS_STATUS_ALLOC_FAILED =3, + CUBLAS_STATUS_INVALID_VALUE =7, + CUBLAS_STATUS_ARCH_MISMATCH =8, + CUBLAS_STATUS_MAPPING_ERROR =11, + CUBLAS_STATUS_EXECUTION_FAILED=13, + CUBLAS_STATUS_INTERNAL_ERROR =14, + CUBLAS_STATUS_NOT_SUPPORTED =15, + CUBLAS_STATUS_LICENSE_ERROR =16 +} cublasStatus_t; + +/*For different GEMM algorithm */ +typedef enum { + CUBLAS_GEMM_DFALT = -1, + CUBLAS_GEMM_DEFAULT = -1, + CUBLAS_GEMM_ALGO0 = 0, + CUBLAS_GEMM_ALGO1 = 1, + CUBLAS_GEMM_ALGO2 = 2, + CUBLAS_GEMM_ALGO3 = 3, + CUBLAS_GEMM_ALGO4 = 4, + CUBLAS_GEMM_ALGO5 = 5, + CUBLAS_GEMM_ALGO6 = 6, + CUBLAS_GEMM_ALGO7 = 7, + CUBLAS_GEMM_ALGO8 = 8, + CUBLAS_GEMM_ALGO9 = 9, + CUBLAS_GEMM_ALGO10 = 10, + CUBLAS_GEMM_ALGO11 = 11, + CUBLAS_GEMM_ALGO12 = 12, + CUBLAS_GEMM_ALGO13 = 13, + CUBLAS_GEMM_ALGO14 = 14, + CUBLAS_GEMM_ALGO15 = 15, + CUBLAS_GEMM_ALGO16 = 16, + CUBLAS_GEMM_ALGO17 = 17, + CUBLAS_GEMM_ALGO18 = 18, //sliced 32x32 + CUBLAS_GEMM_ALGO19 = 19, //sliced 64x32 + CUBLAS_GEMM_ALGO20 = 20, //sliced 128x32 + CUBLAS_GEMM_ALGO21 = 21, //sliced 32x32 -splitK + CUBLAS_GEMM_ALGO22 = 22, //sliced 64x32 -splitK + CUBLAS_GEMM_ALGO23 = 23, //sliced 128x32 -splitK + CUBLAS_GEMM_DEFAULT_TENSOR_OP = 99, + CUBLAS_GEMM_DFALT_TENSOR_OP = 99, + CUBLAS_GEMM_ALGO0_TENSOR_OP = 100, + CUBLAS_GEMM_ALGO1_TENSOR_OP = 101, + CUBLAS_GEMM_ALGO2_TENSOR_OP = 102, + CUBLAS_GEMM_ALGO3_TENSOR_OP = 103, + CUBLAS_GEMM_ALGO4_TENSOR_OP = 104, + CUBLAS_GEMM_ALGO5_TENSOR_OP = 105, + CUBLAS_GEMM_ALGO6_TENSOR_OP = 106, + CUBLAS_GEMM_ALGO7_TENSOR_OP = 107, + CUBLAS_GEMM_ALGO8_TENSOR_OP = 108, + CUBLAS_GEMM_ALGO9_TENSOR_OP = 109, + CUBLAS_GEMM_ALGO10_TENSOR_OP = 110, + CUBLAS_GEMM_ALGO11_TENSOR_OP = 111, + CUBLAS_GEMM_ALGO12_TENSOR_OP = 112, + CUBLAS_GEMM_ALGO13_TENSOR_OP = 113, + CUBLAS_GEMM_ALGO14_TENSOR_OP = 114, + CUBLAS_GEMM_ALGO15_TENSOR_OP = 115 +} cublasGemmAlgo_t; + +typedef enum cudaDataType_t +{ + CUDA_R_16F= 2, /* real as a half */ + CUDA_C_16F= 6, /* complex as a pair of half numbers */ + CUDA_R_32F= 0, /* real as a float */ + CUDA_C_32F= 4, /* complex as a pair of float numbers */ + CUDA_R_64F= 1, /* real as a double */ + CUDA_C_64F= 5, /* complex as a pair of double numbers */ + CUDA_R_8I = 3, /* real as a signed char */ + CUDA_C_8I = 7, /* complex as a pair of signed char numbers */ + CUDA_R_8U = 8, /* real as a unsigned char */ + CUDA_C_8U = 9, /* complex as a pair of unsigned char numbers */ + CUDA_R_32I= 10, /* real as a signed int */ + CUDA_C_32I= 11, /* complex as a pair of signed int numbers */ + CUDA_R_32U= 12, /* real as a unsigned int */ + CUDA_C_32U= 13 /* complex as a pair of unsigned int numbers */ +} cudaDataType; + +typedef cudaDataType cublasDataType_t; + +typedef enum { + CUBLAS_OP_N=0, + CUBLAS_OP_T=1, + CUBLAS_OP_C=2, + CUBLAS_OP_HERMITAN=2, /* synonym if CUBLAS_OP_C */ + CUBLAS_OP_CONJG=3 /* conjugate */ +} cublasOperation_t; + +/*Enum for default math mode/tensor operation*/ +typedef enum { + CUBLAS_DEFAULT_MATH = 0, + CUBLAS_TENSOR_OP_MATH = 1 +} cublasMath_t; + +#endif \ No newline at end of file diff --git a/tests/common/src/dot.h b/tests/common/src/dot.h new file mode 100644 index 000000000..00814c0f0 --- /dev/null +++ b/tests/common/src/dot.h @@ -0,0 +1,77 @@ +namespace src { + + const char *dot = +R"( +#ifdef AT +#define USEA ^a +#else +#define USEA a +#endif + +#ifdef BT +#define USEB ^b +#else +#define USEB b +#endif + +void dot(TYPE * A __noalias __readonly __aligned(16), + TYPE * B __noalias __readonly __aligned(16), + TYPE * C __noalias __readonly __aligned(16), + int M, int N, int K, + int lda __multipleof(8), + int ldb __multipleof(8), + int ldc) { + int ridx = get_program_id(0); + int ridy = get_program_id(1); + int rxa[TM] = ridx * TM + 0 ... TM; + int ryb[TN] = ridy * TN + 0 ... TN; + int rka[TK] = 0 ... TK; + int rkb[TK] = 0 ... TK; + float xc[TM, TN] = 0; +#ifdef AT + TYPE* pa[TK, TM] = A + rka[:, newaxis] + rxa[newaxis, :]*lda; + bool checka[TK, TM] = rka[:, newaxis] < K; + TYPE a[TK, TM] = checka ? *pa : 0; +#else + TYPE* pa[TM, TK] = A + rka[newaxis, :]*lda + rxa[:, newaxis]; + bool checka[TM, TK] = rka[newaxis, :] < K; + TYPE a[TM, TK] = checka ? *pa : 0; +#endif +#ifdef BT + TYPE* pb[TN, TK] = B + rkb[newaxis, :]*ldb + ryb[:, newaxis]; + bool checkb[TN, TK] = rkb[newaxis, :] < K; + TYPE b[TN, TK] = checkb ? *pb : 0; +#else + TYPE* pb[TK, TN] = B + rkb[:, newaxis] + ryb[newaxis, :]*ldb; + bool checkb[TK, TN] = rkb[:, newaxis] < K; + TYPE b[TK, TN] = checkb ? *pb : 0; +#endif + for(int k = K; k > 0; k = k - TK){ + xc = USEA @ USEB + xc; +#ifdef AT + pa = pa + TK; +#else + pa = pa + TK*lda; +#endif +#ifdef BT + pb = pb + TK*ldb; +#else + pb = pb + TK; +#endif + checka = k > TK; + checkb = k > TK; + a = checka ? *pa : 0; + b = checkb ? *pb : 0; + } + int rxc[TM] = ridx * TM + (0 ... TM); + int ryc[TN] = ridy * TN + (0 ... TN); + TYPE* pc[TM, TN] = C + ryc[newaxis, :]*ldc + rxc[:, newaxis]; + TYPE c[TM, TN] = xc; + bool checkc0[TM] = rxc < M; + bool checkc1[TN] = ryc < N; + bool checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :]; + *?(checkc) pc = c; +} +)"; + +} \ No newline at end of file diff --git a/tests/unit/CMakeLists.txt b/tests/unit/CMakeLists.txt new file mode 100644 index 000000000..f3cdae9a1 --- /dev/null +++ b/tests/unit/CMakeLists.txt @@ -0,0 +1,6 @@ +foreach(PROG dot) + set(TARGET test_${PROG}) + add_executable(${TARGET} ${PROG}.cc) + set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME ${TARGET}) + target_link_libraries(${TARGET} triton dl) +endforeach(PROG) diff --git a/examples/cpp/dot.cc b/tests/unit/dot.cc similarity index 67% rename from examples/cpp/dot.cc rename to tests/unit/dot.cc index 7d5a44324..3ddc8953e 100644 --- a/examples/cpp/dot.cc +++ b/tests/unit/dot.cc @@ -6,7 +6,8 @@ #include "triton/tools/bench.hpp" #include "triton/external/half.hpp" #include "triton/runtime/function.h" -#include "cuda.h" +#include "src/dot.h" +#include "cuda/cublas.h" template void diff(const std::vector& x, const std::vector& y){ @@ -44,81 +45,6 @@ void cpu_ref(bool AT_, bool BT_, size_t M, size_t N, size_t K, } - -std::string src = -R"( -#ifdef AT -#define USEA ^a -#else -#define USEA a -#endif - -#ifdef BT -#define USEB ^b -#else -#define USEB b -#endif - -void dot(TYPE * A __noalias __readonly __aligned(16), - TYPE * B __noalias __readonly __aligned(16), - TYPE * C __noalias __readonly __aligned(16), - int M, int N, int K, - int lda __multipleof(8), - int ldb __multipleof(8), - int ldc) { - int ridx = get_program_id(0); - int ridy = get_program_id(1); - int rxa[TM] = ridx * TM + 0 ... TM; - int ryb[TN] = ridy * TN + 0 ... TN; - int rka[TK] = 0 ... TK; - int rkb[TK] = 0 ... TK; - float xc[TM, TN] = 0; -#ifdef AT - TYPE* pa[TK, TM] = A + rka[:, newaxis] + rxa[newaxis, :]*lda; - bool checka[TK, TM] = rka[:, newaxis] < K; - TYPE a[TK, TM] = checka ? *pa : 0; -#else - TYPE* pa[TM, TK] = A + rka[newaxis, :]*lda + rxa[:, newaxis]; - bool checka[TM, TK] = rka[newaxis, :] < K; - TYPE a[TM, TK] = checka ? *pa : 0; -#endif -#ifdef BT - TYPE* pb[TN, TK] = B + rkb[newaxis, :]*ldb + ryb[:, newaxis]; - bool checkb[TN, TK] = rkb[newaxis, :] < K; - TYPE b[TN, TK] = checkb ? *pb : 0; -#else - TYPE* pb[TK, TN] = B + rkb[:, newaxis] + ryb[newaxis, :]*ldb; - bool checkb[TK, TN] = rkb[:, newazis] < K; - TYPE b[TK, TN] = checkb ? *pb : 0; -#endif - for(int k = K; k > 0; k = k - TK){ - xc = USEA @ USEB + xc; -#ifdef AT - pa = pa + TK; -#else - pa = pa + TK*lda; -#endif -#ifdef BT - pb = pb + TK*ldb; -#else - pb = pb + TK; -#endif - checka = k > TK; - checkb = k > TK; - a = checka ? *pa : 0; - b = checkb ? *pb : 0; - } - int rxc[TM] = ridx * TM + (0 ... TM); - int ryc[TN] = ridy * TN + (0 ... TN); - TYPE* pc[TM, TN] = C + ryc[newaxis, :]*ldc + rxc[:, newaxis]; - TYPE c[TM, TN] = xc; - bool checkc0[TM] = rxc < M; - bool checkc1[TN] = ryc < N; - bool checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :]; - *?(checkc) pc = c; -} -)"; - struct perf_t { double triton; double cublas; @@ -128,7 +54,7 @@ namespace drv = triton::driver; namespace rt = triton::runtime; perf_t do_bench(drv::stream* stream, bool AT, bool BT, int32_t M, int32_t N, int32_t K){ - typedef half NumericT; + typedef half_float::half NumericT; std::string ty = "half"; size_t dt_nbytes = sizeof(NumericT); drv::context* context = stream->context(); @@ -140,9 +66,9 @@ perf_t do_bench(drv::stream* stream, bool AT, bool BT, int32_t M, int32_t N, int int32_t ldc = M; srand(0); for(size_t i = 0; i < ha.size(); i++) - ha[i] = static_cast((double)rand()/RAND_MAX); + ha[i] = static_cast((float)rand()/RAND_MAX); for(size_t i = 0; i < hb.size(); i++) - hb[i] = static_cast((double)rand()/RAND_MAX); + hb[i] = static_cast((float)rand()/RAND_MAX); for(size_t i = 0; i < hc.size(); i++) hc[i] = static_cast((double)0); drv::buffer* dc = drv::buffer::create(context, hc.size()*dt_nbytes); @@ -159,11 +85,11 @@ perf_t do_bench(drv::stream* stream, bool AT, bool BT, int32_t M, int32_t N, int opt.defines.push_back({"AT", {""}}); if(BT) opt.defines.push_back({"BT", {""}}); - opt.defines.push_back({"TM", {"128"}}); - opt.defines.push_back({"TN", {"128"}}); + opt.defines.push_back({"TM", {"16", "32", "64", "128"}}); + opt.defines.push_back({"TN", {"16", "32", "64", "128"}}); opt.defines.push_back({"TK", {"32"}}); - opt.num_warps = {4}; - rt::function function(src, opt); + opt.num_warps = {1, 2, 4, 8}; + rt::function function(src::dot, opt); auto ceil = [](size_t x, size_t y) { return (x + y - 1) / y; }; auto grid = [&](const rt::function::options_t& x) { return rt::grid_t{ceil(M, x.D("TM")), ceil(N, x.D("TN")), 1}; }; @@ -171,10 +97,15 @@ perf_t do_bench(drv::stream* stream, bool AT, bool BT, int32_t M, int32_t N, int auto tflops = [&](double nanosec) { return 2.*M*N*K / nanosec * 1e-3; }; perf_t res; res.triton = tflops(triton::tools::bench([&]() { function({da, db, dc, M, N, K, lda, ldb, ldc}, grid, stream);}, stream)); - res.cublas = 0; + NumericT alpha(static_cast(1)); + NumericT beta(static_cast(0)); + cublasGemmAlgo_t fastest; + cublasGemm(CUDA_R_16F, stream, AT, BT, M, N, K, &alpha, da, lda, db, ldb, &beta, dc, ldc, &fastest); + res.cublas = tflops(triton::tools::bench([&]() { cublasGemm(CUDA_R_16F, stream, AT, BT, M, N, K, + &alpha, da, lda, db, ldb, &beta, dc, ldc, nullptr, fastest); }, + stream)); // test - stream->synchronize(); // stream->read(dc, true, 0, hc); // std::vector rc(hc.size()); // cpu_ref(AT, BT, M, N, K, rc, ha, hb); @@ -214,7 +145,7 @@ int main() { // shapes to benchmark std::vector configs = { // {false, false, 8192, 512, 512}, - {false, true, 8192, 8192, 8192} + {false, true, 128, 128, 128} // {false, true, 128, 128, 128}, // {false, false, 128, 128, 128}, // {true, false, 128, 128, 128},