[dnn]: Now implementing all existing DNN routines using common base template and auto-tuner

This commit is contained in:
Philippe Tillet
2019-07-09 17:30:58 -07:00
parent 88675fa01a
commit b7986baffa
19 changed files with 388 additions and 375 deletions

View File

@@ -19,15 +19,6 @@
using namespace tensorflow;
using GPUDevice = Eigen::GpuDevice;
typedef std::tuple<int32_t, int32_t, int32_t, int32_t, int32_t,
int32_t, int32_t, int32_t, int32_t,
int32_t*, int32_t*,
triton::dnn::shift::type, bool> shift_key_t;
static std::map<CUstream, std::unique_ptr<triton::driver::stream>> m_stream;
static std::map<shift_key_t, std::unique_ptr<triton::jit>> m_jit;
static std::map<shift_key_t, std::unique_ptr<triton::dnn::shift>> m_config;
template<triton::dnn::shift::type OP>
class ShiftConvOp : public OpKernel {
public:
@@ -102,23 +93,15 @@ public:
// shift offsets
int32_t* shift_h_data = h_shift_h_.flat<int32_t>().data();
int32_t* shift_w_data = h_shift_w_.flat<int32_t>().data();
std::vector<int32_t> shift_h(shift_h_data, shift_h_data + C);
std::vector<int32_t> shift_w(shift_w_data, shift_w_data + C);
shift_key_t key = {B, C, 1, H, W, 1, R_, S_, F, shift_h_data, shift_w_data, OP, has_bias};
// create configuration
triton::dnn::shift* shift;
if(m_config.find(key) == m_config.end())
shift = m_config.emplace(key, new triton::dnn::shift(
B, C, D, H, W, T, R_, S_, F,
stride_h_, stride_w_,
shift_h, shift_w,
"fp32", "fp32", OP, has_bias))
.first->second.get();
else
shift = m_config.at(key).get();
triton::dnn::shift shift(B, C, D, H, W, T, R_, S_, F,
stride_h_, stride_w_,
shift_h_data, shift_w_data,
"fp32", "fp32", OP, has_bias);
// shapes for c
std::vector<int64> c_shapes;
for(int32_t x: shift->c_shapes())
for(int32_t x: shift.c_shapes())
c_shapes.push_back(x);
TensorShape out_shapes(c_shapes);
Tensor* tf_c = nullptr;
@@ -130,7 +113,7 @@ public:
triton::driver::cu_buffer da(ctx, (CUdeviceptr)tf_a.flat<float>().data(), false);
triton::driver::cu_buffer db(ctx, (CUdeviceptr)tf_b.flat<float>().data(), false);
triton::driver::cu_buffer dc(ctx, (CUdeviceptr)tf_c->flat<float>().data(), false);
shift->enqueue(stream, {&da, &db, &dc});
shift.enqueue(stream, {&da, &db, &dc});
}
private: