[dnn]: Now implementing all existing DNN routines using common base template and auto-tuner
This commit is contained in:
@@ -6,38 +6,43 @@
|
||||
namespace triton{
|
||||
namespace dnn{
|
||||
|
||||
struct cmp_recompile{
|
||||
bool operator()(base* x, base* y) const{
|
||||
return *x < *y;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
|
||||
void base::set_ld(const std::vector<int32_t>& shapes,
|
||||
std::vector<int32_t>& ld) {
|
||||
size_t size = shapes.size();
|
||||
ld.resize(size);
|
||||
ld[size - 1] = 1;
|
||||
for(int i = size - 1; i >= 1; i--)
|
||||
ld[i - 1] = shapes[i] * ld[i];
|
||||
}
|
||||
|
||||
|
||||
base::base(const std::string& name)
|
||||
: name_(name) { }
|
||||
|
||||
void base::enqueue(driver::stream *stream, std::vector<driver::buffer *> args) {
|
||||
static std::map<base*, std::unique_ptr<triton::jit>, cmp_recompile> m_jit;
|
||||
static std::map<base*, std::unique_ptr<triton::jit>, cmp_recompile> m_jit;
|
||||
bool autotune = false;
|
||||
driver::context* ctx = stream->context();
|
||||
triton::jit* jit;
|
||||
/* the current template has not already been compiled */
|
||||
if(m_jit.find(this) == m_jit.end()) {
|
||||
jit = m_jit.emplace(this, new triton::jit(ctx)).first->second.get();
|
||||
jit = m_jit.emplace(this->clone(), new triton::jit(ctx)).first->second.get();
|
||||
std::ostringstream oss;
|
||||
get_src(oss);
|
||||
triton_c_src(oss);
|
||||
std::string src = oss.str();
|
||||
auto benchmark = [&](triton::driver::kernel* kernel,
|
||||
triton::jit::launch_information info) {
|
||||
// launch info
|
||||
unsigned TM = info.global_range_size[0];
|
||||
unsigned TN = info.global_range_size[1];
|
||||
unsigned nthreads = info.num_threads;
|
||||
init_impl(stream, (triton::driver::cu_module*)kernel->module());
|
||||
enqueue_impl(stream, kernel, args, TM, TN, nthreads);
|
||||
enqueue_impl(stream, kernel, args, info.global_range_size, nthreads);
|
||||
stream->synchronize();
|
||||
double ts = triton::tools::bench([&](){ enqueue_impl(stream, kernel, args, TM, TN, nthreads); },
|
||||
double ts = triton::tools::bench([&](){ enqueue_impl(stream, kernel, args, info.global_range_size, nthreads); },
|
||||
[&](){ stream->synchronize(); }, ctx->device());
|
||||
return get_nflops() / ts * 1e-3;
|
||||
return num_flops() / ts * 1e-3;
|
||||
};
|
||||
// auto-tune and save result
|
||||
if(autotune) {
|
||||
@@ -57,12 +62,9 @@ void base::enqueue(driver::stream *stream, std::vector<driver::buffer *> args) {
|
||||
/* get launch parameters */
|
||||
driver::kernel* kernel = jit->get_function(name_.c_str());
|
||||
triton::jit::launch_information info = jit->get_launch_info(name_.c_str());
|
||||
unsigned TM = info.global_range_size[0];
|
||||
unsigned TN = info.global_range_size[1];
|
||||
unsigned nthreads = info.num_threads;
|
||||
|
||||
/* launch */
|
||||
enqueue_impl(stream, kernel, args, TM, TN, nthreads);
|
||||
enqueue_impl(stream, kernel, args,
|
||||
info.global_range_size, info.num_threads);
|
||||
}
|
||||
|
||||
}
|
||||
|
Reference in New Issue
Block a user