Files
triton/examples/python/pytorch/common.hpp

76 lines
1.6 KiB
C++

#include <vector>
#include <chrono>
#include <cmath>
#include "triton/driver/device.h"
#include <algorithm>
class timer{
typedef std::chrono::high_resolution_clock high_resolution_clock;
typedef std::chrono::nanoseconds nanoseconds;
public:
explicit timer(bool run = false)
{ if (run) start(); }
void start()
{ _start = high_resolution_clock::now(); }
nanoseconds get() const
{ return std::chrono::duration_cast<nanoseconds>(high_resolution_clock::now() - _start); }
private:
high_resolution_clock::time_point _start;
};
template<class T>
T min(std::vector<T> x)
{ return *std::min_element(x.begin(), x.end()); }
template<class OP, class SYNC>
double bench(OP const & op, SYNC const & sync, triton::driver::device const & device)
{
timer tmr;
std::vector<size_t> times;
double total_time = 0;
op();
sync();
while(total_time*1e-9 < 1e-3){
float norm = 1;
tmr.start();
op();
sync();
times.push_back(norm*tmr.get().count());
total_time+=times.back();
}
return min(times);
}
// helper function to print a tuple of any size
template<class Tuple, std::size_t N>
struct TuplePrinter {
static void print(const Tuple& t)
{
TuplePrinter<Tuple, N-1>::print(t);
std::cout << ", " << std::get<N-1>(t);
}
};
template<class Tuple>
struct TuplePrinter<Tuple, 1> {
static void print(const Tuple& t)
{
std::cout << std::get<0>(t);
}
};
template<class... Args>
void print(const std::tuple<Args...>& t)
{
std::cout << "(";
TuplePrinter<decltype(t), sizeof...(Args)>::print(t);
std::cout << ")\n";
}