more work on heuristics
This commit is contained in:
@@ -7,8 +7,6 @@ namespace triton{
|
||||
namespace dnn{
|
||||
|
||||
|
||||
|
||||
|
||||
void base::set_ld(const std::vector<int32_t>& shapes,
|
||||
std::vector<int32_t>& ld) {
|
||||
size_t size = shapes.size();
|
||||
@@ -22,7 +20,15 @@ void base::set_ld(const std::vector<int32_t>& shapes,
|
||||
base::base(const std::string& name)
|
||||
: name_(name) { }
|
||||
|
||||
void base::enqueue(driver::stream *stream, std::vector<driver::buffer *> args, bool autotune) {
|
||||
std::vector<params_t> base::search_space() const {
|
||||
return {};
|
||||
}
|
||||
|
||||
params_t base::heuristics() const {
|
||||
return *search_space().begin();
|
||||
}
|
||||
|
||||
void base::enqueue(driver::stream *stream, std::vector<driver::buffer *> args, autotuning_t autotune) {
|
||||
namespace rt = triton::runtime;
|
||||
static std::map<base*, std::unique_ptr<rt::jit>, cmp_recompile> m_jit;
|
||||
driver::context* ctx = stream->context();
|
||||
@@ -30,7 +36,7 @@ void base::enqueue(driver::stream *stream, std::vector<driver::buffer *> args, b
|
||||
/* the current template has not already been compiled */
|
||||
if(m_jit.find(this) == m_jit.end()) {
|
||||
base* clone = this->clone();
|
||||
jit = m_jit.emplace(clone, std::unique_ptr<rt::jit>(new rt::jit(ctx, 8))).first->second.get();
|
||||
jit = m_jit.emplace(clone, std::unique_ptr<rt::jit>(new rt::jit(ctx))).first->second.get();
|
||||
std::ostringstream oss;
|
||||
clone->triton_c_src(oss);
|
||||
std::string src = oss.str();
|
||||
@@ -40,18 +46,21 @@ void base::enqueue(driver::stream *stream, std::vector<driver::buffer *> args, b
|
||||
clone->init_impl(stream, (triton::driver::cu_module*)kernel->module());
|
||||
clone->enqueue_impl(stream, kernel, args, info);
|
||||
stream->synchronize();
|
||||
double ts = triton::tools::bench([&](){ clone->enqueue_impl(stream, kernel, args, info); },
|
||||
[&](){ stream->synchronize(); }, ctx->device());
|
||||
double ts = triton::tools::bench([&](){ clone->enqueue_impl(stream, kernel, args, info); }, stream);
|
||||
clone->deinit_impl();
|
||||
return num_flops() / ts * 1e-3;
|
||||
};
|
||||
// auto-tune and save result
|
||||
if(autotune) {
|
||||
rt::jit::tune_res_t best = jit->autotune(name_.c_str(), src.c_str(), benchmark);
|
||||
if(autotune != NO_TUNING) {
|
||||
std::vector<params_t> space = {};
|
||||
if(autotune == PARTIAL_TUNING)
|
||||
space = search_space();
|
||||
rt::jit::tune_res_t best = jit->autotune(name_.c_str(), src.c_str(), benchmark, space);
|
||||
jit->add_module(name_.c_str(), src.c_str(), best.params);
|
||||
}
|
||||
else {
|
||||
jit->add_module(name_.c_str(), src.c_str(), jit->get_valid(name_.c_str(), src.c_str()));
|
||||
params_t params = heuristics();
|
||||
jit->add_module(name_.c_str(), src.c_str(), params);
|
||||
}
|
||||
triton::driver::kernel* kernel = jit->get_function(name_.c_str());
|
||||
clone->init_impl(stream, (triton::driver::cu_module*)kernel->module());
|
||||
|
Reference in New Issue
Block a user