[examples/pytorch] added a bunch of models for more thorough testing
This commit is contained in:
75
examples/python/pytorch/common.hpp
Normal file
75
examples/python/pytorch/common.hpp
Normal file
@@ -0,0 +1,75 @@
|
||||
#include <vector>
|
||||
#include <chrono>
|
||||
#include <cmath>
|
||||
#include "triton/driver/device.h"
|
||||
#include <algorithm>
|
||||
|
||||
class timer{
|
||||
typedef std::chrono::high_resolution_clock high_resolution_clock;
|
||||
typedef std::chrono::nanoseconds nanoseconds;
|
||||
|
||||
public:
|
||||
explicit timer(bool run = false)
|
||||
{ if (run) start(); }
|
||||
|
||||
void start()
|
||||
{ _start = high_resolution_clock::now(); }
|
||||
|
||||
nanoseconds get() const
|
||||
{ return std::chrono::duration_cast<nanoseconds>(high_resolution_clock::now() - _start); }
|
||||
|
||||
private:
|
||||
high_resolution_clock::time_point _start;
|
||||
};
|
||||
|
||||
template<class T>
|
||||
T min(std::vector<T> x)
|
||||
{ return *std::min_element(x.begin(), x.end()); }
|
||||
|
||||
|
||||
template<class OP, class SYNC>
|
||||
double bench(OP const & op, SYNC const & sync, triton::driver::device const & device)
|
||||
{
|
||||
timer tmr;
|
||||
std::vector<size_t> times;
|
||||
double total_time = 0;
|
||||
op();
|
||||
sync();
|
||||
while(total_time*1e-9 < 1e-3){
|
||||
float norm = 1;
|
||||
tmr.start();
|
||||
op();
|
||||
sync();
|
||||
times.push_back(norm*tmr.get().count());
|
||||
total_time+=times.back();
|
||||
}
|
||||
return min(times);
|
||||
}
|
||||
|
||||
// helper function to print a tuple of any size
|
||||
template<class Tuple, std::size_t N>
|
||||
struct TuplePrinter {
|
||||
static void print(const Tuple& t)
|
||||
{
|
||||
TuplePrinter<Tuple, N-1>::print(t);
|
||||
std::cout << ", " << std::get<N-1>(t);
|
||||
}
|
||||
};
|
||||
|
||||
template<class Tuple>
|
||||
struct TuplePrinter<Tuple, 1> {
|
||||
static void print(const Tuple& t)
|
||||
{
|
||||
std::cout << std::get<0>(t);
|
||||
}
|
||||
};
|
||||
|
||||
template<class... Args>
|
||||
void print(const std::tuple<Args...>& t)
|
||||
{
|
||||
std::cout << "(";
|
||||
TuplePrinter<decltype(t), sizeof...(Args)>::print(t);
|
||||
std::cout << ")\n";
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user