[structure] better directory structure for tests
This commit is contained in:
@@ -4,7 +4,7 @@ include(CTest)
|
||||
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake")
|
||||
|
||||
# Options
|
||||
option(BUILD_EXAMPLES "Build C++ Triton examples" ON)
|
||||
option(BUILD_TESTS "Build C++ Triton tests" ON)
|
||||
option(BUILD_PYTHON_MODULE "Build Python Triton bindings" OFF)
|
||||
|
||||
# LLVM
|
||||
@@ -23,10 +23,10 @@ endif()
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include)
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${LLVM_CXXFLAGS} -std=c++11")
|
||||
|
||||
# Examples
|
||||
if(BUILD_EXAMPLES)
|
||||
message(STATUS "Adding C++ examples")
|
||||
add_subdirectory(examples)
|
||||
# Tests
|
||||
if(BUILD_TESTS)
|
||||
message(STATUS "Adding C++ tests")
|
||||
add_subdirectory(tests)
|
||||
endif()
|
||||
|
||||
# Python module
|
||||
|
@@ -1,20 +0,0 @@
|
||||
include(FindPackageHandleStandardArgs)
|
||||
unset(TENSORFLOW_FOUND)
|
||||
|
||||
execute_process(COMMAND python -c "from os.path import dirname; import tensorflow as tf; print(dirname(dirname(tf.sysconfig.get_include())))"
|
||||
OUTPUT_VARIABLE TF_INC OUTPUT_STRIP_TRAILING_WHITESPACE ERROR_QUIET)
|
||||
execute_process(COMMAND python -c "import tensorflow as tf; print(tf.sysconfig.get_lib())"
|
||||
OUTPUT_VARIABLE TF_LIB OUTPUT_STRIP_TRAILING_WHITESPACE ERROR_QUIET)
|
||||
execute_process(COMMAND python -c "import tensorflow as tf; print(tf.__cxx11_abi_flag__ if \"__cxx11_abi_flag__\" in tf.__dict__ else 0)"
|
||||
OUTPUT_VARIABLE TF_ABI OUTPUT_STRIP_TRAILING_WHITESPACE ERROR_QUIET)
|
||||
|
||||
find_package_handle_standard_args(TensorFlow DEFAULT_MSG TF_INC TF_LIB)
|
||||
|
||||
# set external variables for usage in CMakeLists.txt
|
||||
if(TensorFlow_FOUND)
|
||||
set(TensorFlow_LIBRARIES ${TF_LIB})
|
||||
set(TensorFlow_INCLUDE_DIRS ${TF_INC})
|
||||
set(TensorFlow_ABI ${TF_ABI})
|
||||
endif()
|
||||
|
||||
mark_as_advanced(TF_INC TF_LIB TF_ABI)
|
@@ -1,14 +0,0 @@
|
||||
include(FindPackageHandleStandardArgs)
|
||||
execute_process(COMMAND python -c "import torch; import os; print(os.path.dirname(torch.__file__))"
|
||||
OUTPUT_VARIABLE TORCH_INSTALL_PREFIX OUTPUT_STRIP_TRAILING_WHITESPACE ERROR_QUIET)
|
||||
|
||||
find_package_handle_standard_args(TORCH DEFAULT_MSG TORCH_INSTALL_PREFIX)
|
||||
if(TORCH_INSTALL_PREFIX)
|
||||
set(TORCH_INCLUDE_DIRS ${TORCH_INSTALL_PREFIX}/lib/include/
|
||||
${TORCH_INSTALL_PREFIX}/lib/include/torch/csrc/api/include
|
||||
${TORCH_INSTALL_PREFIX}/include/
|
||||
${TORCH_INSTALL_PREFIX}/include/torch/csrc/api/include/)
|
||||
set(TORCH_LIBRARY_DIRS ${TORCH_INSTALL_PREFIX}/lib/)
|
||||
endif()
|
||||
|
||||
mark_as_advanced(TORCH_INCLUDE_DIRS TORCH_LIBRARY_DIRS)
|
@@ -1 +0,0 @@
|
||||
add_subdirectory(cpp)
|
@@ -1,6 +0,0 @@
|
||||
foreach(PROG dot)
|
||||
add_executable(${PROG} ${PROG}.cc)
|
||||
set_target_properties(${PROG} PROPERTIES OUTPUT_NAME ${PROG})
|
||||
include_directories(/usr/local/cuda/include/)
|
||||
target_link_libraries(${PROG} triton cublas)
|
||||
endforeach(PROG)
|
@@ -1,160 +0,0 @@
|
||||
/* Copyright 2015-2017 Philippe Tillet
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining
|
||||
* a copy of this software and associated documentation files
|
||||
* (the "Software"), to deal in the Software without restriction,
|
||||
* including without limitation the rights to use, copy, modify, merge,
|
||||
* publish, distribute, sublicense, and/or sell copies of the Software,
|
||||
* and to permit persons to whom the Software is furnished to do so,
|
||||
* subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be
|
||||
* included in all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*/
|
||||
|
||||
#include <algorithm>
|
||||
#include <vector>
|
||||
#include <cassert>
|
||||
#include "cublas_v2.h"
|
||||
#include "triton/driver/buffer.h"
|
||||
#include "triton/driver/stream.h"
|
||||
#include "triton/driver/context.h"
|
||||
#include "triton/tools/bench.hpp"
|
||||
|
||||
enum cublasStrategy_t{
|
||||
CUBLAS_PREFER_FASTEST,
|
||||
CUBLAS_HEURISTICS
|
||||
};
|
||||
|
||||
enum DType{
|
||||
HALF_TYPE,
|
||||
FLOAT_TYPE,
|
||||
DOUBLE_TYPE,
|
||||
};
|
||||
|
||||
inline size_t size_of(DType dtype){
|
||||
switch (dtype) {
|
||||
case HALF_TYPE: return 2;
|
||||
case FLOAT_TYPE: return 4;
|
||||
case DOUBLE_TYPE: return 8;
|
||||
default: throw;
|
||||
}
|
||||
}
|
||||
|
||||
inline std::vector<cublasGemmAlgo_t> gather_all_algos() {
|
||||
std::vector<cublasGemmAlgo_t> result;
|
||||
// non-tensor ops
|
||||
for(int i = -1; i < 24; i++)
|
||||
result.push_back((cublasGemmAlgo_t)i);
|
||||
// tensor ops
|
||||
for(int i = 99; i < 116; i++)
|
||||
result.push_back((cublasGemmAlgo_t)i);
|
||||
return result;
|
||||
}
|
||||
|
||||
static const std::vector<cublasGemmAlgo_t> algorithms = gather_all_algos();
|
||||
|
||||
static const std::map<DType, cudaDataType> cu_dtype = {
|
||||
{HALF_TYPE, CUDA_R_16F},
|
||||
{FLOAT_TYPE, CUDA_R_32F},
|
||||
{DOUBLE_TYPE, CUDA_R_64F}
|
||||
};
|
||||
|
||||
static const std::map<char, cublasOperation_t> cu_op = {
|
||||
{false, CUBLAS_OP_N},
|
||||
{true, CUBLAS_OP_T}
|
||||
};
|
||||
|
||||
inline cublasGemmAlgo_t cublasGemmFastest(
|
||||
triton::driver::stream* stream,
|
||||
cublasHandle_t handle, cudaDataType cudt,
|
||||
cublasOperation_t AT, cublasOperation_t BT,
|
||||
int32_t M, int32_t N, int32_t K,
|
||||
void* alpha, CUdeviceptr A, int32_t lda, CUdeviceptr B, int32_t ldb,
|
||||
void* beta, CUdeviceptr C, int32_t ldc) {
|
||||
|
||||
// cache to avoid re-benchmarking
|
||||
typedef std::tuple<cudaDataType_t,
|
||||
cublasOperation_t, cublasOperation_t,
|
||||
int32_t, int32_t, int32_t> key_t;
|
||||
static std::map<key_t, cublasGemmAlgo_t> cache;
|
||||
key_t key(cudt, AT, BT, M, N, K);
|
||||
// benchmark algorithms if necessary
|
||||
if(cache.find(key) == cache.end()){
|
||||
std::vector<double> times;
|
||||
for(cublasGemmAlgo_t a: algorithms) {
|
||||
cublasStatus_t status;
|
||||
double nanosec = triton::tools::bench([&](){ status = cublasGemmEx(handle, AT, BT,
|
||||
M, N, K,
|
||||
alpha, (const void*)A, cudt, lda,
|
||||
(const void*)B, cudt, ldb,
|
||||
beta, (void*)C, cudt, ldc, cudt,
|
||||
a); }, stream);
|
||||
if(status != CUBLAS_STATUS_SUCCESS)
|
||||
nanosec = INFINITY;
|
||||
}
|
||||
size_t argmin = std::min_element(times.begin(), times.end()) - times.begin();
|
||||
assert(times[argmin] != INFINITY);
|
||||
cache.insert({key, algorithms[argmin]});
|
||||
}
|
||||
|
||||
// return best algorithm
|
||||
return cache.at(key);
|
||||
}
|
||||
|
||||
/* Wrapper for cublasGemmEx */
|
||||
inline cublasStatus_t cublasGemmEx(cublasHandle_t handle, cudaDataType cudt, cublasOperation_t AT, cublasOperation_t BT, int32_t M, int32_t N, int32_t K,
|
||||
void* alpha, CUdeviceptr A, int32_t lda, CUdeviceptr B, int32_t ldb,
|
||||
void* beta, CUdeviceptr C, int32_t ldc, cublasGemmAlgo_t algo)
|
||||
{
|
||||
cublasStatus_t status = cublasGemmEx(handle, AT, BT, M, N, K, alpha, (const void*)A, cudt, lda, (const void*)B, cudt, ldb, beta, (void*)C, cudt, ldc, cudt, algo);
|
||||
if(status != CUBLAS_STATUS_SUCCESS){
|
||||
std::cout << status;
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/* Get cuBLAS handle */
|
||||
inline cublasHandle_t cublasGetHandle(triton::driver::stream* stream) {
|
||||
static std::map<CUstream, cublasHandle_t> cache;
|
||||
CUstream key = *stream->cu();
|
||||
|
||||
// create handle if necessary
|
||||
if(cache.find(key) == cache.end()) {
|
||||
cublasHandle_t handle;
|
||||
if(cublasCreate_v2(&handle) != CUBLAS_STATUS_SUCCESS)
|
||||
throw std::runtime_error("Error: could not create cuBLAS handle");
|
||||
cublasSetStream_v2(handle, key);
|
||||
cache.insert({key, handle});
|
||||
}
|
||||
|
||||
// return handle for the stream
|
||||
return cache.at(key);
|
||||
}
|
||||
|
||||
/* Simplified API for default GEMM */
|
||||
inline void cublasGemm(DType dtype, triton::driver::stream* stream, bool AT, bool BT,
|
||||
int32_t M, int32_t N, int32_t K,
|
||||
void* alpha, triton::driver::buffer* A, int32_t lda,
|
||||
triton::driver::buffer* B, int32_t ldb,
|
||||
void* beta, triton::driver::buffer* C, int32_t ldc,
|
||||
cublasGemmAlgo_t* fastest = NULL, cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT) {
|
||||
triton::driver::cu_context::context_switcher scope(*stream->context());
|
||||
static cublasHandle_t handle = cublasGetHandle(stream);
|
||||
if(dtype == HALF_TYPE)
|
||||
cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH);
|
||||
cublasStatus_t status;
|
||||
if(fastest)
|
||||
*fastest = cublasGemmFastest(stream, handle, cu_dtype.at(dtype), cu_op.at(AT), cu_op.at(BT), M, N, K, alpha, *A->cu(), lda, *B->cu(), ldb, beta, *C->cu(), ldc);
|
||||
else
|
||||
status = cublasGemmEx(handle, cu_dtype.at(dtype), cu_op.at(AT), cu_op.at(BT), M, N, K, alpha, *A->cu(), lda, *B->cu(), ldb, beta, *C->cu(), ldc, algo);
|
||||
}
|
@@ -34,7 +34,7 @@ void check(cl_int err);
|
||||
|
||||
class dispatch
|
||||
{
|
||||
private:
|
||||
protected:
|
||||
template <class F>
|
||||
struct return_type;
|
||||
|
||||
|
3
tests/CMakeLists.txt
Normal file
3
tests/CMakeLists.txt
Normal file
@@ -0,0 +1,3 @@
|
||||
include_directories("${CMAKE_CURRENT_SOURCE_DIR}/common")
|
||||
add_subdirectory(bench)
|
||||
add_subdirectory(unit)
|
6
tests/bench/CMakeLists.txt
Normal file
6
tests/bench/CMakeLists.txt
Normal file
@@ -0,0 +1,6 @@
|
||||
foreach(PROG dot)
|
||||
set(TARGET bench_${PROG})
|
||||
add_executable(${TARGET} ${PROG}.cc)
|
||||
set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME ${TARGET})
|
||||
target_link_libraries(${TARGET} triton dl)
|
||||
endforeach(PROG)
|
98
tests/bench/dot.cc
Normal file
98
tests/bench/dot.cc
Normal file
@@ -0,0 +1,98 @@
|
||||
#include <cstring>
|
||||
#include <sstream>
|
||||
#include <cstdio>
|
||||
#include "triton/driver/backend.h"
|
||||
#include "triton/driver/stream.h"
|
||||
#include "triton/tools/bench.hpp"
|
||||
#include "triton/external/half.hpp"
|
||||
#include "triton/runtime/function.h"
|
||||
#include "src/dot.h"
|
||||
#include "cuda/cublas.h"
|
||||
|
||||
|
||||
struct perf_t {
|
||||
double triton;
|
||||
double cublas;
|
||||
};
|
||||
|
||||
namespace drv = triton::driver;
|
||||
namespace rt = triton::runtime;
|
||||
|
||||
inline size_t ceil(size_t x, size_t y) {
|
||||
return (x + y - 1) / y;
|
||||
};
|
||||
|
||||
|
||||
std::vector<double> do_bench(drv::stream* stream, bool AT, bool BT, int32_t M, int32_t N, int32_t K){
|
||||
typedef half_float::half NumericT;
|
||||
std::string ty = "half";
|
||||
size_t dt_nbytes = sizeof(NumericT);
|
||||
drv::context* context = stream->context();
|
||||
// leading dimensions
|
||||
int32_t lda = AT ? K : M;
|
||||
int32_t ldb = BT ? N : K;
|
||||
int32_t ldc = M;
|
||||
// create inputs
|
||||
auto dc = std::unique_ptr<drv::buffer>(drv::buffer::create(context, M*N*dt_nbytes));
|
||||
auto da = std::unique_ptr<drv::buffer>(drv::buffer::create(context, M*K*dt_nbytes));
|
||||
auto db = std::unique_ptr<drv::buffer>(drv::buffer::create(context, K*N*dt_nbytes));
|
||||
// create options
|
||||
rt::function::options_space_t opt;
|
||||
opt.defines.push_back({"TYPE", {ty}});
|
||||
if(AT)
|
||||
opt.defines.push_back({"AT", {""}});
|
||||
if(BT)
|
||||
opt.defines.push_back({"BT", {""}});
|
||||
opt.defines.push_back({"TM", {"16", "32", "64", "128"}});
|
||||
opt.defines.push_back({"TN", {"16", "32", "64", "128"}});
|
||||
opt.defines.push_back({"TK", {"32"}});
|
||||
opt.num_warps = {1, 2, 4, 8};
|
||||
// create grid
|
||||
auto grid = [&](const rt::function::options_t& x) {
|
||||
return rt::grid_t{ceil(M, x.D<int>("TM")),
|
||||
ceil(N, x.D<int>("TN"))};
|
||||
};
|
||||
// create function
|
||||
rt::function function(src::dot, opt);
|
||||
// benchmark available libraries
|
||||
std::vector<double> result;
|
||||
auto tflops = [&](double nanosec) { return 2.*M*N*K / nanosec * 1e-3; };
|
||||
// cublas
|
||||
if(cublas::cublasinit()){
|
||||
NumericT alpha(static_cast<double>(1));
|
||||
NumericT beta(static_cast<double>(0));
|
||||
cublasGemmAlgo_t fastest;
|
||||
cublasGemm(CUDA_R_16F, stream, AT, BT, M, N, K, &alpha, &*da, lda, &*db, ldb, &beta, &*dc, ldc, &fastest);
|
||||
double cublas_ms = triton::tools::bench([&]() { cublasGemm(CUDA_R_16F, stream, AT, BT, M, N, K,
|
||||
&alpha, &*da, lda, &*db, ldb, &beta, &*dc, ldc, nullptr, fastest); }, stream);
|
||||
result.push_back(tflops(cublas_ms));
|
||||
}
|
||||
// triton
|
||||
double triton_ms = triton::tools::bench([&]() { function({&*da, &*db, &*dc, M, N, K, lda, ldb, ldc}, grid, stream);}, stream);
|
||||
result.push_back(tflops(triton_ms));
|
||||
// done
|
||||
return result;
|
||||
}
|
||||
|
||||
int main() {
|
||||
// initialize default compute device
|
||||
auto context = triton::driver::backend::contexts::get_default();
|
||||
triton::driver::stream* stream = triton::driver::stream::create(context);
|
||||
// shapes to benchmark
|
||||
typedef std::tuple<bool, bool, int, int, int> config_t;
|
||||
std::vector<config_t> configs = {
|
||||
config_t{false, true, 512, 512, 512},
|
||||
config_t{false, true, 2048, 2048, 2048},
|
||||
config_t{false, true, 8192, 8192, 8192}
|
||||
};
|
||||
// does the work
|
||||
bool AT, BT;
|
||||
int32_t M, N, K;
|
||||
for(const auto& c: configs){
|
||||
std::tie(AT, BT, M, N, K) = c;
|
||||
std::cout << "// " << AT << " " << BT << " " << M << " " << N << " " << K << std::flush;
|
||||
for(auto perf: do_bench(stream, AT, BT, M, N, K))
|
||||
std::cout << ", " << perf << std::flush;
|
||||
std::cout << std::endl;
|
||||
}
|
||||
}
|
221
tests/common/cuda/cublas.h
Normal file
221
tests/common/cuda/cublas.h
Normal file
@@ -0,0 +1,221 @@
|
||||
/* Copyright 2019 Philippe Tillet
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining
|
||||
* a copy of this software and associated documentation files
|
||||
* (the "Software"), to deal in the Software without restriction,
|
||||
* including without limitation the rights to use, copy, modify, merge,
|
||||
* publish, distribute, sublicense, and/or sell copies of the Software,
|
||||
* and to permit persons to whom the Software is furnished to do so,
|
||||
* subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be
|
||||
* included in all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*/
|
||||
|
||||
#include <algorithm>
|
||||
#include <vector>
|
||||
#include <cassert>
|
||||
#include "forward.h"
|
||||
#include "triton/driver/buffer.h"
|
||||
#include "triton/driver/stream.h"
|
||||
#include "triton/driver/context.h"
|
||||
#include "triton/driver/error.h"
|
||||
#include "triton/tools/bench.hpp"
|
||||
|
||||
|
||||
class cublas {
|
||||
private:
|
||||
template <class F>
|
||||
struct return_type;
|
||||
|
||||
template <class R, class... A>
|
||||
struct return_type<R (*)(A...)>
|
||||
{ typedef R type; };
|
||||
|
||||
typedef bool (*f_init_t)();
|
||||
|
||||
template<f_init_t initializer, typename FunPtrT, typename... Args>
|
||||
static typename return_type<FunPtrT>::type f_impl(void*& lib_h, FunPtrT, void*& cache, const char * name, Args... args)
|
||||
{
|
||||
initializer();
|
||||
if(cache == nullptr){
|
||||
cache = dlsym(lib_h, name);
|
||||
if(cache == 0)
|
||||
throw std::runtime_error("dlsym unable to load function");
|
||||
}
|
||||
FunPtrT fptr;
|
||||
*reinterpret_cast<void **>(&fptr) = cache;
|
||||
typename return_type<FunPtrT>::type res = (*fptr)(args...);
|
||||
triton::driver::check(res);
|
||||
return res;
|
||||
}
|
||||
|
||||
public:
|
||||
static bool cublasinit();
|
||||
static cublasStatus_t cublasSetMathMode(cublasHandle_t h, cublasMath_t m);
|
||||
static cublasStatus_t cublasCreate_v2(cublasHandle_t* h);
|
||||
static cublasStatus_t cublasGetStream_v2(cublasHandle_t h, cudaStream_t *streamId);
|
||||
static cublasStatus_t cublasSetStream_v2(cublasHandle_t h, cudaStream_t streamId);
|
||||
static cublasStatus_t cublasGemmEx(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb,
|
||||
int m, int n, int k,
|
||||
const void *alpha, const void *A, cudaDataType Atype, int lda,
|
||||
const void *B, cudaDataType Btype, int ldb, const void *beta,
|
||||
void *C, cudaDataType Ctype, int ldc,
|
||||
cudaDataType computeType, cublasGemmAlgo_t algo);
|
||||
|
||||
private:
|
||||
static void* so_;
|
||||
static void* cublasGetStream_v2_;
|
||||
static void* cublasSetStream_v2_;
|
||||
static void* cublasCreate_v2_;
|
||||
static void* cublasGemmEx_;
|
||||
static void* cublasSetMathMode_;
|
||||
};
|
||||
|
||||
void* cublas::so_;
|
||||
void* cublas::cublasGetStream_v2_;
|
||||
void* cublas::cublasSetStream_v2_;
|
||||
void* cublas::cublasCreate_v2_;
|
||||
void* cublas::cublasGemmEx_;
|
||||
void* cublas::cublasSetMathMode_;
|
||||
|
||||
|
||||
bool cublas::cublasinit() {
|
||||
if(so_==nullptr)
|
||||
so_ = dlopen("libcublas.so", RTLD_LAZY);
|
||||
return so_ != nullptr;
|
||||
}
|
||||
|
||||
cublasStatus_t cublas::cublasGetStream_v2(cublasHandle_t h, cudaStream_t *a)
|
||||
{ return f_impl<cublas::cublasinit>(so_, cublasGetStream_v2, cublasGetStream_v2_, "cublasGetStream_v2", h, a); }
|
||||
cublasStatus_t cublas::cublasSetStream_v2(cublasHandle_t h, cudaStream_t a)
|
||||
{ return f_impl<cublas::cublasinit>(so_, cublasSetStream_v2, cublasSetStream_v2_, "cublasSetStream_v2", h, a); }
|
||||
cublasStatus_t cublas::cublasGemmEx(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k,
|
||||
const void *alpha, const void *A, cudaDataType Atype, int lda,
|
||||
const void *B, cudaDataType Btype, int ldb, const void *beta,
|
||||
void *C, cudaDataType Ctype, int ldc, cudaDataType computeType, cublasGemmAlgo_t algo) {
|
||||
return f_impl<cublas::cublasinit>(so_, cublasGemmEx, cublasGemmEx_, "cublasGemmEx", handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, Btype, ldb, beta, C, Ctype, ldc, computeType, algo);
|
||||
}
|
||||
cublasStatus_t cublas::cublasCreate_v2(cublasHandle_t *h) {
|
||||
return f_impl<cublas::cublasinit>(so_, cublasCreate_v2, cublasCreate_v2_, "cublasCreate_v2", h);
|
||||
}
|
||||
cublasStatus_t cublas::cublasSetMathMode(cublasHandle_t h, cublasMath_t m) {
|
||||
return f_impl<cublas::cublasinit>(so_, cublasSetMathMode, cublasSetMathMode_, "cublasSetMathMode", h, m);
|
||||
}
|
||||
|
||||
|
||||
|
||||
inline cublasGemmAlgo_t cublasGemmFastest(
|
||||
triton::driver::stream* stream,
|
||||
cublasHandle_t handle, cudaDataType cudt,
|
||||
cublasOperation_t AT, cublasOperation_t BT,
|
||||
int32_t M, int32_t N, int32_t K,
|
||||
void* alpha, CUdeviceptr A, int32_t lda, CUdeviceptr B, int32_t ldb,
|
||||
void* beta, CUdeviceptr C, int32_t ldc) {
|
||||
|
||||
// initialize list of cublas algorithms
|
||||
static std::vector<cublasGemmAlgo_t> algorithms;
|
||||
if(algorithms.empty()) {
|
||||
// non-tensor ops
|
||||
for(int i = -1; i < 24; i++)
|
||||
algorithms.push_back((cublasGemmAlgo_t)i);
|
||||
// tensor ops
|
||||
for(int i = 99; i < 116; i++)
|
||||
algorithms.push_back((cublasGemmAlgo_t)i);
|
||||
}
|
||||
|
||||
// cache to avoid re-benchmarking
|
||||
typedef std::tuple<cudaDataType_t,
|
||||
cublasOperation_t, cublasOperation_t,
|
||||
int32_t, int32_t, int32_t> key_t;
|
||||
static std::map<key_t, cublasGemmAlgo_t> cache;
|
||||
key_t key(cudt, AT, BT, M, N, K);
|
||||
// benchmark algorithms if necessary
|
||||
if(cache.find(key) == cache.end()){
|
||||
std::vector<double> times;
|
||||
for(cublasGemmAlgo_t a: algorithms) {
|
||||
cublasStatus_t status;
|
||||
double nanosec = triton::tools::bench([&](){ status = cublas::cublasGemmEx(handle, AT, BT,
|
||||
M, N, K,
|
||||
alpha, (const void*)A, cudt, lda,
|
||||
(const void*)B, cudt, ldb,
|
||||
beta, (void*)C, cudt, ldc, cudt,
|
||||
a); }, stream);
|
||||
if(status != CUBLAS_STATUS_SUCCESS)
|
||||
nanosec = INFINITY;
|
||||
}
|
||||
size_t argmin = std::min_element(times.begin(), times.end()) - times.begin();
|
||||
assert(times[argmin] != INFINITY);
|
||||
cache.insert({key, algorithms[argmin]});
|
||||
}
|
||||
|
||||
// return best algorithm
|
||||
return cache.at(key);
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
/* Get cuBLAS handle */
|
||||
inline cublasHandle_t cublasGetHandle(triton::driver::stream* stream) {
|
||||
static std::map<CUstream, cublasHandle_t> cache;
|
||||
CUstream key = *stream->cu();
|
||||
|
||||
// create handle if necessary
|
||||
if(cache.find(key) == cache.end()) {
|
||||
cublasHandle_t handle;
|
||||
if(cublas::cublasCreate_v2(&handle) != CUBLAS_STATUS_SUCCESS)
|
||||
throw std::runtime_error("Error: could not create cuBLAS handle");
|
||||
cublas::cublasSetStream_v2(handle, key);
|
||||
cache.insert({key, handle});
|
||||
}
|
||||
|
||||
// return handle for the stream
|
||||
return cache.at(key);
|
||||
}
|
||||
|
||||
|
||||
|
||||
/* Simplified API for default GEMM */
|
||||
inline void cublasGemm(cublasDataType_t dtype,
|
||||
triton::driver::stream* stream,
|
||||
bool AT, bool BT,
|
||||
int32_t M, int32_t N, int32_t K,
|
||||
void* alpha, triton::driver::buffer* A, int32_t lda,
|
||||
triton::driver::buffer* B, int32_t ldb,
|
||||
void* beta, triton::driver::buffer* C, int32_t ldc,
|
||||
cublasGemmAlgo_t* fastest = NULL, cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT) {
|
||||
|
||||
// switch triton context
|
||||
triton::driver::cu_context::context_switcher scope(*stream->context());
|
||||
// get handle
|
||||
static cublasHandle_t handle = cublasGetHandle(stream);
|
||||
// set math mode
|
||||
if(dtype == CUDA_R_16F)
|
||||
cublas::cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH);
|
||||
// cuda types
|
||||
static const std::map<char, cublasOperation_t> cu_op = {
|
||||
{false, CUBLAS_OP_N},
|
||||
{true, CUBLAS_OP_T}
|
||||
};
|
||||
cublasOperation_t opa = cu_op.at(AT);
|
||||
cublasOperation_t opb = cu_op.at(BT);
|
||||
// benchmark fastest
|
||||
if(fastest)
|
||||
*fastest = cublasGemmFastest(stream, handle, dtype, opa, opb, M, N, K, alpha, *A->cu(), lda, *B->cu(), ldb, beta, *C->cu(), ldc);
|
||||
else {
|
||||
// execute supplied algo
|
||||
cublasStatus_t status = cublas::cublasGemmEx(handle, opa, opb, M, N, K,
|
||||
alpha, (const void*)*A->cu(), dtype, lda,
|
||||
(const void*)*B->cu(), dtype, ldb,
|
||||
beta, (void*)*C->cu(), dtype, ldc, dtype, algo);
|
||||
}
|
||||
}
|
105
tests/common/cuda/forward.h
Normal file
105
tests/common/cuda/forward.h
Normal file
@@ -0,0 +1,105 @@
|
||||
#ifndef _COMMON_CUDA_FORWARDS_H_
|
||||
#define _COMMON_CUDA_FORwARDS_H_
|
||||
|
||||
struct cublasContext;
|
||||
typedef struct cublasContext *cublasHandle_t;
|
||||
struct CUstream_st;
|
||||
typedef struct CUstream_st *cudaStream_t;
|
||||
|
||||
/* CUBLAS status type returns */
|
||||
typedef enum{
|
||||
CUBLAS_STATUS_SUCCESS =0,
|
||||
CUBLAS_STATUS_NOT_INITIALIZED =1,
|
||||
CUBLAS_STATUS_ALLOC_FAILED =3,
|
||||
CUBLAS_STATUS_INVALID_VALUE =7,
|
||||
CUBLAS_STATUS_ARCH_MISMATCH =8,
|
||||
CUBLAS_STATUS_MAPPING_ERROR =11,
|
||||
CUBLAS_STATUS_EXECUTION_FAILED=13,
|
||||
CUBLAS_STATUS_INTERNAL_ERROR =14,
|
||||
CUBLAS_STATUS_NOT_SUPPORTED =15,
|
||||
CUBLAS_STATUS_LICENSE_ERROR =16
|
||||
} cublasStatus_t;
|
||||
|
||||
/*For different GEMM algorithm */
|
||||
typedef enum {
|
||||
CUBLAS_GEMM_DFALT = -1,
|
||||
CUBLAS_GEMM_DEFAULT = -1,
|
||||
CUBLAS_GEMM_ALGO0 = 0,
|
||||
CUBLAS_GEMM_ALGO1 = 1,
|
||||
CUBLAS_GEMM_ALGO2 = 2,
|
||||
CUBLAS_GEMM_ALGO3 = 3,
|
||||
CUBLAS_GEMM_ALGO4 = 4,
|
||||
CUBLAS_GEMM_ALGO5 = 5,
|
||||
CUBLAS_GEMM_ALGO6 = 6,
|
||||
CUBLAS_GEMM_ALGO7 = 7,
|
||||
CUBLAS_GEMM_ALGO8 = 8,
|
||||
CUBLAS_GEMM_ALGO9 = 9,
|
||||
CUBLAS_GEMM_ALGO10 = 10,
|
||||
CUBLAS_GEMM_ALGO11 = 11,
|
||||
CUBLAS_GEMM_ALGO12 = 12,
|
||||
CUBLAS_GEMM_ALGO13 = 13,
|
||||
CUBLAS_GEMM_ALGO14 = 14,
|
||||
CUBLAS_GEMM_ALGO15 = 15,
|
||||
CUBLAS_GEMM_ALGO16 = 16,
|
||||
CUBLAS_GEMM_ALGO17 = 17,
|
||||
CUBLAS_GEMM_ALGO18 = 18, //sliced 32x32
|
||||
CUBLAS_GEMM_ALGO19 = 19, //sliced 64x32
|
||||
CUBLAS_GEMM_ALGO20 = 20, //sliced 128x32
|
||||
CUBLAS_GEMM_ALGO21 = 21, //sliced 32x32 -splitK
|
||||
CUBLAS_GEMM_ALGO22 = 22, //sliced 64x32 -splitK
|
||||
CUBLAS_GEMM_ALGO23 = 23, //sliced 128x32 -splitK
|
||||
CUBLAS_GEMM_DEFAULT_TENSOR_OP = 99,
|
||||
CUBLAS_GEMM_DFALT_TENSOR_OP = 99,
|
||||
CUBLAS_GEMM_ALGO0_TENSOR_OP = 100,
|
||||
CUBLAS_GEMM_ALGO1_TENSOR_OP = 101,
|
||||
CUBLAS_GEMM_ALGO2_TENSOR_OP = 102,
|
||||
CUBLAS_GEMM_ALGO3_TENSOR_OP = 103,
|
||||
CUBLAS_GEMM_ALGO4_TENSOR_OP = 104,
|
||||
CUBLAS_GEMM_ALGO5_TENSOR_OP = 105,
|
||||
CUBLAS_GEMM_ALGO6_TENSOR_OP = 106,
|
||||
CUBLAS_GEMM_ALGO7_TENSOR_OP = 107,
|
||||
CUBLAS_GEMM_ALGO8_TENSOR_OP = 108,
|
||||
CUBLAS_GEMM_ALGO9_TENSOR_OP = 109,
|
||||
CUBLAS_GEMM_ALGO10_TENSOR_OP = 110,
|
||||
CUBLAS_GEMM_ALGO11_TENSOR_OP = 111,
|
||||
CUBLAS_GEMM_ALGO12_TENSOR_OP = 112,
|
||||
CUBLAS_GEMM_ALGO13_TENSOR_OP = 113,
|
||||
CUBLAS_GEMM_ALGO14_TENSOR_OP = 114,
|
||||
CUBLAS_GEMM_ALGO15_TENSOR_OP = 115
|
||||
} cublasGemmAlgo_t;
|
||||
|
||||
typedef enum cudaDataType_t
|
||||
{
|
||||
CUDA_R_16F= 2, /* real as a half */
|
||||
CUDA_C_16F= 6, /* complex as a pair of half numbers */
|
||||
CUDA_R_32F= 0, /* real as a float */
|
||||
CUDA_C_32F= 4, /* complex as a pair of float numbers */
|
||||
CUDA_R_64F= 1, /* real as a double */
|
||||
CUDA_C_64F= 5, /* complex as a pair of double numbers */
|
||||
CUDA_R_8I = 3, /* real as a signed char */
|
||||
CUDA_C_8I = 7, /* complex as a pair of signed char numbers */
|
||||
CUDA_R_8U = 8, /* real as a unsigned char */
|
||||
CUDA_C_8U = 9, /* complex as a pair of unsigned char numbers */
|
||||
CUDA_R_32I= 10, /* real as a signed int */
|
||||
CUDA_C_32I= 11, /* complex as a pair of signed int numbers */
|
||||
CUDA_R_32U= 12, /* real as a unsigned int */
|
||||
CUDA_C_32U= 13 /* complex as a pair of unsigned int numbers */
|
||||
} cudaDataType;
|
||||
|
||||
typedef cudaDataType cublasDataType_t;
|
||||
|
||||
typedef enum {
|
||||
CUBLAS_OP_N=0,
|
||||
CUBLAS_OP_T=1,
|
||||
CUBLAS_OP_C=2,
|
||||
CUBLAS_OP_HERMITAN=2, /* synonym if CUBLAS_OP_C */
|
||||
CUBLAS_OP_CONJG=3 /* conjugate */
|
||||
} cublasOperation_t;
|
||||
|
||||
/*Enum for default math mode/tensor operation*/
|
||||
typedef enum {
|
||||
CUBLAS_DEFAULT_MATH = 0,
|
||||
CUBLAS_TENSOR_OP_MATH = 1
|
||||
} cublasMath_t;
|
||||
|
||||
#endif
|
77
tests/common/src/dot.h
Normal file
77
tests/common/src/dot.h
Normal file
@@ -0,0 +1,77 @@
|
||||
namespace src {
|
||||
|
||||
const char *dot =
|
||||
R"(
|
||||
#ifdef AT
|
||||
#define USEA ^a
|
||||
#else
|
||||
#define USEA a
|
||||
#endif
|
||||
|
||||
#ifdef BT
|
||||
#define USEB ^b
|
||||
#else
|
||||
#define USEB b
|
||||
#endif
|
||||
|
||||
void dot(TYPE * A __noalias __readonly __aligned(16),
|
||||
TYPE * B __noalias __readonly __aligned(16),
|
||||
TYPE * C __noalias __readonly __aligned(16),
|
||||
int M, int N, int K,
|
||||
int lda __multipleof(8),
|
||||
int ldb __multipleof(8),
|
||||
int ldc) {
|
||||
int ridx = get_program_id(0);
|
||||
int ridy = get_program_id(1);
|
||||
int rxa[TM] = ridx * TM + 0 ... TM;
|
||||
int ryb[TN] = ridy * TN + 0 ... TN;
|
||||
int rka[TK] = 0 ... TK;
|
||||
int rkb[TK] = 0 ... TK;
|
||||
float xc[TM, TN] = 0;
|
||||
#ifdef AT
|
||||
TYPE* pa[TK, TM] = A + rka[:, newaxis] + rxa[newaxis, :]*lda;
|
||||
bool checka[TK, TM] = rka[:, newaxis] < K;
|
||||
TYPE a[TK, TM] = checka ? *pa : 0;
|
||||
#else
|
||||
TYPE* pa[TM, TK] = A + rka[newaxis, :]*lda + rxa[:, newaxis];
|
||||
bool checka[TM, TK] = rka[newaxis, :] < K;
|
||||
TYPE a[TM, TK] = checka ? *pa : 0;
|
||||
#endif
|
||||
#ifdef BT
|
||||
TYPE* pb[TN, TK] = B + rkb[newaxis, :]*ldb + ryb[:, newaxis];
|
||||
bool checkb[TN, TK] = rkb[newaxis, :] < K;
|
||||
TYPE b[TN, TK] = checkb ? *pb : 0;
|
||||
#else
|
||||
TYPE* pb[TK, TN] = B + rkb[:, newaxis] + ryb[newaxis, :]*ldb;
|
||||
bool checkb[TK, TN] = rkb[:, newaxis] < K;
|
||||
TYPE b[TK, TN] = checkb ? *pb : 0;
|
||||
#endif
|
||||
for(int k = K; k > 0; k = k - TK){
|
||||
xc = USEA @ USEB + xc;
|
||||
#ifdef AT
|
||||
pa = pa + TK;
|
||||
#else
|
||||
pa = pa + TK*lda;
|
||||
#endif
|
||||
#ifdef BT
|
||||
pb = pb + TK*ldb;
|
||||
#else
|
||||
pb = pb + TK;
|
||||
#endif
|
||||
checka = k > TK;
|
||||
checkb = k > TK;
|
||||
a = checka ? *pa : 0;
|
||||
b = checkb ? *pb : 0;
|
||||
}
|
||||
int rxc[TM] = ridx * TM + (0 ... TM);
|
||||
int ryc[TN] = ridy * TN + (0 ... TN);
|
||||
TYPE* pc[TM, TN] = C + ryc[newaxis, :]*ldc + rxc[:, newaxis];
|
||||
TYPE c[TM, TN] = xc;
|
||||
bool checkc0[TM] = rxc < M;
|
||||
bool checkc1[TN] = ryc < N;
|
||||
bool checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];
|
||||
*?(checkc) pc = c;
|
||||
}
|
||||
)";
|
||||
|
||||
}
|
6
tests/unit/CMakeLists.txt
Normal file
6
tests/unit/CMakeLists.txt
Normal file
@@ -0,0 +1,6 @@
|
||||
foreach(PROG dot)
|
||||
set(TARGET test_${PROG})
|
||||
add_executable(${TARGET} ${PROG}.cc)
|
||||
set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME ${TARGET})
|
||||
target_link_libraries(${TARGET} triton dl)
|
||||
endforeach(PROG)
|
@@ -6,7 +6,8 @@
|
||||
#include "triton/tools/bench.hpp"
|
||||
#include "triton/external/half.hpp"
|
||||
#include "triton/runtime/function.h"
|
||||
#include "cuda.h"
|
||||
#include "src/dot.h"
|
||||
#include "cuda/cublas.h"
|
||||
|
||||
template<class T>
|
||||
void diff(const std::vector<T>& x, const std::vector<T>& y){
|
||||
@@ -44,81 +45,6 @@ void cpu_ref(bool AT_, bool BT_, size_t M, size_t N, size_t K,
|
||||
}
|
||||
|
||||
|
||||
|
||||
std::string src =
|
||||
R"(
|
||||
#ifdef AT
|
||||
#define USEA ^a
|
||||
#else
|
||||
#define USEA a
|
||||
#endif
|
||||
|
||||
#ifdef BT
|
||||
#define USEB ^b
|
||||
#else
|
||||
#define USEB b
|
||||
#endif
|
||||
|
||||
void dot(TYPE * A __noalias __readonly __aligned(16),
|
||||
TYPE * B __noalias __readonly __aligned(16),
|
||||
TYPE * C __noalias __readonly __aligned(16),
|
||||
int M, int N, int K,
|
||||
int lda __multipleof(8),
|
||||
int ldb __multipleof(8),
|
||||
int ldc) {
|
||||
int ridx = get_program_id(0);
|
||||
int ridy = get_program_id(1);
|
||||
int rxa[TM] = ridx * TM + 0 ... TM;
|
||||
int ryb[TN] = ridy * TN + 0 ... TN;
|
||||
int rka[TK] = 0 ... TK;
|
||||
int rkb[TK] = 0 ... TK;
|
||||
float xc[TM, TN] = 0;
|
||||
#ifdef AT
|
||||
TYPE* pa[TK, TM] = A + rka[:, newaxis] + rxa[newaxis, :]*lda;
|
||||
bool checka[TK, TM] = rka[:, newaxis] < K;
|
||||
TYPE a[TK, TM] = checka ? *pa : 0;
|
||||
#else
|
||||
TYPE* pa[TM, TK] = A + rka[newaxis, :]*lda + rxa[:, newaxis];
|
||||
bool checka[TM, TK] = rka[newaxis, :] < K;
|
||||
TYPE a[TM, TK] = checka ? *pa : 0;
|
||||
#endif
|
||||
#ifdef BT
|
||||
TYPE* pb[TN, TK] = B + rkb[newaxis, :]*ldb + ryb[:, newaxis];
|
||||
bool checkb[TN, TK] = rkb[newaxis, :] < K;
|
||||
TYPE b[TN, TK] = checkb ? *pb : 0;
|
||||
#else
|
||||
TYPE* pb[TK, TN] = B + rkb[:, newaxis] + ryb[newaxis, :]*ldb;
|
||||
bool checkb[TK, TN] = rkb[:, newazis] < K;
|
||||
TYPE b[TK, TN] = checkb ? *pb : 0;
|
||||
#endif
|
||||
for(int k = K; k > 0; k = k - TK){
|
||||
xc = USEA @ USEB + xc;
|
||||
#ifdef AT
|
||||
pa = pa + TK;
|
||||
#else
|
||||
pa = pa + TK*lda;
|
||||
#endif
|
||||
#ifdef BT
|
||||
pb = pb + TK*ldb;
|
||||
#else
|
||||
pb = pb + TK;
|
||||
#endif
|
||||
checka = k > TK;
|
||||
checkb = k > TK;
|
||||
a = checka ? *pa : 0;
|
||||
b = checkb ? *pb : 0;
|
||||
}
|
||||
int rxc[TM] = ridx * TM + (0 ... TM);
|
||||
int ryc[TN] = ridy * TN + (0 ... TN);
|
||||
TYPE* pc[TM, TN] = C + ryc[newaxis, :]*ldc + rxc[:, newaxis];
|
||||
TYPE c[TM, TN] = xc;
|
||||
bool checkc0[TM] = rxc < M;
|
||||
bool checkc1[TN] = ryc < N;
|
||||
bool checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];
|
||||
*?(checkc) pc = c;
|
||||
}
|
||||
)";
|
||||
|
||||
struct perf_t {
|
||||
double triton;
|
||||
double cublas;
|
||||
@@ -128,7 +54,7 @@ namespace drv = triton::driver;
|
||||
namespace rt = triton::runtime;
|
||||
|
||||
perf_t do_bench(drv::stream* stream, bool AT, bool BT, int32_t M, int32_t N, int32_t K){
|
||||
typedef half NumericT;
|
||||
typedef half_float::half NumericT;
|
||||
std::string ty = "half";
|
||||
size_t dt_nbytes = sizeof(NumericT);
|
||||
drv::context* context = stream->context();
|
||||
@@ -140,9 +66,9 @@ perf_t do_bench(drv::stream* stream, bool AT, bool BT, int32_t M, int32_t N, int
|
||||
int32_t ldc = M;
|
||||
srand(0);
|
||||
for(size_t i = 0; i < ha.size(); i++)
|
||||
ha[i] = static_cast<NumericT>((double)rand()/RAND_MAX);
|
||||
ha[i] = static_cast<NumericT>((float)rand()/RAND_MAX);
|
||||
for(size_t i = 0; i < hb.size(); i++)
|
||||
hb[i] = static_cast<NumericT>((double)rand()/RAND_MAX);
|
||||
hb[i] = static_cast<NumericT>((float)rand()/RAND_MAX);
|
||||
for(size_t i = 0; i < hc.size(); i++)
|
||||
hc[i] = static_cast<NumericT>((double)0);
|
||||
drv::buffer* dc = drv::buffer::create(context, hc.size()*dt_nbytes);
|
||||
@@ -159,11 +85,11 @@ perf_t do_bench(drv::stream* stream, bool AT, bool BT, int32_t M, int32_t N, int
|
||||
opt.defines.push_back({"AT", {""}});
|
||||
if(BT)
|
||||
opt.defines.push_back({"BT", {""}});
|
||||
opt.defines.push_back({"TM", {"128"}});
|
||||
opt.defines.push_back({"TN", {"128"}});
|
||||
opt.defines.push_back({"TM", {"16", "32", "64", "128"}});
|
||||
opt.defines.push_back({"TN", {"16", "32", "64", "128"}});
|
||||
opt.defines.push_back({"TK", {"32"}});
|
||||
opt.num_warps = {4};
|
||||
rt::function function(src, opt);
|
||||
opt.num_warps = {1, 2, 4, 8};
|
||||
rt::function function(src::dot, opt);
|
||||
|
||||
auto ceil = [](size_t x, size_t y) { return (x + y - 1) / y; };
|
||||
auto grid = [&](const rt::function::options_t& x) { return rt::grid_t{ceil(M, x.D<int>("TM")), ceil(N, x.D<int>("TN")), 1}; };
|
||||
@@ -171,10 +97,15 @@ perf_t do_bench(drv::stream* stream, bool AT, bool BT, int32_t M, int32_t N, int
|
||||
auto tflops = [&](double nanosec) { return 2.*M*N*K / nanosec * 1e-3; };
|
||||
perf_t res;
|
||||
res.triton = tflops(triton::tools::bench([&]() { function({da, db, dc, M, N, K, lda, ldb, ldc}, grid, stream);}, stream));
|
||||
res.cublas = 0;
|
||||
NumericT alpha(static_cast<double>(1));
|
||||
NumericT beta(static_cast<double>(0));
|
||||
cublasGemmAlgo_t fastest;
|
||||
cublasGemm(CUDA_R_16F, stream, AT, BT, M, N, K, &alpha, da, lda, db, ldb, &beta, dc, ldc, &fastest);
|
||||
res.cublas = tflops(triton::tools::bench([&]() { cublasGemm(CUDA_R_16F, stream, AT, BT, M, N, K,
|
||||
&alpha, da, lda, db, ldb, &beta, dc, ldc, nullptr, fastest); },
|
||||
stream));
|
||||
|
||||
// test
|
||||
stream->synchronize();
|
||||
// stream->read(dc, true, 0, hc);
|
||||
// std::vector<NumericT> rc(hc.size());
|
||||
// cpu_ref(AT, BT, M, N, K, rc, ha, hb);
|
||||
@@ -214,7 +145,7 @@ int main() {
|
||||
// shapes to benchmark
|
||||
std::vector<config_t> configs = {
|
||||
// {false, false, 8192, 512, 512},
|
||||
{false, true, 8192, 8192, 8192}
|
||||
{false, true, 128, 128, 128}
|
||||
// {false, true, 128, 128, 128},
|
||||
// {false, false, 128, 128, 128},
|
||||
// {true, false, 128, 128, 128},
|
Reference in New Issue
Block a user