[dnn] added base template class for mutualized auto-tuning
This commit is contained in:
@@ -14,8 +14,6 @@ int main() {
|
|||||||
|
|
||||||
// initialize default compute device
|
// initialize default compute device
|
||||||
auto context = triton::driver::backend::contexts::get_default();
|
auto context = triton::driver::backend::contexts::get_default();
|
||||||
// initialize just-in-time compiler
|
|
||||||
triton::jit jit(context);
|
|
||||||
|
|
||||||
// initialization
|
// initialization
|
||||||
int32_t R = 3, S = 3;
|
int32_t R = 3, S = 3;
|
||||||
@@ -31,7 +29,7 @@ int main() {
|
|||||||
shift_w[c] = rand() % S - S/2;
|
shift_w[c] = rand() % S - S/2;
|
||||||
}
|
}
|
||||||
// configuration
|
// configuration
|
||||||
triton::dnn::shift shift(BS, C, 1, H, W, 1, R, S, F, shift_h, shift_w, numeric_t_str, numeric_t_str, triton::dnn::shift::BPROP);
|
triton::dnn::shift shift(BS, C, 1, H, W, 1, R, S, F, 1, 1, shift_h, shift_w, numeric_t_str, numeric_t_str, triton::dnn::shift::BPROP);
|
||||||
// host buffers
|
// host buffers
|
||||||
std::vector<float> hc(shift.c_size());
|
std::vector<float> hc(shift.c_size());
|
||||||
std::vector<float> rc(shift.c_size());
|
std::vector<float> rc(shift.c_size());
|
||||||
@@ -55,35 +53,7 @@ int main() {
|
|||||||
stream->write(db, true, 0, hb);
|
stream->write(db, true, 0, hb);
|
||||||
stream->write(dc, true, 0, hc);
|
stream->write(dc, true, 0, hc);
|
||||||
stream->synchronize();
|
stream->synchronize();
|
||||||
// benchmark
|
shift.enqueue(stream, da, db, dc);
|
||||||
auto benchmark = [&](triton::driver::kernel* kernel,
|
|
||||||
triton::jit::launch_information info) {
|
|
||||||
shift.init(stream, (triton::driver::cu_module*)kernel->module());
|
|
||||||
// launch infoRR
|
|
||||||
unsigned TM = info.global_range_size[0];
|
|
||||||
unsigned TN = info.global_range_size[1];
|
|
||||||
unsigned nthreads = info.num_threads;
|
|
||||||
// set argument
|
|
||||||
shift.enqueue(stream, kernel, da, db, dc, TM, TN, nthreads);
|
|
||||||
stream->synchronize();
|
|
||||||
// benchmark
|
|
||||||
double ts = triton::tools::bench([&](){shift.enqueue(stream, kernel, da, db, dc, TM, TN, nthreads);},
|
|
||||||
[&](){ stream->synchronize(); }, context->device());
|
|
||||||
return shift.get_nflops() / ts * 1e-3;
|
|
||||||
};
|
|
||||||
|
|
||||||
// shift
|
|
||||||
std::vector<unsigned> params = {
|
|
||||||
4, 2, 16, 4, 128, 2, 2, 1, 1, 8, 16, 8, 2
|
|
||||||
};
|
|
||||||
std::ostringstream oss;
|
|
||||||
shift.src(oss);
|
|
||||||
std::string src = oss.str();
|
|
||||||
// jit.autotune("shift", src.c_str(), benchmark);
|
|
||||||
jit.add_module("shift", src.c_str(), params);
|
|
||||||
triton::driver::kernel* kernel = jit.get_function("shift");
|
|
||||||
triton::jit::launch_information info = jit.get_launch_info("shift");
|
|
||||||
std::cout << "Performance: " << benchmark(kernel, info) << " TFLOPS " << std::endl;
|
|
||||||
// stream->read(dc, true, 0, hc);
|
// stream->read(dc, true, 0, hc);
|
||||||
// shift.cpu_ref(rc.data(), ha.data(), hb.data());
|
// shift.cpu_ref(rc.data(), ha.data(), hb.data());
|
||||||
// for(size_t i = 0; i < hc.size(); i++)
|
// for(size_t i = 0; i < hc.size(); i++)
|
||||||
|
@@ -72,18 +72,18 @@ torch::Tensor shift_common(
|
|||||||
if(m_shift_jit.find(key) == m_shift_jit.end()){
|
if(m_shift_jit.find(key) == m_shift_jit.end()){
|
||||||
jit = m_shift_jit.emplace(key, new triton::jit(ctx)).first->second.get();
|
jit = m_shift_jit.emplace(key, new triton::jit(ctx)).first->second.get();
|
||||||
std::ostringstream oss;
|
std::ostringstream oss;
|
||||||
configuration->src(oss);
|
configuration->get_src(oss);
|
||||||
std::string src = oss.str();
|
std::string src = oss.str();
|
||||||
// benchmark a given shiftolution kernel
|
// benchmark a given shiftolution kernel
|
||||||
auto benchmark = [&](triton::driver::kernel* kernel,
|
auto benchmark = [&](triton::driver::kernel* kernel,
|
||||||
triton::jit::launch_information info) {
|
triton::jit::launch_information info) {
|
||||||
configuration->init(stream, (triton::driver::cu_module*)kernel->module());
|
configuration->init_impl(stream, (triton::driver::cu_module*)kernel->module());
|
||||||
unsigned TM = info.global_range_size[0];
|
unsigned TM = info.global_range_size[0];
|
||||||
unsigned TN = info.global_range_size[1];
|
unsigned TN = info.global_range_size[1];
|
||||||
unsigned nthreads = info.num_threads;
|
unsigned nthreads = info.num_threads;
|
||||||
configuration->enqueue(stream, kernel, &a, &b, &c, TM, TN, nthreads);
|
configuration->enqueue_impl(stream, kernel, &a, &b, &c, TM, TN, nthreads);
|
||||||
stream->synchronize();
|
stream->synchronize();
|
||||||
double ts = triton::tools::bench([&](){ configuration->enqueue(stream, kernel, &a, &b, &c, TM, TN, nthreads); },
|
double ts = triton::tools::bench([&](){ configuration->enqueue_impl(stream, kernel, &a, &b, &c, TM, TN, nthreads); },
|
||||||
[&](){ stream->synchronize(); }, stream->context()->device());
|
[&](){ stream->synchronize(); }, stream->context()->device());
|
||||||
return configuration->get_nflops() / ts * 1e-3;
|
return configuration->get_nflops() / ts * 1e-3;
|
||||||
};
|
};
|
||||||
@@ -96,7 +96,7 @@ torch::Tensor shift_common(
|
|||||||
jit->add_module("shift", src.c_str(), jit->get_valid("shift", src.c_str()));
|
jit->add_module("shift", src.c_str(), jit->get_valid("shift", src.c_str()));
|
||||||
}
|
}
|
||||||
triton::driver::kernel* kernel = jit->get_function("shift");
|
triton::driver::kernel* kernel = jit->get_function("shift");
|
||||||
configuration->init(stream, (triton::driver::cu_module*)kernel->module());
|
configuration->init_impl(stream, (triton::driver::cu_module*)kernel->module());
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
jit = m_shift_jit.at(key).get();
|
jit = m_shift_jit.at(key).get();
|
||||||
@@ -109,6 +109,6 @@ torch::Tensor shift_common(
|
|||||||
unsigned TN = info.global_range_size[1];
|
unsigned TN = info.global_range_size[1];
|
||||||
unsigned nthreads = info.num_threads;
|
unsigned nthreads = info.num_threads;
|
||||||
// enqueue
|
// enqueue
|
||||||
configuration->enqueue(stream, kernel, &a, &b, &c, TM, TN, nthreads);
|
configuration->enqueue_impl(stream, kernel, &a, &b, &c, TM, TN, nthreads);
|
||||||
return torchc;
|
return torchc;
|
||||||
}
|
}
|
||||||
|
@@ -133,7 +133,6 @@ public:
|
|||||||
triton::driver::cu_buffer tdx(ctx, (CUdeviceptr)dx->flat<float>().data(), false);
|
triton::driver::cu_buffer tdx(ctx, (CUdeviceptr)dx->flat<float>().data(), false);
|
||||||
triton::driver::cu_buffer tdg(ctx, (CUdeviceptr)dg->flat<float>().data(), false);
|
triton::driver::cu_buffer tdg(ctx, (CUdeviceptr)dg->flat<float>().data(), false);
|
||||||
triton::driver::cu_buffer tdb(ctx, (CUdeviceptr)db->flat<float>().data(), false);
|
triton::driver::cu_buffer tdb(ctx, (CUdeviceptr)db->flat<float>().data(), false);
|
||||||
|
|
||||||
// create config
|
// create config
|
||||||
triton::dnn::batchnorm_backward batchnorm(C, 1, H, W, B, "fp32");
|
triton::dnn::batchnorm_backward batchnorm(C, 1, H, W, B, "fp32");
|
||||||
std::ostringstream oss;
|
std::ostringstream oss;
|
||||||
|
@@ -128,3 +128,4 @@ def run_batchnorm():
|
|||||||
print(np.max(np.abs(db_t - db_n)))
|
print(np.max(np.abs(db_t - db_n)))
|
||||||
|
|
||||||
run_shift()
|
run_shift()
|
||||||
|
#run_batchnorm()
|
||||||
|
@@ -99,7 +99,7 @@ public:
|
|||||||
FillShapes(context, C, H, W, B, F, tf_a, tf_b);
|
FillShapes(context, C, H, W, B, F, tf_a, tf_b);
|
||||||
int64_t D = 1, T = 1;
|
int64_t D = 1, T = 1;
|
||||||
bool has_bias = false;
|
bool has_bias = false;
|
||||||
// shift configuration
|
// shift offsets
|
||||||
int32_t* shift_h_data = h_shift_h_.flat<int32_t>().data();
|
int32_t* shift_h_data = h_shift_h_.flat<int32_t>().data();
|
||||||
int32_t* shift_w_data = h_shift_w_.flat<int32_t>().data();
|
int32_t* shift_w_data = h_shift_w_.flat<int32_t>().data();
|
||||||
std::vector<int32_t> shift_h(shift_h_data, shift_h_data + C);
|
std::vector<int32_t> shift_h(shift_h_data, shift_h_data + C);
|
||||||
@@ -116,7 +116,6 @@ public:
|
|||||||
.first->second.get();
|
.first->second.get();
|
||||||
else
|
else
|
||||||
shift = m_config.at(key).get();
|
shift = m_config.at(key).get();
|
||||||
|
|
||||||
// shapes for c
|
// shapes for c
|
||||||
std::vector<int64> c_shapes;
|
std::vector<int64> c_shapes;
|
||||||
for(int32_t x: shift->c_shapes())
|
for(int32_t x: shift->c_shapes())
|
||||||
@@ -131,49 +130,7 @@ public:
|
|||||||
triton::driver::cu_buffer da(ctx, (CUdeviceptr)tf_a.flat<float>().data(), false);
|
triton::driver::cu_buffer da(ctx, (CUdeviceptr)tf_a.flat<float>().data(), false);
|
||||||
triton::driver::cu_buffer db(ctx, (CUdeviceptr)tf_b.flat<float>().data(), false);
|
triton::driver::cu_buffer db(ctx, (CUdeviceptr)tf_b.flat<float>().data(), false);
|
||||||
triton::driver::cu_buffer dc(ctx, (CUdeviceptr)tf_c->flat<float>().data(), false);
|
triton::driver::cu_buffer dc(ctx, (CUdeviceptr)tf_c->flat<float>().data(), false);
|
||||||
// get JIT
|
shift->enqueue(stream, {&da, &db, &dc});
|
||||||
triton::jit* jit;
|
|
||||||
bool autotune = false;
|
|
||||||
if(m_jit.find(key) == m_jit.end()) {
|
|
||||||
jit = m_jit.emplace(key, new triton::jit(ctx)).first->second.get();
|
|
||||||
std::ostringstream oss;
|
|
||||||
shift->src(oss);
|
|
||||||
std::string src = oss.str();
|
|
||||||
auto benchmark = [&](triton::driver::kernel* kernel,
|
|
||||||
triton::jit::launch_information info) {
|
|
||||||
// launch info
|
|
||||||
unsigned TM = info.global_range_size[0];
|
|
||||||
unsigned TN = info.global_range_size[1];
|
|
||||||
unsigned nthreads = info.num_threads;
|
|
||||||
shift->init(stream, (triton::driver::cu_module*)kernel->module());
|
|
||||||
shift->enqueue(stream, kernel, &da, &db, &dc, TM, TN, nthreads);
|
|
||||||
stream->synchronize();
|
|
||||||
double ts = triton::tools::bench([&](){ shift->enqueue(stream, kernel, &da, &db, &dc, TM, TN, nthreads); },
|
|
||||||
[&](){ stream->synchronize(); }, ctx->device());
|
|
||||||
return shift->get_nflops() / ts * 1e-3;
|
|
||||||
};
|
|
||||||
// auto-tune and save result
|
|
||||||
if(autotune) {
|
|
||||||
triton::jit::tune_res_t best = jit->autotune("shift", src.c_str(), benchmark);
|
|
||||||
jit->add_module("shift", src.c_str(), best.params);
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
jit->add_module("shift", src.c_str(), jit->get_valid("shift", src.c_str()));
|
|
||||||
}
|
|
||||||
triton::driver::kernel* kernel = jit->get_function("shift");
|
|
||||||
shift->init(stream, (triton::driver::cu_module*)kernel->module());
|
|
||||||
}
|
|
||||||
else
|
|
||||||
jit = m_jit.at(key).get();
|
|
||||||
// Run
|
|
||||||
triton::driver::kernel* kernel = jit->get_function("shift");
|
|
||||||
triton::jit::launch_information info = jit->get_launch_info("shift");
|
|
||||||
// launch info
|
|
||||||
unsigned TM = info.global_range_size[0];
|
|
||||||
unsigned TN = info.global_range_size[1];
|
|
||||||
unsigned nthreads = info.num_threads;
|
|
||||||
// enqueue
|
|
||||||
shift->enqueue(stream, kernel, &da, &db, &dc, TM, TN, nthreads);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
65
include/triton/dnn/base.h
Normal file
65
include/triton/dnn/base.h
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
/* Copyright 2015-2017 Philippe Tillet
|
||||||
|
*
|
||||||
|
* Permission is hereby granted, free of charge, to any person obtaining
|
||||||
|
* a copy of this software and associated documentation files
|
||||||
|
* (the "Software"), to deal in the Software without restriction,
|
||||||
|
* including without limitation the rights to use, copy, modify, merge,
|
||||||
|
* publish, distribute, sublicense, and/or sell copies of the Software,
|
||||||
|
* and to permit persons to whom the Software is furnished to do so,
|
||||||
|
* subject to the following conditions:
|
||||||
|
*
|
||||||
|
* The above copyright notice and this permission notice shall be
|
||||||
|
* included in all copies or substantial portions of the Software.
|
||||||
|
*
|
||||||
|
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||||
|
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||||
|
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||||
|
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||||
|
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||||
|
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||||
|
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef TDL_INCLUDE_DNN_BASE_H
|
||||||
|
#define TDL_INCLUDE_DNN_BASE_H
|
||||||
|
|
||||||
|
#include "triton/driver/stream.h"
|
||||||
|
#include "triton/driver/kernel.h"
|
||||||
|
|
||||||
|
namespace triton{
|
||||||
|
namespace dnn{
|
||||||
|
|
||||||
|
class base {
|
||||||
|
protected:
|
||||||
|
// leading dimensions
|
||||||
|
static void set_ld(const std::vector<int32_t>& shapes,
|
||||||
|
std::vector<int32_t>& ld);
|
||||||
|
|
||||||
|
private:
|
||||||
|
// initialize
|
||||||
|
virtual void init_impl(driver::stream *stream, driver::cu_module *module) = 0;
|
||||||
|
// enqueue
|
||||||
|
virtual void enqueue_impl(driver::stream *stream, driver::kernel *kernel,
|
||||||
|
std::vector<driver::buffer*> args,
|
||||||
|
size_t TM, size_t TN, size_t nthreads) = 0;
|
||||||
|
|
||||||
|
public:
|
||||||
|
// constructor
|
||||||
|
base(const std::string& name);
|
||||||
|
// number of flops
|
||||||
|
virtual size_t get_nflops() const = 0;
|
||||||
|
// triton-c source
|
||||||
|
virtual void get_src(std::ostream &os) const = 0;
|
||||||
|
// comparison for maps
|
||||||
|
virtual bool operator<(const base& other) const = 0;
|
||||||
|
// enqueue
|
||||||
|
void enqueue(driver::stream* stream, std::vector<driver::buffer*> args);
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::string name_;
|
||||||
|
};
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
@@ -28,13 +28,15 @@
|
|||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
|
#include "triton/dnn/base.h"
|
||||||
#include "triton/driver/stream.h"
|
#include "triton/driver/stream.h"
|
||||||
#include "triton/driver/kernel.h"
|
#include "triton/driver/kernel.h"
|
||||||
|
#include "triton/runtime/jit.h"
|
||||||
|
|
||||||
namespace triton{
|
namespace triton{
|
||||||
namespace dnn{
|
namespace dnn{
|
||||||
|
|
||||||
class shift {
|
class shift: public base {
|
||||||
|
|
||||||
public:
|
public:
|
||||||
enum type {
|
enum type {
|
||||||
@@ -44,8 +46,14 @@ public:
|
|||||||
};
|
};
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
// leading dimensions
|
||||||
void set_ld(const std::vector<int32_t>& shapes,
|
void set_ld(const std::vector<int32_t>& shapes,
|
||||||
std::vector<int32_t>& ld);
|
std::vector<int32_t>& ld);
|
||||||
|
// initialize and enqueue
|
||||||
|
void init_impl(driver::stream *stream, driver::cu_module *module);
|
||||||
|
void enqueue_impl(driver::stream *stream, driver::kernel *kernel,
|
||||||
|
std::vector<driver::buffer*> args,
|
||||||
|
size_t TM, size_t TN, size_t nthreads);
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
|
||||||
@@ -60,26 +68,18 @@ public:
|
|||||||
// look-up table
|
// look-up table
|
||||||
void build_deltas();
|
void build_deltas();
|
||||||
void build_masks();
|
void build_masks();
|
||||||
|
|
||||||
// accessors
|
// accessors
|
||||||
size_t a_size();
|
size_t a_size();
|
||||||
size_t b_size();
|
size_t b_size();
|
||||||
size_t c_size();
|
size_t c_size();
|
||||||
std::vector<int32_t> c_shapes();
|
std::vector<int32_t> c_shapes();
|
||||||
|
// number of flops
|
||||||
// device function
|
size_t get_nflops() const;
|
||||||
void init(driver::stream *stream, driver::cu_module *module);
|
|
||||||
void enqueue(driver::stream *stream, driver::kernel *kernel,
|
|
||||||
driver::buffer *a, driver::buffer *b, driver::buffer *c,
|
|
||||||
size_t TM, size_t TN, size_t nthreads);
|
|
||||||
|
|
||||||
// utils
|
|
||||||
size_t get_nflops();
|
|
||||||
|
|
||||||
// source
|
// source
|
||||||
void src(std::ostream &os);
|
void get_src(std::ostream &os) const;
|
||||||
|
// comparison
|
||||||
// cpu_ref
|
bool operator<(const base& other) const;
|
||||||
|
// cpu reference
|
||||||
template<class IN_DTYPE, class OUT_DTYPE>
|
template<class IN_DTYPE, class OUT_DTYPE>
|
||||||
void cpu_ref(OUT_DTYPE* O,
|
void cpu_ref(OUT_DTYPE* O,
|
||||||
const IN_DTYPE* I,
|
const IN_DTYPE* I,
|
||||||
|
69
lib/dnn/base.cpp
Normal file
69
lib/dnn/base.cpp
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
#include <sstream>
|
||||||
|
#include "triton/dnn/base.h"
|
||||||
|
#include "triton/runtime/jit.h"
|
||||||
|
#include "triton/tools/bench.hpp"
|
||||||
|
|
||||||
|
namespace triton{
|
||||||
|
namespace dnn{
|
||||||
|
|
||||||
|
struct cmp_recompile{
|
||||||
|
bool operator()(base* x, base* y) const{
|
||||||
|
return *x < *y;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
base::base(const std::string& name)
|
||||||
|
: name_(name) { }
|
||||||
|
|
||||||
|
void base::enqueue(driver::stream *stream, std::vector<driver::buffer *> args) {
|
||||||
|
static std::map<base*, std::unique_ptr<triton::jit>, cmp_recompile> m_jit;
|
||||||
|
bool autotune = false;
|
||||||
|
driver::context* ctx = stream->context();
|
||||||
|
triton::jit* jit;
|
||||||
|
/* the current template has not already been compiled */
|
||||||
|
if(m_jit.find(this) == m_jit.end()) {
|
||||||
|
jit = m_jit.emplace(this, new triton::jit(ctx)).first->second.get();
|
||||||
|
std::ostringstream oss;
|
||||||
|
get_src(oss);
|
||||||
|
std::string src = oss.str();
|
||||||
|
auto benchmark = [&](triton::driver::kernel* kernel,
|
||||||
|
triton::jit::launch_information info) {
|
||||||
|
// launch info
|
||||||
|
unsigned TM = info.global_range_size[0];
|
||||||
|
unsigned TN = info.global_range_size[1];
|
||||||
|
unsigned nthreads = info.num_threads;
|
||||||
|
init_impl(stream, (triton::driver::cu_module*)kernel->module());
|
||||||
|
enqueue_impl(stream, kernel, args, TM, TN, nthreads);
|
||||||
|
stream->synchronize();
|
||||||
|
double ts = triton::tools::bench([&](){ enqueue_impl(stream, kernel, args, TM, TN, nthreads); },
|
||||||
|
[&](){ stream->synchronize(); }, ctx->device());
|
||||||
|
return get_nflops() / ts * 1e-3;
|
||||||
|
};
|
||||||
|
// auto-tune and save result
|
||||||
|
if(autotune) {
|
||||||
|
triton::jit::tune_res_t best = jit->autotune(name_.c_str(), src.c_str(), benchmark);
|
||||||
|
jit->add_module(name_.c_str(), src.c_str(), best.params);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
jit->add_module(name_.c_str(), src.c_str(), jit->get_valid(name_.c_str(), src.c_str()));
|
||||||
|
}
|
||||||
|
triton::driver::kernel* kernel = jit->get_function(name_.c_str());
|
||||||
|
init_impl(stream, (triton::driver::cu_module*)kernel->module());
|
||||||
|
}
|
||||||
|
/* retrieved compiled template */
|
||||||
|
else
|
||||||
|
jit = m_jit.at(this).get();
|
||||||
|
|
||||||
|
/* get launch parameters */
|
||||||
|
driver::kernel* kernel = jit->get_function(name_.c_str());
|
||||||
|
triton::jit::launch_information info = jit->get_launch_info(name_.c_str());
|
||||||
|
unsigned TM = info.global_range_size[0];
|
||||||
|
unsigned TN = info.global_range_size[1];
|
||||||
|
unsigned nthreads = info.num_threads;
|
||||||
|
|
||||||
|
/* launch */
|
||||||
|
enqueue_impl(stream, kernel, args, TM, TN, nthreads);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
@@ -1,5 +1,6 @@
|
|||||||
|
#include <sstream>
|
||||||
#include "triton/dnn/shift.h"
|
#include "triton/dnn/shift.h"
|
||||||
|
#include "triton/tools/bench.hpp"
|
||||||
|
|
||||||
namespace triton{
|
namespace triton{
|
||||||
namespace dnn{
|
namespace dnn{
|
||||||
@@ -21,7 +22,8 @@ shift::shift(int B, int C,
|
|||||||
const std::vector<int32_t>& shift_h, const std::vector<int32_t>& shift_w,
|
const std::vector<int32_t>& shift_h, const std::vector<int32_t>& shift_w,
|
||||||
std::string a_ty, std::string b_ty,
|
std::string a_ty, std::string b_ty,
|
||||||
type ty, bool bias)
|
type ty, bool bias)
|
||||||
: B_(B), C_(C),
|
: base("shift"),
|
||||||
|
B_(B), C_(C),
|
||||||
AD_(D), AH_(H), AW_(W),
|
AD_(D), AH_(H), AW_(W),
|
||||||
BD_(T), BH_(R), BW_(S),
|
BD_(T), BH_(R), BW_(S),
|
||||||
F_(F),
|
F_(F),
|
||||||
@@ -118,21 +120,33 @@ std::vector<int32_t> shift::c_shapes(){
|
|||||||
return shapes_c_;
|
return shapes_c_;
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t shift::get_nflops() {
|
size_t shift::get_nflops() const {
|
||||||
return 2.*M_*N_*K_;
|
return 2.*M_*N_*K_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool shift::operator <(const base& other) const{
|
||||||
|
auto *y = dynamic_cast<const shift*>(&other);
|
||||||
|
if(!y)
|
||||||
|
return false;
|
||||||
|
const int32_t *x_shift_h = shift_h_.data(), *x_shift_w = shift_w_.data();
|
||||||
|
const int32_t *y_shift_h = y->shift_h_.data(), *y_shift_w = y->shift_w_.data();
|
||||||
|
return std::tie(B_, C_, AD_, AH_, AW_, BD_, BH_, BW_, F_,
|
||||||
|
x_shift_h, x_shift_w, ty_, bias_)
|
||||||
|
< std::tie(y->B_, y->C_, y->AD_, y->AH_, y->AW_, y->BD_, y->BH_, y->BW_, y->F_,
|
||||||
|
y_shift_h, y_shift_w, y->ty_, y->bias_);
|
||||||
|
}
|
||||||
|
|
||||||
void shift::init(driver::stream *stream, driver::cu_module *module) {
|
void shift::init_impl(driver::stream *stream, driver::cu_module *module) {
|
||||||
triton::driver::buffer* delta = ((triton::driver::cu_module*)module)->symbol("delta");
|
triton::driver::buffer* delta = ((triton::driver::cu_module*)module)->symbol("delta");
|
||||||
stream->write(delta, false, 0, h_deltas_.size()*4, h_deltas_.data());
|
stream->write(delta, false, 0, h_deltas_.size()*4, h_deltas_.data());
|
||||||
}
|
}
|
||||||
|
|
||||||
void shift::enqueue(driver::stream *stream, driver::kernel *kernel,
|
void shift::enqueue_impl(driver::stream *stream, driver::kernel *kernel,
|
||||||
driver::buffer *a, driver::buffer *b, driver::buffer *c,
|
std::vector<driver::buffer *> args,
|
||||||
size_t TM, size_t TN, size_t nthreads) {
|
size_t TM, size_t TN, size_t nthreads) {
|
||||||
int32_t lda = AT_ ? K_ : M_;
|
int32_t lda = AT_ ? K_ : M_;
|
||||||
int32_t ldb = BT_ ? N_ : K_;
|
int32_t ldb = BT_ ? N_ : K_;
|
||||||
|
driver::buffer *a = args[0], *b = args[1], *c = args[2];
|
||||||
kernel->setArg(0, a);
|
kernel->setArg(0, a);
|
||||||
kernel->setArg(1, b);
|
kernel->setArg(1, b);
|
||||||
kernel->setArg(2, c);
|
kernel->setArg(2, c);
|
||||||
@@ -154,7 +168,7 @@ void shift::enqueue(driver::stream *stream, driver::kernel *kernel,
|
|||||||
stream->enqueue(kernel, grid, {nthreads, 1, 1});
|
stream->enqueue(kernel, grid, {nthreads, 1, 1});
|
||||||
}
|
}
|
||||||
|
|
||||||
void shift::src(std::ostream &os) {
|
void shift::get_src(std::ostream &os) const {
|
||||||
std::string AS0 = "TM", AS1 = "TK";
|
std::string AS0 = "TM", AS1 = "TK";
|
||||||
std::string BS0 = "TK", BS1 = "TN";
|
std::string BS0 = "TK", BS1 = "TN";
|
||||||
std::string bcb0 = "[:, newaxis]", bcb1 = "[newaxis, :]";
|
std::string bcb0 = "[:, newaxis]", bcb1 = "[newaxis, :]";
|
||||||
|
Reference in New Issue
Block a user