/* Copyright 2019 Philippe Tillet * * Permission is hereby granted, free of charge, to any person obtaining * a copy of this software and associated documentation files * (the "Software"), to deal in the Software without restriction, * including without limitation the rights to use, copy, modify, merge, * publish, distribute, sublicense, and/or sell copies of the Software, * and to permit persons to whom the Software is furnished to do so, * subject to the following conditions: * * The above copyright notice and this permission notice shall be * included in all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ #include #include #include #include "forward.h" #include "triton/driver/buffer.h" #include "triton/driver/stream.h" #include "triton/driver/context.h" #include "triton/driver/error.h" #include "triton/tools/bench.hpp" class cublas { private: template struct return_type; template struct return_type { typedef R type; }; typedef bool (*f_init_t)(); template static typename return_type::type f_impl(void*& lib_h, FunPtrT, void*& cache, const char * name, Args... args) { initializer(); if(cache == nullptr){ cache = dlsym(lib_h, name); if(cache == 0) throw std::runtime_error("dlsym unable to load function"); } FunPtrT fptr; *reinterpret_cast(&fptr) = cache; typename return_type::type res = (*fptr)(args...); triton::driver::check(res); return res; } public: static bool cublasinit(); static cublasStatus_t cublasSetMathMode(cublasHandle_t h, cublasMath_t m); static cublasStatus_t cublasCreate_v2(cublasHandle_t* h); static cublasStatus_t cublasGetStream_v2(cublasHandle_t h, cudaStream_t *streamId); static cublasStatus_t cublasSetStream_v2(cublasHandle_t h, cudaStream_t streamId); static cublasStatus_t cublasGemmEx(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, const void *alpha, const void *A, cudaDataType Atype, int lda, const void *B, cudaDataType Btype, int ldb, const void *beta, void *C, cudaDataType Ctype, int ldc, cudaDataType computeType, cublasGemmAlgo_t algo); private: static void* so_; static void* cublasGetStream_v2_; static void* cublasSetStream_v2_; static void* cublasCreate_v2_; static void* cublasGemmEx_; static void* cublasSetMathMode_; }; void* cublas::so_; void* cublas::cublasGetStream_v2_; void* cublas::cublasSetStream_v2_; void* cublas::cublasCreate_v2_; void* cublas::cublasGemmEx_; void* cublas::cublasSetMathMode_; bool cublas::cublasinit() { if(so_==nullptr) so_ = dlopen("libcublas.so", RTLD_LAZY); return so_ != nullptr; } cublasStatus_t cublas::cublasGetStream_v2(cublasHandle_t h, cudaStream_t *a) { return f_impl(so_, cublasGetStream_v2, cublasGetStream_v2_, "cublasGetStream_v2", h, a); } cublasStatus_t cublas::cublasSetStream_v2(cublasHandle_t h, cudaStream_t a) { return f_impl(so_, cublasSetStream_v2, cublasSetStream_v2_, "cublasSetStream_v2", h, a); } cublasStatus_t cublas::cublasGemmEx(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, const void *alpha, const void *A, cudaDataType Atype, int lda, const void *B, cudaDataType Btype, int ldb, const void *beta, void *C, cudaDataType Ctype, int ldc, cudaDataType computeType, cublasGemmAlgo_t algo) { return f_impl(so_, cublasGemmEx, cublasGemmEx_, "cublasGemmEx", handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, Btype, ldb, beta, C, Ctype, ldc, computeType, algo); } cublasStatus_t cublas::cublasCreate_v2(cublasHandle_t *h) { return f_impl(so_, cublasCreate_v2, cublasCreate_v2_, "cublasCreate_v2", h); } cublasStatus_t cublas::cublasSetMathMode(cublasHandle_t h, cublasMath_t m) { return f_impl(so_, cublasSetMathMode, cublasSetMathMode_, "cublasSetMathMode", h, m); } inline cublasGemmAlgo_t cublasGemmFastest( triton::driver::stream* stream, cublasHandle_t handle, cudaDataType cudt, cublasOperation_t AT, cublasOperation_t BT, int32_t M, int32_t N, int32_t K, void* alpha, CUdeviceptr A, int32_t lda, CUdeviceptr B, int32_t ldb, void* beta, CUdeviceptr C, int32_t ldc) { // initialize list of cublas algorithms static std::vector algorithms; if(algorithms.empty()) { // non-tensor ops for(int i = -1; i < 24; i++) algorithms.push_back((cublasGemmAlgo_t)i); // tensor ops for(int i = 99; i < 116; i++) algorithms.push_back((cublasGemmAlgo_t)i); } // cache to avoid re-benchmarking typedef std::tuple key_t; static std::map cache; key_t key(cudt, AT, BT, M, N, K); // benchmark algorithms if necessary if(cache.find(key) == cache.end()){ std::vector times; for(cublasGemmAlgo_t a: algorithms) { cublasStatus_t status; double nanosec = triton::tools::bench([&](){ status = cublas::cublasGemmEx(handle, AT, BT, M, N, K, alpha, (const void*)A, cudt, lda, (const void*)B, cudt, ldb, beta, (void*)C, cudt, ldc, cudt, a); }, stream); if(status != CUBLAS_STATUS_SUCCESS) nanosec = INFINITY; } size_t argmin = std::min_element(times.begin(), times.end()) - times.begin(); assert(times[argmin] != INFINITY); cache.insert({key, algorithms[argmin]}); } // return best algorithm return cache.at(key); } /* Get cuBLAS handle */ inline cublasHandle_t cublasGetHandle(triton::driver::stream* stream) { static std::map cache; CUstream key = *stream->cu(); // create handle if necessary if(cache.find(key) == cache.end()) { cublasHandle_t handle; if(cublas::cublasCreate_v2(&handle) != CUBLAS_STATUS_SUCCESS) throw std::runtime_error("Error: could not create cuBLAS handle"); cublas::cublasSetStream_v2(handle, key); cache.insert({key, handle}); } // return handle for the stream return cache.at(key); } /* Simplified API for default GEMM */ inline void cublasGemm(cublasDataType_t dtype, triton::driver::stream* stream, bool AT, bool BT, int32_t M, int32_t N, int32_t K, void* alpha, triton::driver::buffer* A, int32_t lda, triton::driver::buffer* B, int32_t ldb, void* beta, triton::driver::buffer* C, int32_t ldc, cublasGemmAlgo_t* fastest = NULL, cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT) { // switch triton context triton::driver::cu_context::context_switcher scope(*stream->context()); // get handle static cublasHandle_t handle = cublasGetHandle(stream); // set math mode if(dtype == CUDA_R_16F) cublas::cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH); // cuda types static const std::map cu_op = { {false, CUBLAS_OP_N}, {true, CUBLAS_OP_T} }; cublasOperation_t opa = cu_op.at(AT); cublasOperation_t opb = cu_op.at(BT); // benchmark fastest if(fastest) *fastest = cublasGemmFastest(stream, handle, dtype, opa, opb, M, N, K, alpha, *A->cu(), lda, *B->cu(), ldb, beta, *C->cu(), ldc); else { // execute supplied algo cublasStatus_t status = cublas::cublasGemmEx(handle, opa, opb, M, N, K, alpha, (const void*)*A->cu(), dtype, lda, (const void*)*B->cu(), dtype, ldb, beta, (void*)*C->cu(), dtype, ldc, dtype, algo); } }