[dnn/conv] some minor fixes
This commit is contained in:
@@ -9,22 +9,19 @@
|
||||
int main() {
|
||||
// initialize default compute device
|
||||
auto context = triton::driver::backend::contexts::get_default();
|
||||
// initialize just-in-time compiler
|
||||
triton::jit jit(context);
|
||||
triton::dnn::conv::type ty = triton::dnn::conv::BPROP;
|
||||
// initialization
|
||||
int32_t B = 4, NF = 32;
|
||||
int32_t D = 1, H = 24, W = 240;
|
||||
int32_t NC = 64, T = 1, R = 3, S = 3;
|
||||
int32_t NC = 32, T = 1, R = 3, S = 3;
|
||||
int32_t pad_d = 0, pad_h = 1, pad_w = 1;
|
||||
int32_t stride_d = 1, stride_h = 1, stride_w = 1;
|
||||
int32_t upsample_d = 1, upsample_h = 1, upsample_w = 1;
|
||||
int32_t RD = (D*upsample_d - T + 1 + 2*pad_d + stride_d - 1)/stride_d;
|
||||
int32_t RH = (H*upsample_h - R + 1 + 2*pad_h + stride_h - 1)/stride_h;
|
||||
int32_t RW = (W*upsample_w - S + 1 + 2*pad_w + stride_w - 1)/stride_w;
|
||||
// equivalent matmul dimensions
|
||||
int32_t M = B*RD*RH*RW;
|
||||
int32_t N = NF;
|
||||
int32_t K = NC*T*R*S;
|
||||
triton::dnn::conv configuration(B, NC, H, W, R, S, NF, 1, 1, pad_h, pad_w, ty);
|
||||
// convolution configuration
|
||||
std::vector<float> hc(B*RH*RW*NF);
|
||||
std::vector<float> rc(B*RH*RW*NF);
|
||||
@@ -36,7 +33,8 @@ int main() {
|
||||
for(size_t i = 0; i < hb.size(); i++)
|
||||
hb[i] = (float)rand()/RAND_MAX;
|
||||
for(size_t i = 0; i < hc.size(); i++)
|
||||
hc[i] = 0;
|
||||
hc[i] = (float)rand()/RAND_MAX;
|
||||
rc = hc;
|
||||
triton::driver::buffer* dc = triton::driver::buffer::create(context, hc.size()*4);
|
||||
triton::driver::buffer* da = triton::driver::buffer::create(context, ha.size()*4);
|
||||
triton::driver::buffer* db = triton::driver::buffer::create(context, hb.size()*4);
|
||||
@@ -45,80 +43,38 @@ int main() {
|
||||
stream->write(db, true, 0, hb);
|
||||
stream->write(dc, true, 0, hc);
|
||||
stream->synchronize();
|
||||
// memory strides for data
|
||||
int32_t stride_i_w = 1;
|
||||
int32_t stride_i_h = W*stride_i_w;
|
||||
int32_t stride_i_d = H*stride_i_h;
|
||||
int32_t stride_i_c = D*stride_i_d;
|
||||
int32_t stride_i_n = NC*stride_i_c;
|
||||
// memory stride for activations
|
||||
int32_t stride_o_q = 1;
|
||||
int32_t stride_o_p = RW*stride_o_q;
|
||||
int32_t stride_o_m = RH*stride_o_p;
|
||||
int32_t stride_o_k = RD*stride_o_m;
|
||||
int32_t stride_o_n = NF*stride_o_k;
|
||||
// look-up table
|
||||
triton::dnn::conv configuration(B, NC, H, W, R, S, NF, 1, 1, 0, 0);
|
||||
std::vector<int> h_delta, h_masks;
|
||||
configuration.build_lut(h_delta, h_masks);
|
||||
configuration.build_deltas(h_delta);
|
||||
configuration.build_masks(h_masks);
|
||||
// benchmark a given convolution kernel
|
||||
auto benchmark = [&](triton::driver::kernel* kernel,
|
||||
triton::jit::launch_information info) {
|
||||
// launch info
|
||||
unsigned TM = info.global_range_size[0];
|
||||
unsigned TN = info.global_range_size[1];
|
||||
// initialize constant memory
|
||||
unsigned nthreads = info.num_threads;
|
||||
std::array<size_t, 3> grid = configuration.get_grid(TM, TN);
|
||||
triton::driver::buffer* delta = jit.get_buffer("delta");
|
||||
triton::driver::buffer* masks = jit.get_buffer("masks");
|
||||
stream->write(delta, false, 0, h_delta.size()*4, h_delta.data());
|
||||
stream->write(masks, false, 0, h_masks.size()*4, h_masks.data());
|
||||
stream->synchronize();
|
||||
// launch info
|
||||
unsigned nthreads = info.num_threads;
|
||||
std::array<size_t, 3> grid = {(M + TM - 1)/TM, (N + TN - 1)/TN, 1};
|
||||
// set arguments
|
||||
kernel->setArg(0, da);
|
||||
kernel->setArg(1, db);
|
||||
kernel->setArg(2, dc);
|
||||
kernel->setArg(3, M);
|
||||
kernel->setArg(4, N);
|
||||
kernel->setArg(5, K);
|
||||
kernel->setArg(6, B);
|
||||
kernel->setArg(7, H);
|
||||
kernel->setArg(8, W);
|
||||
kernel->setArg(9, NF);
|
||||
kernel->setArg(10, RH);
|
||||
kernel->setArg(11, RW);
|
||||
kernel->setArg(12, NC);
|
||||
kernel->setArg(13, R);
|
||||
kernel->setArg(14, S);
|
||||
kernel->setArg(15, stride_i_n);
|
||||
kernel->setArg(16, stride_i_c);
|
||||
kernel->setArg(17, stride_i_h);
|
||||
kernel->setArg(18, stride_i_w);
|
||||
kernel->setArg(19, stride_o_n);
|
||||
kernel->setArg(20, stride_o_k);
|
||||
kernel->setArg(21, stride_o_p);
|
||||
kernel->setArg(22, stride_o_q);
|
||||
kernel->setArg(23, pad_h);
|
||||
kernel->setArg(24, pad_w);
|
||||
// dry run
|
||||
configuration.set_arg(kernel, da, db, dc);
|
||||
stream->enqueue(kernel, grid, {nthreads, 1, 1});
|
||||
stream->synchronize();
|
||||
// benchmark
|
||||
double ts = bench([&](){stream->enqueue(kernel, grid, {nthreads, 1, 1});},
|
||||
[&](){ stream->synchronize(); }, *context->device());
|
||||
return 2.*M*N*K / ts * 1e-3;
|
||||
return configuration.get_nflops() / ts * 1e-3;
|
||||
};
|
||||
std::string src = triton::dnn::conv::src();
|
||||
std::string src = configuration.src();
|
||||
// jit.autotune("conv", src.c_str(), benchmark);
|
||||
jit.add_module("conv", src.c_str(), triton::dnn::conv::default_params());
|
||||
jit.add_module("conv", src.c_str(), configuration.default_params());
|
||||
triton::driver::kernel* kernel = jit.get_function("conv");
|
||||
triton::jit::launch_information info = jit.get_launch_info("conv");
|
||||
std::cout << "Performance: " << benchmark(kernel, info) << " TFLOPS " << std::endl;
|
||||
stream->read(dc, true, 0, hc);
|
||||
cpp_conv_nchw(NC, B, NF, D, H, W, T, R, S, pad_d, pad_h, pad_w, stride_d, stride_h, stride_w, RD, RH, RW, rc, ha, hb);
|
||||
for(size_t i = 0; i < M*N; i++)
|
||||
configuration.cpu_ref(rc.data(), ha.data(), hb.data());
|
||||
for(size_t i = 0; i < hc.size(); i++)
|
||||
if(std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-4){
|
||||
std::cout << i << " " << hc[i] << " " << rc[i] << std::endl;
|
||||
exit(EXIT_FAILURE);
|
||||
|
@@ -1,5 +1,7 @@
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "triton/driver/stream.h"
|
||||
#include "triton/driver/kernel.h"
|
||||
|
||||
namespace triton{
|
||||
namespace dnn{
|
||||
@@ -15,10 +17,13 @@ public:
|
||||
|
||||
conv(int B, int NC, int H, int W, int R, int S, int NF,
|
||||
int upsample_h, int upsample_w,
|
||||
int pad_h, int pad_w)
|
||||
int pad_h, int pad_w,
|
||||
type ty = FPROP)
|
||||
: B_(B), NC_(NC), D_(1), H_(H), W_(W), T_(1), R_(R), S_(S), NF_(NF),
|
||||
upsample_d_(1), upsample_h_(upsample_h), upsample_w_(upsample_w),
|
||||
pad_d_(0), pad_h_(pad_h), pad_w_(pad_w)
|
||||
stride_d_(1), stride_h_(1), stride_w_(1),
|
||||
pad_d_(0), pad_h_(pad_h), pad_w_(pad_w),
|
||||
ty_(ty)
|
||||
{
|
||||
RD_ = (D_*upsample_d_ - T_ + 1 + 2*pad_d_ + stride_d_ - 1)/stride_d_;
|
||||
RH_ = (H_*upsample_h_ - R_ + 1 + 2*pad_h_ + stride_h_ - 1)/stride_h_;
|
||||
@@ -26,9 +31,6 @@ public:
|
||||
M_ = B*RD_*RH_*RW_;
|
||||
N_ = NF;
|
||||
K_ = NC*T_*R_*S_;
|
||||
Fs_ = T_*R_*S_;
|
||||
TK_ = 8;
|
||||
Luts_ = (TK_ + Fs_ - 1) / Fs_ * Fs_;
|
||||
// memory strides for data
|
||||
stride_a_w_ = 1;
|
||||
stride_a_h_ = W_*stride_a_w_;
|
||||
@@ -41,88 +43,160 @@ public:
|
||||
stride_c_m_ = RH_*stride_c_p_;
|
||||
stride_c_k_ = RD_*stride_c_m_;
|
||||
stride_c_n_ = NF_*stride_c_k_;
|
||||
// swap a and c for bprop
|
||||
if(ty_ == BPROP){
|
||||
std::swap(stride_a_n_, stride_c_n_);
|
||||
std::swap(stride_a_c_, stride_c_k_);
|
||||
std::swap(stride_a_h_, stride_c_p_);
|
||||
std::swap(stride_a_w_, stride_c_q_);
|
||||
std::swap(D_, RD_);
|
||||
std::swap(H_, RH_);
|
||||
std::swap(W_, RW_);
|
||||
pad_d_ = (RD_ - D_ + T_ - 1) / 2;
|
||||
pad_h_ = (RH_ - H_ + R_ - 1) / 2;
|
||||
pad_w_ = (RW_ - W_ + S_ - 1) / 2;
|
||||
}
|
||||
// look-up table info
|
||||
Fs_ = T_*R_*S_;
|
||||
TK_ = 8;
|
||||
Luts_ = (TK_ + Fs_ - 1) / Fs_ * Fs_;
|
||||
}
|
||||
|
||||
|
||||
void build_lut(std::vector<int>& delta, std::vector<int>& masks) {
|
||||
delta.resize(Luts_ + upsample_d_*upsample_h_*upsample_w_*Luts_);
|
||||
masks.resize(Luts_ + (2*pad_h_+1)*(2*pad_w_+1)*(2*pad_d_+1)*Luts_);
|
||||
|
||||
/* unpack index wrt filters */
|
||||
auto unpack = [&](int32_t trs){
|
||||
int32_t tr = trs / S_;
|
||||
int32_t s = trs - tr*S_;
|
||||
int32_t t = tr / R_;
|
||||
int32_t r = tr - t*R_;
|
||||
return std::make_tuple(t, r, s);
|
||||
};
|
||||
/* increments */
|
||||
for(size_t i = 0; i < Luts_; ++i)
|
||||
delta[i] = (((i + TK_) % Luts_) - i);
|
||||
/* deltas */
|
||||
size_t Ds0 = Luts_;
|
||||
size_t Ds1 = upsample_w_;
|
||||
size_t Ds2 = upsample_h_;
|
||||
size_t Ds3 = upsample_d_;
|
||||
for(size_t pd = 0; pd < Ds3; ++pd)
|
||||
for(size_t ph = 0; ph < Ds2; ++ph)
|
||||
for(size_t pw = 0; pw < Ds1; ++pw){
|
||||
int32_t* deltas_ptr = &delta[Luts_ + pw*Ds0 + ph*Ds0*Ds1 + pd*Ds0*Ds1*Ds2];
|
||||
// cumulative increments
|
||||
for(size_t i = 0; i < Ds0; ++i){
|
||||
int32_t ctrs = i;
|
||||
int32_t c = ctrs / Fs_;
|
||||
int32_t t, r, s;
|
||||
std::tie(t, r, s) = unpack(ctrs % Fs_);
|
||||
// next indices
|
||||
int32_t nextctrs = ctrs + TK_;
|
||||
int32_t nextc = nextctrs / Fs_;
|
||||
int32_t nextt, nextr, nexts;
|
||||
std::tie(nextt, nextr, nexts) = unpack(nextctrs % Fs_);
|
||||
// diffs
|
||||
int32_t cdiff = nextc - c;
|
||||
int32_t tdiff = (nextt + pd)/upsample_d_ - (t + pd)/upsample_d_;
|
||||
int32_t rdiff = (nextr + ph)/upsample_h_ - (r + ph)/upsample_h_;
|
||||
int32_t sdiff = (nexts + pw)/upsample_w_ - (s + pw)/upsample_w_;
|
||||
// delta pointers
|
||||
deltas_ptr[i] = cdiff*stride_a_c_ + sdiff*stride_a_w_ + rdiff*stride_a_h_ + tdiff*stride_a_d_;
|
||||
}
|
||||
void build_deltas(std::vector<int>& deltas){
|
||||
deltas.resize(Luts_ + upsample_d_*upsample_h_*upsample_w_*Luts_);
|
||||
auto unpack = [&](int32_t trs){
|
||||
int32_t tr = trs / S_;
|
||||
int32_t s = trs - tr*S_;
|
||||
int32_t t = tr / R_;
|
||||
int32_t r = tr - t*R_;
|
||||
return std::make_tuple(t, r, s);
|
||||
};
|
||||
for(size_t i = 0; i < Luts_; ++i)
|
||||
deltas[i] = (((i + TK_) % Luts_) - i);
|
||||
size_t Ds0 = Luts_;
|
||||
size_t Ds1 = upsample_w_;
|
||||
size_t Ds2 = upsample_h_;
|
||||
size_t Ds3 = upsample_d_;
|
||||
for(size_t pd = 0; pd < Ds3; ++pd)
|
||||
for(size_t ph = 0; ph < Ds2; ++ph)
|
||||
for(size_t pw = 0; pw < Ds1; ++pw){
|
||||
int32_t* deltas_ptr = &deltas[Luts_ + pw*Ds0 + ph*Ds0*Ds1 + pd*Ds0*Ds1*Ds2];
|
||||
// cumulative increments
|
||||
for(size_t i = 0; i < Ds0; ++i){
|
||||
int32_t ctrs = i;
|
||||
int32_t c = ctrs / Fs_;
|
||||
int32_t t, r, s;
|
||||
std::tie(t, r, s) = unpack(ctrs % Fs_);
|
||||
// next indices
|
||||
int32_t nextctrs = ctrs + TK_;
|
||||
int32_t nextc = nextctrs / Fs_;
|
||||
int32_t nextt, nextr, nexts;
|
||||
std::tie(nextt, nextr, nexts) = unpack(nextctrs % Fs_);
|
||||
// diffs
|
||||
int32_t cdiff = nextc - c;
|
||||
int32_t tdiff = (nextt + pd)/upsample_d_ - (t + pd)/upsample_d_;
|
||||
int32_t rdiff = (nextr + ph)/upsample_h_ - (r + ph)/upsample_h_;
|
||||
int32_t sdiff = (nexts + pw)/upsample_w_ - (s + pw)/upsample_w_;
|
||||
// delta pointers
|
||||
deltas_ptr[i] = cdiff*stride_a_c_ + sdiff*stride_a_w_ + rdiff*stride_a_h_ + tdiff*stride_a_d_;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Masks */
|
||||
size_t Ms0 = Luts_;
|
||||
size_t Ms1 = 2*pad_w_ + 1;
|
||||
size_t Ms2 = 2*pad_h_ + 1;
|
||||
size_t Ms3 = 2*pad_d_ + 1;
|
||||
for(size_t pd = 0; pd < Ms3; ++pd)
|
||||
for(size_t ph = 0; ph < Ms2; ++ph)
|
||||
for(size_t pw = 0; pw < Ms1; ++pw){
|
||||
int32_t* masks_ptr = &masks[Luts_ + pw*Ms0 + ph*Ms0*Ms1 + pd*Ms0*Ms1*Ms2];
|
||||
for(size_t i = 0; i < Ms0; ++i){
|
||||
int32_t t, r, s;
|
||||
int32_t mask = 0x0;
|
||||
for(size_t j = 0; j < TK_; ++j){
|
||||
std::tie(t, r, s) = unpack((i + j) % Fs_);
|
||||
bool in_bounds_d = (t + pd) >= pad_d_ && (t + pd) < (T_ + pad_d_);
|
||||
bool in_bounds_h = (r + ph) >= pad_h_ && (r + ph) < (R_ + pad_h_);
|
||||
bool in_bounds_w = (s + pw) >= pad_w_ && (s + pw) < (S_ + pad_w_);
|
||||
mask |= (in_bounds_d && in_bounds_h && in_bounds_w) << j;
|
||||
}
|
||||
masks_ptr[i] = mask;
|
||||
}
|
||||
void build_masks(std::vector<int>& masks){
|
||||
masks.resize(Luts_ + (2*pad_h_+1)*(2*pad_w_+1)*(2*pad_d_+1)*Luts_);
|
||||
auto unpack = [&](int32_t trs){
|
||||
int32_t tr = trs / S_;
|
||||
int32_t s = trs - tr*S_;
|
||||
int32_t t = tr / R_;
|
||||
int32_t r = tr - t*R_;
|
||||
return std::make_tuple(t, r, s);
|
||||
};
|
||||
size_t Ms0 = Luts_;
|
||||
size_t Ms1 = 2*pad_w_ + 1;
|
||||
size_t Ms2 = 2*pad_h_ + 1;
|
||||
size_t Ms3 = 2*pad_d_ + 1;
|
||||
for(size_t pd = 0; pd < Ms3; ++pd)
|
||||
for(size_t ph = 0; ph < Ms2; ++ph)
|
||||
for(size_t pw = 0; pw < Ms1; ++pw){
|
||||
int32_t* masks_ptr = &masks[Luts_ + pw*Ms0 + ph*Ms0*Ms1 + pd*Ms0*Ms1*Ms2];
|
||||
for(size_t i = 0; i < Ms0; ++i){
|
||||
int32_t t, r, s;
|
||||
int32_t mask = 0x0;
|
||||
for(size_t j = 0; j < TK_; ++j){
|
||||
std::tie(t, r, s) = unpack((i + j) % Fs_);
|
||||
bool in_bounds_d = (t + pd) >= pad_d_ && (t + pd) < (T_ + pad_d_);
|
||||
bool in_bounds_h = (r + ph) >= pad_h_ && (r + ph) < (R_ + pad_h_);
|
||||
bool in_bounds_w = (s + pw) >= pad_w_ && (s + pw) < (S_ + pad_w_);
|
||||
mask |= (in_bounds_d && in_bounds_h && in_bounds_w) << j;
|
||||
}
|
||||
masks_ptr[i] = mask;
|
||||
}
|
||||
for(size_t i = 0; i < Luts_; ++i)
|
||||
masks[i] = 0x0;
|
||||
|
||||
}
|
||||
for(size_t i = 0; i < Luts_; ++i)
|
||||
masks[i] = 0x0;
|
||||
}
|
||||
|
||||
static std::vector<unsigned> default_params() {
|
||||
return {16, 2, 64, 32, 2, 64, 16, 8, 2, 2, 8, 1, 8, 4 };
|
||||
std::array<size_t, 3> get_grid(size_t TM, size_t TN){
|
||||
return {(M_ + TM - 1)/TM, (N_ + TN - 1)/TN, 1};
|
||||
}
|
||||
|
||||
size_t get_nflops(){
|
||||
return 2.*M_*N_*K_;
|
||||
}
|
||||
|
||||
void set_arg(driver::kernel *kernel,
|
||||
driver::buffer *a, driver::buffer *b, driver::buffer *c)
|
||||
{
|
||||
|
||||
if(ty_ == BPROP)
|
||||
std::swap(a, c);
|
||||
kernel->setArg(0, a);
|
||||
kernel->setArg(1, b);
|
||||
kernel->setArg(2, c);
|
||||
kernel->setArg(3, M_);
|
||||
kernel->setArg(4, N_);
|
||||
kernel->setArg(5, K_);
|
||||
kernel->setArg(6, B_);
|
||||
kernel->setArg(7, H_);
|
||||
kernel->setArg(8, W_);
|
||||
kernel->setArg(9, NF_);
|
||||
kernel->setArg(10, RH_);
|
||||
kernel->setArg(11, RW_);
|
||||
kernel->setArg(12, NC_);
|
||||
kernel->setArg(13, R_);
|
||||
kernel->setArg(14, S_);
|
||||
kernel->setArg(15, stride_a_n_);
|
||||
kernel->setArg(16, stride_a_c_);
|
||||
kernel->setArg(17, stride_a_h_);
|
||||
kernel->setArg(18, stride_a_w_);
|
||||
kernel->setArg(19, stride_c_n_);
|
||||
kernel->setArg(20, stride_c_k_);
|
||||
kernel->setArg(21, stride_c_p_);
|
||||
kernel->setArg(22, stride_c_q_);
|
||||
kernel->setArg(23, pad_h_);
|
||||
kernel->setArg(24, pad_w_);
|
||||
}
|
||||
|
||||
std::vector<unsigned> default_params() {
|
||||
if(ty_ == FPROP)
|
||||
return {16, 2, 64, 32, 2, 64, 16, 8, 2, 2, 8, 1, 8, 4};
|
||||
else
|
||||
return {16, 2, 64, 16, 32, 16, 4, 2, 2, 4, 2, 8, 4, 2};
|
||||
}
|
||||
|
||||
|
||||
static std::string src(type ty = FPROP) {
|
||||
|
||||
std::string src() {
|
||||
std::string bs0 = "TN", bs1 = "TK";
|
||||
std::string ldb0 = "*NF", ldb1 = "";
|
||||
std::string bcb0 = "[:, newaxis]", bcb1 = "[newaxis, :]";
|
||||
std::string b = "b";
|
||||
if(ty_ == BPROP){
|
||||
std::swap(bs0, bs1);
|
||||
std::swap(ldb0, ldb1);
|
||||
std::swap(bcb0, bcb1);
|
||||
b = "trans(b)";
|
||||
}
|
||||
std::string res =
|
||||
R"(
|
||||
const tunable int32 TM = {16, 32, 64};
|
||||
@@ -158,7 +232,7 @@ public:
|
||||
int32 rar[TK] = racr % R;
|
||||
int32 ra1[TK] = rac*lda_c + rar*lda_h + ras*lda_w;
|
||||
fp32* pa[TM, TK] = a + ra1[newaxis, :] + ra0[:, newaxis];
|
||||
fp32* pb[TN, TK] = b + rb1[newaxis, :]*NF + rb0[:, newaxis];
|
||||
fp32* pb[)" + bs0 + ", " + bs1 + R"(] = b + rb1)" + bcb1 + ldb0 + " + rb0" + bcb0 + ldb1 + R"(;
|
||||
__constant__ int32* pincd[TK] = delta + rka;
|
||||
__constant__ int32* pd[TK] = delta + R*S + rka;
|
||||
int32 d[TK] = *pd;
|
||||
@@ -172,10 +246,10 @@ public:
|
||||
int32 checka1[TK] = 1 << rka;
|
||||
int1 checka[TM, TK] = (checka0[:, newaxis] & checka1[newaxis, :]) > 0;
|
||||
fp32 a[TM, TK] = checka ? *pa : 0;
|
||||
fp32 b[TN, TK] = *pb;
|
||||
fp32 b[)" + bs0 + ", " + bs1 + R"(] = *pb;
|
||||
for(int32 k = K; k > 0; k = k - TK){
|
||||
C = dot(a, trans(b), C);
|
||||
pb = pb + TK*NF;
|
||||
pb = pb + TK)" + ldb0 + R"(;
|
||||
pa = pa + d[newaxis, :];
|
||||
b = *pb;
|
||||
pd = pd + incd;
|
||||
@@ -203,42 +277,90 @@ public:
|
||||
return res;
|
||||
}
|
||||
|
||||
template<class IN_DTYPE, class OUT_DTYPE>
|
||||
void cpu_ref(OUT_DTYPE* C, IN_DTYPE* A, IN_DTYPE* B)
|
||||
{
|
||||
auto idx = [&](int32_t x, int32_t y, int32_t z, int32_t w, int32_t u,
|
||||
int32_t /*s0*/, int32_t s1, int32_t s2, int32_t s3, int32_t s4)
|
||||
{ return u + w*s4 + z*s4*s3 + y*s4*s3*s2 + x*s4*s3*s2*s1; };
|
||||
|
||||
if(ty_==BPROP){
|
||||
std::swap(A, C);
|
||||
}
|
||||
std::cout << A[0] << std::endl;
|
||||
IN_DTYPE accs[1];
|
||||
float tmp[1];
|
||||
for(int32_t m = 0 ; m < RD_; ++m)
|
||||
for(int32_t p = 0 ; p < RH_; ++p)
|
||||
for(int32_t q = 0; q < RW_; ++q)
|
||||
for(int32_t n = 0; n < B_; ++n)
|
||||
for(int32_t k = 0; k < NF_ ; ++k)
|
||||
{
|
||||
for(int32_t i = 0; i < 1; ++i)
|
||||
accs[i] = 0;
|
||||
int32_t mm = m*stride_d_ - pad_d_;
|
||||
int32_t pp = p*stride_h_ - pad_h_;
|
||||
int32_t qq = q*stride_w_ - pad_w_;
|
||||
for(int32_t kk = 0; kk < 1; ++kk)
|
||||
for(int32_t c = 0; c < NC_; ++c)
|
||||
for(int32_t t = 0; t < T_; ++t)
|
||||
for(int32_t r = 0; r < R_; ++r)
|
||||
for(int32_t s = 0; s < S_; ++s){
|
||||
int32_t d = mm + t;
|
||||
int32_t h = pp + r;
|
||||
int32_t w = qq + s;
|
||||
bool in_bounds = (d >= 0 && h >= 0 && w >= 0 && d < D_ && h < H_ && w < W_);
|
||||
IN_DTYPE a = in_bounds?A[idx(n, c, d, h, w, B_, NC_, D_, H_, W_)]:0;
|
||||
IN_DTYPE b;
|
||||
if(ty_==FPROP)
|
||||
b = B[idx(c, t, r, s, k*1 + kk, NC_, T_, R_, S_, NF_*1)];
|
||||
else
|
||||
b = B[idx(c, t, s, r, k*1 + kk, NC_, T_, R_, S_, NF_*1)];
|
||||
accs[kk] = std::fma(a, b, accs[kk]);
|
||||
}
|
||||
for(int32_t kk = 0; kk < 1; ++kk){
|
||||
tmp[kk] = accs[kk];
|
||||
}
|
||||
C[idx(n, k, m, p, q, B_, NF_, RD_, RH_, RW_)] = tmp[0];
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
// image size
|
||||
int B_;
|
||||
int NC_;
|
||||
int D_;
|
||||
int H_;
|
||||
int W_;
|
||||
int32_t B_;
|
||||
int32_t NC_;
|
||||
int32_t D_;
|
||||
int32_t H_;
|
||||
int32_t W_;
|
||||
// filter size
|
||||
int T_;
|
||||
int R_;
|
||||
int S_;
|
||||
int NF_;
|
||||
int32_t T_;
|
||||
int32_t R_;
|
||||
int32_t S_;
|
||||
int32_t NF_;
|
||||
// activation size
|
||||
int RD_;
|
||||
int RH_;
|
||||
int RW_;
|
||||
int32_t RD_;
|
||||
int32_t RH_;
|
||||
int32_t RW_;
|
||||
// upsampling
|
||||
int upsample_d_;
|
||||
int upsample_h_;
|
||||
int upsample_w_;
|
||||
int32_t upsample_d_;
|
||||
int32_t upsample_h_;
|
||||
int32_t upsample_w_;
|
||||
// padding
|
||||
int pad_d_;
|
||||
int pad_h_;
|
||||
int pad_w_;
|
||||
int32_t pad_d_;
|
||||
int32_t pad_h_;
|
||||
int32_t pad_w_;
|
||||
// striding
|
||||
int stride_d_;
|
||||
int stride_h_;
|
||||
int stride_w_;
|
||||
int32_t stride_d_;
|
||||
int32_t stride_h_;
|
||||
int32_t stride_w_;
|
||||
// equivalent matmul
|
||||
int M_;
|
||||
int N_;
|
||||
int K_;
|
||||
int32_t M_;
|
||||
int32_t N_;
|
||||
int32_t K_;
|
||||
// helpers
|
||||
int Fs_;
|
||||
int TK_;
|
||||
int Luts_;
|
||||
int32_t Fs_;
|
||||
int32_t TK_;
|
||||
int32_t Luts_;
|
||||
// memory strides for data
|
||||
int32_t stride_a_w_;
|
||||
int32_t stride_a_h_;
|
||||
@@ -251,7 +373,9 @@ private:
|
||||
int32_t stride_c_m_;
|
||||
int32_t stride_c_k_;
|
||||
int32_t stride_c_n_;
|
||||
|
||||
// type
|
||||
type ty_;
|
||||
bool is_bprop_;
|
||||
};
|
||||
|
||||
}
|
||||
|
Reference in New Issue
Block a user