Files
triton/include/triton/dnn/conv.h
2019-05-22 17:49:40 -04:00

282 lines
8.3 KiB
C++

#include <string>
#include <vector>
#include <algorithm>
#include <numeric>
#include "triton/driver/stream.h"
#include "triton/driver/kernel.h"
namespace triton{
namespace dnn{
class conv {
public:
enum type {
FPROP,
BPROP,
WGRAD
};
conv(int B, int NC,
int D, int H, int W,
int T, int R, int S, int NF,
int stride_d, int stride_h, int stride_w,
int pad_d, int pad_h, int pad_w,
int upsample_d, int upsample_h, int upsample_w,
type ty = FPROP, bool bias = false);
// accessors
size_t a_size();
size_t b_size();
size_t c_size();
std::vector<int32_t> c_shapes();
// initialize
void build_deltas();
void build_masks();
void init(driver::stream *stream, driver::cu_module *module);
std::array<size_t, 3> get_grid(size_t TM, size_t TN);
void set_arg(driver::kernel *kernel,
driver::buffer *a, driver::buffer *b, driver::buffer *c,
driver::buffer *bias);
// utilities
size_t get_nflops();
std::vector<unsigned> default_params();
// source
std::string src(){
std::string BS = b_trans_ ? "[TN,TK]" : "[TK, TN]";
std::string bcb0 = b_trans_ ? "[:, newaxis]" : "[newaxis, :]";
std::string bcb1 = b_trans_ ? "[newaxis, :]" : "[:, newaxis]";
std::string ldb0 = b_trans_ ? "*ldb_s" : "";
std::string ldb1 = b_trans_ ? "*ldb_k" : "*ldb_c";
std::string useb = b_trans_ ? "trans(b)" : "b";
std::string flipr = b_trans_ ? "" : "BH - 1 -";
std::string flips = b_trans_ ? "" : "BW - 1 -";
std::string ax = b_trans_ ? "crs" : "rsc";
std::vector<std::string> redax;
if(b_trans_)
redax = {"C", "BH", "BW"};
else
redax = {"BH", "BW", "N"};
std::string inc_pb = b_lut_ ? "db[newaxis, :]" : "TK" + ldb0;
std::string a_delta_mem = is_a_deltas_cst ? "__constant__" : "";
std::string b_delta_mem = is_b_deltas_cst_? "__constant__" : "";
std::string masks_mem = is_mask_cst_? "__constant__" : "";
std::string res =
R"(
const tunable int32 TM = {16, 32, 64};
const tunable int32 TN = {16, 32, 64};
const tunable int32 TK = {8};
)";
if(is_a_deltas_cst)
res += "__constant__ int32* delta = alloc_const int32[" + std::to_string(h_a_deltas_.size()) + "];\n";
if(b_lut_ && is_b_deltas_cst_)
res += "__constant__ int32* b_delta = alloc_const int32[" + std::to_string(h_b_deltas_.size()) + "];\n";
if(is_mask_cst_)
res += "__constant__ int32* masks = alloc_const int32[" + std::to_string(h_masks_.size()) + "];\n";
res += R"(
void conv(read_only restrict fp32 *a,
read_only restrict fp32 *b,
fp32 *c,
fp32 *bias,
int32 M, int32 N, int32 K,
int32 AH, int32 AW,
int32 BH, int32 BW,
int32 CH, int32 CW,
int32 lda_n, int32 lda_c, int32 lda_d, int32 lda_h, int32 lda_w,
int32 ldb_c, int32 ldb_t, int32 ldb_r, int32 ldb_s, int32 ldb_k,
int32 ldc_n, int32 ldc_k, int32 ldc_m, int32 ldc_p, int32 ldc_q,
int32 pad_h, int32 pad_w,
int32 stride_h, int32 stride_w,
int32 upsample_h, int32 upsample_w)";
if(!is_a_deltas_cst)
res += ", int32* delta";
if(b_lut_ && !is_b_deltas_cst_)
res += ", int32* b_delta";
if(!is_mask_cst_)
res += ", int32* masks";
res += R"(){
int32 rxa[TM] = get_global_range[TM](0);
int32 rb0[TN] = get_global_range[TN](1);
int32 rka[TK] = 0 ... TK;
int32 rkb[TK] = 0 ... TK;
fp32 C[TM, TN] = 0;
int32 ldlut = )" + std::to_string(Fs_) + R"(;
int32 rabh[TM] = rxa / CW;
int32 raw[TM] = rxa % CW;
int32 rab[TM] = rabh / CH;
int32 rah[TM] = rabh % CH;
raw = raw*stride_w - pad_w;
rah = rah*stride_h - pad_h;
int32 ra0[TM] = rab*lda_n + rah*lda_h + raw*lda_w;
int32 ra)" + ax[0] + ax[1] + "[TK] = rka / " + redax[2] + R"(;
int32 ra)" + ax[2] + "[TK] = rka % " + redax[2] + R"(;
int32 ra)" + ax[0] + "[TK] = ra" + ax[0] + ax[1] + " / " + redax[1] + R"(;
int32 ra)" + ax[1] + "[TK] = ra" + ax[0] + ax[1] + " % " + redax[1] + R"(;
rar = )" + flipr + R"( rar;
ras = )" + flips + R"( ras;
int32 ra1[TK] = rac*lda_c + rar*lda_h + ras*lda_w;
fp32* pa[TM, TK] = a + ra1[newaxis, :] + ra0[:, newaxis];)";
if(b_lut_){
res += R"(
int32 rbcr[TK] = rkb / BW;
int32 rbs[TK] = rkb % BW;
int32 rbc[TK] = rbcr / BH;
int32 rbr[TK] = rbcr % BH;
int32 rb1[TK] = rbc*ldb_c + rbr*ldb_r + ras*ldb_s;
)" + b_delta_mem + R"( int32* pdb[TK] = b_delta + rkb;
int32 db[TK] = *pdb;)";
}
else{
res += R"(
int32 rb1[TK] = rkb;)";
}
res += R"(
fp32* pb)" + BS + " = b + rb1" + bcb1 + ldb0 + " + rb0" + bcb0 + ldb1 + R"(;
)" + a_delta_mem + R"( int32* pincd[TK] = delta + rka;
)" + a_delta_mem + R"( int32* pd[TK] = delta + ldlut + rka;
int32 d[TK] = *pd;
int32 incd[TK] = *pincd;
int32 maskh[TM] = pad_h + min(rah, 0) + max(rah + BH - AH, 0);
int32 maskw[TM] = pad_w + min(raw, 0) + max(raw + BW - AW, 0);
)" + masks_mem + R"( int32* pm[TM] = masks + ldlut + maskw*ldlut + maskh*ldlut*(2*pad_w + 1);
)" + a_delta_mem + R"( int32* pincm[TM] = delta;
int32 incm[TM] = *pincm;
int32 checka0[TM] = *pm;
int32 checka1[TK] = 1 << rka;
int1 checka[TM, TK] = (checka0[:, newaxis] & checka1[newaxis, :]) > 0;
fp32 a[TM, TK] = checka ? *pa : 0;
fp32 b)" + BS + R"( = *pb;
for(int32 k = K; k > 0; k = k - TK){
C = dot(a, )" + useb + R"(, C);
pa = pa + d[newaxis, :];
pb = pb + )" + inc_pb + R"(;
b = *pb;
pd = pd + incd;)";
if(b_lut_){
res += R"(
pdb = pdb + TK;
db = *pdb;)";
}
res += R"(
pincd = pincd + incd;
d = *pd;
incd = *pincd;
pm = pm + incm;
pincm = pincm + incm;
incm = *pincm;
checka0 = *pm;
checka = (checka0[:, newaxis] & checka1[newaxis, :]) > 0;
checka = checka && (k > TK);
a = checka ? *pa : 0;
}
int32 rxc[TM] = get_global_range[TM](0);
int32 rc1[TN] = get_global_range[TN](1);
int32 rcn[TM] = rxc / (CH*CW);
int32 rcpq[TM] = rxc % (CH*CW);
int32 rc0[TM] = rcn * ldc_n + rcpq * ldc_q;
fp32* pc[TM, TN] = c + rc1[newaxis, :]*ldc_k + rc0[:, newaxis];)";
if(bias_ && ty_==FPROP){
res += R"(
fp32* pbias[TN] = bias + rc1;
fp32 bias[TN] = *pbias;
C = C + bias[newaxis, :];)";
}
res += R"(
int1 checkc0[TM] = rxc < M;
int1 checkc1[TN] = rc1 < N;
int1 checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];
@checkc *pc = C;
})";
return res;
}
// cpu check
template<class IN_DTYPE, class OUT_DTYPE>
void cpu_xprop(OUT_DTYPE* C, IN_DTYPE* A, IN_DTYPE* B);
template<class IN_DTYPE, class OUT_DTYPE>
void cpu_wgrad(OUT_DTYPE* C, IN_DTYPE* A, IN_DTYPE* B);
template<class IN_DTYPE, class OUT_DTYPE>
void cpu_ref(OUT_DTYPE* C, IN_DTYPE* A, IN_DTYPE* B);
private:
// image size
int32_t NB_;
int32_t NC_;
int32_t AD_;
int32_t AH_;
int32_t AW_;
// filter size
int32_t BD_;
int32_t BH_;
int32_t BW_;
int32_t NF_;
// activation size
int32_t CD_;
int32_t CH_;
int32_t CW_;
// striding
int32_t stride_d_;
int32_t stride_h_;
int32_t stride_w_;
// padding
int32_t pad_d_;
int32_t pad_h_;
int32_t pad_w_;
// upsampling
int32_t upsample_d_;
int32_t upsample_h_;
int32_t upsample_w_;
// equivalent matmul
int32_t M_;
int32_t N_;
int32_t K_;
// helpers
int32_t Fs_;
int32_t TK_;
int32_t Luts_;
// memory strides for A
std::vector<int32_t> shapes_a_;
std::vector<int32_t> ld_a_;
// memory strides for B
std::vector<int32_t> shapes_b_;
std::vector<int32_t> ld_b_;
// memory stride for C
std::vector<int32_t> shapes_c_;
std::vector<int32_t> ld_c_;
// constant memory
std::vector<int32_t> h_a_deltas_;
std::vector<int32_t> h_b_deltas_;
std::vector<int32_t> h_masks_;
driver::buffer* d_a_deltas_;
driver::buffer* d_b_deltas_;
driver::buffer* d_masks_;
bool is_a_deltas_cst;
bool is_b_deltas_cst_;
bool is_mask_cst_;
// type
type ty_;
bool bias_;
bool b_trans_;
bool b_lut_;
// axis index
int32_t a_inner_idx_;
int32_t a_outer_idx_;
int32_t a_pix_idx_;
int32_t b_inner_idx_;
int32_t b_outer_idx_;
int32_t b_pix_idx_;
int32_t c_outer_0_idx_;
int32_t c_outer_1_idx_;
int32_t c_pix_idx;
};
}
}