76 lines
1.6 KiB
C++
76 lines
1.6 KiB
C++
#include <vector>
|
|
#include <chrono>
|
|
#include <cmath>
|
|
#include "triton/driver/device.h"
|
|
#include <algorithm>
|
|
|
|
class timer{
|
|
typedef std::chrono::high_resolution_clock high_resolution_clock;
|
|
typedef std::chrono::nanoseconds nanoseconds;
|
|
|
|
public:
|
|
explicit timer(bool run = false)
|
|
{ if (run) start(); }
|
|
|
|
void start()
|
|
{ _start = high_resolution_clock::now(); }
|
|
|
|
nanoseconds get() const
|
|
{ return std::chrono::duration_cast<nanoseconds>(high_resolution_clock::now() - _start); }
|
|
|
|
private:
|
|
high_resolution_clock::time_point _start;
|
|
};
|
|
|
|
template<class T>
|
|
T min(std::vector<T> x)
|
|
{ return *std::min_element(x.begin(), x.end()); }
|
|
|
|
|
|
template<class OP, class SYNC>
|
|
double bench(OP const & op, SYNC const & sync, triton::driver::device const & device)
|
|
{
|
|
timer tmr;
|
|
std::vector<size_t> times;
|
|
double total_time = 0;
|
|
op();
|
|
sync();
|
|
while(total_time*1e-9 < 1e-3){
|
|
float norm = 1;
|
|
tmr.start();
|
|
op();
|
|
sync();
|
|
times.push_back(norm*tmr.get().count());
|
|
total_time+=times.back();
|
|
}
|
|
return min(times);
|
|
}
|
|
|
|
// helper function to print a tuple of any size
|
|
template<class Tuple, std::size_t N>
|
|
struct TuplePrinter {
|
|
static void print(const Tuple& t)
|
|
{
|
|
TuplePrinter<Tuple, N-1>::print(t);
|
|
std::cout << ", " << std::get<N-1>(t);
|
|
}
|
|
};
|
|
|
|
template<class Tuple>
|
|
struct TuplePrinter<Tuple, 1> {
|
|
static void print(const Tuple& t)
|
|
{
|
|
std::cout << std::get<0>(t);
|
|
}
|
|
};
|
|
|
|
template<class... Args>
|
|
void print(const std::tuple<Args...>& t)
|
|
{
|
|
std::cout << "(";
|
|
TuplePrinter<decltype(t), sizeof...(Args)>::print(t);
|
|
std::cout << ")\n";
|
|
}
|
|
|
|
|