[dnn/conv] added triton-c code for wgrad

This commit is contained in:
Philippe Tillet
2019-05-11 18:09:23 -04:00
parent fc4daf11dd
commit f6fe9492e4
3 changed files with 299 additions and 173 deletions

View File

@@ -10,13 +10,13 @@ int main() {
// initialize default compute device // initialize default compute device
auto context = triton::driver::backend::contexts::get_default(); auto context = triton::driver::backend::contexts::get_default();
triton::jit jit(context); triton::jit jit(context);
triton::dnn::conv::type ty = triton::dnn::conv::BPROP; triton::dnn::conv::type ty = triton::dnn::conv::WGRAD;
// initialization // initialization
int32_t B = 4, NF = 32; int32_t B = 4, NF = 32;
int32_t D = 1, H = 24, W = 240; int32_t D = 1, H = 24, W = 240;
int32_t NC = 32, T = 1, R = 3, S = 3; int32_t NC = 32, T = 1, R = 3, S = 3;
int32_t pad_d = 0, pad_h = 1, pad_w = 1; int32_t pad_d = 0, pad_h = 1, pad_w = 1;
triton::dnn::conv configuration(B, NC, H, W, R, S, NF, 1, 1, pad_h, pad_w, ty); triton::dnn::conv configuration(B, NC, D, H, W, T, R, S, NF, 1, 1, 1, pad_d, pad_h, pad_w, ty);
// convolution configuration // convolution configuration
std::vector<float> hc(configuration.c_size()); std::vector<float> hc(configuration.c_size());
std::vector<float> rc(configuration.c_size()); std::vector<float> rc(configuration.c_size());
@@ -40,8 +40,10 @@ int main() {
stream->synchronize(); stream->synchronize();
// look-up table // look-up table
std::vector<int> h_delta, h_masks; std::vector<int> h_delta, h_masks;
configuration.build_deltas(h_delta); if(ty != triton::dnn::conv::WGRAD){
configuration.build_masks(h_masks); configuration.build_deltas(h_delta);
configuration.build_masks(h_masks);
}
// benchmark a given convolution kernel // benchmark a given convolution kernel
auto benchmark = [&](triton::driver::kernel* kernel, auto benchmark = [&](triton::driver::kernel* kernel,
triton::jit::launch_information info) { triton::jit::launch_information info) {
@@ -49,10 +51,12 @@ int main() {
unsigned TN = info.global_range_size[1]; unsigned TN = info.global_range_size[1];
unsigned nthreads = info.num_threads; unsigned nthreads = info.num_threads;
std::array<size_t, 3> grid = configuration.get_grid(TM, TN); std::array<size_t, 3> grid = configuration.get_grid(TM, TN);
triton::driver::buffer* delta = jit.get_buffer("delta"); if(ty != triton::dnn::conv::WGRAD){
triton::driver::buffer* masks = jit.get_buffer("masks"); triton::driver::buffer* delta = jit.get_buffer("delta");
stream->write(delta, false, 0, h_delta.size()*4, h_delta.data()); triton::driver::buffer* masks = jit.get_buffer("masks");
stream->write(masks, false, 0, h_masks.size()*4, h_masks.data()); stream->write(delta, false, 0, h_delta.size()*4, h_delta.data());
stream->write(masks, false, 0, h_masks.size()*4, h_masks.data());
}
stream->synchronize(); stream->synchronize();
configuration.set_arg(kernel, da, db, dc); configuration.set_arg(kernel, da, db, dc);
stream->enqueue(kernel, grid, {nthreads, 1, 1}); stream->enqueue(kernel, grid, {nthreads, 1, 1});
@@ -69,11 +73,11 @@ int main() {
std::cout << "Performance: " << benchmark(kernel, info) << " TFLOPS " << std::endl; std::cout << "Performance: " << benchmark(kernel, info) << " TFLOPS " << std::endl;
stream->read(dc, true, 0, hc); stream->read(dc, true, 0, hc);
configuration.cpu_ref(rc.data(), ha.data(), hb.data()); configuration.cpu_ref(rc.data(), ha.data(), hb.data());
// std::cout << c[0] << std::endl; for(size_t i = 0; i < hc.size(); i++){
for(size_t i = 0; i < hc.size(); i++)
if(std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-4){ if(std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-4){
std::cout << i << " " << hc[i] << " " << rc[i] << std::endl; std::cout << i << " " << hc[i] << " " << rc[i] << std::endl;
exit(EXIT_FAILURE); exit(EXIT_FAILURE);
} }
}
std::cout << "Pass!" << std::endl; std::cout << "Pass!" << std::endl;
} }

View File

@@ -1,5 +1,7 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include <algorithm>
#include <numeric>
#include "triton/driver/stream.h" #include "triton/driver/stream.h"
#include "triton/driver/kernel.h" #include "triton/driver/kernel.h"
@@ -15,74 +17,91 @@ public:
}; };
conv(int B, int NC, int H, int W, int R, int S, int NF, conv(int B, int NC,
int upsample_h, int upsample_w, int D, int H, int W,
int pad_h, int pad_w, int T, int R, int S, int NF,
int upsample_d, int upsample_h, int upsample_w,
int pad_d, int pad_h, int pad_w,
type ty = FPROP) type ty = FPROP)
: B_(B), NC_(NC), D_(1), H_(H), W_(W), T_(1), R_(R), S_(S), NF_(NF), : NB_(B), NC_(NC), AD_(D), AH_(H), AW_(W), BD_(T), BH_(R), BW_(S), NF_(NF),
upsample_d_(1), upsample_h_(upsample_h), upsample_w_(upsample_w), upsample_d_(upsample_d), upsample_h_(upsample_h), upsample_w_(upsample_w),
stride_d_(1), stride_h_(1), stride_w_(1), stride_d_(1), stride_h_(1), stride_w_(1),
pad_d_(0), pad_h_(pad_h), pad_w_(pad_w), pad_d_(pad_d), pad_h_(pad_h), pad_w_(pad_w),
ty_(ty) ty_(ty)
{ {
RD_ = (D_*upsample_d_ - T_ + 1 + 2*pad_d_ + stride_d_ - 1)/stride_d_; CD_ = (AD_*upsample_d_ - BD_ + 1 + 2*pad_d_ + stride_d_ - 1)/stride_d_;
RH_ = (H_*upsample_h_ - R_ + 1 + 2*pad_h_ + stride_h_ - 1)/stride_h_; CH_ = (AH_*upsample_h_ - BH_ + 1 + 2*pad_h_ + stride_h_ - 1)/stride_h_;
RW_ = (W_*upsample_w_ - S_ + 1 + 2*pad_w_ + stride_w_ - 1)/stride_w_; CW_ = (AW_*upsample_w_ - BW_ + 1 + 2*pad_w_ + stride_w_ - 1)/stride_w_;
// memory strides for data // shapes
stride_a_w_ = 1; shapes_a_ = {NB_, NC_, AD_, AH_, AW_};
stride_a_h_ = W_*stride_a_w_; shapes_b_ = {NC_, BD_, BH_, BW_, NF_};
stride_a_d_ = H_*stride_a_h_; shapes_c_ = {NB_, NF_, CD_, CH_, CW_};
stride_a_c_ = D_*stride_a_d_;
stride_a_n_ = NC_*stride_a_c_;
// memory stride for activations
stride_c_q_ = 1;
stride_c_p_ = RW_*stride_c_q_;
stride_c_m_ = RH_*stride_c_p_;
stride_c_k_ = RD_*stride_c_m_;
stride_c_n_ = NF_*stride_c_k_;
// swap a and c for bprop // swap a and c for bprop
if(ty_ == BPROP){ if(ty_ == BPROP){
std::swap(stride_a_n_, stride_c_n_); pad_d_ = (CD_ - AD_ + BD_ - 1) / 2;
std::swap(stride_a_c_, stride_c_k_); pad_h_ = (CH_ - AH_ + BH_ - 1) / 2;
std::swap(stride_a_h_, stride_c_p_); pad_w_ = (CW_ - AW_ + BW_ - 1) / 2;
std::swap(stride_a_w_, stride_c_q_); shapes_a_.swap(shapes_c_);
std::swap(D_, RD_);
std::swap(H_, RH_);
std::swap(W_, RW_);
std::swap(NF_, NC_);
pad_d_ = (RD_ - D_ + T_ - 1) / 2;
pad_h_ = (RH_ - H_ + R_ - 1) / 2;
pad_w_ = (RW_ - W_ + S_ - 1) / 2;
} }
// swap b and c for wgrad
if(ty_ == WGRAD){
shapes_b_.swap(shapes_c_);
}
// leading dimensions
auto set_ld = [](const std::vector<int32_t>& shapes,
std::vector<int32_t>& ld) {
size_t size = shapes.size();
ld.resize(size);
ld[4] = 1;
ld[3] = shapes[4]*ld[4];
ld[2] = shapes[3]*ld[3];
ld[1] = shapes[2]*ld[2];
ld[0] = shapes[1]*ld[1];
};
set_ld(shapes_a_, ld_a_);
set_ld(shapes_b_, ld_b_);
set_ld(shapes_c_, ld_c_);
// equivalent matmul // equivalent matmul
M_ = B_*RD_*RH_*RW_; if(ty_ == WGRAD){
N_ = NF_; M_ = shapes_c_[0]*shapes_c_[1]*shapes_c_[2]*shapes_c_[3];
K_ = NC_*T_*R_*S_; N_ = shapes_c_[4];
K_ = shapes_b_[0]*shapes_b_[2]*shapes_b_[3]*shapes_b_[4];
}
else{
M_ = shapes_c_[0]*shapes_c_[2]*shapes_c_[3]*shapes_c_[4];
N_ = shapes_c_[1];
K_ = shapes_b_[0]*shapes_b_[1]*shapes_b_[2]*shapes_b_[3];
}
// look-up table info // look-up table info
Fs_ = T_*R_*S_; Fs_ = shapes_b_[1]*shapes_b_[2]*shapes_b_[3];
TK_ = 8; TK_ = 8;
Luts_ = (TK_ + Fs_ - 1) / Fs_ * Fs_; Luts_ = (TK_ + Fs_ - 1) / Fs_ * Fs_;
} }
size_t a_size() { size_t a_size() {
return B_*NC_*D_*H_*W_; return std::accumulate(shapes_a_.begin(), shapes_a_.end(),
1, std::multiplies<int>());
} }
size_t b_size() { size_t b_size() {
return NC_*NF_*T_*R_*S_; return std::accumulate(shapes_b_.begin(), shapes_b_.end(),
1, std::multiplies<int>());
} }
size_t c_size() { size_t c_size() {
return B_*NF_*RD_*RH_*RW_; return std::accumulate(shapes_c_.begin(), shapes_c_.end(),
1, std::multiplies<int>());
} }
void build_deltas(std::vector<int>& deltas){ void build_deltas(std::vector<int>& deltas){
if(ty_ == WGRAD)
throw std::runtime_error("no look-up table necessary for wgrad");
deltas.resize(Luts_ + upsample_d_*upsample_h_*upsample_w_*Luts_); deltas.resize(Luts_ + upsample_d_*upsample_h_*upsample_w_*Luts_);
auto unpack = [&](int32_t trs){ auto unpack = [&](int32_t trs){
int32_t tr = trs / S_; int32_t tr = trs / BW_;
int32_t s = trs - tr*S_; int32_t s = trs - tr*BW_;
int32_t t = tr / R_; int32_t t = tr / BH_;
int32_t r = tr - t*R_; int32_t r = tr - t*BH_;
return std::make_tuple(t, r, s); return std::make_tuple(t, r, s);
}; };
for(size_t i = 0; i < Luts_; ++i) for(size_t i = 0; i < Luts_; ++i)
@@ -112,18 +131,20 @@ public:
int32_t rdiff = (nextr + ph)/upsample_h_ - (r + ph)/upsample_h_; int32_t rdiff = (nextr + ph)/upsample_h_ - (r + ph)/upsample_h_;
int32_t sdiff = (nexts + pw)/upsample_w_ - (s + pw)/upsample_w_; int32_t sdiff = (nexts + pw)/upsample_w_ - (s + pw)/upsample_w_;
// delta pointers // delta pointers
deltas_ptr[i] = cdiff*stride_a_c_ + sdiff*stride_a_w_ + rdiff*stride_a_h_ + tdiff*stride_a_d_; deltas_ptr[i] = cdiff*ld_a_[1] + tdiff*ld_a_[2] + rdiff*ld_a_[3] + sdiff*ld_a_[4];
} }
} }
} }
void build_masks(std::vector<int>& masks){ void build_masks(std::vector<int>& masks){
if(ty_ == WGRAD)
throw std::runtime_error("no look-up table necessary for wgrad");
masks.resize(Luts_ + (2*pad_h_+1)*(2*pad_w_+1)*(2*pad_d_+1)*Luts_); masks.resize(Luts_ + (2*pad_h_+1)*(2*pad_w_+1)*(2*pad_d_+1)*Luts_);
auto unpack = [&](int32_t trs){ auto unpack = [&](int32_t trs){
int32_t tr = trs / S_; int32_t tr = trs / BW_;
int32_t s = trs - tr*S_; int32_t s = trs - tr*BW_;
int32_t t = tr / R_; int32_t t = tr / BH_;
int32_t r = tr - t*R_; int32_t r = tr - t*BH_;
return std::make_tuple(t, r, s); return std::make_tuple(t, r, s);
}; };
size_t Ms0 = Luts_; size_t Ms0 = Luts_;
@@ -139,9 +160,9 @@ public:
int32_t mask = 0x0; int32_t mask = 0x0;
for(size_t j = 0; j < TK_; ++j){ for(size_t j = 0; j < TK_; ++j){
std::tie(t, r, s) = unpack((i + j) % Fs_); std::tie(t, r, s) = unpack((i + j) % Fs_);
bool in_bounds_d = (t + pd) >= pad_d_ && (t + pd) < (T_ + pad_d_); bool in_bounds_d = (t + pd) >= pad_d_ && (t + pd) < (BD_ + pad_d_);
bool in_bounds_h = (r + ph) >= pad_h_ && (r + ph) < (R_ + pad_h_); bool in_bounds_h = (r + ph) >= pad_h_ && (r + ph) < (BH_ + pad_h_);
bool in_bounds_w = (s + pw) >= pad_w_ && (s + pw) < (S_ + pad_w_); bool in_bounds_w = (s + pw) >= pad_w_ && (s + pw) < (BW_ + pad_w_);
mask |= (in_bounds_d && in_bounds_h && in_bounds_w) << j; mask |= (in_bounds_d && in_bounds_h && in_bounds_w) << j;
} }
masks_ptr[i] = mask; masks_ptr[i] = mask;
@@ -168,46 +189,40 @@ public:
kernel->setArg(3, M_); kernel->setArg(3, M_);
kernel->setArg(4, N_); kernel->setArg(4, N_);
kernel->setArg(5, K_); kernel->setArg(5, K_);
kernel->setArg(6, B_); kernel->setArg(6, AH_);
kernel->setArg(7, H_); kernel->setArg(7, AW_);
kernel->setArg(8, W_); kernel->setArg(8, BH_);
kernel->setArg(9, NF_); kernel->setArg(9, BW_);
kernel->setArg(10, RH_); kernel->setArg(10, CH_);
kernel->setArg(11, RW_); kernel->setArg(11, CW_);
kernel->setArg(12, NC_); kernel->setArg(12, ld_a_[0]);
kernel->setArg(13, R_); kernel->setArg(13, ld_a_[1]);
kernel->setArg(14, S_); kernel->setArg(14, ld_a_[2]);
kernel->setArg(15, stride_a_n_); kernel->setArg(15, ld_a_[3]);
kernel->setArg(16, stride_a_c_); kernel->setArg(16, ld_a_[4]);
kernel->setArg(17, stride_a_h_); kernel->setArg(17, ld_b_[0]);
kernel->setArg(18, stride_a_w_); kernel->setArg(18, ld_b_[1]);
kernel->setArg(19, stride_c_n_); kernel->setArg(19, ld_b_[2]);
kernel->setArg(20, stride_c_k_); kernel->setArg(20, ld_b_[3]);
kernel->setArg(21, stride_c_p_); kernel->setArg(21, ld_b_[4]);
kernel->setArg(22, stride_c_q_); kernel->setArg(22, ld_c_[0]);
kernel->setArg(23, pad_h_); kernel->setArg(23, ld_c_[1]);
kernel->setArg(24, pad_w_); kernel->setArg(24, ld_c_[2]);
kernel->setArg(25, ld_c_[3]);
kernel->setArg(26, ld_c_[4]);
kernel->setArg(27, pad_h_);
kernel->setArg(28, pad_w_);
} }
std::vector<unsigned> default_params() { std::vector<unsigned> default_params() {
// if(ty_ == FPROP) if(ty_ == FPROP || ty_ == BPROP)
return {16, 2, 64, 32, 2, 64, 16, 8, 2, 2, 8, 1, 8, 4}; return {16, 2, 64, 32, 2, 64, 16, 8, 2, 2, 8, 1, 8, 4};
// else else
// return {16, 2, 64, 16, 32, 16, 4, 2, 2, 4, 2, 8, 4, 2}; return {8, 2, 16, 8, 2, 16, 8, 2, 8, 8};
} }
std::string src() { std::string xprop() {
std::string bs0 = "TN", bs1 = "TK";
std::string ldb0 = "*NF", ldb1 = "";
std::string bcb0 = "[:, newaxis]", bcb1 = "[newaxis, :]";
std::string b = "b";
if(ty_ == BPROP){
std::swap(bs0, bs1);
std::swap(ldb0, ldb1);
std::swap(bcb0, bcb1);
b = "trans(b)";
}
std::string res = std::string res =
R"( R"(
const tunable int32 TM = {16, 32, 64}; const tunable int32 TM = {16, 32, 64};
@@ -221,36 +236,37 @@ public:
read_only restrict fp32 *b, read_only restrict fp32 *b,
fp32 *c, fp32 *c,
int32 M, int32 N, int32 K, int32 M, int32 N, int32 K,
int32 B, int32 H, int32 W, int32 AH, int32 AW,
int32 NF, int32 RH, int32 RW, int32 BH, int32 BW,
int32 NC, int32 R, int32 S, int32 CH, int32 CW,
int32 lda_n, int32 lda_c, int32 lda_h, int32 lda_w, int32 lda_n, int32 lda_c, int32 lda_d, int32 lda_h, int32 lda_w,
int32 ldc_n, int32 ldc_k, int32 ldc_p, int32 ldc_q, int32 ldb_c, int32 ldb_t, int32 ldb_r, int32 ldb_s, int32 ldb_k,
int32 ldc_n, int32 ldc_k, int32 ldc_m, int32 ldc_p, int32 ldc_q,
int32 pad_h, int32 pad_w){ int32 pad_h, int32 pad_w){
int32 rxa[TM] = get_global_range[TM](0); int32 rxa[TM] = get_global_range[TM](0);
int32 rb0[TN] = get_global_range[TN](1); int32 rb0[TN] = get_global_range[TN](1);
int32 rka[TK] = 0 ... TK; int32 rka[TK] = 0 ... TK;
int32 rb1[TK] = 0 ... TK; int32 rb1[TK] = 0 ... TK;
fp32 C[TM, TN] = 0; fp32 C[TM, TN] = 0;
int32 rabh[TM] = rxa / RW; int32 rabh[TM] = rxa / CW;
int32 raw[TM] = rxa % RW - pad_w; int32 raw[TM] = rxa % CW - pad_w;
int32 rab[TM] = rabh / RH; int32 rab[TM] = rabh / CH;
int32 rah[TM] = rabh % RH - pad_h; int32 rah[TM] = rabh % CH - pad_h;
int32 ra0[TM] = rab*lda_n + rah*lda_h + raw*lda_w; int32 ra0[TM] = rab*lda_n + rah*lda_h + raw*lda_w;
int32 racr[TK] = rka / S; int32 racr[TK] = rka / BW;
int32 ras[TK] = rka % S; int32 ras[TK] = rka % BW;
int32 rac[TK] = racr / R; int32 rac[TK] = racr / BH;
int32 rar[TK] = racr % R; int32 rar[TK] = racr % BH;
int32 ra1[TK] = rac*lda_c + rar*lda_h + ras*lda_w; int32 ra1[TK] = rac*lda_c + rar*lda_h + ras*lda_w;
fp32* pa[TM, TK] = a + ra1[newaxis, :] + ra0[:, newaxis]; fp32* pa[TM, TK] = a + ra1[newaxis, :] + ra0[:, newaxis];
fp32* pb[TN, TK] = b + rb1[newaxis, :]*NF + rb0[:, newaxis]; fp32* pb[TN, TK] = b + rb1[newaxis, :]*ldb_s + rb0[:, newaxis];
__constant__ int32* pincd[TK] = delta + rka; __constant__ int32* pincd[TK] = delta + rka;
__constant__ int32* pd[TK] = delta + R*S + rka; __constant__ int32* pd[TK] = delta + BH*BW + rka;
int32 d[TK] = *pd; int32 d[TK] = *pd;
int32 incd[TK] = *pincd; int32 incd[TK] = *pincd;
int32 maskh[TM] = pad_h + min(rah, 0) + max(rah + R - H, 0); int32 maskh[TM] = pad_h + min(rah, 0) + max(rah + BH - AH, 0);
int32 maskw[TM] = pad_w + min(raw, 0) + max(raw + S - W, 0); int32 maskw[TM] = pad_w + min(raw, 0) + max(raw + BW - AW, 0);
__constant__ int32* pm[TM] = masks + R*S + maskw*R*S + maskh*R*S*(2*pad_w + 1); __constant__ int32* pm[TM] = masks + BH*BW + maskw*BH*BW + maskh*BH*BW*(2*pad_w + 1);
__constant__ int32* pincm[TM] = delta; __constant__ int32* pincm[TM] = delta;
int32 incm[TM] = *pincm; int32 incm[TM] = *pincm;
int32 checka0[TM] = *pm; int32 checka0[TM] = *pm;
@@ -260,7 +276,7 @@ public:
fp32 b[TN, TK] = *pb; fp32 b[TN, TK] = *pb;
for(int32 k = K; k > 0; k = k - TK){ for(int32 k = K; k > 0; k = k - TK){
C = dot(a, trans(b), C); C = dot(a, trans(b), C);
pb = pb + TK*NF; pb = pb + TK*ldb_s;
pa = pa + d[newaxis, :]; pa = pa + d[newaxis, :];
b = *pb; b = *pb;
pd = pd + incd; pd = pd + incd;
@@ -276,8 +292,8 @@ public:
} }
int32 rxc[TM] = get_global_range[TM](0); int32 rxc[TM] = get_global_range[TM](0);
int32 rc1[TN] = get_global_range[TN](1); int32 rc1[TN] = get_global_range[TN](1);
int32 rcn[TM] = rxc / (RH*RW); int32 rcn[TM] = rxc / (CH*CW);
int32 rcpq[TM] = rxc % (RH*RW); int32 rcpq[TM] = rxc % (CH*CW);
int32 rc0[TM] = rcn * ldc_n + rcpq; int32 rc0[TM] = rcn * ldc_n + rcpq;
fp32* pc[TM, TN] = c + rc1[newaxis, :]*ldc_k + rc0[:, newaxis]; fp32* pc[TM, TN] = c + rc1[newaxis, :]*ldc_k + rc0[:, newaxis];
int1 checkc0[TM] = rxc < M; int1 checkc0[TM] = rxc < M;
@@ -288,62 +304,169 @@ public:
return res; return res;
} }
// C = A * B
// where A is N,C,AH,AW
// B is N,K,BH,BW
// C is C,CH,CW,K
std::string wgrad() {
std::string res =
R"(
const tunable int32 TM = {16, 32, 64};
const tunable int32 TN = {16, 32, 64};
const tunable int32 TK = {8};
void conv(read_only restrict fp32 *a,
read_only restrict fp32 *b,
fp32 *c,
int32 M, int32 N, int32 K,
int32 AH, int32 AW,
int32 CH, int32 CW,
int32 BH, int32 BW,
int32 lda_n, int32 lda_c, int32 lda_d, int32 lda_h, int32 lda_w,
int32 ldb_n, int32 ldb_k, int32 ldb_m, int32 ldb_p, int32 ldb_q,
int32 ldc_c, int32 ldc_t, int32 ldc_r, int32 ldc_s, int32 ldc_k,
int32 pad_h, int32 pad_w){
int32 rxa[TM] = get_global_range[TM](0);
int32 ryb[TN] = get_global_range[TN](1);
int32 rk[TK] = 0 ... TK;
fp32 C[TM, TN] = 0;
int32 racr[TM] = rxa / CW;
int32 raw_base[TM] = rxa % CW - pad_w;
int32 rac[TM] = racr / CH;
int32 rah_base[TM] = racr % CH - pad_h;
fp32* pa_base[TM, TK] = a + rac[:, newaxis]*lda_c;
fp32* pb_base[TN, TK] = b + ryb[:, newaxis]*ldb_k;
for(int32 k = K; k > 0; k = k - TK){
int32 rknp[TK] = rk / BW;
int32 rkq[TK] = rk % BW;
int32 rkn[TK] = rknp / BH;
int32 rkp[TK] = rknp % BH;
int32 rah[TM, TK] = rah_base[:, newaxis] + rkp[newaxis, :];
int32 raw[TM, TK] = raw_base[:, newaxis] + rkq[newaxis, :];
int1 checka[TM, TK] = (rah >= 0) && (rah < AH) && (raw >= 0) && (raw < AW);
fp32* pa[TM, TK] = pa_base + rah*lda_h + raw*lda_w + rkn*lda_n;
fp32* pb[TN, TK] = pb_base + rkp*ldb_p + rkq*ldb_q + rkn*ldb_n;
fp32 A[TM, TK] = checka ? *pa : 0;
fp32 B[TN, TK] = *pb;
C = dot(A, trans(B), C);
rk = rk + TK;
}
int32 rxc[TM] = get_global_range[TM](0);
int32 ryc[TN] = get_global_range[TN](1);
int32 rccr[TM] = rxc / CW;
int32 rcs[TM] = rxa % CW;
int32 rcc[TM] = racr / CH;
int32 rcr[TM] = racr % CH;
int32 rc0[TM] = rcc*ldc_c + rcr*ldc_r + rcs*ldc_s;
fp32* pc[TM, TN] = c + rc0[:, newaxis] + ryc[newaxis, :]*ldc_k;
int1 checkc0[TM] = rxc < M;
int1 checkc1[TN] = ryc < N;
int1 checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];
@checkc *pc = C;
})";
return res;
}
std::string src() {
if(ty_ == FPROP || ty_ == BPROP)
return xprop();
else
return wgrad();
}
template<class IN_DTYPE, class OUT_DTYPE>
void cpu_xprop(OUT_DTYPE* C, IN_DTYPE* A, IN_DTYPE* B)
{
IN_DTYPE acc;
for(int32_t n = 0; n < shapes_c_[0]; ++n)
for(int32_t k = 0; k < shapes_c_[1] ; ++k)
for(int32_t cd = 0 ; cd < shapes_c_[2]; ++cd)
for(int32_t ch = 0 ; ch < shapes_c_[3]; ++ch)
for(int32_t cw = 0; cw < shapes_c_[4]; ++cw)
{
acc = 0;
int32_t d = cd*stride_d_ - pad_d_;
int32_t h = ch*stride_h_ - pad_h_;
int32_t w = cw*stride_w_ - pad_w_;
for(int32_t c = 0; c < shapes_b_[0]; ++c)
for(int32_t bd = 0; bd < shapes_b_[1]; ++bd)
for(int32_t bh = 0; bh < shapes_b_[2]; ++bh)
for(int32_t bw = 0; bw < shapes_b_[3]; ++bw){
int32_t ad = d + bd;
int32_t ah = h + bh;
int32_t aw = w + bw;
bool in_bounds = (ad >= 0 && ad < shapes_a_[2] &&
ah >= 0 && ah < shapes_a_[3] &&
aw >= 0 && aw < shapes_a_[4]);
IN_DTYPE a = 0;
if(in_bounds)
a = A[n*ld_a_[0] + c*ld_a_[1] + ad*ld_a_[2] + ah*ld_a_[3] + aw*ld_a_[4]];
IN_DTYPE b = B[c*ld_b_[0] + bd*ld_b_[1] + bh*ld_b_[2] + bw*ld_b_[3] + k*ld_b_[4]];
acc = std::fma(a, b, acc);
}
C[n*ld_c_[0] + k*ld_c_[1] + cd*ld_c_[2] + ch*ld_c_[3] + cw*ld_c_[4]] = acc;
}
}
template<class IN_DTYPE, class OUT_DTYPE>
void cpu_wgrad(OUT_DTYPE* C, IN_DTYPE* A, IN_DTYPE* B)
{
IN_DTYPE acc;
for(int32_t c = 0 ; c < shapes_c_[0]; ++c)
for(int32_t cd = 0; cd < shapes_c_[1]; ++cd)
for(int32_t ch = 0; ch < shapes_c_[2]; ++ch)
for(int32_t cw = 0; cw < shapes_c_[3]; ++cw)
for(int32_t k = 0 ; k < shapes_c_[4]; ++k)
{
acc = 0;
int32_t d = cd*stride_d_ - pad_d_;
int32_t h = ch*stride_h_ - pad_h_;
int32_t w = cw*stride_w_ - pad_w_;
for(int32_t n = 0; n < shapes_b_[0]; ++n)
for(int32_t bd = 0; bd < shapes_b_[2]; ++bd)
for(int32_t bh = 0; bh < shapes_b_[3]; ++bh)
for(int32_t bw = 0; bw < shapes_b_[4]; ++bw){
int32_t ad = d + bd;
int32_t ah = h + bh;
int32_t aw = w + bw;
bool in_bounds = (ad >= 0 && ad < shapes_a_[2] &&
ah >= 0 && ah < shapes_a_[3] &&
aw >= 0 && aw < shapes_a_[4]);
IN_DTYPE a = 0;
if(in_bounds)
a = A[n*ld_a_[0] + c*ld_a_[1] + ad*ld_a_[2] + ah*ld_a_[3] + aw*ld_a_[4]];
IN_DTYPE b = B[n*ld_b_[0] + k*ld_b_[1] + bd*ld_b_[2] + bh*ld_b_[3] + bw*ld_b_[4]];
acc = std::fma(a, b, acc);
}
C[c*ld_c_[0] + cd*ld_c_[1] + ch*ld_c_[2] + cw*ld_c_[3] + k*ld_c_[4]] = acc;
}
}
template<class IN_DTYPE, class OUT_DTYPE> template<class IN_DTYPE, class OUT_DTYPE>
void cpu_ref(OUT_DTYPE* C, IN_DTYPE* A, IN_DTYPE* B) void cpu_ref(OUT_DTYPE* C, IN_DTYPE* A, IN_DTYPE* B)
{ {
auto idx = [&](int32_t x, int32_t y, int32_t z, int32_t w, int32_t u, if(ty_ == FPROP || ty_ == BPROP)
int32_t /*s0*/, int32_t s1, int32_t s2, int32_t s3, int32_t s4) cpu_xprop(C, A, B);
{ return u + w*s4 + z*s4*s3 + y*s4*s3*s2 + x*s4*s3*s2*s1; }; else
cpu_wgrad(C, A, B);
IN_DTYPE accs[1];
float tmp[1];
for(int32_t m = 0 ; m < RD_; ++m)
for(int32_t p = 0 ; p < RH_; ++p)
for(int32_t q = 0; q < RW_; ++q)
for(int32_t n = 0; n < B_; ++n)
for(int32_t k = 0; k < NF_ ; ++k)
{
for(int32_t i = 0; i < 1; ++i)
accs[i] = 0;
int32_t mm = m*stride_d_ - pad_d_;
int32_t pp = p*stride_h_ - pad_h_;
int32_t qq = q*stride_w_ - pad_w_;
for(int32_t kk = 0; kk < 1; ++kk)
for(int32_t c = 0; c < NC_; ++c)
for(int32_t t = 0; t < T_; ++t)
for(int32_t r = 0; r < R_; ++r)
for(int32_t s = 0; s < S_; ++s){
int32_t d = mm + t;
int32_t h = pp + r;
int32_t w = qq + s;
bool in_bounds = (d >= 0 && h >= 0 && w >= 0 && d < D_ && h < H_ && w < W_);
IN_DTYPE a = in_bounds?A[idx(n, c, d, h, w, B_, NC_, D_, H_, W_)]:0;
IN_DTYPE b = B[idx(c, t, r, s, k*1 + kk, NC_, T_, R_, S_, NF_*1)];
accs[kk] = std::fma(a, b, accs[kk]);
}
for(int32_t kk = 0; kk < 1; ++kk){
tmp[kk] = accs[kk];
}
C[idx(n, k, m, p, q, B_, NF_, RD_, RH_, RW_)] = tmp[0];
}
} }
private: private:
// image size // image size
int32_t B_; int32_t NB_;
int32_t NC_; int32_t NC_;
int32_t D_; int32_t AD_;
int32_t H_; int32_t AH_;
int32_t W_; int32_t AW_;
// filter size // filter size
int32_t T_; int32_t BD_;
int32_t R_; int32_t BH_;
int32_t S_; int32_t BW_;
int32_t NF_; int32_t NF_;
// activation size // activation size
int32_t RD_; int32_t CD_;
int32_t RH_; int32_t CH_;
int32_t RW_; int32_t CW_;
// upsampling // upsampling
int32_t upsample_d_; int32_t upsample_d_;
int32_t upsample_h_; int32_t upsample_h_;
@@ -364,18 +487,15 @@ private:
int32_t Fs_; int32_t Fs_;
int32_t TK_; int32_t TK_;
int32_t Luts_; int32_t Luts_;
// memory strides for data // memory strides for A
int32_t stride_a_w_; std::vector<int32_t> shapes_a_;
int32_t stride_a_h_; std::vector<int32_t> ld_a_;
int32_t stride_a_d_; // memory strides for B
int32_t stride_a_c_; std::vector<int32_t> shapes_b_;
int32_t stride_a_n_; std::vector<int32_t> ld_b_;
// memory stride for activations // memory stride for C
int32_t stride_c_q_; std::vector<int32_t> shapes_c_;
int32_t stride_c_p_; std::vector<int32_t> ld_c_;
int32_t stride_c_m_;
int32_t stride_c_k_;
int32_t stride_c_n_;
// type // type
type ty_; type ty_;
bool is_bprop_; bool is_bprop_;

View File

@@ -812,7 +812,9 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
std::swap(b_idx[0], b_idx[1]); std::swap(b_idx[0], b_idx[1]);
Value *a = TA->get_value(a_idx); Value *a = TA->get_value(a_idx);
Value *b = TB->get_value(b_idx); Value *b = TB->get_value(b_idx);
// res = builder.CreateCall(f_mul_add, {ConstantFP::get(a->getType(), 1), ConstantFP::get(b->getType(), 1), res});
res = builder.CreateCall(f_mul_add, {a, b, res}); res = builder.CreateCall(f_mul_add, {a, b, res});
} }
result->set_value(idx, res); result->set_value(idx, res);
}); });