[general] a bunch of fixes in anticipation of proper triton vs cudnn

benchmarks

* DNN: Added partial auto-tuning mode and skeleton for heuristics
* Examples: Moduralized benchmarking and now evaluating ResNet-18 shapes
This commit is contained in:
Philippe Tillet
2019-07-21 20:17:56 -07:00
parent b1d81a5802
commit ead368d1ed
10 changed files with 221 additions and 147 deletions

View File

@@ -1,6 +1,7 @@
#include "triton/driver/stream.h"
#include "triton/driver/kernel.h"
#include "triton/dnn/gemm.h"
#include "triton/dnn/heuristics.h"
#include <string>
namespace triton{
@@ -147,99 +148,12 @@ void matmul(restrict read_only align(16) )" + a_ty_ + R"( *A,
// small search space for partial auto-tuning
std::vector<params_t> dot::search_space() const {
typedef std::vector<unsigned> params_t;
typedef std::tuple<size_t, size_t> key_t;
static std::vector<key_t> keys = {
{16, 16}, {16, 32}, {16, 64}, {16, 128},
{32, 16}, {32, 32}, {32, 64}, {32, 128},
{64, 16}, {64, 32}, {64, 64}, {64, 128},
{128, 16},{128, 32},{128, 64},{128, 128}
};
static std::vector<params_t> space_nn = {
{4, 4, 16, 8, 16, 2, 2, 1, 1, 8, 32, 4, 8, 1},
{2, 8, 16, 8, 32, 2, 2, 1, 1, 16, 32, 4, 8, 1},
{4, 4, 16, 4, 64, 2, 2, 1, 1, 8, 32, 8, 4, 1},
{4, 4, 16, 16, 128, 2, 2, 1, 2, 16, 32, 4, 8, 1},
{4, 8, 32, 8, 16, 2, 2, 1, 1, 8, 32, 4, 8, 1},
{4, 8, 32, 8, 32, 2, 2, 1, 1, 8, 32, 4, 8, 1},
{8, 4, 32, 8, 64, 2, 2, 1, 1, 4, 32, 4, 8, 1},
{8, 4, 32, 16, 128, 2, 2, 1, 4, 16, 32, 8, 4, 1},
{8, 8, 64, 4, 16, 2, 2, 1, 1, 4, 32, 8, 4, 1},
{8, 8, 64, 8, 32, 2, 2, 1, 1, 4, 32, 4, 8, 1},
{8, 8, 64, 16, 64, 2, 2, 2, 1, 8, 32, 4, 8, 1},
{16, 4, 64, 16, 128, 2, 2, 2, 2, 8, 32, 8, 4, 1},
{8, 8, 128, 8, 16, 2, 2, 2, 1, 8, 32, 8, 4, 1},
{8, 8, 128, 16, 32, 2, 2, 2, 1, 8, 32, 4, 8, 1},
{8, 8, 128, 32, 64, 2, 2, 2, 2, 16, 32, 4, 8, 1},
{8, 8, 128, 32, 128, 2, 2, 1, 4, 16, 32, 4, 8, 1},
};
static std::vector<params_t> space_nt = {
{4, 4, 16, 2, 8, 16, 2, 2, 1, 1, 8, 32, 16, 1},
{4, 4, 16, 4, 8, 32, 2, 2, 1, 1, 8, 32, 8, 1},
{4, 4, 16, 8, 8, 64, 2, 2, 1, 4, 32, 32, 16, 1},
{4, 4, 16, 32, 4, 128, 2, 2, 1, 2, 16, 32, 2, 1},
{8, 4, 32, 2, 8, 16, 2, 2, 1, 1, 4, 32, 16, 1},
{4, 8, 32, 4, 8, 32, 2, 2, 1, 1, 8, 32, 8, 1},
{16, 8, 128, 4, 4, 64, 2, 2, 1, 4, 8, 32, 32, 1},
{4, 8, 32, 8, 8, 128, 2, 2, 1, 2, 16, 32, 8, 1},
{8, 8, 64, 2, 8, 16, 2, 2, 1, 1, 4, 32, 16, 1},
{8, 8, 64, 4, 8, 32, 2, 2, 1, 1, 4, 32, 8, 1},
{8, 8, 64, 8, 8, 64, 2, 2, 1, 2, 8, 32, 8, 1},
{8, 8, 64, 16, 8, 128, 2, 2, 1, 4, 16, 32, 8, 1},
{8, 8, 128, 2, 8, 16, 2, 2, 2, 1, 8, 32, 32, 1},
{16, 8, 128, 4, 8, 32, 2, 2, 2, 1, 4, 32, 16, 1},
{8, 8, 128, 8, 8, 64, 2, 2, 2, 2, 16, 32, 16, 1},
{8, 8, 128, 8, 8, 128, 2, 2, 4, 1, 16, 32, 16, 1},
};
static std::vector<params_t> space_tn = {
{8, 16, 16, 16, 2, 2, 1, 1, 4, 8, 32, 2, 8, 1},
{4, 16, 8, 32, 2, 2, 1, 1, 8, 4, 32, 4, 8, 1},
{4, 16, 4, 64, 2, 2, 1, 1, 8, 4, 32, 8, 4, 1},
{16, 16, 16, 128, 2, 2, 1, 2, 4, 8, 32, 4, 8, 1},
{4, 32, 8, 16, 2, 2, 1, 1, 8, 4, 32, 4, 8, 1},
{8, 32, 8, 32, 2, 2, 1, 1, 4, 8, 32, 4, 8, 1},
{8, 32, 8, 64, 2, 2, 1, 1, 4, 8, 32, 4, 8, 1},
{32, 32, 64, 128, 2, 2, 2, 2, 4, 8, 32, 2, 8, 1},
{8, 64, 8, 16, 2, 2, 1, 1, 4, 8, 32, 4, 8, 1},
{8, 64, 8, 32, 2, 2, 1, 1, 4, 8, 32, 4, 8, 1},
{16, 64, 16, 64, 2, 2, 2, 1, 4, 8, 32, 4, 8, 1},
{32, 64, 16, 128, 2, 2, 2, 2, 4, 8, 32, 8, 4, 1},
{16, 128, 16, 16, 2, 2, 2, 1, 4, 8, 32, 4, 8, 1},
{32, 128, 32, 32, 2, 2, 4, 1, 4, 8, 32, 4, 8, 1},
{32, 128, 32, 64, 2, 2, 4, 1, 4, 8, 32, 4, 8, 1},
{32, 128, 32, 128, 2, 2, 4, 1, 4, 8, 32, 4, 8, 1},
};
static std::vector<params_t> space_tt = {
{4, 16, 2, 8, 16, 2, 2, 1, 1, 8, 4, 32, 16, 1},
{8, 16, 4, 8, 32, 2, 2, 1, 1, 4, 8, 32, 8, 1},
{16, 16, 4, 8, 64, 2, 2, 1, 4, 8, 4, 32, 32, 1},
{16, 16, 8, 4, 128, 2, 2, 1, 2, 4, 8, 32, 8, 1},
{4, 32, 2, 8, 16, 2, 2, 1, 1, 8, 4, 32, 16, 1},
{8, 32, 4, 8, 32, 2, 2, 1, 1, 4, 8, 32, 8, 1},
{16, 64, 4, 8, 64, 2, 2, 2, 1, 4, 8, 32, 16, 1},
{32, 32, 8, 8, 128, 2, 2, 1, 4, 4, 8, 32, 16, 1},
{8, 64, 2, 8, 16, 2, 2, 1, 1, 4, 8, 32, 16, 1},
{8, 64, 4, 8, 32, 2, 2, 1, 1, 4, 8, 32, 8, 1},
{16, 64, 8, 8, 64, 2, 2, 2, 1, 4, 8, 32, 8, 1},
{32, 64, 8, 8, 128, 2, 2, 1, 4, 4, 8, 32, 16, 1},
{16, 128, 2, 8, 16, 2, 2, 2, 1, 4, 8, 32, 32, 1},
{32, 128, 8, 4, 32, 2, 2, 4, 1, 4, 8, 32, 16, 1},
{32, 128, 16, 4, 64, 2, 2, 4, 1, 4, 8, 32, 8, 1},
{32, 128, 8, 8, 128, 2, 2, 4, 1, 4, 8, 32, 16, 1}
};
if(!AT_ && !BT_)
return space_nn;
else if(!AT_ && BT_)
return space_nt;
else if(AT_ && !BT_)
return space_tn;
else
return space_tt;
return dot_search_space(AT_, BT_);
}
// simple parameter heuristics
params_t dot::heuristics() const {
return search_space().back();
return dot_heuristics(AT_, BT_, M_, N_, K_);
}
}