[general] a bunch of fixes in anticipation of proper triton vs cudnn

benchmarks

* DNN: Added partial auto-tuning mode and skeleton for heuristics
* Examples: Moduralized benchmarking and now evaluating ResNet-18 shapes
This commit is contained in:
Philippe Tillet
2019-07-21 20:17:56 -07:00
parent b1d81a5802
commit ead368d1ed
10 changed files with 221 additions and 147 deletions

View File

@@ -1,4 +1,5 @@
#include <cstring>
#include <sstream>
#include <cstdio>
#include "triton/runtime/jit.h"
#include "triton/driver/backend.h"
@@ -16,7 +17,7 @@ void diff(const std::vector<T>& x, const std::vector<T>& y){
std::cout << "Pass!" << std::endl;
}
double bench(triton::driver::context* context, bool AT, bool BT, int32_t M, int32_t N, int32_t K){
double do_bench(triton::driver::context* context, bool AT, bool BT, int32_t M, int32_t N, int32_t K){
typedef float T;
std::string ty = "fp16";
size_t dt_nbytes = sizeof(T);
@@ -39,11 +40,11 @@ double bench(triton::driver::context* context, bool AT, bool BT, int32_t M, int3
stream->write(dc, true, 0, hc);
stream->synchronize();
triton::dnn::dot dot(M, N, K, AT, BT, ty, ty, 8, 8);
double result = triton::tools::bench([&]() { dot.enqueue(stream, {da, db, dc}, triton::dnn::PARTIAL_TUNING);}, stream);
double nanosec = triton::tools::bench([&]() { dot.enqueue(stream, {da, db, dc}, triton::dnn::PARTIAL_TUNING);}, stream);
delete dc;
delete da;
delete db;
return result;
return dot.num_flops() / nanosec * 1e-3;
}
int main() {
@@ -53,20 +54,28 @@ int main() {
int32_t M;
int32_t N;
int32_t K;
std::string repr() {
std::ostringstream oss;
oss << AT << " " << BT << " " << M << " " << N << " " << K;
return oss.str();
}
double perf(triton::driver::context *context){
return do_bench(context, AT, BT, M, N, K);
}
};
// shapes to benchmark
std::vector<config_t> configs = {
{false, false, 4096, 4096, 4096},
{false, true, 4096, 4096, 4096},
{true, false, 4096, 4096, 4096},
{true, true, 4096, 4096, 4096}
{false, true, 4096, 4096, 4096},
{true, false, 4096, 4096, 4096},
{true, true, 4096, 4096, 4096}
};
// initialize default compute device
auto context = triton::driver::backend::contexts::get_default();
// does the work
for(config_t c: configs){
double tns = bench(context, c.AT, c.BT, c.M, c.N, c.K);
double tflops = 2.*c.M*c.N*c.K / tns * 1e-3;
std::cout << c.AT << ", " << c.BT << ", " << c.M << ", " << c.N << ", " << c.K << ", " << tflops << std::endl;
std::cout << c.repr() << ", " << c.perf(context) << std::endl;
}
}

View File

@@ -10,7 +10,7 @@
double do_bench(triton::driver::context* context,
int32_t R, int32_t S, int32_t B, int32_t F, int32_t H, int32_t W, int32_t C,
triton::dnn::shift::op_t op, triton::dnn::shift::layout_t layout,
triton::dnn::op_t op, triton::dnn::layout_t layout,
std::string numeric_t) {
typedef float NumericT;
@@ -25,14 +25,14 @@ double do_bench(triton::driver::context* context,
triton::dnn::shift shift(B, C, 1, H, W, 1, R, S, F, 1, 1,
shift_h.data(), shift_w.data(),
numeric_t, numeric_t,
op, false, triton::dnn::shift::CHWN);
op, false, layout);
// host buffers
size_t a_size = B*C*H*W;
size_t b_size = C*F;
size_t c_size = B*F*H*W;
if(op == triton::dnn::shift::BPROP)
if(op == triton::dnn::BPROP)
std::swap(a_size, c_size);
if(op == triton::dnn::shift::WGRAD){
if(op == triton::dnn::WGRAD){
std::swap(b_size, c_size);
std::swap(a_size, b_size);
}
@@ -58,20 +58,57 @@ double do_bench(triton::driver::context* context,
stream->write(db, true, 0, hb);
stream->write(dc, true, 0, hc);
stream->synchronize();
shift.enqueue(stream, {da, db, dc}, true);
double tns = triton::tools::bench([&]() { shift.enqueue(stream, {da, db, dc}, true);}, stream);
std::cout << tns << std::endl;
double nanosec = triton::tools::bench([&]() { shift.enqueue(stream, {da, db, dc});}, stream);
return shift.num_flops() / nanosec * 1e-3;
}
int main() {
using triton::dnn::op_t;
using triton::dnn::layout_t;
struct config_t{
int32_t B;
int32_t C;
int32_t H;
int32_t W;
int32_t R;
int32_t S;
int32_t F;
int32_t stride_h;
int32_t stride_w;
op_t op;
layout_t layout;
std::string ty;
std::string repr() {
std::ostringstream oss;
oss << B << ", " << C << ", " << H << ", " << W << ", " << R << ", " << S << ", " << F << ", " << op << ", " << layout << ", " << ty;
return oss.str();
}
double perf(triton::driver::context *context){
return do_bench(context, R, S, B, F, H, W, C, op, layout, ty);
}
};
// shapes to benchmark
std::vector<config_t> configs;
std::vector<config_t> resnet18 = {
{128, 128, 32, 32, 3, 3, 128, 1, 1},
{128, 128, 32, 32, 3, 3, 256, 2, 2},
{128, 256, 16, 16, 3, 3, 256, 1, 1},
{128, 256, 16, 16, 3, 3, 512, 2, 2},
{128, 512, 8, 8, 3, 3, 512, 1, 1},
{128, 512, 8, 8, 3, 3, 1024, 1, 1},
{128, 1024, 8, 8, 3, 3, 1024, 1, 1}
};
for(config_t c: resnet18){
for(op_t op: {op_t::FPROP, op_t::BPROP, op_t::WGRAD})
configs.push_back({c.B, c.C, c.H, c.W, c.R, c.S, c.F, c.stride_h, c.stride_w, op, layout_t::CHWN, "fp16"});
}
// initialize default compute device
auto context = triton::driver::backend::contexts::get_default();
// shapes
int32_t R = 3, S = 3;
int32_t B = 16, F = 4096;
int32_t H = 32, W = 32;
int32_t C = 4096;
// benchmark
do_bench(context, R, S, B, F, H, W, C, triton::dnn::shift::FPROP, triton::dnn::shift::CHWN, "fp16");
for(config_t c: configs)
std::cout << c.repr() << ", " << c.perf(context) << std::endl;
}

View File

@@ -11,14 +11,14 @@
void extract_shapes(const torch::Tensor &x,
int64_t &C, int64_t &H, int64_t &W, int64_t &B,
triton::dnn::shift::layout_t layout) {
if(layout == triton::dnn::shift::CHWN){
triton::dnn::layout_t layout) {
if(layout == triton::dnn::CHWN){
C = x.size(0);
H = x.size(1);
W = x.size(2);
B = x.size(3);
}
else if(layout == triton::dnn::shift::NCHW){
else if(layout == triton::dnn::NCHW){
B = x.size(0);
C = x.size(1);
H = x.size(2);
@@ -29,14 +29,14 @@ void extract_shapes(const torch::Tensor &x,
}
}
static const triton::dnn::shift::layout_t layout = triton::dnn::shift::NCHW;
static const triton::dnn::layout_t layout = triton::dnn::NCHW;
torch::Tensor shift_common(
int32_t B, int32_t C, int32_t D, int32_t H, int32_t W,
int32_t T, int32_t R, int32_t S, int32_t F,
int32_t stride_h, int32_t stride_w,
int32_t* shift_h, int32_t* shift_w,
triton::dnn::shift::op_t ty, triton::dnn::shift::layout_t layout,
triton::dnn::op_t op, triton::dnn::layout_t layout,
torch::Tensor torcha, torch::Tensor torchb, torch::Tensor torchbias,
bool autotune = false
) {
@@ -59,7 +59,7 @@ torch::Tensor shift_common(
triton::dnn::shift shift(B, C, D, H, W, T, R, S, F,
stride_h, stride_w,
shift_h, shift_w, dtype, dtype,
ty, has_bias, layout);
op, has_bias, layout);
// Bind memory
triton::driver::cu_buffer a(ctx, (CUdeviceptr)torcha.storage().data(), false);
triton::driver::cu_buffer b(ctx, (CUdeviceptr)torchb.storage().data(), false);
@@ -74,8 +74,9 @@ torch::Tensor shift_common(
triton::driver::cu_buffer c(ctx, (CUdeviceptr)torchc.storage().data(), false);
std::cout << B << ", " << C << ", " << H << ", " << W << ", " << T << ", " << R << ", " << S << ", " << F << ", " << stride_h << ", " << stride_w << ", " << op << ", " << layout << std::endl;
// Enqueue
shift.enqueue(&stream, {&a, &b, &c}, true);
shift.enqueue(&stream, {&a, &b, &c}, triton::dnn::NO_TUNING);
return torchc;
}
@@ -99,7 +100,7 @@ torch::Tensor shift_y(
// run
return shift_common(B, C, 1, H, W, 1, R, S, F, stride_h, stride_w,
(int32_t*)shift_h.storage().data(), (int32_t*)shift_w.storage().data(),
triton::dnn::shift::FPROP, layout, x, w, bias);
triton::dnn::FPROP, layout, x, w, bias);
}
torch::Tensor shift_dx(
@@ -127,7 +128,7 @@ torch::Tensor shift_dx(
// run
return shift_common(B, C, 1, H, W, 1, R, S, F, stride_h, stride_w,
(int32_t*)shift_h.storage().data(), (int32_t*)shift_w.storage().data(),
triton::dnn::shift::BPROP, layout, dy, w, bias);
triton::dnn::BPROP, layout, dy, w, bias);
}
torch::Tensor shift_dw(
@@ -155,7 +156,7 @@ torch::Tensor shift_dw(
// run
return shift_common(B, C, 1, H, W, 1, R, S, F, stride_h, stride_w,
(int32_t*)shift_h.storage().data(), (int32_t*)shift_w.storage().data(),
triton::dnn::shift::WGRAD, layout, dy, x, bias);
triton::dnn::WGRAD, layout, dy, x, bias);
}
static auto registry =

View File

@@ -14,8 +14,6 @@ private:
void enqueue_impl(driver::stream *stream, driver::kernel *kernel,
std::vector<driver::buffer*> args,
triton::runtime::launch_information info);
// number of flops
size_t num_flops() const;
// comparison for maps
bool operator<(const base& other) const;
// default parameters
@@ -27,6 +25,9 @@ public:
std::string a_ty, std::string b_ty,
unsigned alignment_lda, unsigned alignment_ldb);
// number of flops
size_t num_flops() const;
// triton-c source
void triton_c_src(std::ostream &os) const;

View File

@@ -0,0 +1,109 @@
#ifndef TRITON_DNN_HEURISTICS_H
#define TRITON_DNN_HEURISTICS_H
#include <vector>
#include "triton/dnn/base.h"
namespace triton{
namespace dnn{
typedef std::vector<unsigned> params_t;
typedef std::tuple<bool, bool> trans_key_t;
typedef std::tuple<size_t, size_t> size_key_t;
static const std::map<trans_key_t, std::map<size_key_t, params_t>> params = {
/* NN */
{trans_key_t(false, false), std::map<size_key_t, params_t>{
{size_key_t(16, 16), {4, 4, 16, 8, 16, 2, 2, 1, 1, 8, 32, 4, 8, 1}},
{size_key_t(16, 32), {2, 8, 16, 8, 32, 2, 2, 1, 1, 16, 32, 4, 8, 1}},
{size_key_t(16, 64), {4, 4, 16, 4, 64, 2, 2, 1, 1, 8, 32, 8, 4, 1}},
{size_key_t(16, 128), {4, 4, 16, 16, 128, 2, 2, 1, 2, 16, 32, 4, 8, 1}},
{size_key_t(32, 16), {4, 8, 32, 8, 16, 2, 2, 1, 1, 8, 32, 4, 8, 1}},
{size_key_t(32, 32), {4, 8, 32, 8, 32, 2, 2, 1, 1, 8, 32, 4, 8, 1}},
{size_key_t(32, 64), {8, 4, 32, 8, 64, 2, 2, 1, 1, 4, 32, 4, 8, 1}},
{size_key_t(32, 128), {8, 4, 32, 16, 128, 2, 2, 1, 4, 16, 32, 8, 4, 1}},
{size_key_t(64, 16), {8, 8, 64, 4, 16, 2, 2, 1, 1, 4, 32, 8, 4, 1}},
{size_key_t(64, 32), {8, 8, 64, 8, 32, 2, 2, 1, 1, 4, 32, 4, 8, 1}},
{size_key_t(64, 64), {8, 8, 64, 16, 64, 2, 2, 2, 1, 8, 32, 4, 8, 1}},
{size_key_t(64, 128), {16, 4, 64, 16, 128, 2, 2, 2, 2, 8, 32, 8, 4, 1}},
{size_key_t(128, 16), {8, 8, 128, 8, 16, 2, 2, 2, 1, 8, 32, 8, 4, 1}},
{size_key_t(128, 32), {8, 8, 128, 16, 32, 2, 2, 2, 1, 8, 32, 4, 8, 1}},
{size_key_t(128, 64), {8, 8, 128, 32, 64, 2, 2, 2, 2, 16, 32, 4, 8, 1}},
{size_key_t(128, 128), {8, 8, 128, 32, 128, 2, 2, 1, 4, 16, 32, 4, 8, 1}}
}},
/* NT */
{trans_key_t(false, true), std::map<size_key_t, params_t>{
{size_key_t(16, 16), {4, 4, 16, 2, 8, 16, 2, 2, 1, 1, 8, 32, 16, 1}},
{size_key_t(16, 32), {4, 4, 16, 4, 8, 32, 2, 2, 1, 1, 8, 32, 8, 1}},
{size_key_t(16, 64), {4, 4, 16, 8, 8, 64, 2, 2, 1, 4, 32, 32, 16, 1}},
{size_key_t(16, 128), {4, 4, 16, 32, 4, 128, 2, 2, 1, 2, 16, 32, 2, 1}},
{size_key_t(32, 16), {8, 4, 32, 2, 8, 16, 2, 2, 1, 1, 4, 32, 16, 1}},
{size_key_t(32, 32), {4, 8, 32, 4, 8, 32, 2, 2, 1, 1, 8, 32, 8, 1}},
{size_key_t(32, 64), {16, 8, 128, 4, 4, 64, 2, 2, 1, 4, 8, 32, 32, 1}},
{size_key_t(32, 128), {4, 8, 32, 8, 8, 128, 2, 2, 1, 2, 16, 32, 8, 1}},
{size_key_t(64, 16), {8, 8, 64, 2, 8, 16, 2, 2, 1, 1, 4, 32, 16, 1}},
{size_key_t(64, 32), {8, 8, 64, 4, 8, 32, 2, 2, 1, 1, 4, 32, 8, 1}},
{size_key_t(64, 64), {8, 8, 64, 8, 8, 64, 2, 2, 1, 2, 8, 32, 8, 1}},
{size_key_t(64, 128), {8, 8, 64, 16, 8, 128, 2, 2, 1, 4, 16, 32, 8, 1}},
{size_key_t(128, 16), {8, 8, 128, 2, 8, 16, 2, 2, 2, 1, 8, 32, 32, 1}},
{size_key_t(128, 32), {16, 8, 128, 4, 8, 32, 2, 2, 2, 1, 4, 32, 16, 1}},
{size_key_t(128, 64), {8, 8, 128, 8, 8, 64, 2, 2, 2, 2, 16, 32, 16, 1}},
{size_key_t(128, 128), {8, 8, 128, 8, 8, 128, 2, 2, 4, 1, 16, 32, 16, 1}}
}},
/* TN */
{trans_key_t(true, false), std::map<size_key_t, params_t>{
{size_key_t(16, 16), {8, 16, 16, 16, 2, 2, 1, 1, 4, 8, 32, 2, 8, 1}},
{size_key_t(16, 32), {4, 16, 8, 32, 2, 2, 1, 1, 8, 4, 32, 4, 8, 1}},
{size_key_t(16, 64), {4, 16, 4, 64, 2, 2, 1, 1, 8, 4, 32, 8, 4, 1}},
{size_key_t(16, 128), {16, 16, 16, 128, 2, 2, 1, 2, 4, 8, 32, 4, 8, 1}},
{size_key_t(32, 16), {4, 32, 8, 16, 2, 2, 1, 1, 8, 4, 32, 4, 8, 1}},
{size_key_t(32, 32), {8, 32, 8, 32, 2, 2, 1, 1, 4, 8, 32, 4, 8, 1}},
{size_key_t(32, 64), {8, 32, 8, 64, 2, 2, 1, 1, 4, 8, 32, 4, 8, 1}},
{size_key_t(32, 128), {32, 32, 64, 128, 2, 2, 2, 2, 4, 8, 32, 2, 8, 1}},
{size_key_t(64, 16), {8, 64, 8, 16, 2, 2, 1, 1, 4, 8, 32, 4, 8, 1}},
{size_key_t(64, 32), {8, 64, 8, 32, 2, 2, 1, 1, 4, 8, 32, 4, 8, 1}},
{size_key_t(64, 64), {16, 64, 16, 64, 2, 2, 2, 1, 4, 8, 32, 4, 8, 1}},
{size_key_t(64, 128), {32, 64, 16, 128, 2, 2, 2, 2, 4, 8, 32, 8, 4, 1}},
{size_key_t(128, 16), {16, 128, 16, 16, 2, 2, 2, 1, 4, 8, 32, 4, 8, 1}},
{size_key_t(128, 32), {32, 128, 32, 32, 2, 2, 4, 1, 4, 8, 32, 4, 8, 1}},
{size_key_t(128, 64), {32, 128, 32, 64, 2, 2, 4, 1, 4, 8, 32, 4, 8, 1}},
{size_key_t(128, 128), {32, 128, 32, 128, 2, 2, 4, 1, 4, 8, 32, 4, 8, 1}},
}},
/* TT */
{trans_key_t(true, true), std::map<size_key_t, params_t>{
{size_key_t(16, 16), {4, 16, 2, 8, 16, 2, 2, 1, 1, 8, 4, 32, 16, 1}},
{size_key_t(16, 32), {8, 16, 4, 8, 32, 2, 2, 1, 1, 4, 8, 32, 8, 1}},
{size_key_t(16, 64), {16, 16, 4, 8, 64, 2, 2, 1, 4, 8, 4, 32, 32, 1}},
{size_key_t(16, 128), {16, 16, 8, 4, 128, 2, 2, 1, 2, 4, 8, 32, 8, 1}},
{size_key_t(32, 16), {4, 32, 2, 8, 16, 2, 2, 1, 1, 8, 4, 32, 16, 1}},
{size_key_t(32, 32), {8, 32, 4, 8, 32, 2, 2, 1, 1, 4, 8, 32, 8, 1}},
{size_key_t(32, 64), {16, 64, 4, 8, 64, 2, 2, 2, 1, 4, 8, 32, 16, 1}},
{size_key_t(32, 128), {32, 32, 8, 8, 128, 2, 2, 1, 4, 4, 8, 32, 16, 1}},
{size_key_t(64, 16), {8, 64, 2, 8, 16, 2, 2, 1, 1, 4, 8, 32, 16, 1}},
{size_key_t(64, 32), {8, 64, 4, 8, 32, 2, 2, 1, 1, 4, 8, 32, 8, 1}},
{size_key_t(64, 64), {16, 64, 8, 8, 64, 2, 2, 2, 1, 4, 8, 32, 8, 1}},
{size_key_t(64, 128), {32, 64, 8, 8, 128, 2, 2, 1, 4, 4, 8, 32, 16, 1}},
{size_key_t(128, 16), {16, 128, 2, 8, 16, 2, 2, 2, 1, 4, 8, 32, 32, 1}},
{size_key_t(128, 32), {32, 128, 8, 4, 32, 2, 2, 4, 1, 4, 8, 32, 16, 1}},
{size_key_t(128, 64), {32, 128, 16, 4, 64, 2, 2, 4, 1, 4, 8, 32, 8, 1}},
{size_key_t(128, 128), {32, 128, 8, 8, 128, 2, 2, 4, 1, 4, 8, 32, 16, 1}}
}}
};
// small search space for partial auto-tuning
inline std::vector<params_t> dot_search_space(bool AT, bool BT) {
std::vector<params_t> result;
for(auto x: params.at(trans_key_t{AT, BT}))
result.push_back(x.second);
return result;
}
// simple parameter heuristics
inline params_t dot_heuristics(bool AT, bool BT, size_t M, size_t N, size_t K) {
size_t TM = 128;
size_t TN = 128;
return params.at(trans_key_t{AT, BT}).at(size_key_t{TM, TN});
}
}
}
#endif

View File

@@ -35,20 +35,18 @@
namespace triton{
namespace dnn{
enum op_t {
FPROP,
BPROP,
WGRAD
};
enum layout_t {
NCHW,
CHWN
};
class shift: public base {
public:
enum op_t {
FPROP,
BPROP,
WGRAD
};
enum layout_t {
NCHW,
CHWN
};
private:
// initialize and enqueue
void init_impl(driver::stream *stream, driver::cu_module *module);
@@ -56,7 +54,8 @@ private:
void enqueue_impl(driver::stream *stream, driver::kernel *kernel,
std::vector<driver::buffer*> args,
triton::runtime::launch_information info);
std::vector<unsigned> default_params() const;
std::vector<params_t> search_space() const;
params_t heuristics() const;
public:

View File

@@ -66,8 +66,9 @@ void base::enqueue(driver::stream *stream, std::vector<driver::buffer *> args, a
clone->init_impl(stream, (triton::driver::cu_module*)kernel->module());
}
/* retrieved compiled template */
else
else{
jit = m_jit.at(this).get();
}
/* get launch parameters */
driver::kernel* kernel = jit->get_function(name_.c_str());

View File

@@ -1,6 +1,7 @@
#include "triton/driver/stream.h"
#include "triton/driver/kernel.h"
#include "triton/dnn/gemm.h"
#include "triton/dnn/heuristics.h"
#include <string>
namespace triton{
@@ -147,99 +148,12 @@ void matmul(restrict read_only align(16) )" + a_ty_ + R"( *A,
// small search space for partial auto-tuning
std::vector<params_t> dot::search_space() const {
typedef std::vector<unsigned> params_t;
typedef std::tuple<size_t, size_t> key_t;
static std::vector<key_t> keys = {
{16, 16}, {16, 32}, {16, 64}, {16, 128},
{32, 16}, {32, 32}, {32, 64}, {32, 128},
{64, 16}, {64, 32}, {64, 64}, {64, 128},
{128, 16},{128, 32},{128, 64},{128, 128}
};
static std::vector<params_t> space_nn = {
{4, 4, 16, 8, 16, 2, 2, 1, 1, 8, 32, 4, 8, 1},
{2, 8, 16, 8, 32, 2, 2, 1, 1, 16, 32, 4, 8, 1},
{4, 4, 16, 4, 64, 2, 2, 1, 1, 8, 32, 8, 4, 1},
{4, 4, 16, 16, 128, 2, 2, 1, 2, 16, 32, 4, 8, 1},
{4, 8, 32, 8, 16, 2, 2, 1, 1, 8, 32, 4, 8, 1},
{4, 8, 32, 8, 32, 2, 2, 1, 1, 8, 32, 4, 8, 1},
{8, 4, 32, 8, 64, 2, 2, 1, 1, 4, 32, 4, 8, 1},
{8, 4, 32, 16, 128, 2, 2, 1, 4, 16, 32, 8, 4, 1},
{8, 8, 64, 4, 16, 2, 2, 1, 1, 4, 32, 8, 4, 1},
{8, 8, 64, 8, 32, 2, 2, 1, 1, 4, 32, 4, 8, 1},
{8, 8, 64, 16, 64, 2, 2, 2, 1, 8, 32, 4, 8, 1},
{16, 4, 64, 16, 128, 2, 2, 2, 2, 8, 32, 8, 4, 1},
{8, 8, 128, 8, 16, 2, 2, 2, 1, 8, 32, 8, 4, 1},
{8, 8, 128, 16, 32, 2, 2, 2, 1, 8, 32, 4, 8, 1},
{8, 8, 128, 32, 64, 2, 2, 2, 2, 16, 32, 4, 8, 1},
{8, 8, 128, 32, 128, 2, 2, 1, 4, 16, 32, 4, 8, 1},
};
static std::vector<params_t> space_nt = {
{4, 4, 16, 2, 8, 16, 2, 2, 1, 1, 8, 32, 16, 1},
{4, 4, 16, 4, 8, 32, 2, 2, 1, 1, 8, 32, 8, 1},
{4, 4, 16, 8, 8, 64, 2, 2, 1, 4, 32, 32, 16, 1},
{4, 4, 16, 32, 4, 128, 2, 2, 1, 2, 16, 32, 2, 1},
{8, 4, 32, 2, 8, 16, 2, 2, 1, 1, 4, 32, 16, 1},
{4, 8, 32, 4, 8, 32, 2, 2, 1, 1, 8, 32, 8, 1},
{16, 8, 128, 4, 4, 64, 2, 2, 1, 4, 8, 32, 32, 1},
{4, 8, 32, 8, 8, 128, 2, 2, 1, 2, 16, 32, 8, 1},
{8, 8, 64, 2, 8, 16, 2, 2, 1, 1, 4, 32, 16, 1},
{8, 8, 64, 4, 8, 32, 2, 2, 1, 1, 4, 32, 8, 1},
{8, 8, 64, 8, 8, 64, 2, 2, 1, 2, 8, 32, 8, 1},
{8, 8, 64, 16, 8, 128, 2, 2, 1, 4, 16, 32, 8, 1},
{8, 8, 128, 2, 8, 16, 2, 2, 2, 1, 8, 32, 32, 1},
{16, 8, 128, 4, 8, 32, 2, 2, 2, 1, 4, 32, 16, 1},
{8, 8, 128, 8, 8, 64, 2, 2, 2, 2, 16, 32, 16, 1},
{8, 8, 128, 8, 8, 128, 2, 2, 4, 1, 16, 32, 16, 1},
};
static std::vector<params_t> space_tn = {
{8, 16, 16, 16, 2, 2, 1, 1, 4, 8, 32, 2, 8, 1},
{4, 16, 8, 32, 2, 2, 1, 1, 8, 4, 32, 4, 8, 1},
{4, 16, 4, 64, 2, 2, 1, 1, 8, 4, 32, 8, 4, 1},
{16, 16, 16, 128, 2, 2, 1, 2, 4, 8, 32, 4, 8, 1},
{4, 32, 8, 16, 2, 2, 1, 1, 8, 4, 32, 4, 8, 1},
{8, 32, 8, 32, 2, 2, 1, 1, 4, 8, 32, 4, 8, 1},
{8, 32, 8, 64, 2, 2, 1, 1, 4, 8, 32, 4, 8, 1},
{32, 32, 64, 128, 2, 2, 2, 2, 4, 8, 32, 2, 8, 1},
{8, 64, 8, 16, 2, 2, 1, 1, 4, 8, 32, 4, 8, 1},
{8, 64, 8, 32, 2, 2, 1, 1, 4, 8, 32, 4, 8, 1},
{16, 64, 16, 64, 2, 2, 2, 1, 4, 8, 32, 4, 8, 1},
{32, 64, 16, 128, 2, 2, 2, 2, 4, 8, 32, 8, 4, 1},
{16, 128, 16, 16, 2, 2, 2, 1, 4, 8, 32, 4, 8, 1},
{32, 128, 32, 32, 2, 2, 4, 1, 4, 8, 32, 4, 8, 1},
{32, 128, 32, 64, 2, 2, 4, 1, 4, 8, 32, 4, 8, 1},
{32, 128, 32, 128, 2, 2, 4, 1, 4, 8, 32, 4, 8, 1},
};
static std::vector<params_t> space_tt = {
{4, 16, 2, 8, 16, 2, 2, 1, 1, 8, 4, 32, 16, 1},
{8, 16, 4, 8, 32, 2, 2, 1, 1, 4, 8, 32, 8, 1},
{16, 16, 4, 8, 64, 2, 2, 1, 4, 8, 4, 32, 32, 1},
{16, 16, 8, 4, 128, 2, 2, 1, 2, 4, 8, 32, 8, 1},
{4, 32, 2, 8, 16, 2, 2, 1, 1, 8, 4, 32, 16, 1},
{8, 32, 4, 8, 32, 2, 2, 1, 1, 4, 8, 32, 8, 1},
{16, 64, 4, 8, 64, 2, 2, 2, 1, 4, 8, 32, 16, 1},
{32, 32, 8, 8, 128, 2, 2, 1, 4, 4, 8, 32, 16, 1},
{8, 64, 2, 8, 16, 2, 2, 1, 1, 4, 8, 32, 16, 1},
{8, 64, 4, 8, 32, 2, 2, 1, 1, 4, 8, 32, 8, 1},
{16, 64, 8, 8, 64, 2, 2, 2, 1, 4, 8, 32, 8, 1},
{32, 64, 8, 8, 128, 2, 2, 1, 4, 4, 8, 32, 16, 1},
{16, 128, 2, 8, 16, 2, 2, 2, 1, 4, 8, 32, 32, 1},
{32, 128, 8, 4, 32, 2, 2, 4, 1, 4, 8, 32, 16, 1},
{32, 128, 16, 4, 64, 2, 2, 4, 1, 4, 8, 32, 8, 1},
{32, 128, 8, 8, 128, 2, 2, 4, 1, 4, 8, 32, 16, 1}
};
if(!AT_ && !BT_)
return space_nn;
else if(!AT_ && BT_)
return space_nt;
else if(AT_ && !BT_)
return space_tn;
else
return space_tt;
return dot_search_space(AT_, BT_);
}
// simple parameter heuristics
params_t dot::heuristics() const {
return search_space().back();
return dot_heuristics(AT_, BT_, M_, N_, K_);
}
}

View File

@@ -1,5 +1,6 @@
#include <sstream>
#include "triton/dnn/shift.h"
#include "triton/dnn/heuristics.h"
#include "triton/tools/bench.hpp"
namespace triton{
@@ -513,12 +514,14 @@ else{
}
// small search space for partial auto-tuning
std::vector<params_t> shift::search_space() const {
return dot_search_space(AT_, BT_);
}
// simple parameter heuristics
std::vector<unsigned> shift::default_params() const {
typedef std::vector<unsigned> params_t;
std::map<std::tuple<op_t, size_t, size_t>, params_t> params = {
{{}, {}}
};
params_t shift::heuristics() const {
return dot_heuristics(AT_, BT_, M_, N_, K_);
}

View File

@@ -211,9 +211,9 @@ jit::tune_res_t jit::autotune(const char *name, const char *src, benchmark_t ben
best.perf = perf;
best.params = params;
}
for(size_t i = 0; i < params.size(); i++)
std::cout << ((i==0)?"":", ") << params[i] << std::flush;
std::cout << ", " << perf << " [ " << best.perf << " ] " << std::endl;
// for(size_t i = 0; i < params.size(); i++)
// std::cout << ((i==0)?"":", ") << params[i] << std::flush;
// std::cout << ", " << perf << " [ " << best.perf << " ] " << std::endl;
}
};