diff --git a/examples/cpp/dot.cpp b/examples/cpp/dot.cpp index 4c9f51960..a8723f2e2 100644 --- a/examples/cpp/dot.cpp +++ b/examples/cpp/dot.cpp @@ -1,4 +1,5 @@ #include +#include #include #include "triton/runtime/jit.h" #include "triton/driver/backend.h" @@ -16,7 +17,7 @@ void diff(const std::vector& x, const std::vector& y){ std::cout << "Pass!" << std::endl; } -double bench(triton::driver::context* context, bool AT, bool BT, int32_t M, int32_t N, int32_t K){ +double do_bench(triton::driver::context* context, bool AT, bool BT, int32_t M, int32_t N, int32_t K){ typedef float T; std::string ty = "fp16"; size_t dt_nbytes = sizeof(T); @@ -39,11 +40,11 @@ double bench(triton::driver::context* context, bool AT, bool BT, int32_t M, int3 stream->write(dc, true, 0, hc); stream->synchronize(); triton::dnn::dot dot(M, N, K, AT, BT, ty, ty, 8, 8); - double result = triton::tools::bench([&]() { dot.enqueue(stream, {da, db, dc}, triton::dnn::PARTIAL_TUNING);}, stream); + double nanosec = triton::tools::bench([&]() { dot.enqueue(stream, {da, db, dc}, triton::dnn::PARTIAL_TUNING);}, stream); delete dc; delete da; delete db; - return result; + return dot.num_flops() / nanosec * 1e-3; } int main() { @@ -53,20 +54,28 @@ int main() { int32_t M; int32_t N; int32_t K; + + std::string repr() { + std::ostringstream oss; + oss << AT << " " << BT << " " << M << " " << N << " " << K; + return oss.str(); + } + + double perf(triton::driver::context *context){ + return do_bench(context, AT, BT, M, N, K); + } }; // shapes to benchmark std::vector configs = { {false, false, 4096, 4096, 4096}, - {false, true, 4096, 4096, 4096}, - {true, false, 4096, 4096, 4096}, - {true, true, 4096, 4096, 4096} + {false, true, 4096, 4096, 4096}, + {true, false, 4096, 4096, 4096}, + {true, true, 4096, 4096, 4096} }; // initialize default compute device auto context = triton::driver::backend::contexts::get_default(); // does the work for(config_t c: configs){ - double tns = bench(context, c.AT, c.BT, c.M, c.N, c.K); - double tflops = 2.*c.M*c.N*c.K / tns * 1e-3; - std::cout << c.AT << ", " << c.BT << ", " << c.M << ", " << c.N << ", " << c.K << ", " << tflops << std::endl; + std::cout << c.repr() << ", " << c.perf(context) << std::endl; } } diff --git a/examples/cpp/shift.cpp b/examples/cpp/shift.cpp index 388523de2..c4074c722 100644 --- a/examples/cpp/shift.cpp +++ b/examples/cpp/shift.cpp @@ -10,7 +10,7 @@ double do_bench(triton::driver::context* context, int32_t R, int32_t S, int32_t B, int32_t F, int32_t H, int32_t W, int32_t C, - triton::dnn::shift::op_t op, triton::dnn::shift::layout_t layout, + triton::dnn::op_t op, triton::dnn::layout_t layout, std::string numeric_t) { typedef float NumericT; @@ -25,14 +25,14 @@ double do_bench(triton::driver::context* context, triton::dnn::shift shift(B, C, 1, H, W, 1, R, S, F, 1, 1, shift_h.data(), shift_w.data(), numeric_t, numeric_t, - op, false, triton::dnn::shift::CHWN); + op, false, layout); // host buffers size_t a_size = B*C*H*W; size_t b_size = C*F; size_t c_size = B*F*H*W; - if(op == triton::dnn::shift::BPROP) + if(op == triton::dnn::BPROP) std::swap(a_size, c_size); - if(op == triton::dnn::shift::WGRAD){ + if(op == triton::dnn::WGRAD){ std::swap(b_size, c_size); std::swap(a_size, b_size); } @@ -58,20 +58,57 @@ double do_bench(triton::driver::context* context, stream->write(db, true, 0, hb); stream->write(dc, true, 0, hc); stream->synchronize(); - shift.enqueue(stream, {da, db, dc}, true); - double tns = triton::tools::bench([&]() { shift.enqueue(stream, {da, db, dc}, true);}, stream); - std::cout << tns << std::endl; + double nanosec = triton::tools::bench([&]() { shift.enqueue(stream, {da, db, dc});}, stream); + return shift.num_flops() / nanosec * 1e-3; } int main() { + using triton::dnn::op_t; + using triton::dnn::layout_t; + + struct config_t{ + int32_t B; + int32_t C; + int32_t H; + int32_t W; + int32_t R; + int32_t S; + int32_t F; + int32_t stride_h; + int32_t stride_w; + op_t op; + layout_t layout; + std::string ty; + + std::string repr() { + std::ostringstream oss; + oss << B << ", " << C << ", " << H << ", " << W << ", " << R << ", " << S << ", " << F << ", " << op << ", " << layout << ", " << ty; + return oss.str(); + } + + double perf(triton::driver::context *context){ + return do_bench(context, R, S, B, F, H, W, C, op, layout, ty); + } + }; + // shapes to benchmark + std::vector configs; + std::vector resnet18 = { + {128, 128, 32, 32, 3, 3, 128, 1, 1}, + {128, 128, 32, 32, 3, 3, 256, 2, 2}, + {128, 256, 16, 16, 3, 3, 256, 1, 1}, + {128, 256, 16, 16, 3, 3, 512, 2, 2}, + {128, 512, 8, 8, 3, 3, 512, 1, 1}, + {128, 512, 8, 8, 3, 3, 1024, 1, 1}, + {128, 1024, 8, 8, 3, 3, 1024, 1, 1} + }; + for(config_t c: resnet18){ + for(op_t op: {op_t::FPROP, op_t::BPROP, op_t::WGRAD}) + configs.push_back({c.B, c.C, c.H, c.W, c.R, c.S, c.F, c.stride_h, c.stride_w, op, layout_t::CHWN, "fp16"}); + } + // initialize default compute device auto context = triton::driver::backend::contexts::get_default(); - // shapes - int32_t R = 3, S = 3; - int32_t B = 16, F = 4096; - int32_t H = 32, W = 32; - int32_t C = 4096; - // benchmark - do_bench(context, R, S, B, F, H, W, C, triton::dnn::shift::FPROP, triton::dnn::shift::CHWN, "fp16"); + for(config_t c: configs) + std::cout << c.repr() << ", " << c.perf(context) << std::endl; } diff --git a/examples/python/pytorch/shift.cpp b/examples/python/pytorch/shift.cpp index a16c2922e..7c86b227e 100644 --- a/examples/python/pytorch/shift.cpp +++ b/examples/python/pytorch/shift.cpp @@ -11,14 +11,14 @@ void extract_shapes(const torch::Tensor &x, int64_t &C, int64_t &H, int64_t &W, int64_t &B, - triton::dnn::shift::layout_t layout) { - if(layout == triton::dnn::shift::CHWN){ + triton::dnn::layout_t layout) { + if(layout == triton::dnn::CHWN){ C = x.size(0); H = x.size(1); W = x.size(2); B = x.size(3); } - else if(layout == triton::dnn::shift::NCHW){ + else if(layout == triton::dnn::NCHW){ B = x.size(0); C = x.size(1); H = x.size(2); @@ -29,14 +29,14 @@ void extract_shapes(const torch::Tensor &x, } } -static const triton::dnn::shift::layout_t layout = triton::dnn::shift::NCHW; +static const triton::dnn::layout_t layout = triton::dnn::NCHW; torch::Tensor shift_common( int32_t B, int32_t C, int32_t D, int32_t H, int32_t W, int32_t T, int32_t R, int32_t S, int32_t F, int32_t stride_h, int32_t stride_w, int32_t* shift_h, int32_t* shift_w, - triton::dnn::shift::op_t ty, triton::dnn::shift::layout_t layout, + triton::dnn::op_t op, triton::dnn::layout_t layout, torch::Tensor torcha, torch::Tensor torchb, torch::Tensor torchbias, bool autotune = false ) { @@ -59,7 +59,7 @@ torch::Tensor shift_common( triton::dnn::shift shift(B, C, D, H, W, T, R, S, F, stride_h, stride_w, shift_h, shift_w, dtype, dtype, - ty, has_bias, layout); + op, has_bias, layout); // Bind memory triton::driver::cu_buffer a(ctx, (CUdeviceptr)torcha.storage().data(), false); triton::driver::cu_buffer b(ctx, (CUdeviceptr)torchb.storage().data(), false); @@ -74,8 +74,9 @@ torch::Tensor shift_common( triton::driver::cu_buffer c(ctx, (CUdeviceptr)torchc.storage().data(), false); + std::cout << B << ", " << C << ", " << H << ", " << W << ", " << T << ", " << R << ", " << S << ", " << F << ", " << stride_h << ", " << stride_w << ", " << op << ", " << layout << std::endl; // Enqueue - shift.enqueue(&stream, {&a, &b, &c}, true); + shift.enqueue(&stream, {&a, &b, &c}, triton::dnn::NO_TUNING); return torchc; } @@ -99,7 +100,7 @@ torch::Tensor shift_y( // run return shift_common(B, C, 1, H, W, 1, R, S, F, stride_h, stride_w, (int32_t*)shift_h.storage().data(), (int32_t*)shift_w.storage().data(), - triton::dnn::shift::FPROP, layout, x, w, bias); + triton::dnn::FPROP, layout, x, w, bias); } torch::Tensor shift_dx( @@ -127,7 +128,7 @@ torch::Tensor shift_dx( // run return shift_common(B, C, 1, H, W, 1, R, S, F, stride_h, stride_w, (int32_t*)shift_h.storage().data(), (int32_t*)shift_w.storage().data(), - triton::dnn::shift::BPROP, layout, dy, w, bias); + triton::dnn::BPROP, layout, dy, w, bias); } torch::Tensor shift_dw( @@ -155,7 +156,7 @@ torch::Tensor shift_dw( // run return shift_common(B, C, 1, H, W, 1, R, S, F, stride_h, stride_w, (int32_t*)shift_h.storage().data(), (int32_t*)shift_w.storage().data(), - triton::dnn::shift::WGRAD, layout, dy, x, bias); + triton::dnn::WGRAD, layout, dy, x, bias); } static auto registry = diff --git a/include/triton/dnn/gemm.h b/include/triton/dnn/gemm.h index 1e581f6a1..3df8a13a6 100644 --- a/include/triton/dnn/gemm.h +++ b/include/triton/dnn/gemm.h @@ -14,8 +14,6 @@ private: void enqueue_impl(driver::stream *stream, driver::kernel *kernel, std::vector args, triton::runtime::launch_information info); - // number of flops - size_t num_flops() const; // comparison for maps bool operator<(const base& other) const; // default parameters @@ -27,6 +25,9 @@ public: std::string a_ty, std::string b_ty, unsigned alignment_lda, unsigned alignment_ldb); + // number of flops + size_t num_flops() const; + // triton-c source void triton_c_src(std::ostream &os) const; diff --git a/include/triton/dnn/heuristics.h b/include/triton/dnn/heuristics.h new file mode 100644 index 000000000..4962c55f0 --- /dev/null +++ b/include/triton/dnn/heuristics.h @@ -0,0 +1,109 @@ +#ifndef TRITON_DNN_HEURISTICS_H +#define TRITON_DNN_HEURISTICS_H + +#include +#include "triton/dnn/base.h" + +namespace triton{ +namespace dnn{ + +typedef std::vector params_t; +typedef std::tuple trans_key_t; +typedef std::tuple size_key_t; +static const std::map> params = { + /* NN */ + {trans_key_t(false, false), std::map{ + {size_key_t(16, 16), {4, 4, 16, 8, 16, 2, 2, 1, 1, 8, 32, 4, 8, 1}}, + {size_key_t(16, 32), {2, 8, 16, 8, 32, 2, 2, 1, 1, 16, 32, 4, 8, 1}}, + {size_key_t(16, 64), {4, 4, 16, 4, 64, 2, 2, 1, 1, 8, 32, 8, 4, 1}}, + {size_key_t(16, 128), {4, 4, 16, 16, 128, 2, 2, 1, 2, 16, 32, 4, 8, 1}}, + {size_key_t(32, 16), {4, 8, 32, 8, 16, 2, 2, 1, 1, 8, 32, 4, 8, 1}}, + {size_key_t(32, 32), {4, 8, 32, 8, 32, 2, 2, 1, 1, 8, 32, 4, 8, 1}}, + {size_key_t(32, 64), {8, 4, 32, 8, 64, 2, 2, 1, 1, 4, 32, 4, 8, 1}}, + {size_key_t(32, 128), {8, 4, 32, 16, 128, 2, 2, 1, 4, 16, 32, 8, 4, 1}}, + {size_key_t(64, 16), {8, 8, 64, 4, 16, 2, 2, 1, 1, 4, 32, 8, 4, 1}}, + {size_key_t(64, 32), {8, 8, 64, 8, 32, 2, 2, 1, 1, 4, 32, 4, 8, 1}}, + {size_key_t(64, 64), {8, 8, 64, 16, 64, 2, 2, 2, 1, 8, 32, 4, 8, 1}}, + {size_key_t(64, 128), {16, 4, 64, 16, 128, 2, 2, 2, 2, 8, 32, 8, 4, 1}}, + {size_key_t(128, 16), {8, 8, 128, 8, 16, 2, 2, 2, 1, 8, 32, 8, 4, 1}}, + {size_key_t(128, 32), {8, 8, 128, 16, 32, 2, 2, 2, 1, 8, 32, 4, 8, 1}}, + {size_key_t(128, 64), {8, 8, 128, 32, 64, 2, 2, 2, 2, 16, 32, 4, 8, 1}}, + {size_key_t(128, 128), {8, 8, 128, 32, 128, 2, 2, 1, 4, 16, 32, 4, 8, 1}} + }}, + /* NT */ + {trans_key_t(false, true), std::map{ + {size_key_t(16, 16), {4, 4, 16, 2, 8, 16, 2, 2, 1, 1, 8, 32, 16, 1}}, + {size_key_t(16, 32), {4, 4, 16, 4, 8, 32, 2, 2, 1, 1, 8, 32, 8, 1}}, + {size_key_t(16, 64), {4, 4, 16, 8, 8, 64, 2, 2, 1, 4, 32, 32, 16, 1}}, + {size_key_t(16, 128), {4, 4, 16, 32, 4, 128, 2, 2, 1, 2, 16, 32, 2, 1}}, + {size_key_t(32, 16), {8, 4, 32, 2, 8, 16, 2, 2, 1, 1, 4, 32, 16, 1}}, + {size_key_t(32, 32), {4, 8, 32, 4, 8, 32, 2, 2, 1, 1, 8, 32, 8, 1}}, + {size_key_t(32, 64), {16, 8, 128, 4, 4, 64, 2, 2, 1, 4, 8, 32, 32, 1}}, + {size_key_t(32, 128), {4, 8, 32, 8, 8, 128, 2, 2, 1, 2, 16, 32, 8, 1}}, + {size_key_t(64, 16), {8, 8, 64, 2, 8, 16, 2, 2, 1, 1, 4, 32, 16, 1}}, + {size_key_t(64, 32), {8, 8, 64, 4, 8, 32, 2, 2, 1, 1, 4, 32, 8, 1}}, + {size_key_t(64, 64), {8, 8, 64, 8, 8, 64, 2, 2, 1, 2, 8, 32, 8, 1}}, + {size_key_t(64, 128), {8, 8, 64, 16, 8, 128, 2, 2, 1, 4, 16, 32, 8, 1}}, + {size_key_t(128, 16), {8, 8, 128, 2, 8, 16, 2, 2, 2, 1, 8, 32, 32, 1}}, + {size_key_t(128, 32), {16, 8, 128, 4, 8, 32, 2, 2, 2, 1, 4, 32, 16, 1}}, + {size_key_t(128, 64), {8, 8, 128, 8, 8, 64, 2, 2, 2, 2, 16, 32, 16, 1}}, + {size_key_t(128, 128), {8, 8, 128, 8, 8, 128, 2, 2, 4, 1, 16, 32, 16, 1}} + }}, + /* TN */ + {trans_key_t(true, false), std::map{ + {size_key_t(16, 16), {8, 16, 16, 16, 2, 2, 1, 1, 4, 8, 32, 2, 8, 1}}, + {size_key_t(16, 32), {4, 16, 8, 32, 2, 2, 1, 1, 8, 4, 32, 4, 8, 1}}, + {size_key_t(16, 64), {4, 16, 4, 64, 2, 2, 1, 1, 8, 4, 32, 8, 4, 1}}, + {size_key_t(16, 128), {16, 16, 16, 128, 2, 2, 1, 2, 4, 8, 32, 4, 8, 1}}, + {size_key_t(32, 16), {4, 32, 8, 16, 2, 2, 1, 1, 8, 4, 32, 4, 8, 1}}, + {size_key_t(32, 32), {8, 32, 8, 32, 2, 2, 1, 1, 4, 8, 32, 4, 8, 1}}, + {size_key_t(32, 64), {8, 32, 8, 64, 2, 2, 1, 1, 4, 8, 32, 4, 8, 1}}, + {size_key_t(32, 128), {32, 32, 64, 128, 2, 2, 2, 2, 4, 8, 32, 2, 8, 1}}, + {size_key_t(64, 16), {8, 64, 8, 16, 2, 2, 1, 1, 4, 8, 32, 4, 8, 1}}, + {size_key_t(64, 32), {8, 64, 8, 32, 2, 2, 1, 1, 4, 8, 32, 4, 8, 1}}, + {size_key_t(64, 64), {16, 64, 16, 64, 2, 2, 2, 1, 4, 8, 32, 4, 8, 1}}, + {size_key_t(64, 128), {32, 64, 16, 128, 2, 2, 2, 2, 4, 8, 32, 8, 4, 1}}, + {size_key_t(128, 16), {16, 128, 16, 16, 2, 2, 2, 1, 4, 8, 32, 4, 8, 1}}, + {size_key_t(128, 32), {32, 128, 32, 32, 2, 2, 4, 1, 4, 8, 32, 4, 8, 1}}, + {size_key_t(128, 64), {32, 128, 32, 64, 2, 2, 4, 1, 4, 8, 32, 4, 8, 1}}, + {size_key_t(128, 128), {32, 128, 32, 128, 2, 2, 4, 1, 4, 8, 32, 4, 8, 1}}, + }}, + /* TT */ + {trans_key_t(true, true), std::map{ + {size_key_t(16, 16), {4, 16, 2, 8, 16, 2, 2, 1, 1, 8, 4, 32, 16, 1}}, + {size_key_t(16, 32), {8, 16, 4, 8, 32, 2, 2, 1, 1, 4, 8, 32, 8, 1}}, + {size_key_t(16, 64), {16, 16, 4, 8, 64, 2, 2, 1, 4, 8, 4, 32, 32, 1}}, + {size_key_t(16, 128), {16, 16, 8, 4, 128, 2, 2, 1, 2, 4, 8, 32, 8, 1}}, + {size_key_t(32, 16), {4, 32, 2, 8, 16, 2, 2, 1, 1, 8, 4, 32, 16, 1}}, + {size_key_t(32, 32), {8, 32, 4, 8, 32, 2, 2, 1, 1, 4, 8, 32, 8, 1}}, + {size_key_t(32, 64), {16, 64, 4, 8, 64, 2, 2, 2, 1, 4, 8, 32, 16, 1}}, + {size_key_t(32, 128), {32, 32, 8, 8, 128, 2, 2, 1, 4, 4, 8, 32, 16, 1}}, + {size_key_t(64, 16), {8, 64, 2, 8, 16, 2, 2, 1, 1, 4, 8, 32, 16, 1}}, + {size_key_t(64, 32), {8, 64, 4, 8, 32, 2, 2, 1, 1, 4, 8, 32, 8, 1}}, + {size_key_t(64, 64), {16, 64, 8, 8, 64, 2, 2, 2, 1, 4, 8, 32, 8, 1}}, + {size_key_t(64, 128), {32, 64, 8, 8, 128, 2, 2, 1, 4, 4, 8, 32, 16, 1}}, + {size_key_t(128, 16), {16, 128, 2, 8, 16, 2, 2, 2, 1, 4, 8, 32, 32, 1}}, + {size_key_t(128, 32), {32, 128, 8, 4, 32, 2, 2, 4, 1, 4, 8, 32, 16, 1}}, + {size_key_t(128, 64), {32, 128, 16, 4, 64, 2, 2, 4, 1, 4, 8, 32, 8, 1}}, + {size_key_t(128, 128), {32, 128, 8, 8, 128, 2, 2, 4, 1, 4, 8, 32, 16, 1}} + }} +}; + +// small search space for partial auto-tuning +inline std::vector dot_search_space(bool AT, bool BT) { + std::vector result; + for(auto x: params.at(trans_key_t{AT, BT})) + result.push_back(x.second); + return result; +} + +// simple parameter heuristics +inline params_t dot_heuristics(bool AT, bool BT, size_t M, size_t N, size_t K) { + size_t TM = 128; + size_t TN = 128; + return params.at(trans_key_t{AT, BT}).at(size_key_t{TM, TN}); +} + +} +} +#endif diff --git a/include/triton/dnn/shift.h b/include/triton/dnn/shift.h index 59d26ab44..d1f1bc972 100644 --- a/include/triton/dnn/shift.h +++ b/include/triton/dnn/shift.h @@ -35,20 +35,18 @@ namespace triton{ namespace dnn{ +enum op_t { + FPROP, + BPROP, + WGRAD +}; + +enum layout_t { + NCHW, + CHWN +}; + class shift: public base { - -public: - enum op_t { - FPROP, - BPROP, - WGRAD - }; - - enum layout_t { - NCHW, - CHWN - }; - private: // initialize and enqueue void init_impl(driver::stream *stream, driver::cu_module *module); @@ -56,7 +54,8 @@ private: void enqueue_impl(driver::stream *stream, driver::kernel *kernel, std::vector args, triton::runtime::launch_information info); - std::vector default_params() const; + std::vector search_space() const; + params_t heuristics() const; public: diff --git a/lib/dnn/base.cpp b/lib/dnn/base.cpp index 8224dc846..befb7c842 100644 --- a/lib/dnn/base.cpp +++ b/lib/dnn/base.cpp @@ -66,8 +66,9 @@ void base::enqueue(driver::stream *stream, std::vector args, a clone->init_impl(stream, (triton::driver::cu_module*)kernel->module()); } /* retrieved compiled template */ - else + else{ jit = m_jit.at(this).get(); + } /* get launch parameters */ driver::kernel* kernel = jit->get_function(name_.c_str()); diff --git a/lib/dnn/gemm.cpp b/lib/dnn/gemm.cpp index 43fc9f173..33f8273bf 100644 --- a/lib/dnn/gemm.cpp +++ b/lib/dnn/gemm.cpp @@ -1,6 +1,7 @@ #include "triton/driver/stream.h" #include "triton/driver/kernel.h" #include "triton/dnn/gemm.h" +#include "triton/dnn/heuristics.h" #include namespace triton{ @@ -147,99 +148,12 @@ void matmul(restrict read_only align(16) )" + a_ty_ + R"( *A, // small search space for partial auto-tuning std::vector dot::search_space() const { - typedef std::vector params_t; - typedef std::tuple key_t; - static std::vector keys = { - {16, 16}, {16, 32}, {16, 64}, {16, 128}, - {32, 16}, {32, 32}, {32, 64}, {32, 128}, - {64, 16}, {64, 32}, {64, 64}, {64, 128}, - {128, 16},{128, 32},{128, 64},{128, 128} - }; - static std::vector space_nn = { - {4, 4, 16, 8, 16, 2, 2, 1, 1, 8, 32, 4, 8, 1}, - {2, 8, 16, 8, 32, 2, 2, 1, 1, 16, 32, 4, 8, 1}, - {4, 4, 16, 4, 64, 2, 2, 1, 1, 8, 32, 8, 4, 1}, - {4, 4, 16, 16, 128, 2, 2, 1, 2, 16, 32, 4, 8, 1}, - {4, 8, 32, 8, 16, 2, 2, 1, 1, 8, 32, 4, 8, 1}, - {4, 8, 32, 8, 32, 2, 2, 1, 1, 8, 32, 4, 8, 1}, - {8, 4, 32, 8, 64, 2, 2, 1, 1, 4, 32, 4, 8, 1}, - {8, 4, 32, 16, 128, 2, 2, 1, 4, 16, 32, 8, 4, 1}, - {8, 8, 64, 4, 16, 2, 2, 1, 1, 4, 32, 8, 4, 1}, - {8, 8, 64, 8, 32, 2, 2, 1, 1, 4, 32, 4, 8, 1}, - {8, 8, 64, 16, 64, 2, 2, 2, 1, 8, 32, 4, 8, 1}, - {16, 4, 64, 16, 128, 2, 2, 2, 2, 8, 32, 8, 4, 1}, - {8, 8, 128, 8, 16, 2, 2, 2, 1, 8, 32, 8, 4, 1}, - {8, 8, 128, 16, 32, 2, 2, 2, 1, 8, 32, 4, 8, 1}, - {8, 8, 128, 32, 64, 2, 2, 2, 2, 16, 32, 4, 8, 1}, - {8, 8, 128, 32, 128, 2, 2, 1, 4, 16, 32, 4, 8, 1}, - }; - static std::vector space_nt = { - {4, 4, 16, 2, 8, 16, 2, 2, 1, 1, 8, 32, 16, 1}, - {4, 4, 16, 4, 8, 32, 2, 2, 1, 1, 8, 32, 8, 1}, - {4, 4, 16, 8, 8, 64, 2, 2, 1, 4, 32, 32, 16, 1}, - {4, 4, 16, 32, 4, 128, 2, 2, 1, 2, 16, 32, 2, 1}, - {8, 4, 32, 2, 8, 16, 2, 2, 1, 1, 4, 32, 16, 1}, - {4, 8, 32, 4, 8, 32, 2, 2, 1, 1, 8, 32, 8, 1}, - {16, 8, 128, 4, 4, 64, 2, 2, 1, 4, 8, 32, 32, 1}, - {4, 8, 32, 8, 8, 128, 2, 2, 1, 2, 16, 32, 8, 1}, - {8, 8, 64, 2, 8, 16, 2, 2, 1, 1, 4, 32, 16, 1}, - {8, 8, 64, 4, 8, 32, 2, 2, 1, 1, 4, 32, 8, 1}, - {8, 8, 64, 8, 8, 64, 2, 2, 1, 2, 8, 32, 8, 1}, - {8, 8, 64, 16, 8, 128, 2, 2, 1, 4, 16, 32, 8, 1}, - {8, 8, 128, 2, 8, 16, 2, 2, 2, 1, 8, 32, 32, 1}, - {16, 8, 128, 4, 8, 32, 2, 2, 2, 1, 4, 32, 16, 1}, - {8, 8, 128, 8, 8, 64, 2, 2, 2, 2, 16, 32, 16, 1}, - {8, 8, 128, 8, 8, 128, 2, 2, 4, 1, 16, 32, 16, 1}, - }; - static std::vector space_tn = { - {8, 16, 16, 16, 2, 2, 1, 1, 4, 8, 32, 2, 8, 1}, - {4, 16, 8, 32, 2, 2, 1, 1, 8, 4, 32, 4, 8, 1}, - {4, 16, 4, 64, 2, 2, 1, 1, 8, 4, 32, 8, 4, 1}, - {16, 16, 16, 128, 2, 2, 1, 2, 4, 8, 32, 4, 8, 1}, - {4, 32, 8, 16, 2, 2, 1, 1, 8, 4, 32, 4, 8, 1}, - {8, 32, 8, 32, 2, 2, 1, 1, 4, 8, 32, 4, 8, 1}, - {8, 32, 8, 64, 2, 2, 1, 1, 4, 8, 32, 4, 8, 1}, - {32, 32, 64, 128, 2, 2, 2, 2, 4, 8, 32, 2, 8, 1}, - {8, 64, 8, 16, 2, 2, 1, 1, 4, 8, 32, 4, 8, 1}, - {8, 64, 8, 32, 2, 2, 1, 1, 4, 8, 32, 4, 8, 1}, - {16, 64, 16, 64, 2, 2, 2, 1, 4, 8, 32, 4, 8, 1}, - {32, 64, 16, 128, 2, 2, 2, 2, 4, 8, 32, 8, 4, 1}, - {16, 128, 16, 16, 2, 2, 2, 1, 4, 8, 32, 4, 8, 1}, - {32, 128, 32, 32, 2, 2, 4, 1, 4, 8, 32, 4, 8, 1}, - {32, 128, 32, 64, 2, 2, 4, 1, 4, 8, 32, 4, 8, 1}, - {32, 128, 32, 128, 2, 2, 4, 1, 4, 8, 32, 4, 8, 1}, - }; - static std::vector space_tt = { - {4, 16, 2, 8, 16, 2, 2, 1, 1, 8, 4, 32, 16, 1}, - {8, 16, 4, 8, 32, 2, 2, 1, 1, 4, 8, 32, 8, 1}, - {16, 16, 4, 8, 64, 2, 2, 1, 4, 8, 4, 32, 32, 1}, - {16, 16, 8, 4, 128, 2, 2, 1, 2, 4, 8, 32, 8, 1}, - {4, 32, 2, 8, 16, 2, 2, 1, 1, 8, 4, 32, 16, 1}, - {8, 32, 4, 8, 32, 2, 2, 1, 1, 4, 8, 32, 8, 1}, - {16, 64, 4, 8, 64, 2, 2, 2, 1, 4, 8, 32, 16, 1}, - {32, 32, 8, 8, 128, 2, 2, 1, 4, 4, 8, 32, 16, 1}, - {8, 64, 2, 8, 16, 2, 2, 1, 1, 4, 8, 32, 16, 1}, - {8, 64, 4, 8, 32, 2, 2, 1, 1, 4, 8, 32, 8, 1}, - {16, 64, 8, 8, 64, 2, 2, 2, 1, 4, 8, 32, 8, 1}, - {32, 64, 8, 8, 128, 2, 2, 1, 4, 4, 8, 32, 16, 1}, - {16, 128, 2, 8, 16, 2, 2, 2, 1, 4, 8, 32, 32, 1}, - {32, 128, 8, 4, 32, 2, 2, 4, 1, 4, 8, 32, 16, 1}, - {32, 128, 16, 4, 64, 2, 2, 4, 1, 4, 8, 32, 8, 1}, - {32, 128, 8, 8, 128, 2, 2, 4, 1, 4, 8, 32, 16, 1} - }; - if(!AT_ && !BT_) - return space_nn; - else if(!AT_ && BT_) - return space_nt; - else if(AT_ && !BT_) - return space_tn; - else - return space_tt; + return dot_search_space(AT_, BT_); } // simple parameter heuristics params_t dot::heuristics() const { - return search_space().back(); + return dot_heuristics(AT_, BT_, M_, N_, K_); } } diff --git a/lib/dnn/shift.cpp b/lib/dnn/shift.cpp index e8a4f3584..e09ce0a58 100644 --- a/lib/dnn/shift.cpp +++ b/lib/dnn/shift.cpp @@ -1,5 +1,6 @@ #include #include "triton/dnn/shift.h" +#include "triton/dnn/heuristics.h" #include "triton/tools/bench.hpp" namespace triton{ @@ -513,12 +514,14 @@ else{ } +// small search space for partial auto-tuning +std::vector shift::search_space() const { + return dot_search_space(AT_, BT_); +} + // simple parameter heuristics -std::vector shift::default_params() const { - typedef std::vector params_t; - std::map, params_t> params = { - {{}, {}} - }; +params_t shift::heuristics() const { + return dot_heuristics(AT_, BT_, M_, N_, K_); } diff --git a/lib/runtime/jit.cpp b/lib/runtime/jit.cpp index 6f25de8da..90f9a0e4c 100644 --- a/lib/runtime/jit.cpp +++ b/lib/runtime/jit.cpp @@ -211,9 +211,9 @@ jit::tune_res_t jit::autotune(const char *name, const char *src, benchmark_t ben best.perf = perf; best.params = params; } - for(size_t i = 0; i < params.size(); i++) - std::cout << ((i==0)?"":", ") << params[i] << std::flush; - std::cout << ", " << perf << " [ " << best.perf << " ] " << std::endl; +// for(size_t i = 0; i < params.size(); i++) +// std::cout << ((i==0)?"":", ") << params[i] << std::flush; +// std::cout << ", " << perf << " [ " << best.perf << " ] " << std::endl; } };