Files
triton/include/triton/dnn/shift.h
2019-07-21 18:11:54 -07:00

185 lines
4.5 KiB
C++

/* Copyright 2015-2017 Philippe Tillet
*
* Permission is hereby granted, free of charge, to any person obtaining
* a copy of this software and associated documentation files
* (the "Software"), to deal in the Software without restriction,
* including without limitation the rights to use, copy, modify, merge,
* publish, distribute, sublicense, and/or sell copies of the Software,
* and to permit persons to whom the Software is furnished to do so,
* subject to the following conditions:
*
* The above copyright notice and this permission notice shall be
* included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
#ifndef TDL_INCLUDE_DNN_SHIFT_H
#define TDL_INCLUDE_DNN_SHIFT_H
#include <string>
#include <vector>
#include <algorithm>
#include <numeric>
#include <cmath>
#include "triton/dnn/base.h"
#include "triton/driver/stream.h"
#include "triton/driver/kernel.h"
namespace triton{
namespace dnn{
class shift: public base {
public:
enum op_t {
FPROP,
BPROP,
WGRAD
};
enum layout_t {
NCHW,
CHWN
};
private:
// initialize and enqueue
void init_impl(driver::stream *stream, driver::cu_module *module);
void deinit_impl();
void enqueue_impl(driver::stream *stream, driver::kernel *kernel,
std::vector<driver::buffer*> args,
triton::runtime::launch_information info);
std::vector<unsigned> default_params() const;
public:
shift(int B, int NC,
int D, int H, int W,
int T, int R, int S, int NF,
int stride_h, int stride_w,
const int32_t* shift_h, const int32_t* shift_w,
std::string a_ty = "fp32", std::string b_ty = "fp32",
op_t ty = FPROP, bool bias = false, layout_t layout = CHWN);
// look-up table
void build_delta_a();
void build_masks();
// accessors
size_t c_size();
std::vector<int32_t> c_shapes();
// number of flops
size_t num_flops() const;
// source
void triton_c_src(std::ostream &os) const;
// comparison
bool operator<(const base& other) const;
// clone
base* clone() const;
// cpu reference
template<class IN_DTYPE, class OUT_DTYPE>
void cpu_ref(OUT_DTYPE* O,
const IN_DTYPE* I,
const IN_DTYPE* F)
{
OUT_DTYPE acc;
for(int32_t p = 0; p < AH_; ++p)
for(int32_t q = 0; q < AW_; ++q)
for(int32_t bs = 0; bs < B_; ++bs)
for(int32_t k = 0; k < F_; ++k)
{
acc = 0;
for(int32_t c = 0; c < C_; ++c){
int32_t h = p;
int32_t w = q;
if(h >= BH_/2 && h < AH_ - BH_/2
&& w >= BW_/2 && w < AW_ - BW_/2){
h += shift_h_[c];
w += shift_w_[c];
}
IN_DTYPE a = I[bs + w*B_ + h*B_*AW_ + c*B_*AH_*AW_];
IN_DTYPE b = F[k + c*F_];
acc = std::fma(a, b, acc);
}
O[bs + q*B_ + p*B_*AW_ + k*B_*AH_*AW_] = acc;
}
}
private:
int32_t MAX_C_;
int32_t TK_;
// image size
int32_t B_;
int32_t C_;
int32_t AD_;
int32_t AH_;
int32_t AW_;
// filter size
int32_t BD_;
int32_t BH_;
int32_t BW_;
int32_t F_;
// activation size
int32_t CD_;
int32_t CH_;
int32_t CW_;
// interior image size
int32_t IAD_;
int32_t IAH_;
int32_t IAW_;
// interior activation size
int32_t ICD_;
int32_t ICH_;
int32_t ICW_;
// equivalent matmul
int32_t M_;
int32_t N_;
int32_t K_;
// shapes
std::vector<int32_t> shapes_c_;
// strides
int32_t stride_d_;
int32_t stride_h_;
int32_t stride_w_;
// memory strides
int32_t lda_n_, lda_c_, lda_h_, lda_w_;
int32_t ldb_n_, ldb_c_, ldb_h_, ldb_w_;
int32_t ldc_n_, ldc_f_, ldc_h_, ldc_w_;
// shift values
const int32_t* shift_h_;
const int32_t* shift_w_;
bool shift_edge_h_;
bool shift_edge_w_;
// look-up tables
std::vector<int32_t> h_delta_a;
std::vector<int32_t> h_delta_b;
driver::buffer* d_delta_a;
driver::buffer* d_delta_b;
// data types
std::string a_ty_;
std::string b_ty_;
std::string c_ty_;
// convolution type
op_t op_;
bool bias_;
// transpose
bool AT_;
bool BT_;
// layout
layout_t layout_;
// locks
size_t max_locks_;
driver::buffer *locks_;
};
}
}
#endif