[general] a bunch of fixes in anticipation of proper triton vs cudnn
benchmarks * DNN: Added partial auto-tuning mode and skeleton for heuristics * Examples: Moduralized benchmarking and now evaluating ResNet-18 shapes
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
#include <cstring>
|
||||
#include <sstream>
|
||||
#include <cstdio>
|
||||
#include "triton/runtime/jit.h"
|
||||
#include "triton/driver/backend.h"
|
||||
@@ -16,7 +17,7 @@ void diff(const std::vector<T>& x, const std::vector<T>& y){
|
||||
std::cout << "Pass!" << std::endl;
|
||||
}
|
||||
|
||||
double bench(triton::driver::context* context, bool AT, bool BT, int32_t M, int32_t N, int32_t K){
|
||||
double do_bench(triton::driver::context* context, bool AT, bool BT, int32_t M, int32_t N, int32_t K){
|
||||
typedef float T;
|
||||
std::string ty = "fp16";
|
||||
size_t dt_nbytes = sizeof(T);
|
||||
@@ -39,11 +40,11 @@ double bench(triton::driver::context* context, bool AT, bool BT, int32_t M, int3
|
||||
stream->write(dc, true, 0, hc);
|
||||
stream->synchronize();
|
||||
triton::dnn::dot dot(M, N, K, AT, BT, ty, ty, 8, 8);
|
||||
double result = triton::tools::bench([&]() { dot.enqueue(stream, {da, db, dc}, triton::dnn::PARTIAL_TUNING);}, stream);
|
||||
double nanosec = triton::tools::bench([&]() { dot.enqueue(stream, {da, db, dc}, triton::dnn::PARTIAL_TUNING);}, stream);
|
||||
delete dc;
|
||||
delete da;
|
||||
delete db;
|
||||
return result;
|
||||
return dot.num_flops() / nanosec * 1e-3;
|
||||
}
|
||||
|
||||
int main() {
|
||||
@@ -53,20 +54,28 @@ int main() {
|
||||
int32_t M;
|
||||
int32_t N;
|
||||
int32_t K;
|
||||
|
||||
std::string repr() {
|
||||
std::ostringstream oss;
|
||||
oss << AT << " " << BT << " " << M << " " << N << " " << K;
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
double perf(triton::driver::context *context){
|
||||
return do_bench(context, AT, BT, M, N, K);
|
||||
}
|
||||
};
|
||||
// shapes to benchmark
|
||||
std::vector<config_t> configs = {
|
||||
{false, false, 4096, 4096, 4096},
|
||||
{false, true, 4096, 4096, 4096},
|
||||
{true, false, 4096, 4096, 4096},
|
||||
{true, true, 4096, 4096, 4096}
|
||||
{false, true, 4096, 4096, 4096},
|
||||
{true, false, 4096, 4096, 4096},
|
||||
{true, true, 4096, 4096, 4096}
|
||||
};
|
||||
// initialize default compute device
|
||||
auto context = triton::driver::backend::contexts::get_default();
|
||||
// does the work
|
||||
for(config_t c: configs){
|
||||
double tns = bench(context, c.AT, c.BT, c.M, c.N, c.K);
|
||||
double tflops = 2.*c.M*c.N*c.K / tns * 1e-3;
|
||||
std::cout << c.AT << ", " << c.BT << ", " << c.M << ", " << c.N << ", " << c.K << ", " << tflops << std::endl;
|
||||
std::cout << c.repr() << ", " << c.perf(context) << std::endl;
|
||||
}
|
||||
}
|
||||
|
@@ -10,7 +10,7 @@
|
||||
|
||||
double do_bench(triton::driver::context* context,
|
||||
int32_t R, int32_t S, int32_t B, int32_t F, int32_t H, int32_t W, int32_t C,
|
||||
triton::dnn::shift::op_t op, triton::dnn::shift::layout_t layout,
|
||||
triton::dnn::op_t op, triton::dnn::layout_t layout,
|
||||
std::string numeric_t) {
|
||||
typedef float NumericT;
|
||||
|
||||
@@ -25,14 +25,14 @@ double do_bench(triton::driver::context* context,
|
||||
triton::dnn::shift shift(B, C, 1, H, W, 1, R, S, F, 1, 1,
|
||||
shift_h.data(), shift_w.data(),
|
||||
numeric_t, numeric_t,
|
||||
op, false, triton::dnn::shift::CHWN);
|
||||
op, false, layout);
|
||||
// host buffers
|
||||
size_t a_size = B*C*H*W;
|
||||
size_t b_size = C*F;
|
||||
size_t c_size = B*F*H*W;
|
||||
if(op == triton::dnn::shift::BPROP)
|
||||
if(op == triton::dnn::BPROP)
|
||||
std::swap(a_size, c_size);
|
||||
if(op == triton::dnn::shift::WGRAD){
|
||||
if(op == triton::dnn::WGRAD){
|
||||
std::swap(b_size, c_size);
|
||||
std::swap(a_size, b_size);
|
||||
}
|
||||
@@ -58,20 +58,57 @@ double do_bench(triton::driver::context* context,
|
||||
stream->write(db, true, 0, hb);
|
||||
stream->write(dc, true, 0, hc);
|
||||
stream->synchronize();
|
||||
shift.enqueue(stream, {da, db, dc}, true);
|
||||
double tns = triton::tools::bench([&]() { shift.enqueue(stream, {da, db, dc}, true);}, stream);
|
||||
std::cout << tns << std::endl;
|
||||
double nanosec = triton::tools::bench([&]() { shift.enqueue(stream, {da, db, dc});}, stream);
|
||||
return shift.num_flops() / nanosec * 1e-3;
|
||||
}
|
||||
|
||||
int main() {
|
||||
using triton::dnn::op_t;
|
||||
using triton::dnn::layout_t;
|
||||
|
||||
struct config_t{
|
||||
int32_t B;
|
||||
int32_t C;
|
||||
int32_t H;
|
||||
int32_t W;
|
||||
int32_t R;
|
||||
int32_t S;
|
||||
int32_t F;
|
||||
int32_t stride_h;
|
||||
int32_t stride_w;
|
||||
op_t op;
|
||||
layout_t layout;
|
||||
std::string ty;
|
||||
|
||||
std::string repr() {
|
||||
std::ostringstream oss;
|
||||
oss << B << ", " << C << ", " << H << ", " << W << ", " << R << ", " << S << ", " << F << ", " << op << ", " << layout << ", " << ty;
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
double perf(triton::driver::context *context){
|
||||
return do_bench(context, R, S, B, F, H, W, C, op, layout, ty);
|
||||
}
|
||||
};
|
||||
// shapes to benchmark
|
||||
std::vector<config_t> configs;
|
||||
std::vector<config_t> resnet18 = {
|
||||
{128, 128, 32, 32, 3, 3, 128, 1, 1},
|
||||
{128, 128, 32, 32, 3, 3, 256, 2, 2},
|
||||
{128, 256, 16, 16, 3, 3, 256, 1, 1},
|
||||
{128, 256, 16, 16, 3, 3, 512, 2, 2},
|
||||
{128, 512, 8, 8, 3, 3, 512, 1, 1},
|
||||
{128, 512, 8, 8, 3, 3, 1024, 1, 1},
|
||||
{128, 1024, 8, 8, 3, 3, 1024, 1, 1}
|
||||
};
|
||||
for(config_t c: resnet18){
|
||||
for(op_t op: {op_t::FPROP, op_t::BPROP, op_t::WGRAD})
|
||||
configs.push_back({c.B, c.C, c.H, c.W, c.R, c.S, c.F, c.stride_h, c.stride_w, op, layout_t::CHWN, "fp16"});
|
||||
}
|
||||
|
||||
// initialize default compute device
|
||||
auto context = triton::driver::backend::contexts::get_default();
|
||||
// shapes
|
||||
int32_t R = 3, S = 3;
|
||||
int32_t B = 16, F = 4096;
|
||||
int32_t H = 32, W = 32;
|
||||
int32_t C = 4096;
|
||||
// benchmark
|
||||
do_bench(context, R, S, B, F, H, W, C, triton::dnn::shift::FPROP, triton::dnn::shift::CHWN, "fp16");
|
||||
for(config_t c: configs)
|
||||
std::cout << c.repr() << ", " << c.perf(context) << std::endl;
|
||||
|
||||
}
|
||||
|
@@ -11,14 +11,14 @@
|
||||
|
||||
void extract_shapes(const torch::Tensor &x,
|
||||
int64_t &C, int64_t &H, int64_t &W, int64_t &B,
|
||||
triton::dnn::shift::layout_t layout) {
|
||||
if(layout == triton::dnn::shift::CHWN){
|
||||
triton::dnn::layout_t layout) {
|
||||
if(layout == triton::dnn::CHWN){
|
||||
C = x.size(0);
|
||||
H = x.size(1);
|
||||
W = x.size(2);
|
||||
B = x.size(3);
|
||||
}
|
||||
else if(layout == triton::dnn::shift::NCHW){
|
||||
else if(layout == triton::dnn::NCHW){
|
||||
B = x.size(0);
|
||||
C = x.size(1);
|
||||
H = x.size(2);
|
||||
@@ -29,14 +29,14 @@ void extract_shapes(const torch::Tensor &x,
|
||||
}
|
||||
}
|
||||
|
||||
static const triton::dnn::shift::layout_t layout = triton::dnn::shift::NCHW;
|
||||
static const triton::dnn::layout_t layout = triton::dnn::NCHW;
|
||||
|
||||
torch::Tensor shift_common(
|
||||
int32_t B, int32_t C, int32_t D, int32_t H, int32_t W,
|
||||
int32_t T, int32_t R, int32_t S, int32_t F,
|
||||
int32_t stride_h, int32_t stride_w,
|
||||
int32_t* shift_h, int32_t* shift_w,
|
||||
triton::dnn::shift::op_t ty, triton::dnn::shift::layout_t layout,
|
||||
triton::dnn::op_t op, triton::dnn::layout_t layout,
|
||||
torch::Tensor torcha, torch::Tensor torchb, torch::Tensor torchbias,
|
||||
bool autotune = false
|
||||
) {
|
||||
@@ -59,7 +59,7 @@ torch::Tensor shift_common(
|
||||
triton::dnn::shift shift(B, C, D, H, W, T, R, S, F,
|
||||
stride_h, stride_w,
|
||||
shift_h, shift_w, dtype, dtype,
|
||||
ty, has_bias, layout);
|
||||
op, has_bias, layout);
|
||||
// Bind memory
|
||||
triton::driver::cu_buffer a(ctx, (CUdeviceptr)torcha.storage().data(), false);
|
||||
triton::driver::cu_buffer b(ctx, (CUdeviceptr)torchb.storage().data(), false);
|
||||
@@ -74,8 +74,9 @@ torch::Tensor shift_common(
|
||||
|
||||
|
||||
triton::driver::cu_buffer c(ctx, (CUdeviceptr)torchc.storage().data(), false);
|
||||
std::cout << B << ", " << C << ", " << H << ", " << W << ", " << T << ", " << R << ", " << S << ", " << F << ", " << stride_h << ", " << stride_w << ", " << op << ", " << layout << std::endl;
|
||||
// Enqueue
|
||||
shift.enqueue(&stream, {&a, &b, &c}, true);
|
||||
shift.enqueue(&stream, {&a, &b, &c}, triton::dnn::NO_TUNING);
|
||||
return torchc;
|
||||
}
|
||||
|
||||
@@ -99,7 +100,7 @@ torch::Tensor shift_y(
|
||||
// run
|
||||
return shift_common(B, C, 1, H, W, 1, R, S, F, stride_h, stride_w,
|
||||
(int32_t*)shift_h.storage().data(), (int32_t*)shift_w.storage().data(),
|
||||
triton::dnn::shift::FPROP, layout, x, w, bias);
|
||||
triton::dnn::FPROP, layout, x, w, bias);
|
||||
}
|
||||
|
||||
torch::Tensor shift_dx(
|
||||
@@ -127,7 +128,7 @@ torch::Tensor shift_dx(
|
||||
// run
|
||||
return shift_common(B, C, 1, H, W, 1, R, S, F, stride_h, stride_w,
|
||||
(int32_t*)shift_h.storage().data(), (int32_t*)shift_w.storage().data(),
|
||||
triton::dnn::shift::BPROP, layout, dy, w, bias);
|
||||
triton::dnn::BPROP, layout, dy, w, bias);
|
||||
}
|
||||
|
||||
torch::Tensor shift_dw(
|
||||
@@ -155,7 +156,7 @@ torch::Tensor shift_dw(
|
||||
// run
|
||||
return shift_common(B, C, 1, H, W, 1, R, S, F, stride_h, stride_w,
|
||||
(int32_t*)shift_h.storage().data(), (int32_t*)shift_w.storage().data(),
|
||||
triton::dnn::shift::WGRAD, layout, dy, x, bias);
|
||||
triton::dnn::WGRAD, layout, dy, x, bias);
|
||||
}
|
||||
|
||||
static auto registry =
|
||||
|
@@ -14,8 +14,6 @@ private:
|
||||
void enqueue_impl(driver::stream *stream, driver::kernel *kernel,
|
||||
std::vector<driver::buffer*> args,
|
||||
triton::runtime::launch_information info);
|
||||
// number of flops
|
||||
size_t num_flops() const;
|
||||
// comparison for maps
|
||||
bool operator<(const base& other) const;
|
||||
// default parameters
|
||||
@@ -27,6 +25,9 @@ public:
|
||||
std::string a_ty, std::string b_ty,
|
||||
unsigned alignment_lda, unsigned alignment_ldb);
|
||||
|
||||
// number of flops
|
||||
size_t num_flops() const;
|
||||
|
||||
// triton-c source
|
||||
void triton_c_src(std::ostream &os) const;
|
||||
|
||||
|
109
include/triton/dnn/heuristics.h
Normal file
109
include/triton/dnn/heuristics.h
Normal file
@@ -0,0 +1,109 @@
|
||||
#ifndef TRITON_DNN_HEURISTICS_H
|
||||
#define TRITON_DNN_HEURISTICS_H
|
||||
|
||||
#include <vector>
|
||||
#include "triton/dnn/base.h"
|
||||
|
||||
namespace triton{
|
||||
namespace dnn{
|
||||
|
||||
typedef std::vector<unsigned> params_t;
|
||||
typedef std::tuple<bool, bool> trans_key_t;
|
||||
typedef std::tuple<size_t, size_t> size_key_t;
|
||||
static const std::map<trans_key_t, std::map<size_key_t, params_t>> params = {
|
||||
/* NN */
|
||||
{trans_key_t(false, false), std::map<size_key_t, params_t>{
|
||||
{size_key_t(16, 16), {4, 4, 16, 8, 16, 2, 2, 1, 1, 8, 32, 4, 8, 1}},
|
||||
{size_key_t(16, 32), {2, 8, 16, 8, 32, 2, 2, 1, 1, 16, 32, 4, 8, 1}},
|
||||
{size_key_t(16, 64), {4, 4, 16, 4, 64, 2, 2, 1, 1, 8, 32, 8, 4, 1}},
|
||||
{size_key_t(16, 128), {4, 4, 16, 16, 128, 2, 2, 1, 2, 16, 32, 4, 8, 1}},
|
||||
{size_key_t(32, 16), {4, 8, 32, 8, 16, 2, 2, 1, 1, 8, 32, 4, 8, 1}},
|
||||
{size_key_t(32, 32), {4, 8, 32, 8, 32, 2, 2, 1, 1, 8, 32, 4, 8, 1}},
|
||||
{size_key_t(32, 64), {8, 4, 32, 8, 64, 2, 2, 1, 1, 4, 32, 4, 8, 1}},
|
||||
{size_key_t(32, 128), {8, 4, 32, 16, 128, 2, 2, 1, 4, 16, 32, 8, 4, 1}},
|
||||
{size_key_t(64, 16), {8, 8, 64, 4, 16, 2, 2, 1, 1, 4, 32, 8, 4, 1}},
|
||||
{size_key_t(64, 32), {8, 8, 64, 8, 32, 2, 2, 1, 1, 4, 32, 4, 8, 1}},
|
||||
{size_key_t(64, 64), {8, 8, 64, 16, 64, 2, 2, 2, 1, 8, 32, 4, 8, 1}},
|
||||
{size_key_t(64, 128), {16, 4, 64, 16, 128, 2, 2, 2, 2, 8, 32, 8, 4, 1}},
|
||||
{size_key_t(128, 16), {8, 8, 128, 8, 16, 2, 2, 2, 1, 8, 32, 8, 4, 1}},
|
||||
{size_key_t(128, 32), {8, 8, 128, 16, 32, 2, 2, 2, 1, 8, 32, 4, 8, 1}},
|
||||
{size_key_t(128, 64), {8, 8, 128, 32, 64, 2, 2, 2, 2, 16, 32, 4, 8, 1}},
|
||||
{size_key_t(128, 128), {8, 8, 128, 32, 128, 2, 2, 1, 4, 16, 32, 4, 8, 1}}
|
||||
}},
|
||||
/* NT */
|
||||
{trans_key_t(false, true), std::map<size_key_t, params_t>{
|
||||
{size_key_t(16, 16), {4, 4, 16, 2, 8, 16, 2, 2, 1, 1, 8, 32, 16, 1}},
|
||||
{size_key_t(16, 32), {4, 4, 16, 4, 8, 32, 2, 2, 1, 1, 8, 32, 8, 1}},
|
||||
{size_key_t(16, 64), {4, 4, 16, 8, 8, 64, 2, 2, 1, 4, 32, 32, 16, 1}},
|
||||
{size_key_t(16, 128), {4, 4, 16, 32, 4, 128, 2, 2, 1, 2, 16, 32, 2, 1}},
|
||||
{size_key_t(32, 16), {8, 4, 32, 2, 8, 16, 2, 2, 1, 1, 4, 32, 16, 1}},
|
||||
{size_key_t(32, 32), {4, 8, 32, 4, 8, 32, 2, 2, 1, 1, 8, 32, 8, 1}},
|
||||
{size_key_t(32, 64), {16, 8, 128, 4, 4, 64, 2, 2, 1, 4, 8, 32, 32, 1}},
|
||||
{size_key_t(32, 128), {4, 8, 32, 8, 8, 128, 2, 2, 1, 2, 16, 32, 8, 1}},
|
||||
{size_key_t(64, 16), {8, 8, 64, 2, 8, 16, 2, 2, 1, 1, 4, 32, 16, 1}},
|
||||
{size_key_t(64, 32), {8, 8, 64, 4, 8, 32, 2, 2, 1, 1, 4, 32, 8, 1}},
|
||||
{size_key_t(64, 64), {8, 8, 64, 8, 8, 64, 2, 2, 1, 2, 8, 32, 8, 1}},
|
||||
{size_key_t(64, 128), {8, 8, 64, 16, 8, 128, 2, 2, 1, 4, 16, 32, 8, 1}},
|
||||
{size_key_t(128, 16), {8, 8, 128, 2, 8, 16, 2, 2, 2, 1, 8, 32, 32, 1}},
|
||||
{size_key_t(128, 32), {16, 8, 128, 4, 8, 32, 2, 2, 2, 1, 4, 32, 16, 1}},
|
||||
{size_key_t(128, 64), {8, 8, 128, 8, 8, 64, 2, 2, 2, 2, 16, 32, 16, 1}},
|
||||
{size_key_t(128, 128), {8, 8, 128, 8, 8, 128, 2, 2, 4, 1, 16, 32, 16, 1}}
|
||||
}},
|
||||
/* TN */
|
||||
{trans_key_t(true, false), std::map<size_key_t, params_t>{
|
||||
{size_key_t(16, 16), {8, 16, 16, 16, 2, 2, 1, 1, 4, 8, 32, 2, 8, 1}},
|
||||
{size_key_t(16, 32), {4, 16, 8, 32, 2, 2, 1, 1, 8, 4, 32, 4, 8, 1}},
|
||||
{size_key_t(16, 64), {4, 16, 4, 64, 2, 2, 1, 1, 8, 4, 32, 8, 4, 1}},
|
||||
{size_key_t(16, 128), {16, 16, 16, 128, 2, 2, 1, 2, 4, 8, 32, 4, 8, 1}},
|
||||
{size_key_t(32, 16), {4, 32, 8, 16, 2, 2, 1, 1, 8, 4, 32, 4, 8, 1}},
|
||||
{size_key_t(32, 32), {8, 32, 8, 32, 2, 2, 1, 1, 4, 8, 32, 4, 8, 1}},
|
||||
{size_key_t(32, 64), {8, 32, 8, 64, 2, 2, 1, 1, 4, 8, 32, 4, 8, 1}},
|
||||
{size_key_t(32, 128), {32, 32, 64, 128, 2, 2, 2, 2, 4, 8, 32, 2, 8, 1}},
|
||||
{size_key_t(64, 16), {8, 64, 8, 16, 2, 2, 1, 1, 4, 8, 32, 4, 8, 1}},
|
||||
{size_key_t(64, 32), {8, 64, 8, 32, 2, 2, 1, 1, 4, 8, 32, 4, 8, 1}},
|
||||
{size_key_t(64, 64), {16, 64, 16, 64, 2, 2, 2, 1, 4, 8, 32, 4, 8, 1}},
|
||||
{size_key_t(64, 128), {32, 64, 16, 128, 2, 2, 2, 2, 4, 8, 32, 8, 4, 1}},
|
||||
{size_key_t(128, 16), {16, 128, 16, 16, 2, 2, 2, 1, 4, 8, 32, 4, 8, 1}},
|
||||
{size_key_t(128, 32), {32, 128, 32, 32, 2, 2, 4, 1, 4, 8, 32, 4, 8, 1}},
|
||||
{size_key_t(128, 64), {32, 128, 32, 64, 2, 2, 4, 1, 4, 8, 32, 4, 8, 1}},
|
||||
{size_key_t(128, 128), {32, 128, 32, 128, 2, 2, 4, 1, 4, 8, 32, 4, 8, 1}},
|
||||
}},
|
||||
/* TT */
|
||||
{trans_key_t(true, true), std::map<size_key_t, params_t>{
|
||||
{size_key_t(16, 16), {4, 16, 2, 8, 16, 2, 2, 1, 1, 8, 4, 32, 16, 1}},
|
||||
{size_key_t(16, 32), {8, 16, 4, 8, 32, 2, 2, 1, 1, 4, 8, 32, 8, 1}},
|
||||
{size_key_t(16, 64), {16, 16, 4, 8, 64, 2, 2, 1, 4, 8, 4, 32, 32, 1}},
|
||||
{size_key_t(16, 128), {16, 16, 8, 4, 128, 2, 2, 1, 2, 4, 8, 32, 8, 1}},
|
||||
{size_key_t(32, 16), {4, 32, 2, 8, 16, 2, 2, 1, 1, 8, 4, 32, 16, 1}},
|
||||
{size_key_t(32, 32), {8, 32, 4, 8, 32, 2, 2, 1, 1, 4, 8, 32, 8, 1}},
|
||||
{size_key_t(32, 64), {16, 64, 4, 8, 64, 2, 2, 2, 1, 4, 8, 32, 16, 1}},
|
||||
{size_key_t(32, 128), {32, 32, 8, 8, 128, 2, 2, 1, 4, 4, 8, 32, 16, 1}},
|
||||
{size_key_t(64, 16), {8, 64, 2, 8, 16, 2, 2, 1, 1, 4, 8, 32, 16, 1}},
|
||||
{size_key_t(64, 32), {8, 64, 4, 8, 32, 2, 2, 1, 1, 4, 8, 32, 8, 1}},
|
||||
{size_key_t(64, 64), {16, 64, 8, 8, 64, 2, 2, 2, 1, 4, 8, 32, 8, 1}},
|
||||
{size_key_t(64, 128), {32, 64, 8, 8, 128, 2, 2, 1, 4, 4, 8, 32, 16, 1}},
|
||||
{size_key_t(128, 16), {16, 128, 2, 8, 16, 2, 2, 2, 1, 4, 8, 32, 32, 1}},
|
||||
{size_key_t(128, 32), {32, 128, 8, 4, 32, 2, 2, 4, 1, 4, 8, 32, 16, 1}},
|
||||
{size_key_t(128, 64), {32, 128, 16, 4, 64, 2, 2, 4, 1, 4, 8, 32, 8, 1}},
|
||||
{size_key_t(128, 128), {32, 128, 8, 8, 128, 2, 2, 4, 1, 4, 8, 32, 16, 1}}
|
||||
}}
|
||||
};
|
||||
|
||||
// small search space for partial auto-tuning
|
||||
inline std::vector<params_t> dot_search_space(bool AT, bool BT) {
|
||||
std::vector<params_t> result;
|
||||
for(auto x: params.at(trans_key_t{AT, BT}))
|
||||
result.push_back(x.second);
|
||||
return result;
|
||||
}
|
||||
|
||||
// simple parameter heuristics
|
||||
inline params_t dot_heuristics(bool AT, bool BT, size_t M, size_t N, size_t K) {
|
||||
size_t TM = 128;
|
||||
size_t TN = 128;
|
||||
return params.at(trans_key_t{AT, BT}).at(size_key_t{TM, TN});
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
#endif
|
@@ -35,20 +35,18 @@
|
||||
namespace triton{
|
||||
namespace dnn{
|
||||
|
||||
enum op_t {
|
||||
FPROP,
|
||||
BPROP,
|
||||
WGRAD
|
||||
};
|
||||
|
||||
enum layout_t {
|
||||
NCHW,
|
||||
CHWN
|
||||
};
|
||||
|
||||
class shift: public base {
|
||||
|
||||
public:
|
||||
enum op_t {
|
||||
FPROP,
|
||||
BPROP,
|
||||
WGRAD
|
||||
};
|
||||
|
||||
enum layout_t {
|
||||
NCHW,
|
||||
CHWN
|
||||
};
|
||||
|
||||
private:
|
||||
// initialize and enqueue
|
||||
void init_impl(driver::stream *stream, driver::cu_module *module);
|
||||
@@ -56,7 +54,8 @@ private:
|
||||
void enqueue_impl(driver::stream *stream, driver::kernel *kernel,
|
||||
std::vector<driver::buffer*> args,
|
||||
triton::runtime::launch_information info);
|
||||
std::vector<unsigned> default_params() const;
|
||||
std::vector<params_t> search_space() const;
|
||||
params_t heuristics() const;
|
||||
|
||||
public:
|
||||
|
||||
|
@@ -66,8 +66,9 @@ void base::enqueue(driver::stream *stream, std::vector<driver::buffer *> args, a
|
||||
clone->init_impl(stream, (triton::driver::cu_module*)kernel->module());
|
||||
}
|
||||
/* retrieved compiled template */
|
||||
else
|
||||
else{
|
||||
jit = m_jit.at(this).get();
|
||||
}
|
||||
|
||||
/* get launch parameters */
|
||||
driver::kernel* kernel = jit->get_function(name_.c_str());
|
||||
|
@@ -1,6 +1,7 @@
|
||||
#include "triton/driver/stream.h"
|
||||
#include "triton/driver/kernel.h"
|
||||
#include "triton/dnn/gemm.h"
|
||||
#include "triton/dnn/heuristics.h"
|
||||
#include <string>
|
||||
|
||||
namespace triton{
|
||||
@@ -147,99 +148,12 @@ void matmul(restrict read_only align(16) )" + a_ty_ + R"( *A,
|
||||
|
||||
// small search space for partial auto-tuning
|
||||
std::vector<params_t> dot::search_space() const {
|
||||
typedef std::vector<unsigned> params_t;
|
||||
typedef std::tuple<size_t, size_t> key_t;
|
||||
static std::vector<key_t> keys = {
|
||||
{16, 16}, {16, 32}, {16, 64}, {16, 128},
|
||||
{32, 16}, {32, 32}, {32, 64}, {32, 128},
|
||||
{64, 16}, {64, 32}, {64, 64}, {64, 128},
|
||||
{128, 16},{128, 32},{128, 64},{128, 128}
|
||||
};
|
||||
static std::vector<params_t> space_nn = {
|
||||
{4, 4, 16, 8, 16, 2, 2, 1, 1, 8, 32, 4, 8, 1},
|
||||
{2, 8, 16, 8, 32, 2, 2, 1, 1, 16, 32, 4, 8, 1},
|
||||
{4, 4, 16, 4, 64, 2, 2, 1, 1, 8, 32, 8, 4, 1},
|
||||
{4, 4, 16, 16, 128, 2, 2, 1, 2, 16, 32, 4, 8, 1},
|
||||
{4, 8, 32, 8, 16, 2, 2, 1, 1, 8, 32, 4, 8, 1},
|
||||
{4, 8, 32, 8, 32, 2, 2, 1, 1, 8, 32, 4, 8, 1},
|
||||
{8, 4, 32, 8, 64, 2, 2, 1, 1, 4, 32, 4, 8, 1},
|
||||
{8, 4, 32, 16, 128, 2, 2, 1, 4, 16, 32, 8, 4, 1},
|
||||
{8, 8, 64, 4, 16, 2, 2, 1, 1, 4, 32, 8, 4, 1},
|
||||
{8, 8, 64, 8, 32, 2, 2, 1, 1, 4, 32, 4, 8, 1},
|
||||
{8, 8, 64, 16, 64, 2, 2, 2, 1, 8, 32, 4, 8, 1},
|
||||
{16, 4, 64, 16, 128, 2, 2, 2, 2, 8, 32, 8, 4, 1},
|
||||
{8, 8, 128, 8, 16, 2, 2, 2, 1, 8, 32, 8, 4, 1},
|
||||
{8, 8, 128, 16, 32, 2, 2, 2, 1, 8, 32, 4, 8, 1},
|
||||
{8, 8, 128, 32, 64, 2, 2, 2, 2, 16, 32, 4, 8, 1},
|
||||
{8, 8, 128, 32, 128, 2, 2, 1, 4, 16, 32, 4, 8, 1},
|
||||
};
|
||||
static std::vector<params_t> space_nt = {
|
||||
{4, 4, 16, 2, 8, 16, 2, 2, 1, 1, 8, 32, 16, 1},
|
||||
{4, 4, 16, 4, 8, 32, 2, 2, 1, 1, 8, 32, 8, 1},
|
||||
{4, 4, 16, 8, 8, 64, 2, 2, 1, 4, 32, 32, 16, 1},
|
||||
{4, 4, 16, 32, 4, 128, 2, 2, 1, 2, 16, 32, 2, 1},
|
||||
{8, 4, 32, 2, 8, 16, 2, 2, 1, 1, 4, 32, 16, 1},
|
||||
{4, 8, 32, 4, 8, 32, 2, 2, 1, 1, 8, 32, 8, 1},
|
||||
{16, 8, 128, 4, 4, 64, 2, 2, 1, 4, 8, 32, 32, 1},
|
||||
{4, 8, 32, 8, 8, 128, 2, 2, 1, 2, 16, 32, 8, 1},
|
||||
{8, 8, 64, 2, 8, 16, 2, 2, 1, 1, 4, 32, 16, 1},
|
||||
{8, 8, 64, 4, 8, 32, 2, 2, 1, 1, 4, 32, 8, 1},
|
||||
{8, 8, 64, 8, 8, 64, 2, 2, 1, 2, 8, 32, 8, 1},
|
||||
{8, 8, 64, 16, 8, 128, 2, 2, 1, 4, 16, 32, 8, 1},
|
||||
{8, 8, 128, 2, 8, 16, 2, 2, 2, 1, 8, 32, 32, 1},
|
||||
{16, 8, 128, 4, 8, 32, 2, 2, 2, 1, 4, 32, 16, 1},
|
||||
{8, 8, 128, 8, 8, 64, 2, 2, 2, 2, 16, 32, 16, 1},
|
||||
{8, 8, 128, 8, 8, 128, 2, 2, 4, 1, 16, 32, 16, 1},
|
||||
};
|
||||
static std::vector<params_t> space_tn = {
|
||||
{8, 16, 16, 16, 2, 2, 1, 1, 4, 8, 32, 2, 8, 1},
|
||||
{4, 16, 8, 32, 2, 2, 1, 1, 8, 4, 32, 4, 8, 1},
|
||||
{4, 16, 4, 64, 2, 2, 1, 1, 8, 4, 32, 8, 4, 1},
|
||||
{16, 16, 16, 128, 2, 2, 1, 2, 4, 8, 32, 4, 8, 1},
|
||||
{4, 32, 8, 16, 2, 2, 1, 1, 8, 4, 32, 4, 8, 1},
|
||||
{8, 32, 8, 32, 2, 2, 1, 1, 4, 8, 32, 4, 8, 1},
|
||||
{8, 32, 8, 64, 2, 2, 1, 1, 4, 8, 32, 4, 8, 1},
|
||||
{32, 32, 64, 128, 2, 2, 2, 2, 4, 8, 32, 2, 8, 1},
|
||||
{8, 64, 8, 16, 2, 2, 1, 1, 4, 8, 32, 4, 8, 1},
|
||||
{8, 64, 8, 32, 2, 2, 1, 1, 4, 8, 32, 4, 8, 1},
|
||||
{16, 64, 16, 64, 2, 2, 2, 1, 4, 8, 32, 4, 8, 1},
|
||||
{32, 64, 16, 128, 2, 2, 2, 2, 4, 8, 32, 8, 4, 1},
|
||||
{16, 128, 16, 16, 2, 2, 2, 1, 4, 8, 32, 4, 8, 1},
|
||||
{32, 128, 32, 32, 2, 2, 4, 1, 4, 8, 32, 4, 8, 1},
|
||||
{32, 128, 32, 64, 2, 2, 4, 1, 4, 8, 32, 4, 8, 1},
|
||||
{32, 128, 32, 128, 2, 2, 4, 1, 4, 8, 32, 4, 8, 1},
|
||||
};
|
||||
static std::vector<params_t> space_tt = {
|
||||
{4, 16, 2, 8, 16, 2, 2, 1, 1, 8, 4, 32, 16, 1},
|
||||
{8, 16, 4, 8, 32, 2, 2, 1, 1, 4, 8, 32, 8, 1},
|
||||
{16, 16, 4, 8, 64, 2, 2, 1, 4, 8, 4, 32, 32, 1},
|
||||
{16, 16, 8, 4, 128, 2, 2, 1, 2, 4, 8, 32, 8, 1},
|
||||
{4, 32, 2, 8, 16, 2, 2, 1, 1, 8, 4, 32, 16, 1},
|
||||
{8, 32, 4, 8, 32, 2, 2, 1, 1, 4, 8, 32, 8, 1},
|
||||
{16, 64, 4, 8, 64, 2, 2, 2, 1, 4, 8, 32, 16, 1},
|
||||
{32, 32, 8, 8, 128, 2, 2, 1, 4, 4, 8, 32, 16, 1},
|
||||
{8, 64, 2, 8, 16, 2, 2, 1, 1, 4, 8, 32, 16, 1},
|
||||
{8, 64, 4, 8, 32, 2, 2, 1, 1, 4, 8, 32, 8, 1},
|
||||
{16, 64, 8, 8, 64, 2, 2, 2, 1, 4, 8, 32, 8, 1},
|
||||
{32, 64, 8, 8, 128, 2, 2, 1, 4, 4, 8, 32, 16, 1},
|
||||
{16, 128, 2, 8, 16, 2, 2, 2, 1, 4, 8, 32, 32, 1},
|
||||
{32, 128, 8, 4, 32, 2, 2, 4, 1, 4, 8, 32, 16, 1},
|
||||
{32, 128, 16, 4, 64, 2, 2, 4, 1, 4, 8, 32, 8, 1},
|
||||
{32, 128, 8, 8, 128, 2, 2, 4, 1, 4, 8, 32, 16, 1}
|
||||
};
|
||||
if(!AT_ && !BT_)
|
||||
return space_nn;
|
||||
else if(!AT_ && BT_)
|
||||
return space_nt;
|
||||
else if(AT_ && !BT_)
|
||||
return space_tn;
|
||||
else
|
||||
return space_tt;
|
||||
return dot_search_space(AT_, BT_);
|
||||
}
|
||||
|
||||
// simple parameter heuristics
|
||||
params_t dot::heuristics() const {
|
||||
return search_space().back();
|
||||
return dot_heuristics(AT_, BT_, M_, N_, K_);
|
||||
}
|
||||
|
||||
}
|
||||
|
@@ -1,5 +1,6 @@
|
||||
#include <sstream>
|
||||
#include "triton/dnn/shift.h"
|
||||
#include "triton/dnn/heuristics.h"
|
||||
#include "triton/tools/bench.hpp"
|
||||
|
||||
namespace triton{
|
||||
@@ -513,12 +514,14 @@ else{
|
||||
}
|
||||
|
||||
|
||||
// small search space for partial auto-tuning
|
||||
std::vector<params_t> shift::search_space() const {
|
||||
return dot_search_space(AT_, BT_);
|
||||
}
|
||||
|
||||
// simple parameter heuristics
|
||||
std::vector<unsigned> shift::default_params() const {
|
||||
typedef std::vector<unsigned> params_t;
|
||||
std::map<std::tuple<op_t, size_t, size_t>, params_t> params = {
|
||||
{{}, {}}
|
||||
};
|
||||
params_t shift::heuristics() const {
|
||||
return dot_heuristics(AT_, BT_, M_, N_, K_);
|
||||
}
|
||||
|
||||
|
||||
|
@@ -211,9 +211,9 @@ jit::tune_res_t jit::autotune(const char *name, const char *src, benchmark_t ben
|
||||
best.perf = perf;
|
||||
best.params = params;
|
||||
}
|
||||
for(size_t i = 0; i < params.size(); i++)
|
||||
std::cout << ((i==0)?"":", ") << params[i] << std::flush;
|
||||
std::cout << ", " << perf << " [ " << best.perf << " ] " << std::endl;
|
||||
// for(size_t i = 0; i < params.size(); i++)
|
||||
// std::cout << ((i==0)?"":", ") << params[i] << std::flush;
|
||||
// std::cout << ", " << perf << " [ " << best.perf << " ] " << std::endl;
|
||||
}
|
||||
};
|
||||
|
||||
|
Reference in New Issue
Block a user