diff --git a/examples/cpp/dot.cpp b/examples/cpp/dot.cpp index 0eac9c046..4c9f51960 100644 --- a/examples/cpp/dot.cpp +++ b/examples/cpp/dot.cpp @@ -6,19 +6,21 @@ #include "triton/dnn/gemm.h" #include "triton/tools/bench.hpp" +template +void diff(const std::vector& x, const std::vector& y){ + for(size_t i = 0; i < x.size(); i++) + if(std::isnan(x[i]) || std::abs(x[i] - y[i])/std::max(x[i], y[i]) > 1e-4){ + std::cout << i << " " << x[i] << " " << y[i] << std::endl; + exit(EXIT_FAILURE); + } + std::cout << "Pass!" << std::endl; +} -int main() { - bool AT = false; - bool BT = true; +double bench(triton::driver::context* context, bool AT, bool BT, int32_t M, int32_t N, int32_t K){ typedef float T; std::string ty = "fp16"; size_t dt_nbytes = sizeof(T); - // initialize default compute device - auto context = triton::driver::backend::contexts::get_default(); - // matrix multiplication parameters - int32_t M = 65536, N = 2048, K = 2048; std::vector hc(M*N); - std::vector rc(M*N); std::vector ha(M*K); std::vector hb(K*N); srand(0); @@ -36,14 +38,35 @@ int main() { stream->write(db, true, 0, hb); stream->write(dc, true, 0, hc); stream->synchronize(); - triton::dnn::gemm gemm(M, N, K, AT, BT, ty, ty, 4, 4); - gemm.enqueue(stream, {da, db, dc}, true); -// stream->read(dc, true, 0, hc); -// gemm.cpu_ref(rc, ha, hb); -// for(size_t i = 0; i < M*N; i++) -// if(std::isnan(hc[i]) || std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-4){ -// std::cout << i << " " << hc[i] << " " << rc[i] << std::endl; -// exit(EXIT_FAILURE); -// } -// std::cout << "Pass!" << std::endl; + triton::dnn::dot dot(M, N, K, AT, BT, ty, ty, 8, 8); + double result = triton::tools::bench([&]() { dot.enqueue(stream, {da, db, dc}, triton::dnn::PARTIAL_TUNING);}, stream); + delete dc; + delete da; + delete db; + return result; +} + +int main() { + struct config_t{ + bool AT; + bool BT; + int32_t M; + int32_t N; + int32_t K; + }; + // shapes to benchmark + std::vector configs = { + {false, false, 4096, 4096, 4096}, + {false, true, 4096, 4096, 4096}, + {true, false, 4096, 4096, 4096}, + {true, true, 4096, 4096, 4096} + }; + // initialize default compute device + auto context = triton::driver::backend::contexts::get_default(); + // does the work + for(config_t c: configs){ + double tns = bench(context, c.AT, c.BT, c.M, c.N, c.K); + double tflops = 2.*c.M*c.N*c.K / tns * 1e-3; + std::cout << c.AT << ", " << c.BT << ", " << c.M << ", " << c.N << ", " << c.K << ", " << tflops << std::endl; + } } diff --git a/examples/cpp/shift.cpp b/examples/cpp/shift.cpp index 3dabddfe2..388523de2 100644 --- a/examples/cpp/shift.cpp +++ b/examples/cpp/shift.cpp @@ -8,31 +8,23 @@ #include "triton/dnn/shift.h" #include "triton/external/half.hpp" -int main() { +double do_bench(triton::driver::context* context, + int32_t R, int32_t S, int32_t B, int32_t F, int32_t H, int32_t W, int32_t C, + triton::dnn::shift::op_t op, triton::dnn::shift::layout_t layout, + std::string numeric_t) { typedef float NumericT; - std::string numeric_t_str = "fp16"; - - // initialize default compute device - auto context = triton::driver::backend::contexts::get_default(); - auto op = triton::dnn::shift::FPROP; - - // initialization - int32_t R = 3, S = 3; - int32_t B = 64, F = 2048; - int32_t H = 32, W = 32; - int32_t C = 2048; // random shifts std::vector shift_h(C); std::vector shift_w(C); for(int32_t c = 0; c < C; c++){ - shift_h[c] = 0; - shift_w[c] = 0; + shift_h[c] = rand() % R - R / 2; + shift_w[c] = rand() % S - S / 2; } // configuration triton::dnn::shift shift(B, C, 1, H, W, 1, R, S, F, 1, 1, shift_h.data(), shift_w.data(), - numeric_t_str, numeric_t_str, + numeric_t, numeric_t, op, false, triton::dnn::shift::CHWN); // host buffers size_t a_size = B*C*H*W; @@ -67,13 +59,19 @@ int main() { stream->write(dc, true, 0, hc); stream->synchronize(); shift.enqueue(stream, {da, db, dc}, true); -// stream->read(dc, true, 0, hc); -// shift.cpu_ref(rc.data(), ha.data(), hb.data()); -// for(size_t i = 0; i < hc.size(); i++) -// if(std::isnan(hc[i]) || std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-4){ -// std::cout << i << " " << hc[i] << " " << rc[i] << std::endl; -// exit(EXIT_FAILURE); -// } -// std::cout << "Pass!" << std::endl; + double tns = triton::tools::bench([&]() { shift.enqueue(stream, {da, db, dc}, true);}, stream); + std::cout << tns << std::endl; +} + +int main() { + // initialize default compute device + auto context = triton::driver::backend::contexts::get_default(); + // shapes + int32_t R = 3, S = 3; + int32_t B = 16, F = 4096; + int32_t H = 32, W = 32; + int32_t C = 4096; + // benchmark + do_bench(context, R, S, B, F, H, W, C, triton::dnn::shift::FPROP, triton::dnn::shift::CHWN, "fp16"); } diff --git a/examples/python/pytorch/shift.cpp b/examples/python/pytorch/shift.cpp index 7efe0198b..a16c2922e 100644 --- a/examples/python/pytorch/shift.cpp +++ b/examples/python/pytorch/shift.cpp @@ -36,7 +36,7 @@ torch::Tensor shift_common( int32_t T, int32_t R, int32_t S, int32_t F, int32_t stride_h, int32_t stride_w, int32_t* shift_h, int32_t* shift_w, - triton::dnn::shift::type ty, triton::dnn::shift::layout_t layout, + triton::dnn::shift::op_t ty, triton::dnn::shift::layout_t layout, torch::Tensor torcha, torch::Tensor torchb, torch::Tensor torchbias, bool autotune = false ) { diff --git a/examples/python/tensorflow/dot.cpp b/examples/python/tensorflow/dot.cpp index 9bd25eeb3..84f67664c 100644 --- a/examples/python/tensorflow/dot.cpp +++ b/examples/python/tensorflow/dot.cpp @@ -49,7 +49,7 @@ class DotOp : public OpKernel { triton::driver::cu_buffer db(ctx, (CUdeviceptr)b.flat().data(), false); triton::driver::cu_buffer dc(ctx, (CUdeviceptr)c->flat().data(), false); // template - triton::dnn::gemm dot(M, N, K, false, false, "fp16", "fp16", 4, 4); + triton::dnn::dot dot(M, N, K, false, false, "fp16", "fp16", 8, 8); dot.enqueue(stream, {&da, &db, &dc}); } diff --git a/examples/python/tensorflow/shift.cpp b/examples/python/tensorflow/shift.cpp index 1834cadaf..1a9ebbe59 100644 --- a/examples/python/tensorflow/shift.cpp +++ b/examples/python/tensorflow/shift.cpp @@ -19,7 +19,7 @@ using namespace tensorflow; using GPUDevice = Eigen::GpuDevice; -template +template class ShiftConvOp : public OpKernel { public: explicit ShiftConvOp(OpKernelConstruction* context) : OpKernel(context), layout_(triton::dnn::shift::NCHW) { diff --git a/include/triton/dnn/base.h b/include/triton/dnn/base.h index 3045ffb49..1fbded42c 100644 --- a/include/triton/dnn/base.h +++ b/include/triton/dnn/base.h @@ -31,6 +31,13 @@ namespace triton{ namespace dnn{ +enum autotuning_t{ + FULL_TUNING, + PARTIAL_TUNING, + NO_TUNING +}; + +typedef std::vector params_t; class base { friend class cmp_recompile; @@ -53,6 +60,9 @@ private: virtual size_t num_flops() const = 0; // comparison for maps virtual bool operator<(const base& other) const = 0; + // default parameters + virtual std::vector search_space() const; + virtual params_t heuristics() const; public: // constructor @@ -62,7 +72,7 @@ public: // clone virtual base* clone() const = 0; // enqueue - void enqueue(driver::stream* stream, std::vector args, bool autotune = false); + void enqueue(driver::stream* stream, std::vector args, autotuning_t autotune = PARTIAL_TUNING); private: std::string name_; diff --git a/include/triton/dnn/gemm.h b/include/triton/dnn/gemm.h index f43370606..1e581f6a1 100644 --- a/include/triton/dnn/gemm.h +++ b/include/triton/dnn/gemm.h @@ -6,7 +6,7 @@ namespace triton{ namespace dnn{ -class gemm: public base { +class dot: public base { private: // initialize void init_impl(driver::stream *, driver::cu_module *); @@ -18,10 +18,12 @@ private: size_t num_flops() const; // comparison for maps bool operator<(const base& other) const; - + // default parameters + virtual std::vector search_space() const; + virtual params_t heuristics() const; public: - gemm(int M, int N, int K, bool AT, bool BT, + dot(int M, int N, int K, bool AT, bool BT, std::string a_ty, std::string b_ty, unsigned alignment_lda, unsigned alignment_ldb); @@ -46,13 +48,13 @@ public: template void cpu_ref(std::vector &c, const std::vector &a, const std::vector &b) { if(AT_ && BT_) - gemm::cpu_ref(c, a, b, M_, N_, K_); + dot::cpu_ref(c, a, b, M_, N_, K_); else if(AT_ && !BT_) - gemm::cpu_ref(c, a, b, M_, N_, K_); + dot::cpu_ref(c, a, b, M_, N_, K_); else if(!AT_ && BT_) - gemm::cpu_ref(c, a, b, M_, N_, K_); + dot::cpu_ref(c, a, b, M_, N_, K_); else - gemm::cpu_ref(c, a, b, M_, N_, K_); + dot::cpu_ref(c, a, b, M_, N_, K_); } private: diff --git a/include/triton/dnn/shift.h b/include/triton/dnn/shift.h index 84c6ccda7..59d26ab44 100644 --- a/include/triton/dnn/shift.h +++ b/include/triton/dnn/shift.h @@ -38,7 +38,7 @@ namespace dnn{ class shift: public base { public: - enum type { + enum op_t { FPROP, BPROP, WGRAD @@ -56,6 +56,7 @@ private: void enqueue_impl(driver::stream *stream, driver::kernel *kernel, std::vector args, triton::runtime::launch_information info); + std::vector default_params() const; public: @@ -65,7 +66,7 @@ public: int stride_h, int stride_w, const int32_t* shift_h, const int32_t* shift_w, std::string a_ty = "fp32", std::string b_ty = "fp32", - type ty = FPROP, bool bias = false, layout_t layout = CHWN); + op_t ty = FPROP, bool bias = false, layout_t layout = CHWN); // look-up table void build_delta_a(); @@ -165,7 +166,7 @@ private: std::string b_ty_; std::string c_ty_; // convolution type - type op_; + op_t op_; bool bias_; // transpose bool AT_; diff --git a/include/triton/driver/dispatch.h b/include/triton/driver/dispatch.h index c1f4f01f9..9803a163e 100755 --- a/include/triton/driver/dispatch.h +++ b/include/triton/driver/dispatch.h @@ -167,6 +167,8 @@ public: static nvmlReturn_t nvmlDeviceGetClockInfo(nvmlDevice_t device, nvmlClockType_t type, unsigned int *clock); static nvmlReturn_t nvmlDeviceGetMaxClockInfo(nvmlDevice_t device, nvmlClockType_t type, unsigned int *clock); static nvmlReturn_t nvmlDeviceSetApplicationsClocks(nvmlDevice_t device, unsigned int mem_clock, unsigned int sm_clock); + + // SPIR-V libraries static int initializeLLVMToSPIRVPass(llvm::PassRegistry &); static bool writeSpirv(llvm::Module *M, std::ostream &OS, std::string &ErrMsg); diff --git a/include/triton/runtime/jit.h b/include/triton/runtime/jit.h index d3088d73b..f1da2a5a2 100644 --- a/include/triton/runtime/jit.h +++ b/include/triton/runtime/jit.h @@ -108,7 +108,7 @@ public: jit(driver::context* context, unsigned nthreads = 4); ~jit(); std::vector get_valid(const char *name, const char *src); - tune_res_t autotune(const char* name, const char* src, benchmark_t benchmark); + tune_res_t autotune(const char* name, const char* src, benchmark_t benchmark, const std::vector > &targets = {}); void add_module(ir::module &module, const std::vector& params = {}); void add_module(const char* name, const char* src, const std::vector& params = {}); driver::kernel* get_function(const char* name); diff --git a/include/triton/tools/bench.hpp b/include/triton/tools/bench.hpp index 64c88cd64..6ac72fec7 100644 --- a/include/triton/tools/bench.hpp +++ b/include/triton/tools/bench.hpp @@ -2,6 +2,9 @@ #define TRITON_TOOLS_BENCH_HPP #include +#include +#include "triton/driver/device.h" +#include "triton/driver/stream.h" namespace triton{ namespace tools{ @@ -24,14 +27,14 @@ private: high_resolution_clock::time_point _start; }; -template -double bench(OP const & op, SYNC const & sync, const triton::driver::device * device) +inline double bench(std::function const & op, driver::stream * stream) { + const driver::device * device = stream->context()->device(); timer tmr; std::vector times; double total_time = 0; op(); - sync(); + stream->synchronize(); while(total_time*1e-9 < 1e-3){ float norm = 1; // normalize clock if possible to get roughly constant result @@ -39,7 +42,7 @@ double bench(OP const & op, SYNC const & sync, const triton::driver::device * de norm = (float)cu_device->current_sm_clock()/cu_device->max_sm_clock(); tmr.start(); op(); - sync(); + stream->synchronize(); times.push_back(norm*tmr.get().count()); total_time+=times.back(); } diff --git a/lib/codegen/selection.cpp b/lib/codegen/selection.cpp index 4e4741658..8a6e74f33 100644 --- a/lib/codegen/selection.cpp +++ b/lib/codegen/selection.cpp @@ -529,8 +529,8 @@ void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id unsigned num_rep_0 = shapes[0]->get_value() / hmma_bts_0; unsigned num_rep_1 = shapes[1]->get_value() / hmma_bts_1; // size of each pack (interleaving) - pack_size_0_ = std::min(num_rep_0, 2); - pack_size_1_ = std::min(num_rep_1, 2); + pack_size_0_ = std::min(num_rep_0, 1); + pack_size_1_ = std::min(num_rep_1, 1); // number of packs (interleaving) num_packs_0_ = num_rep_0 / pack_size_0_; num_packs_1_ = num_rep_1 / pack_size_1_; diff --git a/lib/codegen/tune.cpp b/lib/codegen/tune.cpp index 47b3f05fa..3f5119577 100644 --- a/lib/codegen/tune.cpp +++ b/lib/codegen/tune.cpp @@ -221,7 +221,7 @@ void tune::run(ir::module &mod) { } else { ir::metaparameter *fpw = ir::metaparameter::create(ctx, ty, 2, 2); - ir::metaparameter *wpt = ir::metaparameter::create(ctx, ty, 2, 4); + ir::metaparameter *wpt = ir::metaparameter::create(ctx, ty, 1, 4); connected_components(node, {fpw, wpt}, {"fpw", "wpt"}, nodes_, dependencies_, group_id++); } } @@ -239,7 +239,7 @@ void tune::run(ir::module &mod) { size_t addr_space = ptr_ty->get_pointer_address_space(); if(addr_space < 4){ ir::type *ty = mod.get_builder().get_int32_ty(); - std::unique_ptr tmp(ir::metaparameter::create(ctx, ty, 8, 8)); + std::unique_ptr tmp(ir::metaparameter::create(ctx, ty, 4, 8)); *params_.at(i).at("nts.d0") = *tmp; } } diff --git a/lib/dnn/base.cpp b/lib/dnn/base.cpp index 73bb474b8..8224dc846 100644 --- a/lib/dnn/base.cpp +++ b/lib/dnn/base.cpp @@ -7,8 +7,6 @@ namespace triton{ namespace dnn{ - - void base::set_ld(const std::vector& shapes, std::vector& ld) { size_t size = shapes.size(); @@ -22,7 +20,15 @@ void base::set_ld(const std::vector& shapes, base::base(const std::string& name) : name_(name) { } -void base::enqueue(driver::stream *stream, std::vector args, bool autotune) { +std::vector base::search_space() const { + return {}; +} + +params_t base::heuristics() const { + return *search_space().begin(); +} + +void base::enqueue(driver::stream *stream, std::vector args, autotuning_t autotune) { namespace rt = triton::runtime; static std::map, cmp_recompile> m_jit; driver::context* ctx = stream->context(); @@ -30,7 +36,7 @@ void base::enqueue(driver::stream *stream, std::vector args, b /* the current template has not already been compiled */ if(m_jit.find(this) == m_jit.end()) { base* clone = this->clone(); - jit = m_jit.emplace(clone, std::unique_ptr(new rt::jit(ctx, 8))).first->second.get(); + jit = m_jit.emplace(clone, std::unique_ptr(new rt::jit(ctx))).first->second.get(); std::ostringstream oss; clone->triton_c_src(oss); std::string src = oss.str(); @@ -40,18 +46,21 @@ void base::enqueue(driver::stream *stream, std::vector args, b clone->init_impl(stream, (triton::driver::cu_module*)kernel->module()); clone->enqueue_impl(stream, kernel, args, info); stream->synchronize(); - double ts = triton::tools::bench([&](){ clone->enqueue_impl(stream, kernel, args, info); }, - [&](){ stream->synchronize(); }, ctx->device()); + double ts = triton::tools::bench([&](){ clone->enqueue_impl(stream, kernel, args, info); }, stream); clone->deinit_impl(); return num_flops() / ts * 1e-3; }; // auto-tune and save result - if(autotune) { - rt::jit::tune_res_t best = jit->autotune(name_.c_str(), src.c_str(), benchmark); + if(autotune != NO_TUNING) { + std::vector space = {}; + if(autotune == PARTIAL_TUNING) + space = search_space(); + rt::jit::tune_res_t best = jit->autotune(name_.c_str(), src.c_str(), benchmark, space); jit->add_module(name_.c_str(), src.c_str(), best.params); } else { - jit->add_module(name_.c_str(), src.c_str(), jit->get_valid(name_.c_str(), src.c_str())); + params_t params = heuristics(); + jit->add_module(name_.c_str(), src.c_str(), params); } triton::driver::kernel* kernel = jit->get_function(name_.c_str()); clone->init_impl(stream, (triton::driver::cu_module*)kernel->module()); diff --git a/lib/dnn/gemm.cpp b/lib/dnn/gemm.cpp index 42c7793c2..43fc9f173 100644 --- a/lib/dnn/gemm.cpp +++ b/lib/dnn/gemm.cpp @@ -6,7 +6,7 @@ namespace triton{ namespace dnn{ -gemm::gemm(int M, int N, int K, +dot::dot(int M, int N, int K, bool AT, bool BT, std::string a_ty, std::string b_ty, unsigned alignment_lda, unsigned alignment_ldb) @@ -18,13 +18,13 @@ gemm::gemm(int M, int N, int K, } -size_t gemm::num_flops() const { +size_t dot::num_flops() const { return 2.*M_*N_*K_; } // comparison for maps -bool gemm::operator<(const base& other) const { - auto *y = dynamic_cast(&other); +bool dot::operator<(const base& other) const { + auto *y = dynamic_cast(&other); if(!y) return true; return std::tie(M_, N_, K_, AT_, BT_, @@ -34,18 +34,18 @@ bool gemm::operator<(const base& other) const { } // clone -base* gemm::clone() const { - return new gemm(*this); +base* dot::clone() const { + return new dot(*this); } -void gemm::init_impl(driver::stream* stream, driver::cu_module *) { +void dot::init_impl(driver::stream* stream, driver::cu_module *) { std::vector hlocks(2048, 0); if(locks_ == nullptr) locks_ = triton::driver::buffer::create(stream->context(), hlocks.size()*4); stream->write(locks_, false, 0, hlocks); } -void gemm::enqueue_impl(driver::stream *stream, driver::kernel *kernel, +void dot::enqueue_impl(driver::stream *stream, driver::kernel *kernel, std::vector args, runtime::launch_information info) { driver::buffer *a = args[0], *b = args[1], *c = args[2]; @@ -75,7 +75,7 @@ void gemm::enqueue_impl(driver::stream *stream, driver::kernel *kernel, stream->enqueue(kernel, grid, {info.num_threads, 1, 1}); } -void gemm::triton_c_src(std::ostream &os) const { +void dot::triton_c_src(std::ostream &os) const { std::string AS0 = "TM", AS1 = "TK"; std::string BS0 = "TK", BS1 = "TN"; std::string bca0 = "[newaxis, :]", bca1 = "[:, newaxis]"; @@ -100,8 +100,8 @@ void gemm::triton_c_src(std::ostream &os) const { std::string align_ldb_str = "multiple_of(" + std::to_string(align_ldb_) + ")"; std::string res = R"( -const tunable int32 TM = {32, 64, 128, 256}; -const tunable int32 TN = {32, 64, 128, 256}; +const tunable int32 TM = {16, 32, 64, 128}; +const tunable int32 TN = {16, 32, 64, 128}; const tunable int32 TK = {32}; const tunable int32 GZ = {1}; @@ -145,5 +145,102 @@ void matmul(restrict read_only align(16) )" + a_ty_ + R"( *A, os << res; } +// small search space for partial auto-tuning +std::vector dot::search_space() const { + typedef std::vector params_t; + typedef std::tuple key_t; + static std::vector keys = { + {16, 16}, {16, 32}, {16, 64}, {16, 128}, + {32, 16}, {32, 32}, {32, 64}, {32, 128}, + {64, 16}, {64, 32}, {64, 64}, {64, 128}, + {128, 16},{128, 32},{128, 64},{128, 128} + }; + static std::vector space_nn = { + {4, 4, 16, 8, 16, 2, 2, 1, 1, 8, 32, 4, 8, 1}, + {2, 8, 16, 8, 32, 2, 2, 1, 1, 16, 32, 4, 8, 1}, + {4, 4, 16, 4, 64, 2, 2, 1, 1, 8, 32, 8, 4, 1}, + {4, 4, 16, 16, 128, 2, 2, 1, 2, 16, 32, 4, 8, 1}, + {4, 8, 32, 8, 16, 2, 2, 1, 1, 8, 32, 4, 8, 1}, + {4, 8, 32, 8, 32, 2, 2, 1, 1, 8, 32, 4, 8, 1}, + {8, 4, 32, 8, 64, 2, 2, 1, 1, 4, 32, 4, 8, 1}, + {8, 4, 32, 16, 128, 2, 2, 1, 4, 16, 32, 8, 4, 1}, + {8, 8, 64, 4, 16, 2, 2, 1, 1, 4, 32, 8, 4, 1}, + {8, 8, 64, 8, 32, 2, 2, 1, 1, 4, 32, 4, 8, 1}, + {8, 8, 64, 16, 64, 2, 2, 2, 1, 8, 32, 4, 8, 1}, + {16, 4, 64, 16, 128, 2, 2, 2, 2, 8, 32, 8, 4, 1}, + {8, 8, 128, 8, 16, 2, 2, 2, 1, 8, 32, 8, 4, 1}, + {8, 8, 128, 16, 32, 2, 2, 2, 1, 8, 32, 4, 8, 1}, + {8, 8, 128, 32, 64, 2, 2, 2, 2, 16, 32, 4, 8, 1}, + {8, 8, 128, 32, 128, 2, 2, 1, 4, 16, 32, 4, 8, 1}, + }; + static std::vector space_nt = { + {4, 4, 16, 2, 8, 16, 2, 2, 1, 1, 8, 32, 16, 1}, + {4, 4, 16, 4, 8, 32, 2, 2, 1, 1, 8, 32, 8, 1}, + {4, 4, 16, 8, 8, 64, 2, 2, 1, 4, 32, 32, 16, 1}, + {4, 4, 16, 32, 4, 128, 2, 2, 1, 2, 16, 32, 2, 1}, + {8, 4, 32, 2, 8, 16, 2, 2, 1, 1, 4, 32, 16, 1}, + {4, 8, 32, 4, 8, 32, 2, 2, 1, 1, 8, 32, 8, 1}, + {16, 8, 128, 4, 4, 64, 2, 2, 1, 4, 8, 32, 32, 1}, + {4, 8, 32, 8, 8, 128, 2, 2, 1, 2, 16, 32, 8, 1}, + {8, 8, 64, 2, 8, 16, 2, 2, 1, 1, 4, 32, 16, 1}, + {8, 8, 64, 4, 8, 32, 2, 2, 1, 1, 4, 32, 8, 1}, + {8, 8, 64, 8, 8, 64, 2, 2, 1, 2, 8, 32, 8, 1}, + {8, 8, 64, 16, 8, 128, 2, 2, 1, 4, 16, 32, 8, 1}, + {8, 8, 128, 2, 8, 16, 2, 2, 2, 1, 8, 32, 32, 1}, + {16, 8, 128, 4, 8, 32, 2, 2, 2, 1, 4, 32, 16, 1}, + {8, 8, 128, 8, 8, 64, 2, 2, 2, 2, 16, 32, 16, 1}, + {8, 8, 128, 8, 8, 128, 2, 2, 4, 1, 16, 32, 16, 1}, + }; + static std::vector space_tn = { + {8, 16, 16, 16, 2, 2, 1, 1, 4, 8, 32, 2, 8, 1}, + {4, 16, 8, 32, 2, 2, 1, 1, 8, 4, 32, 4, 8, 1}, + {4, 16, 4, 64, 2, 2, 1, 1, 8, 4, 32, 8, 4, 1}, + {16, 16, 16, 128, 2, 2, 1, 2, 4, 8, 32, 4, 8, 1}, + {4, 32, 8, 16, 2, 2, 1, 1, 8, 4, 32, 4, 8, 1}, + {8, 32, 8, 32, 2, 2, 1, 1, 4, 8, 32, 4, 8, 1}, + {8, 32, 8, 64, 2, 2, 1, 1, 4, 8, 32, 4, 8, 1}, + {32, 32, 64, 128, 2, 2, 2, 2, 4, 8, 32, 2, 8, 1}, + {8, 64, 8, 16, 2, 2, 1, 1, 4, 8, 32, 4, 8, 1}, + {8, 64, 8, 32, 2, 2, 1, 1, 4, 8, 32, 4, 8, 1}, + {16, 64, 16, 64, 2, 2, 2, 1, 4, 8, 32, 4, 8, 1}, + {32, 64, 16, 128, 2, 2, 2, 2, 4, 8, 32, 8, 4, 1}, + {16, 128, 16, 16, 2, 2, 2, 1, 4, 8, 32, 4, 8, 1}, + {32, 128, 32, 32, 2, 2, 4, 1, 4, 8, 32, 4, 8, 1}, + {32, 128, 32, 64, 2, 2, 4, 1, 4, 8, 32, 4, 8, 1}, + {32, 128, 32, 128, 2, 2, 4, 1, 4, 8, 32, 4, 8, 1}, + }; + static std::vector space_tt = { + {4, 16, 2, 8, 16, 2, 2, 1, 1, 8, 4, 32, 16, 1}, + {8, 16, 4, 8, 32, 2, 2, 1, 1, 4, 8, 32, 8, 1}, + {16, 16, 4, 8, 64, 2, 2, 1, 4, 8, 4, 32, 32, 1}, + {16, 16, 8, 4, 128, 2, 2, 1, 2, 4, 8, 32, 8, 1}, + {4, 32, 2, 8, 16, 2, 2, 1, 1, 8, 4, 32, 16, 1}, + {8, 32, 4, 8, 32, 2, 2, 1, 1, 4, 8, 32, 8, 1}, + {16, 64, 4, 8, 64, 2, 2, 2, 1, 4, 8, 32, 16, 1}, + {32, 32, 8, 8, 128, 2, 2, 1, 4, 4, 8, 32, 16, 1}, + {8, 64, 2, 8, 16, 2, 2, 1, 1, 4, 8, 32, 16, 1}, + {8, 64, 4, 8, 32, 2, 2, 1, 1, 4, 8, 32, 8, 1}, + {16, 64, 8, 8, 64, 2, 2, 2, 1, 4, 8, 32, 8, 1}, + {32, 64, 8, 8, 128, 2, 2, 1, 4, 4, 8, 32, 16, 1}, + {16, 128, 2, 8, 16, 2, 2, 2, 1, 4, 8, 32, 32, 1}, + {32, 128, 8, 4, 32, 2, 2, 4, 1, 4, 8, 32, 16, 1}, + {32, 128, 16, 4, 64, 2, 2, 4, 1, 4, 8, 32, 8, 1}, + {32, 128, 8, 8, 128, 2, 2, 4, 1, 4, 8, 32, 16, 1} + }; + if(!AT_ && !BT_) + return space_nn; + else if(!AT_ && BT_) + return space_nt; + else if(AT_ && !BT_) + return space_tn; + else + return space_tt; +} + +// simple parameter heuristics +params_t dot::heuristics() const { + return search_space().back(); +} + } } diff --git a/lib/dnn/shift.cpp b/lib/dnn/shift.cpp index adc36740c..e8a4f3584 100644 --- a/lib/dnn/shift.cpp +++ b/lib/dnn/shift.cpp @@ -13,7 +13,7 @@ shift::shift(int B, int C, int stride_h, int stride_w, const int32_t *shift_h, const int32_t *shift_w, std::string a_ty, std::string b_ty, - type ty, bool bias, + op_t ty, bool bias, layout_t layout) : base("shift"), B_(B), C_(C), @@ -512,5 +512,15 @@ else{ os << result; } + +// simple parameter heuristics +std::vector shift::default_params() const { + typedef std::vector params_t; + std::map, params_t> params = { + {{}, {}} + }; +} + + } } diff --git a/lib/runtime/jit.cpp b/lib/runtime/jit.cpp index c925a690c..6f25de8da 100644 --- a/lib/runtime/jit.cpp +++ b/lib/runtime/jit.cpp @@ -31,7 +31,7 @@ extern triton::lang::translation_unit *ast_root; namespace triton { namespace runtime{ -void loop_nest(std::vector const & ranges, +void parallel_loop_nest(std::vector const & ranges, std::function const &)> const & f, size_t nthreads){ size_t D = ranges.size(); @@ -55,7 +55,7 @@ void loop_nest(std::vector const & ranges, } template -void loop_nest(std::vector> const & iterates, std::function)> const & f, size_t nthreads){ +void parallel_loop_nest(std::vector> const & iterates, std::function)> const & f, size_t nthreads){ //Ranges to iterate over std::vector ranges; for(auto const & x: iterates) @@ -68,10 +68,14 @@ void loop_nest(std::vector> const & iterates, std::function> const & iterates, std::function)> const & f, size_t nthreads) { + ThreadPool pool(nthreads); + for(const std::vector& values: iterates) + pool.enqueue(f, values); +} std::unique_ptr jit::make_llvm_module(ir::module &module, passes_wrapper &passes, llvm::LLVMContext& llvm_context, launch_information& info) { @@ -128,7 +132,7 @@ std::vector jit::get_valid(const char *name, const char *src) { ranges.push_back(mp->get_space()); // iterate over parameters std::vector result; - loop_nest(ranges, [&](const std::vector params){ + parallel_loop_nest(ranges, [&](const std::vector params){ if(!result.empty()) return; std::map> errors; @@ -148,7 +152,7 @@ std::vector jit::get_valid(const char *name, const char *src) { -jit::tune_res_t jit::autotune(const char *name, const char *src, benchmark_t benchmark) { +jit::tune_res_t jit::autotune(const char *name, const char *src, benchmark_t benchmark, const std::vector> & targets) { // find metaparameters triton::lang::translation_unit* program = parse_program(name, src); auto ptt_module_0 = make_triton_module(name, triton_context_, program); @@ -157,15 +161,12 @@ jit::tune_res_t jit::autotune(const char *name, const char *src, benchmark_t ben passes_wrapper passes_0(target_.get()); passes_0.target_independent(tt_module_0); passes_0.tune.run(tt_module_0); - // create parameter ranges - std::vector> ranges; auto mps = passes_0.tune.get_params(tt_module_0); - for(ir::metaparameter *mp: mps) - ranges.push_back(mp->get_space()); // iterate over parameters tune_res_t best; std::mutex mutex; - loop_nest(ranges, [&](const std::vector params){ + // update_best + auto update_best = [&](const std::vector params){ std::map> errors; unsigned i = 0; { @@ -200,10 +201,10 @@ jit::tune_res_t jit::autotune(const char *name, const char *src, benchmark_t ben launch_information info; llvm::LLVMContext llvm_context; auto ll_module = make_llvm_module(tt_module_1, passes_1, llvm_context, info); + std::unique_ptr module(driver::module::create(driver_context_, &*ll_module)); double perf; { std::lock_guard lock(mutex); - std::unique_ptr module(driver::module::create(driver_context_, &*ll_module)); std::unique_ptr kernel(driver::kernel::create(module.get(), name)); perf = benchmark(kernel.get(), info); if(perf > best.perf){ @@ -214,8 +215,21 @@ jit::tune_res_t jit::autotune(const char *name, const char *src, benchmark_t ben std::cout << ((i==0)?"":", ") << params[i] << std::flush; std::cout << ", " << perf << " [ " << best.perf << " ] " << std::endl; } - }, nthreads_); - std::cout << "Autotuning done - Best performance: " << best.perf << std::endl; + }; + + + if(targets.empty()) { + // create parameter ranges + std::vector> ranges; + for(ir::metaparameter *mp: mps) + ranges.push_back(mp->get_space()); + parallel_loop_nest(ranges, update_best, nthreads_); + } + else { + parallel_for_each(targets, update_best, nthreads_); + } + +// std::cout << "Autotuning done - Best performance: " << best.perf << std::endl; return best; }