more work on heuristics

This commit is contained in:
Philippe Tillet
2019-07-21 18:11:54 -07:00
parent 484e3871cf
commit b1d81a5802
17 changed files with 268 additions and 99 deletions

View File

@@ -7,8 +7,6 @@ namespace triton{
namespace dnn{
void base::set_ld(const std::vector<int32_t>& shapes,
std::vector<int32_t>& ld) {
size_t size = shapes.size();
@@ -22,7 +20,15 @@ void base::set_ld(const std::vector<int32_t>& shapes,
base::base(const std::string& name)
: name_(name) { }
void base::enqueue(driver::stream *stream, std::vector<driver::buffer *> args, bool autotune) {
std::vector<params_t> base::search_space() const {
return {};
}
params_t base::heuristics() const {
return *search_space().begin();
}
void base::enqueue(driver::stream *stream, std::vector<driver::buffer *> args, autotuning_t autotune) {
namespace rt = triton::runtime;
static std::map<base*, std::unique_ptr<rt::jit>, cmp_recompile> m_jit;
driver::context* ctx = stream->context();
@@ -30,7 +36,7 @@ void base::enqueue(driver::stream *stream, std::vector<driver::buffer *> args, b
/* the current template has not already been compiled */
if(m_jit.find(this) == m_jit.end()) {
base* clone = this->clone();
jit = m_jit.emplace(clone, std::unique_ptr<rt::jit>(new rt::jit(ctx, 8))).first->second.get();
jit = m_jit.emplace(clone, std::unique_ptr<rt::jit>(new rt::jit(ctx))).first->second.get();
std::ostringstream oss;
clone->triton_c_src(oss);
std::string src = oss.str();
@@ -40,18 +46,21 @@ void base::enqueue(driver::stream *stream, std::vector<driver::buffer *> args, b
clone->init_impl(stream, (triton::driver::cu_module*)kernel->module());
clone->enqueue_impl(stream, kernel, args, info);
stream->synchronize();
double ts = triton::tools::bench([&](){ clone->enqueue_impl(stream, kernel, args, info); },
[&](){ stream->synchronize(); }, ctx->device());
double ts = triton::tools::bench([&](){ clone->enqueue_impl(stream, kernel, args, info); }, stream);
clone->deinit_impl();
return num_flops() / ts * 1e-3;
};
// auto-tune and save result
if(autotune) {
rt::jit::tune_res_t best = jit->autotune(name_.c_str(), src.c_str(), benchmark);
if(autotune != NO_TUNING) {
std::vector<params_t> space = {};
if(autotune == PARTIAL_TUNING)
space = search_space();
rt::jit::tune_res_t best = jit->autotune(name_.c_str(), src.c_str(), benchmark, space);
jit->add_module(name_.c_str(), src.c_str(), best.params);
}
else {
jit->add_module(name_.c_str(), src.c_str(), jit->get_valid(name_.c_str(), src.c_str()));
params_t params = heuristics();
jit->add_module(name_.c_str(), src.c_str(), params);
}
triton::driver::kernel* kernel = jit->get_function(name_.c_str());
clone->init_impl(stream, (triton::driver::cu_module*)kernel->module());