/* Copyright 2015-2017 Philippe Tillet * * Permission is hereby granted, free of charge, to any person obtaining * a copy of this software and associated documentation files * (the "Software"), to deal in the Software without restriction, * including without limitation the rights to use, copy, modify, merge, * publish, distribute, sublicense, and/or sell copies of the Software, * and to permit persons to whom the Software is furnished to do so, * subject to the following conditions: * * The above copyright notice and this permission notice shall be * included in all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ #include #include #include #include "cublas_v2.h" #include "triton/driver/buffer.h" #include "triton/driver/stream.h" #include "triton/driver/context.h" #include "triton/tools/bench.hpp" enum cublasStrategy_t{ CUBLAS_PREFER_FASTEST, CUBLAS_HEURISTICS }; enum DType{ HALF_TYPE, FLOAT_TYPE, DOUBLE_TYPE, }; inline size_t size_of(DType dtype){ switch (dtype) { case HALF_TYPE: return 2; case FLOAT_TYPE: return 4; case DOUBLE_TYPE: return 8; default: throw; } } std::vector gather_all_algos() { std::vector result; // non-tensor ops for(int i = -1; i < 24; i++) result.push_back((cublasGemmAlgo_t)i); // tensor ops for(int i = 99; i < 116; i++) result.push_back((cublasGemmAlgo_t)i); return result; } static const std::vector algorithms = gather_all_algos(); static const std::map cu_dtype = { {HALF_TYPE, CUDA_R_16F}, {FLOAT_TYPE, CUDA_R_32F}, {DOUBLE_TYPE, CUDA_R_64F} }; static const std::map cu_op = { {false, CUBLAS_OP_N}, {true, CUBLAS_OP_T} }; inline cublasGemmAlgo_t cublasGemmFastest( triton::driver::stream* stream, cublasHandle_t handle, cudaDataType cudt, cublasOperation_t AT, cublasOperation_t BT, int32_t M, int32_t N, int32_t K, void* alpha, CUdeviceptr A, int32_t lda, CUdeviceptr B, int32_t ldb, void* beta, CUdeviceptr C, int32_t ldc) { // cache to avoid re-benchmarking typedef std::tuple key_t; static std::map cache; key_t key(cudt, AT, BT, M, N, K); // benchmark algorithms if necessary if(cache.find(key) == cache.end()){ std::vector times; for(cublasGemmAlgo_t a: algorithms) { cublasStatus_t status; double nanosec = triton::tools::bench([&](){ status = cublasGemmEx(handle, AT, BT, M, N, K, alpha, (const void*)A, cudt, lda, (const void*)B, cudt, ldb, beta, (void*)C, cudt, ldc, cudt, a); }, stream); if(status != CUBLAS_STATUS_SUCCESS) nanosec = INFINITY; } size_t argmin = std::min_element(times.begin(), times.end()) - times.begin(); assert(times[argmin] != INFINITY); cache.insert({key, algorithms[argmin]}); } // return best algorithm return cache.at(key); } /* Wrapper for cublasGemmEx */ inline cublasStatus_t cublasGemmEx(cublasHandle_t handle, cudaDataType cudt, cublasOperation_t AT, cublasOperation_t BT, int32_t M, int32_t N, int32_t K, void* alpha, CUdeviceptr A, int32_t lda, CUdeviceptr B, int32_t ldb, void* beta, CUdeviceptr C, int32_t ldc, cublasGemmAlgo_t algo) { cublasStatus_t status = cublasGemmEx(handle, AT, BT, M, N, K, alpha, (const void*)A, cudt, lda, (const void*)B, cudt, ldb, beta, (void*)C, cudt, ldc, cudt, algo); if(status != CUBLAS_STATUS_SUCCESS){ std::cout << status; exit(EXIT_FAILURE); } } /* Get cuBLAS handle */ cublasHandle_t cublasGetHandle(triton::driver::stream* stream) { static std::map cache; CUstream key = *stream->cu(); // create handle if necessary if(cache.find(key) == cache.end()) { cublasHandle_t handle; if(cublasCreate_v2(&handle) != CUBLAS_STATUS_SUCCESS) throw std::runtime_error("Error: could not create cuBLAS handle"); cublasSetStream_v2(handle, key); cache.insert({key, handle}); } // return handle for the stream return cache.at(key); } /* Simplified API for default GEMM */ inline void cublasGemm(DType dtype, triton::driver::stream* stream, bool AT, bool BT, int32_t M, int32_t N, int32_t K, void* alpha, triton::driver::buffer* A, int32_t lda, triton::driver::buffer* B, int32_t ldb, void* beta, triton::driver::buffer* C, int32_t ldc, cublasGemmAlgo_t* fastest = NULL, cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT) { triton::driver::cu_context::context_switcher scope(*stream->context()); static cublasHandle_t handle = cublasGetHandle(stream); if(dtype == HALF_TYPE) cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH); cublasStatus_t status; if(fastest) *fastest = cublasGemmFastest(stream, handle, cu_dtype.at(dtype), cu_op.at(AT), cu_op.at(BT), M, N, K, alpha, *A->cu(), lda, *B->cu(), ldb, beta, *C->cu(), ldc); else status = cublasGemmEx(handle, cu_dtype.at(dtype), cu_op.at(AT), cu_op.at(BT), M, N, K, alpha, *A->cu(), lda, *B->cu(), ldb, beta, *C->cu(), ldc, algo); }