[general] a bunch of fixes in anticipation of proper triton vs cudnn
benchmarks * DNN: Added partial auto-tuning mode and skeleton for heuristics * Examples: Moduralized benchmarking and now evaluating ResNet-18 shapes
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
#include <cstring>
|
||||
#include <sstream>
|
||||
#include <cstdio>
|
||||
#include "triton/runtime/jit.h"
|
||||
#include "triton/driver/backend.h"
|
||||
@@ -16,7 +17,7 @@ void diff(const std::vector<T>& x, const std::vector<T>& y){
|
||||
std::cout << "Pass!" << std::endl;
|
||||
}
|
||||
|
||||
double bench(triton::driver::context* context, bool AT, bool BT, int32_t M, int32_t N, int32_t K){
|
||||
double do_bench(triton::driver::context* context, bool AT, bool BT, int32_t M, int32_t N, int32_t K){
|
||||
typedef float T;
|
||||
std::string ty = "fp16";
|
||||
size_t dt_nbytes = sizeof(T);
|
||||
@@ -39,11 +40,11 @@ double bench(triton::driver::context* context, bool AT, bool BT, int32_t M, int3
|
||||
stream->write(dc, true, 0, hc);
|
||||
stream->synchronize();
|
||||
triton::dnn::dot dot(M, N, K, AT, BT, ty, ty, 8, 8);
|
||||
double result = triton::tools::bench([&]() { dot.enqueue(stream, {da, db, dc}, triton::dnn::PARTIAL_TUNING);}, stream);
|
||||
double nanosec = triton::tools::bench([&]() { dot.enqueue(stream, {da, db, dc}, triton::dnn::PARTIAL_TUNING);}, stream);
|
||||
delete dc;
|
||||
delete da;
|
||||
delete db;
|
||||
return result;
|
||||
return dot.num_flops() / nanosec * 1e-3;
|
||||
}
|
||||
|
||||
int main() {
|
||||
@@ -53,20 +54,28 @@ int main() {
|
||||
int32_t M;
|
||||
int32_t N;
|
||||
int32_t K;
|
||||
|
||||
std::string repr() {
|
||||
std::ostringstream oss;
|
||||
oss << AT << " " << BT << " " << M << " " << N << " " << K;
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
double perf(triton::driver::context *context){
|
||||
return do_bench(context, AT, BT, M, N, K);
|
||||
}
|
||||
};
|
||||
// shapes to benchmark
|
||||
std::vector<config_t> configs = {
|
||||
{false, false, 4096, 4096, 4096},
|
||||
{false, true, 4096, 4096, 4096},
|
||||
{true, false, 4096, 4096, 4096},
|
||||
{true, true, 4096, 4096, 4096}
|
||||
{false, true, 4096, 4096, 4096},
|
||||
{true, false, 4096, 4096, 4096},
|
||||
{true, true, 4096, 4096, 4096}
|
||||
};
|
||||
// initialize default compute device
|
||||
auto context = triton::driver::backend::contexts::get_default();
|
||||
// does the work
|
||||
for(config_t c: configs){
|
||||
double tns = bench(context, c.AT, c.BT, c.M, c.N, c.K);
|
||||
double tflops = 2.*c.M*c.N*c.K / tns * 1e-3;
|
||||
std::cout << c.AT << ", " << c.BT << ", " << c.M << ", " << c.N << ", " << c.K << ", " << tflops << std::endl;
|
||||
std::cout << c.repr() << ", " << c.perf(context) << std::endl;
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user