Files
triton/tests/common/cuda/cublas.h
2019-08-27 20:33:38 -07:00

222 lines
8.9 KiB
C++

/* Copyright 2019 Philippe Tillet
*
* Permission is hereby granted, free of charge, to any person obtaining
* a copy of this software and associated documentation files
* (the "Software"), to deal in the Software without restriction,
* including without limitation the rights to use, copy, modify, merge,
* publish, distribute, sublicense, and/or sell copies of the Software,
* and to permit persons to whom the Software is furnished to do so,
* subject to the following conditions:
*
* The above copyright notice and this permission notice shall be
* included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
#include <algorithm>
#include <vector>
#include <cassert>
#include "forward.h"
#include "triton/driver/buffer.h"
#include "triton/driver/stream.h"
#include "triton/driver/context.h"
#include "triton/driver/error.h"
#include "triton/tools/bench.hpp"
class cublas {
private:
template <class F>
struct return_type;
template <class R, class... A>
struct return_type<R (*)(A...)>
{ typedef R type; };
typedef bool (*f_init_t)();
template<f_init_t initializer, typename FunPtrT, typename... Args>
static typename return_type<FunPtrT>::type f_impl(void*& lib_h, FunPtrT, void*& cache, const char * name, Args... args)
{
initializer();
if(cache == nullptr){
cache = dlsym(lib_h, name);
if(cache == 0)
throw std::runtime_error("dlsym unable to load function");
}
FunPtrT fptr;
*reinterpret_cast<void **>(&fptr) = cache;
typename return_type<FunPtrT>::type res = (*fptr)(args...);
triton::driver::check(res);
return res;
}
public:
static bool cublasinit();
static cublasStatus_t cublasSetMathMode(cublasHandle_t h, cublasMath_t m);
static cublasStatus_t cublasCreate_v2(cublasHandle_t* h);
static cublasStatus_t cublasGetStream_v2(cublasHandle_t h, cudaStream_t *streamId);
static cublasStatus_t cublasSetStream_v2(cublasHandle_t h, cudaStream_t streamId);
static cublasStatus_t cublasGemmEx(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb,
int m, int n, int k,
const void *alpha, const void *A, cudaDataType Atype, int lda,
const void *B, cudaDataType Btype, int ldb, const void *beta,
void *C, cudaDataType Ctype, int ldc,
cudaDataType computeType, cublasGemmAlgo_t algo);
private:
static void* so_;
static void* cublasGetStream_v2_;
static void* cublasSetStream_v2_;
static void* cublasCreate_v2_;
static void* cublasGemmEx_;
static void* cublasSetMathMode_;
};
void* cublas::so_;
void* cublas::cublasGetStream_v2_;
void* cublas::cublasSetStream_v2_;
void* cublas::cublasCreate_v2_;
void* cublas::cublasGemmEx_;
void* cublas::cublasSetMathMode_;
bool cublas::cublasinit() {
if(so_==nullptr)
so_ = dlopen("libcublas.so", RTLD_LAZY);
return so_ != nullptr;
}
cublasStatus_t cublas::cublasGetStream_v2(cublasHandle_t h, cudaStream_t *a)
{ return f_impl<cublas::cublasinit>(so_, cublasGetStream_v2, cublasGetStream_v2_, "cublasGetStream_v2", h, a); }
cublasStatus_t cublas::cublasSetStream_v2(cublasHandle_t h, cudaStream_t a)
{ return f_impl<cublas::cublasinit>(so_, cublasSetStream_v2, cublasSetStream_v2_, "cublasSetStream_v2", h, a); }
cublasStatus_t cublas::cublasGemmEx(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k,
const void *alpha, const void *A, cudaDataType Atype, int lda,
const void *B, cudaDataType Btype, int ldb, const void *beta,
void *C, cudaDataType Ctype, int ldc, cudaDataType computeType, cublasGemmAlgo_t algo) {
return f_impl<cublas::cublasinit>(so_, cublasGemmEx, cublasGemmEx_, "cublasGemmEx", handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, Btype, ldb, beta, C, Ctype, ldc, computeType, algo);
}
cublasStatus_t cublas::cublasCreate_v2(cublasHandle_t *h) {
return f_impl<cublas::cublasinit>(so_, cublasCreate_v2, cublasCreate_v2_, "cublasCreate_v2", h);
}
cublasStatus_t cublas::cublasSetMathMode(cublasHandle_t h, cublasMath_t m) {
return f_impl<cublas::cublasinit>(so_, cublasSetMathMode, cublasSetMathMode_, "cublasSetMathMode", h, m);
}
inline cublasGemmAlgo_t cublasGemmFastest(
triton::driver::stream* stream,
cublasHandle_t handle, cudaDataType cudt,
cublasOperation_t AT, cublasOperation_t BT,
int32_t M, int32_t N, int32_t K,
void* alpha, CUdeviceptr A, int32_t lda, CUdeviceptr B, int32_t ldb,
void* beta, CUdeviceptr C, int32_t ldc) {
// initialize list of cublas algorithms
static std::vector<cublasGemmAlgo_t> algorithms;
if(algorithms.empty()) {
// non-tensor ops
for(int i = -1; i < 24; i++)
algorithms.push_back((cublasGemmAlgo_t)i);
// tensor ops
for(int i = 99; i < 116; i++)
algorithms.push_back((cublasGemmAlgo_t)i);
}
// cache to avoid re-benchmarking
typedef std::tuple<cudaDataType_t,
cublasOperation_t, cublasOperation_t,
int32_t, int32_t, int32_t> key_t;
static std::map<key_t, cublasGemmAlgo_t> cache;
key_t key(cudt, AT, BT, M, N, K);
// benchmark algorithms if necessary
if(cache.find(key) == cache.end()){
std::vector<double> times;
for(cublasGemmAlgo_t a: algorithms) {
cublasStatus_t status;
double nanosec = triton::tools::bench([&](){ status = cublas::cublasGemmEx(handle, AT, BT,
M, N, K,
alpha, (const void*)A, cudt, lda,
(const void*)B, cudt, ldb,
beta, (void*)C, cudt, ldc, cudt,
a); }, stream);
if(status != CUBLAS_STATUS_SUCCESS)
nanosec = INFINITY;
}
size_t argmin = std::min_element(times.begin(), times.end()) - times.begin();
assert(times[argmin] != INFINITY);
cache.insert({key, algorithms[argmin]});
}
// return best algorithm
return cache.at(key);
}
/* Get cuBLAS handle */
inline cublasHandle_t cublasGetHandle(triton::driver::stream* stream) {
static std::map<CUstream, cublasHandle_t> cache;
CUstream key = *stream->cu();
// create handle if necessary
if(cache.find(key) == cache.end()) {
cublasHandle_t handle;
if(cublas::cublasCreate_v2(&handle) != CUBLAS_STATUS_SUCCESS)
throw std::runtime_error("Error: could not create cuBLAS handle");
cublas::cublasSetStream_v2(handle, key);
cache.insert({key, handle});
}
// return handle for the stream
return cache.at(key);
}
/* Simplified API for default GEMM */
inline void cublasGemm(cublasDataType_t dtype,
triton::driver::stream* stream,
bool AT, bool BT,
int32_t M, int32_t N, int32_t K,
void* alpha, triton::driver::buffer* A, int32_t lda,
triton::driver::buffer* B, int32_t ldb,
void* beta, triton::driver::buffer* C, int32_t ldc,
cublasGemmAlgo_t* fastest = NULL, cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT) {
// switch triton context
triton::driver::cu_context::context_switcher scope(*stream->context());
// get handle
static cublasHandle_t handle = cublasGetHandle(stream);
// set math mode
if(dtype == CUDA_R_16F)
cublas::cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH);
// cuda types
static const std::map<char, cublasOperation_t> cu_op = {
{false, CUBLAS_OP_N},
{true, CUBLAS_OP_T}
};
cublasOperation_t opa = cu_op.at(AT);
cublasOperation_t opb = cu_op.at(BT);
// benchmark fastest
if(fastest)
*fastest = cublasGemmFastest(stream, handle, dtype, opa, opb, M, N, K, alpha, *A->cu(), lda, *B->cu(), ldb, beta, *C->cu(), ldc);
else {
// execute supplied algo
cublasStatus_t status = cublas::cublasGemmEx(handle, opa, opb, M, N, K,
alpha, (const void*)*A->cu(), dtype, lda,
(const void*)*B->cu(), dtype, ldb,
beta, (void*)*C->cu(), dtype, ldc, dtype, algo);
}
}