[dnn] added base template class for mutualized auto-tuning

This commit is contained in:
Philippe Tillet
2019-07-09 16:09:34 -07:00
parent 066ae338f1
commit 88675fa01a
9 changed files with 181 additions and 106 deletions

View File

@@ -14,8 +14,6 @@ int main() {
// initialize default compute device // initialize default compute device
auto context = triton::driver::backend::contexts::get_default(); auto context = triton::driver::backend::contexts::get_default();
// initialize just-in-time compiler
triton::jit jit(context);
// initialization // initialization
int32_t R = 3, S = 3; int32_t R = 3, S = 3;
@@ -31,7 +29,7 @@ int main() {
shift_w[c] = rand() % S - S/2; shift_w[c] = rand() % S - S/2;
} }
// configuration // configuration
triton::dnn::shift shift(BS, C, 1, H, W, 1, R, S, F, shift_h, shift_w, numeric_t_str, numeric_t_str, triton::dnn::shift::BPROP); triton::dnn::shift shift(BS, C, 1, H, W, 1, R, S, F, 1, 1, shift_h, shift_w, numeric_t_str, numeric_t_str, triton::dnn::shift::BPROP);
// host buffers // host buffers
std::vector<float> hc(shift.c_size()); std::vector<float> hc(shift.c_size());
std::vector<float> rc(shift.c_size()); std::vector<float> rc(shift.c_size());
@@ -55,35 +53,7 @@ int main() {
stream->write(db, true, 0, hb); stream->write(db, true, 0, hb);
stream->write(dc, true, 0, hc); stream->write(dc, true, 0, hc);
stream->synchronize(); stream->synchronize();
// benchmark shift.enqueue(stream, da, db, dc);
auto benchmark = [&](triton::driver::kernel* kernel,
triton::jit::launch_information info) {
shift.init(stream, (triton::driver::cu_module*)kernel->module());
// launch infoRR
unsigned TM = info.global_range_size[0];
unsigned TN = info.global_range_size[1];
unsigned nthreads = info.num_threads;
// set argument
shift.enqueue(stream, kernel, da, db, dc, TM, TN, nthreads);
stream->synchronize();
// benchmark
double ts = triton::tools::bench([&](){shift.enqueue(stream, kernel, da, db, dc, TM, TN, nthreads);},
[&](){ stream->synchronize(); }, context->device());
return shift.get_nflops() / ts * 1e-3;
};
// shift
std::vector<unsigned> params = {
4, 2, 16, 4, 128, 2, 2, 1, 1, 8, 16, 8, 2
};
std::ostringstream oss;
shift.src(oss);
std::string src = oss.str();
// jit.autotune("shift", src.c_str(), benchmark);
jit.add_module("shift", src.c_str(), params);
triton::driver::kernel* kernel = jit.get_function("shift");
triton::jit::launch_information info = jit.get_launch_info("shift");
std::cout << "Performance: " << benchmark(kernel, info) << " TFLOPS " << std::endl;
// stream->read(dc, true, 0, hc); // stream->read(dc, true, 0, hc);
// shift.cpu_ref(rc.data(), ha.data(), hb.data()); // shift.cpu_ref(rc.data(), ha.data(), hb.data());
// for(size_t i = 0; i < hc.size(); i++) // for(size_t i = 0; i < hc.size(); i++)

View File

@@ -72,18 +72,18 @@ torch::Tensor shift_common(
if(m_shift_jit.find(key) == m_shift_jit.end()){ if(m_shift_jit.find(key) == m_shift_jit.end()){
jit = m_shift_jit.emplace(key, new triton::jit(ctx)).first->second.get(); jit = m_shift_jit.emplace(key, new triton::jit(ctx)).first->second.get();
std::ostringstream oss; std::ostringstream oss;
configuration->src(oss); configuration->get_src(oss);
std::string src = oss.str(); std::string src = oss.str();
// benchmark a given shiftolution kernel // benchmark a given shiftolution kernel
auto benchmark = [&](triton::driver::kernel* kernel, auto benchmark = [&](triton::driver::kernel* kernel,
triton::jit::launch_information info) { triton::jit::launch_information info) {
configuration->init(stream, (triton::driver::cu_module*)kernel->module()); configuration->init_impl(stream, (triton::driver::cu_module*)kernel->module());
unsigned TM = info.global_range_size[0]; unsigned TM = info.global_range_size[0];
unsigned TN = info.global_range_size[1]; unsigned TN = info.global_range_size[1];
unsigned nthreads = info.num_threads; unsigned nthreads = info.num_threads;
configuration->enqueue(stream, kernel, &a, &b, &c, TM, TN, nthreads); configuration->enqueue_impl(stream, kernel, &a, &b, &c, TM, TN, nthreads);
stream->synchronize(); stream->synchronize();
double ts = triton::tools::bench([&](){ configuration->enqueue(stream, kernel, &a, &b, &c, TM, TN, nthreads); }, double ts = triton::tools::bench([&](){ configuration->enqueue_impl(stream, kernel, &a, &b, &c, TM, TN, nthreads); },
[&](){ stream->synchronize(); }, stream->context()->device()); [&](){ stream->synchronize(); }, stream->context()->device());
return configuration->get_nflops() / ts * 1e-3; return configuration->get_nflops() / ts * 1e-3;
}; };
@@ -96,7 +96,7 @@ torch::Tensor shift_common(
jit->add_module("shift", src.c_str(), jit->get_valid("shift", src.c_str())); jit->add_module("shift", src.c_str(), jit->get_valid("shift", src.c_str()));
} }
triton::driver::kernel* kernel = jit->get_function("shift"); triton::driver::kernel* kernel = jit->get_function("shift");
configuration->init(stream, (triton::driver::cu_module*)kernel->module()); configuration->init_impl(stream, (triton::driver::cu_module*)kernel->module());
} }
else else
jit = m_shift_jit.at(key).get(); jit = m_shift_jit.at(key).get();
@@ -109,6 +109,6 @@ torch::Tensor shift_common(
unsigned TN = info.global_range_size[1]; unsigned TN = info.global_range_size[1];
unsigned nthreads = info.num_threads; unsigned nthreads = info.num_threads;
// enqueue // enqueue
configuration->enqueue(stream, kernel, &a, &b, &c, TM, TN, nthreads); configuration->enqueue_impl(stream, kernel, &a, &b, &c, TM, TN, nthreads);
return torchc; return torchc;
} }

View File

@@ -133,7 +133,6 @@ public:
triton::driver::cu_buffer tdx(ctx, (CUdeviceptr)dx->flat<float>().data(), false); triton::driver::cu_buffer tdx(ctx, (CUdeviceptr)dx->flat<float>().data(), false);
triton::driver::cu_buffer tdg(ctx, (CUdeviceptr)dg->flat<float>().data(), false); triton::driver::cu_buffer tdg(ctx, (CUdeviceptr)dg->flat<float>().data(), false);
triton::driver::cu_buffer tdb(ctx, (CUdeviceptr)db->flat<float>().data(), false); triton::driver::cu_buffer tdb(ctx, (CUdeviceptr)db->flat<float>().data(), false);
// create config // create config
triton::dnn::batchnorm_backward batchnorm(C, 1, H, W, B, "fp32"); triton::dnn::batchnorm_backward batchnorm(C, 1, H, W, B, "fp32");
std::ostringstream oss; std::ostringstream oss;

View File

@@ -128,3 +128,4 @@ def run_batchnorm():
print(np.max(np.abs(db_t - db_n))) print(np.max(np.abs(db_t - db_n)))
run_shift() run_shift()
#run_batchnorm()

View File

@@ -99,7 +99,7 @@ public:
FillShapes(context, C, H, W, B, F, tf_a, tf_b); FillShapes(context, C, H, W, B, F, tf_a, tf_b);
int64_t D = 1, T = 1; int64_t D = 1, T = 1;
bool has_bias = false; bool has_bias = false;
// shift configuration // shift offsets
int32_t* shift_h_data = h_shift_h_.flat<int32_t>().data(); int32_t* shift_h_data = h_shift_h_.flat<int32_t>().data();
int32_t* shift_w_data = h_shift_w_.flat<int32_t>().data(); int32_t* shift_w_data = h_shift_w_.flat<int32_t>().data();
std::vector<int32_t> shift_h(shift_h_data, shift_h_data + C); std::vector<int32_t> shift_h(shift_h_data, shift_h_data + C);
@@ -116,7 +116,6 @@ public:
.first->second.get(); .first->second.get();
else else
shift = m_config.at(key).get(); shift = m_config.at(key).get();
// shapes for c // shapes for c
std::vector<int64> c_shapes; std::vector<int64> c_shapes;
for(int32_t x: shift->c_shapes()) for(int32_t x: shift->c_shapes())
@@ -131,49 +130,7 @@ public:
triton::driver::cu_buffer da(ctx, (CUdeviceptr)tf_a.flat<float>().data(), false); triton::driver::cu_buffer da(ctx, (CUdeviceptr)tf_a.flat<float>().data(), false);
triton::driver::cu_buffer db(ctx, (CUdeviceptr)tf_b.flat<float>().data(), false); triton::driver::cu_buffer db(ctx, (CUdeviceptr)tf_b.flat<float>().data(), false);
triton::driver::cu_buffer dc(ctx, (CUdeviceptr)tf_c->flat<float>().data(), false); triton::driver::cu_buffer dc(ctx, (CUdeviceptr)tf_c->flat<float>().data(), false);
// get JIT shift->enqueue(stream, {&da, &db, &dc});
triton::jit* jit;
bool autotune = false;
if(m_jit.find(key) == m_jit.end()) {
jit = m_jit.emplace(key, new triton::jit(ctx)).first->second.get();
std::ostringstream oss;
shift->src(oss);
std::string src = oss.str();
auto benchmark = [&](triton::driver::kernel* kernel,
triton::jit::launch_information info) {
// launch info
unsigned TM = info.global_range_size[0];
unsigned TN = info.global_range_size[1];
unsigned nthreads = info.num_threads;
shift->init(stream, (triton::driver::cu_module*)kernel->module());
shift->enqueue(stream, kernel, &da, &db, &dc, TM, TN, nthreads);
stream->synchronize();
double ts = triton::tools::bench([&](){ shift->enqueue(stream, kernel, &da, &db, &dc, TM, TN, nthreads); },
[&](){ stream->synchronize(); }, ctx->device());
return shift->get_nflops() / ts * 1e-3;
};
// auto-tune and save result
if(autotune) {
triton::jit::tune_res_t best = jit->autotune("shift", src.c_str(), benchmark);
jit->add_module("shift", src.c_str(), best.params);
}
else {
jit->add_module("shift", src.c_str(), jit->get_valid("shift", src.c_str()));
}
triton::driver::kernel* kernel = jit->get_function("shift");
shift->init(stream, (triton::driver::cu_module*)kernel->module());
}
else
jit = m_jit.at(key).get();
// Run
triton::driver::kernel* kernel = jit->get_function("shift");
triton::jit::launch_information info = jit->get_launch_info("shift");
// launch info
unsigned TM = info.global_range_size[0];
unsigned TN = info.global_range_size[1];
unsigned nthreads = info.num_threads;
// enqueue
shift->enqueue(stream, kernel, &da, &db, &dc, TM, TN, nthreads);
} }
private: private:

65
include/triton/dnn/base.h Normal file
View File

@@ -0,0 +1,65 @@
/* Copyright 2015-2017 Philippe Tillet
*
* Permission is hereby granted, free of charge, to any person obtaining
* a copy of this software and associated documentation files
* (the "Software"), to deal in the Software without restriction,
* including without limitation the rights to use, copy, modify, merge,
* publish, distribute, sublicense, and/or sell copies of the Software,
* and to permit persons to whom the Software is furnished to do so,
* subject to the following conditions:
*
* The above copyright notice and this permission notice shall be
* included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
#ifndef TDL_INCLUDE_DNN_BASE_H
#define TDL_INCLUDE_DNN_BASE_H
#include "triton/driver/stream.h"
#include "triton/driver/kernel.h"
namespace triton{
namespace dnn{
class base {
protected:
// leading dimensions
static void set_ld(const std::vector<int32_t>& shapes,
std::vector<int32_t>& ld);
private:
// initialize
virtual void init_impl(driver::stream *stream, driver::cu_module *module) = 0;
// enqueue
virtual void enqueue_impl(driver::stream *stream, driver::kernel *kernel,
std::vector<driver::buffer*> args,
size_t TM, size_t TN, size_t nthreads) = 0;
public:
// constructor
base(const std::string& name);
// number of flops
virtual size_t get_nflops() const = 0;
// triton-c source
virtual void get_src(std::ostream &os) const = 0;
// comparison for maps
virtual bool operator<(const base& other) const = 0;
// enqueue
void enqueue(driver::stream* stream, std::vector<driver::buffer*> args);
private:
std::string name_;
};
}
}
#endif

View File

@@ -28,13 +28,15 @@
#include <algorithm> #include <algorithm>
#include <numeric> #include <numeric>
#include <cmath> #include <cmath>
#include "triton/dnn/base.h"
#include "triton/driver/stream.h" #include "triton/driver/stream.h"
#include "triton/driver/kernel.h" #include "triton/driver/kernel.h"
#include "triton/runtime/jit.h"
namespace triton{ namespace triton{
namespace dnn{ namespace dnn{
class shift { class shift: public base {
public: public:
enum type { enum type {
@@ -44,8 +46,14 @@ public:
}; };
private: private:
// leading dimensions
void set_ld(const std::vector<int32_t>& shapes, void set_ld(const std::vector<int32_t>& shapes,
std::vector<int32_t>& ld); std::vector<int32_t>& ld);
// initialize and enqueue
void init_impl(driver::stream *stream, driver::cu_module *module);
void enqueue_impl(driver::stream *stream, driver::kernel *kernel,
std::vector<driver::buffer*> args,
size_t TM, size_t TN, size_t nthreads);
public: public:
@@ -60,26 +68,18 @@ public:
// look-up table // look-up table
void build_deltas(); void build_deltas();
void build_masks(); void build_masks();
// accessors // accessors
size_t a_size(); size_t a_size();
size_t b_size(); size_t b_size();
size_t c_size(); size_t c_size();
std::vector<int32_t> c_shapes(); std::vector<int32_t> c_shapes();
// number of flops
// device function size_t get_nflops() const;
void init(driver::stream *stream, driver::cu_module *module);
void enqueue(driver::stream *stream, driver::kernel *kernel,
driver::buffer *a, driver::buffer *b, driver::buffer *c,
size_t TM, size_t TN, size_t nthreads);
// utils
size_t get_nflops();
// source // source
void src(std::ostream &os); void get_src(std::ostream &os) const;
// comparison
// cpu_ref bool operator<(const base& other) const;
// cpu reference
template<class IN_DTYPE, class OUT_DTYPE> template<class IN_DTYPE, class OUT_DTYPE>
void cpu_ref(OUT_DTYPE* O, void cpu_ref(OUT_DTYPE* O,
const IN_DTYPE* I, const IN_DTYPE* I,

69
lib/dnn/base.cpp Normal file
View File

@@ -0,0 +1,69 @@
#include <sstream>
#include "triton/dnn/base.h"
#include "triton/runtime/jit.h"
#include "triton/tools/bench.hpp"
namespace triton{
namespace dnn{
struct cmp_recompile{
bool operator()(base* x, base* y) const{
return *x < *y;
}
};
base::base(const std::string& name)
: name_(name) { }
void base::enqueue(driver::stream *stream, std::vector<driver::buffer *> args) {
static std::map<base*, std::unique_ptr<triton::jit>, cmp_recompile> m_jit;
bool autotune = false;
driver::context* ctx = stream->context();
triton::jit* jit;
/* the current template has not already been compiled */
if(m_jit.find(this) == m_jit.end()) {
jit = m_jit.emplace(this, new triton::jit(ctx)).first->second.get();
std::ostringstream oss;
get_src(oss);
std::string src = oss.str();
auto benchmark = [&](triton::driver::kernel* kernel,
triton::jit::launch_information info) {
// launch info
unsigned TM = info.global_range_size[0];
unsigned TN = info.global_range_size[1];
unsigned nthreads = info.num_threads;
init_impl(stream, (triton::driver::cu_module*)kernel->module());
enqueue_impl(stream, kernel, args, TM, TN, nthreads);
stream->synchronize();
double ts = triton::tools::bench([&](){ enqueue_impl(stream, kernel, args, TM, TN, nthreads); },
[&](){ stream->synchronize(); }, ctx->device());
return get_nflops() / ts * 1e-3;
};
// auto-tune and save result
if(autotune) {
triton::jit::tune_res_t best = jit->autotune(name_.c_str(), src.c_str(), benchmark);
jit->add_module(name_.c_str(), src.c_str(), best.params);
}
else {
jit->add_module(name_.c_str(), src.c_str(), jit->get_valid(name_.c_str(), src.c_str()));
}
triton::driver::kernel* kernel = jit->get_function(name_.c_str());
init_impl(stream, (triton::driver::cu_module*)kernel->module());
}
/* retrieved compiled template */
else
jit = m_jit.at(this).get();
/* get launch parameters */
driver::kernel* kernel = jit->get_function(name_.c_str());
triton::jit::launch_information info = jit->get_launch_info(name_.c_str());
unsigned TM = info.global_range_size[0];
unsigned TN = info.global_range_size[1];
unsigned nthreads = info.num_threads;
/* launch */
enqueue_impl(stream, kernel, args, TM, TN, nthreads);
}
}
}

View File

@@ -1,5 +1,6 @@
#include <sstream>
#include "triton/dnn/shift.h" #include "triton/dnn/shift.h"
#include "triton/tools/bench.hpp"
namespace triton{ namespace triton{
namespace dnn{ namespace dnn{
@@ -21,7 +22,8 @@ shift::shift(int B, int C,
const std::vector<int32_t>& shift_h, const std::vector<int32_t>& shift_w, const std::vector<int32_t>& shift_h, const std::vector<int32_t>& shift_w,
std::string a_ty, std::string b_ty, std::string a_ty, std::string b_ty,
type ty, bool bias) type ty, bool bias)
: B_(B), C_(C), : base("shift"),
B_(B), C_(C),
AD_(D), AH_(H), AW_(W), AD_(D), AH_(H), AW_(W),
BD_(T), BH_(R), BW_(S), BD_(T), BH_(R), BW_(S),
F_(F), F_(F),
@@ -118,21 +120,33 @@ std::vector<int32_t> shift::c_shapes(){
return shapes_c_; return shapes_c_;
} }
size_t shift::get_nflops() { size_t shift::get_nflops() const {
return 2.*M_*N_*K_; return 2.*M_*N_*K_;
} }
bool shift::operator <(const base& other) const{
auto *y = dynamic_cast<const shift*>(&other);
if(!y)
return false;
const int32_t *x_shift_h = shift_h_.data(), *x_shift_w = shift_w_.data();
const int32_t *y_shift_h = y->shift_h_.data(), *y_shift_w = y->shift_w_.data();
return std::tie(B_, C_, AD_, AH_, AW_, BD_, BH_, BW_, F_,
x_shift_h, x_shift_w, ty_, bias_)
< std::tie(y->B_, y->C_, y->AD_, y->AH_, y->AW_, y->BD_, y->BH_, y->BW_, y->F_,
y_shift_h, y_shift_w, y->ty_, y->bias_);
}
void shift::init(driver::stream *stream, driver::cu_module *module) { void shift::init_impl(driver::stream *stream, driver::cu_module *module) {
triton::driver::buffer* delta = ((triton::driver::cu_module*)module)->symbol("delta"); triton::driver::buffer* delta = ((triton::driver::cu_module*)module)->symbol("delta");
stream->write(delta, false, 0, h_deltas_.size()*4, h_deltas_.data()); stream->write(delta, false, 0, h_deltas_.size()*4, h_deltas_.data());
} }
void shift::enqueue(driver::stream *stream, driver::kernel *kernel, void shift::enqueue_impl(driver::stream *stream, driver::kernel *kernel,
driver::buffer *a, driver::buffer *b, driver::buffer *c, std::vector<driver::buffer *> args,
size_t TM, size_t TN, size_t nthreads) { size_t TM, size_t TN, size_t nthreads) {
int32_t lda = AT_ ? K_ : M_; int32_t lda = AT_ ? K_ : M_;
int32_t ldb = BT_ ? N_ : K_; int32_t ldb = BT_ ? N_ : K_;
driver::buffer *a = args[0], *b = args[1], *c = args[2];
kernel->setArg(0, a); kernel->setArg(0, a);
kernel->setArg(1, b); kernel->setArg(1, b);
kernel->setArg(2, c); kernel->setArg(2, c);
@@ -154,7 +168,7 @@ void shift::enqueue(driver::stream *stream, driver::kernel *kernel,
stream->enqueue(kernel, grid, {nthreads, 1, 1}); stream->enqueue(kernel, grid, {nthreads, 1, 1});
} }
void shift::src(std::ostream &os) { void shift::get_src(std::ostream &os) const {
std::string AS0 = "TM", AS1 = "TK"; std::string AS0 = "TM", AS1 = "TK";
std::string BS0 = "TK", BS1 = "TN"; std::string BS0 = "TK", BS1 = "TN";
std::string bcb0 = "[:, newaxis]", bcb1 = "[newaxis, :]"; std::string bcb0 = "[:, newaxis]", bcb1 = "[newaxis, :]";