[dnn] added base template class for mutualized auto-tuning
This commit is contained in:
@@ -14,8 +14,6 @@ int main() {
|
||||
|
||||
// initialize default compute device
|
||||
auto context = triton::driver::backend::contexts::get_default();
|
||||
// initialize just-in-time compiler
|
||||
triton::jit jit(context);
|
||||
|
||||
// initialization
|
||||
int32_t R = 3, S = 3;
|
||||
@@ -31,7 +29,7 @@ int main() {
|
||||
shift_w[c] = rand() % S - S/2;
|
||||
}
|
||||
// configuration
|
||||
triton::dnn::shift shift(BS, C, 1, H, W, 1, R, S, F, shift_h, shift_w, numeric_t_str, numeric_t_str, triton::dnn::shift::BPROP);
|
||||
triton::dnn::shift shift(BS, C, 1, H, W, 1, R, S, F, 1, 1, shift_h, shift_w, numeric_t_str, numeric_t_str, triton::dnn::shift::BPROP);
|
||||
// host buffers
|
||||
std::vector<float> hc(shift.c_size());
|
||||
std::vector<float> rc(shift.c_size());
|
||||
@@ -55,35 +53,7 @@ int main() {
|
||||
stream->write(db, true, 0, hb);
|
||||
stream->write(dc, true, 0, hc);
|
||||
stream->synchronize();
|
||||
// benchmark
|
||||
auto benchmark = [&](triton::driver::kernel* kernel,
|
||||
triton::jit::launch_information info) {
|
||||
shift.init(stream, (triton::driver::cu_module*)kernel->module());
|
||||
// launch infoRR
|
||||
unsigned TM = info.global_range_size[0];
|
||||
unsigned TN = info.global_range_size[1];
|
||||
unsigned nthreads = info.num_threads;
|
||||
// set argument
|
||||
shift.enqueue(stream, kernel, da, db, dc, TM, TN, nthreads);
|
||||
stream->synchronize();
|
||||
// benchmark
|
||||
double ts = triton::tools::bench([&](){shift.enqueue(stream, kernel, da, db, dc, TM, TN, nthreads);},
|
||||
[&](){ stream->synchronize(); }, context->device());
|
||||
return shift.get_nflops() / ts * 1e-3;
|
||||
};
|
||||
|
||||
// shift
|
||||
std::vector<unsigned> params = {
|
||||
4, 2, 16, 4, 128, 2, 2, 1, 1, 8, 16, 8, 2
|
||||
};
|
||||
std::ostringstream oss;
|
||||
shift.src(oss);
|
||||
std::string src = oss.str();
|
||||
// jit.autotune("shift", src.c_str(), benchmark);
|
||||
jit.add_module("shift", src.c_str(), params);
|
||||
triton::driver::kernel* kernel = jit.get_function("shift");
|
||||
triton::jit::launch_information info = jit.get_launch_info("shift");
|
||||
std::cout << "Performance: " << benchmark(kernel, info) << " TFLOPS " << std::endl;
|
||||
shift.enqueue(stream, da, db, dc);
|
||||
// stream->read(dc, true, 0, hc);
|
||||
// shift.cpu_ref(rc.data(), ha.data(), hb.data());
|
||||
// for(size_t i = 0; i < hc.size(); i++)
|
||||
|
@@ -72,18 +72,18 @@ torch::Tensor shift_common(
|
||||
if(m_shift_jit.find(key) == m_shift_jit.end()){
|
||||
jit = m_shift_jit.emplace(key, new triton::jit(ctx)).first->second.get();
|
||||
std::ostringstream oss;
|
||||
configuration->src(oss);
|
||||
configuration->get_src(oss);
|
||||
std::string src = oss.str();
|
||||
// benchmark a given shiftolution kernel
|
||||
auto benchmark = [&](triton::driver::kernel* kernel,
|
||||
triton::jit::launch_information info) {
|
||||
configuration->init(stream, (triton::driver::cu_module*)kernel->module());
|
||||
configuration->init_impl(stream, (triton::driver::cu_module*)kernel->module());
|
||||
unsigned TM = info.global_range_size[0];
|
||||
unsigned TN = info.global_range_size[1];
|
||||
unsigned nthreads = info.num_threads;
|
||||
configuration->enqueue(stream, kernel, &a, &b, &c, TM, TN, nthreads);
|
||||
configuration->enqueue_impl(stream, kernel, &a, &b, &c, TM, TN, nthreads);
|
||||
stream->synchronize();
|
||||
double ts = triton::tools::bench([&](){ configuration->enqueue(stream, kernel, &a, &b, &c, TM, TN, nthreads); },
|
||||
double ts = triton::tools::bench([&](){ configuration->enqueue_impl(stream, kernel, &a, &b, &c, TM, TN, nthreads); },
|
||||
[&](){ stream->synchronize(); }, stream->context()->device());
|
||||
return configuration->get_nflops() / ts * 1e-3;
|
||||
};
|
||||
@@ -96,7 +96,7 @@ torch::Tensor shift_common(
|
||||
jit->add_module("shift", src.c_str(), jit->get_valid("shift", src.c_str()));
|
||||
}
|
||||
triton::driver::kernel* kernel = jit->get_function("shift");
|
||||
configuration->init(stream, (triton::driver::cu_module*)kernel->module());
|
||||
configuration->init_impl(stream, (triton::driver::cu_module*)kernel->module());
|
||||
}
|
||||
else
|
||||
jit = m_shift_jit.at(key).get();
|
||||
@@ -109,6 +109,6 @@ torch::Tensor shift_common(
|
||||
unsigned TN = info.global_range_size[1];
|
||||
unsigned nthreads = info.num_threads;
|
||||
// enqueue
|
||||
configuration->enqueue(stream, kernel, &a, &b, &c, TM, TN, nthreads);
|
||||
configuration->enqueue_impl(stream, kernel, &a, &b, &c, TM, TN, nthreads);
|
||||
return torchc;
|
||||
}
|
||||
|
@@ -133,7 +133,6 @@ public:
|
||||
triton::driver::cu_buffer tdx(ctx, (CUdeviceptr)dx->flat<float>().data(), false);
|
||||
triton::driver::cu_buffer tdg(ctx, (CUdeviceptr)dg->flat<float>().data(), false);
|
||||
triton::driver::cu_buffer tdb(ctx, (CUdeviceptr)db->flat<float>().data(), false);
|
||||
|
||||
// create config
|
||||
triton::dnn::batchnorm_backward batchnorm(C, 1, H, W, B, "fp32");
|
||||
std::ostringstream oss;
|
||||
|
@@ -128,3 +128,4 @@ def run_batchnorm():
|
||||
print(np.max(np.abs(db_t - db_n)))
|
||||
|
||||
run_shift()
|
||||
#run_batchnorm()
|
||||
|
@@ -99,7 +99,7 @@ public:
|
||||
FillShapes(context, C, H, W, B, F, tf_a, tf_b);
|
||||
int64_t D = 1, T = 1;
|
||||
bool has_bias = false;
|
||||
// shift configuration
|
||||
// shift offsets
|
||||
int32_t* shift_h_data = h_shift_h_.flat<int32_t>().data();
|
||||
int32_t* shift_w_data = h_shift_w_.flat<int32_t>().data();
|
||||
std::vector<int32_t> shift_h(shift_h_data, shift_h_data + C);
|
||||
@@ -116,7 +116,6 @@ public:
|
||||
.first->second.get();
|
||||
else
|
||||
shift = m_config.at(key).get();
|
||||
|
||||
// shapes for c
|
||||
std::vector<int64> c_shapes;
|
||||
for(int32_t x: shift->c_shapes())
|
||||
@@ -131,49 +130,7 @@ public:
|
||||
triton::driver::cu_buffer da(ctx, (CUdeviceptr)tf_a.flat<float>().data(), false);
|
||||
triton::driver::cu_buffer db(ctx, (CUdeviceptr)tf_b.flat<float>().data(), false);
|
||||
triton::driver::cu_buffer dc(ctx, (CUdeviceptr)tf_c->flat<float>().data(), false);
|
||||
// get JIT
|
||||
triton::jit* jit;
|
||||
bool autotune = false;
|
||||
if(m_jit.find(key) == m_jit.end()) {
|
||||
jit = m_jit.emplace(key, new triton::jit(ctx)).first->second.get();
|
||||
std::ostringstream oss;
|
||||
shift->src(oss);
|
||||
std::string src = oss.str();
|
||||
auto benchmark = [&](triton::driver::kernel* kernel,
|
||||
triton::jit::launch_information info) {
|
||||
// launch info
|
||||
unsigned TM = info.global_range_size[0];
|
||||
unsigned TN = info.global_range_size[1];
|
||||
unsigned nthreads = info.num_threads;
|
||||
shift->init(stream, (triton::driver::cu_module*)kernel->module());
|
||||
shift->enqueue(stream, kernel, &da, &db, &dc, TM, TN, nthreads);
|
||||
stream->synchronize();
|
||||
double ts = triton::tools::bench([&](){ shift->enqueue(stream, kernel, &da, &db, &dc, TM, TN, nthreads); },
|
||||
[&](){ stream->synchronize(); }, ctx->device());
|
||||
return shift->get_nflops() / ts * 1e-3;
|
||||
};
|
||||
// auto-tune and save result
|
||||
if(autotune) {
|
||||
triton::jit::tune_res_t best = jit->autotune("shift", src.c_str(), benchmark);
|
||||
jit->add_module("shift", src.c_str(), best.params);
|
||||
}
|
||||
else {
|
||||
jit->add_module("shift", src.c_str(), jit->get_valid("shift", src.c_str()));
|
||||
}
|
||||
triton::driver::kernel* kernel = jit->get_function("shift");
|
||||
shift->init(stream, (triton::driver::cu_module*)kernel->module());
|
||||
}
|
||||
else
|
||||
jit = m_jit.at(key).get();
|
||||
// Run
|
||||
triton::driver::kernel* kernel = jit->get_function("shift");
|
||||
triton::jit::launch_information info = jit->get_launch_info("shift");
|
||||
// launch info
|
||||
unsigned TM = info.global_range_size[0];
|
||||
unsigned TN = info.global_range_size[1];
|
||||
unsigned nthreads = info.num_threads;
|
||||
// enqueue
|
||||
shift->enqueue(stream, kernel, &da, &db, &dc, TM, TN, nthreads);
|
||||
shift->enqueue(stream, {&da, &db, &dc});
|
||||
}
|
||||
|
||||
private:
|
||||
|
65
include/triton/dnn/base.h
Normal file
65
include/triton/dnn/base.h
Normal file
@@ -0,0 +1,65 @@
|
||||
/* Copyright 2015-2017 Philippe Tillet
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining
|
||||
* a copy of this software and associated documentation files
|
||||
* (the "Software"), to deal in the Software without restriction,
|
||||
* including without limitation the rights to use, copy, modify, merge,
|
||||
* publish, distribute, sublicense, and/or sell copies of the Software,
|
||||
* and to permit persons to whom the Software is furnished to do so,
|
||||
* subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be
|
||||
* included in all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*/
|
||||
|
||||
#ifndef TDL_INCLUDE_DNN_BASE_H
|
||||
#define TDL_INCLUDE_DNN_BASE_H
|
||||
|
||||
#include "triton/driver/stream.h"
|
||||
#include "triton/driver/kernel.h"
|
||||
|
||||
namespace triton{
|
||||
namespace dnn{
|
||||
|
||||
class base {
|
||||
protected:
|
||||
// leading dimensions
|
||||
static void set_ld(const std::vector<int32_t>& shapes,
|
||||
std::vector<int32_t>& ld);
|
||||
|
||||
private:
|
||||
// initialize
|
||||
virtual void init_impl(driver::stream *stream, driver::cu_module *module) = 0;
|
||||
// enqueue
|
||||
virtual void enqueue_impl(driver::stream *stream, driver::kernel *kernel,
|
||||
std::vector<driver::buffer*> args,
|
||||
size_t TM, size_t TN, size_t nthreads) = 0;
|
||||
|
||||
public:
|
||||
// constructor
|
||||
base(const std::string& name);
|
||||
// number of flops
|
||||
virtual size_t get_nflops() const = 0;
|
||||
// triton-c source
|
||||
virtual void get_src(std::ostream &os) const = 0;
|
||||
// comparison for maps
|
||||
virtual bool operator<(const base& other) const = 0;
|
||||
// enqueue
|
||||
void enqueue(driver::stream* stream, std::vector<driver::buffer*> args);
|
||||
|
||||
private:
|
||||
std::string name_;
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -28,13 +28,15 @@
|
||||
#include <algorithm>
|
||||
#include <numeric>
|
||||
#include <cmath>
|
||||
#include "triton/dnn/base.h"
|
||||
#include "triton/driver/stream.h"
|
||||
#include "triton/driver/kernel.h"
|
||||
#include "triton/runtime/jit.h"
|
||||
|
||||
namespace triton{
|
||||
namespace dnn{
|
||||
|
||||
class shift {
|
||||
class shift: public base {
|
||||
|
||||
public:
|
||||
enum type {
|
||||
@@ -44,8 +46,14 @@ public:
|
||||
};
|
||||
|
||||
private:
|
||||
// leading dimensions
|
||||
void set_ld(const std::vector<int32_t>& shapes,
|
||||
std::vector<int32_t>& ld);
|
||||
// initialize and enqueue
|
||||
void init_impl(driver::stream *stream, driver::cu_module *module);
|
||||
void enqueue_impl(driver::stream *stream, driver::kernel *kernel,
|
||||
std::vector<driver::buffer*> args,
|
||||
size_t TM, size_t TN, size_t nthreads);
|
||||
|
||||
public:
|
||||
|
||||
@@ -60,26 +68,18 @@ public:
|
||||
// look-up table
|
||||
void build_deltas();
|
||||
void build_masks();
|
||||
|
||||
// accessors
|
||||
size_t a_size();
|
||||
size_t b_size();
|
||||
size_t c_size();
|
||||
std::vector<int32_t> c_shapes();
|
||||
|
||||
// device function
|
||||
void init(driver::stream *stream, driver::cu_module *module);
|
||||
void enqueue(driver::stream *stream, driver::kernel *kernel,
|
||||
driver::buffer *a, driver::buffer *b, driver::buffer *c,
|
||||
size_t TM, size_t TN, size_t nthreads);
|
||||
|
||||
// utils
|
||||
size_t get_nflops();
|
||||
|
||||
// number of flops
|
||||
size_t get_nflops() const;
|
||||
// source
|
||||
void src(std::ostream &os);
|
||||
|
||||
// cpu_ref
|
||||
void get_src(std::ostream &os) const;
|
||||
// comparison
|
||||
bool operator<(const base& other) const;
|
||||
// cpu reference
|
||||
template<class IN_DTYPE, class OUT_DTYPE>
|
||||
void cpu_ref(OUT_DTYPE* O,
|
||||
const IN_DTYPE* I,
|
||||
|
69
lib/dnn/base.cpp
Normal file
69
lib/dnn/base.cpp
Normal file
@@ -0,0 +1,69 @@
|
||||
#include <sstream>
|
||||
#include "triton/dnn/base.h"
|
||||
#include "triton/runtime/jit.h"
|
||||
#include "triton/tools/bench.hpp"
|
||||
|
||||
namespace triton{
|
||||
namespace dnn{
|
||||
|
||||
struct cmp_recompile{
|
||||
bool operator()(base* x, base* y) const{
|
||||
return *x < *y;
|
||||
}
|
||||
};
|
||||
|
||||
base::base(const std::string& name)
|
||||
: name_(name) { }
|
||||
|
||||
void base::enqueue(driver::stream *stream, std::vector<driver::buffer *> args) {
|
||||
static std::map<base*, std::unique_ptr<triton::jit>, cmp_recompile> m_jit;
|
||||
bool autotune = false;
|
||||
driver::context* ctx = stream->context();
|
||||
triton::jit* jit;
|
||||
/* the current template has not already been compiled */
|
||||
if(m_jit.find(this) == m_jit.end()) {
|
||||
jit = m_jit.emplace(this, new triton::jit(ctx)).first->second.get();
|
||||
std::ostringstream oss;
|
||||
get_src(oss);
|
||||
std::string src = oss.str();
|
||||
auto benchmark = [&](triton::driver::kernel* kernel,
|
||||
triton::jit::launch_information info) {
|
||||
// launch info
|
||||
unsigned TM = info.global_range_size[0];
|
||||
unsigned TN = info.global_range_size[1];
|
||||
unsigned nthreads = info.num_threads;
|
||||
init_impl(stream, (triton::driver::cu_module*)kernel->module());
|
||||
enqueue_impl(stream, kernel, args, TM, TN, nthreads);
|
||||
stream->synchronize();
|
||||
double ts = triton::tools::bench([&](){ enqueue_impl(stream, kernel, args, TM, TN, nthreads); },
|
||||
[&](){ stream->synchronize(); }, ctx->device());
|
||||
return get_nflops() / ts * 1e-3;
|
||||
};
|
||||
// auto-tune and save result
|
||||
if(autotune) {
|
||||
triton::jit::tune_res_t best = jit->autotune(name_.c_str(), src.c_str(), benchmark);
|
||||
jit->add_module(name_.c_str(), src.c_str(), best.params);
|
||||
}
|
||||
else {
|
||||
jit->add_module(name_.c_str(), src.c_str(), jit->get_valid(name_.c_str(), src.c_str()));
|
||||
}
|
||||
triton::driver::kernel* kernel = jit->get_function(name_.c_str());
|
||||
init_impl(stream, (triton::driver::cu_module*)kernel->module());
|
||||
}
|
||||
/* retrieved compiled template */
|
||||
else
|
||||
jit = m_jit.at(this).get();
|
||||
|
||||
/* get launch parameters */
|
||||
driver::kernel* kernel = jit->get_function(name_.c_str());
|
||||
triton::jit::launch_information info = jit->get_launch_info(name_.c_str());
|
||||
unsigned TM = info.global_range_size[0];
|
||||
unsigned TN = info.global_range_size[1];
|
||||
unsigned nthreads = info.num_threads;
|
||||
|
||||
/* launch */
|
||||
enqueue_impl(stream, kernel, args, TM, TN, nthreads);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
@@ -1,5 +1,6 @@
|
||||
#include <sstream>
|
||||
#include "triton/dnn/shift.h"
|
||||
|
||||
#include "triton/tools/bench.hpp"
|
||||
|
||||
namespace triton{
|
||||
namespace dnn{
|
||||
@@ -21,7 +22,8 @@ shift::shift(int B, int C,
|
||||
const std::vector<int32_t>& shift_h, const std::vector<int32_t>& shift_w,
|
||||
std::string a_ty, std::string b_ty,
|
||||
type ty, bool bias)
|
||||
: B_(B), C_(C),
|
||||
: base("shift"),
|
||||
B_(B), C_(C),
|
||||
AD_(D), AH_(H), AW_(W),
|
||||
BD_(T), BH_(R), BW_(S),
|
||||
F_(F),
|
||||
@@ -118,21 +120,33 @@ std::vector<int32_t> shift::c_shapes(){
|
||||
return shapes_c_;
|
||||
}
|
||||
|
||||
size_t shift::get_nflops() {
|
||||
size_t shift::get_nflops() const {
|
||||
return 2.*M_*N_*K_;
|
||||
}
|
||||
|
||||
bool shift::operator <(const base& other) const{
|
||||
auto *y = dynamic_cast<const shift*>(&other);
|
||||
if(!y)
|
||||
return false;
|
||||
const int32_t *x_shift_h = shift_h_.data(), *x_shift_w = shift_w_.data();
|
||||
const int32_t *y_shift_h = y->shift_h_.data(), *y_shift_w = y->shift_w_.data();
|
||||
return std::tie(B_, C_, AD_, AH_, AW_, BD_, BH_, BW_, F_,
|
||||
x_shift_h, x_shift_w, ty_, bias_)
|
||||
< std::tie(y->B_, y->C_, y->AD_, y->AH_, y->AW_, y->BD_, y->BH_, y->BW_, y->F_,
|
||||
y_shift_h, y_shift_w, y->ty_, y->bias_);
|
||||
}
|
||||
|
||||
void shift::init(driver::stream *stream, driver::cu_module *module) {
|
||||
void shift::init_impl(driver::stream *stream, driver::cu_module *module) {
|
||||
triton::driver::buffer* delta = ((triton::driver::cu_module*)module)->symbol("delta");
|
||||
stream->write(delta, false, 0, h_deltas_.size()*4, h_deltas_.data());
|
||||
}
|
||||
|
||||
void shift::enqueue(driver::stream *stream, driver::kernel *kernel,
|
||||
driver::buffer *a, driver::buffer *b, driver::buffer *c,
|
||||
void shift::enqueue_impl(driver::stream *stream, driver::kernel *kernel,
|
||||
std::vector<driver::buffer *> args,
|
||||
size_t TM, size_t TN, size_t nthreads) {
|
||||
int32_t lda = AT_ ? K_ : M_;
|
||||
int32_t ldb = BT_ ? N_ : K_;
|
||||
driver::buffer *a = args[0], *b = args[1], *c = args[2];
|
||||
kernel->setArg(0, a);
|
||||
kernel->setArg(1, b);
|
||||
kernel->setArg(2, c);
|
||||
@@ -154,7 +168,7 @@ void shift::enqueue(driver::stream *stream, driver::kernel *kernel,
|
||||
stream->enqueue(kernel, grid, {nthreads, 1, 1});
|
||||
}
|
||||
|
||||
void shift::src(std::ostream &os) {
|
||||
void shift::get_src(std::ostream &os) const {
|
||||
std::string AS0 = "TM", AS1 = "TK";
|
||||
std::string BS0 = "TK", BS1 = "TN";
|
||||
std::string bcb0 = "[:, newaxis]", bcb1 = "[newaxis, :]";
|
||||
|
Reference in New Issue
Block a user