[dnn/conv] added triton-c code for wgrad
This commit is contained in:
@@ -10,13 +10,13 @@ int main() {
|
||||
// initialize default compute device
|
||||
auto context = triton::driver::backend::contexts::get_default();
|
||||
triton::jit jit(context);
|
||||
triton::dnn::conv::type ty = triton::dnn::conv::BPROP;
|
||||
triton::dnn::conv::type ty = triton::dnn::conv::WGRAD;
|
||||
// initialization
|
||||
int32_t B = 4, NF = 32;
|
||||
int32_t D = 1, H = 24, W = 240;
|
||||
int32_t NC = 32, T = 1, R = 3, S = 3;
|
||||
int32_t pad_d = 0, pad_h = 1, pad_w = 1;
|
||||
triton::dnn::conv configuration(B, NC, H, W, R, S, NF, 1, 1, pad_h, pad_w, ty);
|
||||
triton::dnn::conv configuration(B, NC, D, H, W, T, R, S, NF, 1, 1, 1, pad_d, pad_h, pad_w, ty);
|
||||
// convolution configuration
|
||||
std::vector<float> hc(configuration.c_size());
|
||||
std::vector<float> rc(configuration.c_size());
|
||||
@@ -40,8 +40,10 @@ int main() {
|
||||
stream->synchronize();
|
||||
// look-up table
|
||||
std::vector<int> h_delta, h_masks;
|
||||
if(ty != triton::dnn::conv::WGRAD){
|
||||
configuration.build_deltas(h_delta);
|
||||
configuration.build_masks(h_masks);
|
||||
}
|
||||
// benchmark a given convolution kernel
|
||||
auto benchmark = [&](triton::driver::kernel* kernel,
|
||||
triton::jit::launch_information info) {
|
||||
@@ -49,10 +51,12 @@ int main() {
|
||||
unsigned TN = info.global_range_size[1];
|
||||
unsigned nthreads = info.num_threads;
|
||||
std::array<size_t, 3> grid = configuration.get_grid(TM, TN);
|
||||
if(ty != triton::dnn::conv::WGRAD){
|
||||
triton::driver::buffer* delta = jit.get_buffer("delta");
|
||||
triton::driver::buffer* masks = jit.get_buffer("masks");
|
||||
stream->write(delta, false, 0, h_delta.size()*4, h_delta.data());
|
||||
stream->write(masks, false, 0, h_masks.size()*4, h_masks.data());
|
||||
}
|
||||
stream->synchronize();
|
||||
configuration.set_arg(kernel, da, db, dc);
|
||||
stream->enqueue(kernel, grid, {nthreads, 1, 1});
|
||||
@@ -69,11 +73,11 @@ int main() {
|
||||
std::cout << "Performance: " << benchmark(kernel, info) << " TFLOPS " << std::endl;
|
||||
stream->read(dc, true, 0, hc);
|
||||
configuration.cpu_ref(rc.data(), ha.data(), hb.data());
|
||||
// std::cout << c[0] << std::endl;
|
||||
for(size_t i = 0; i < hc.size(); i++)
|
||||
for(size_t i = 0; i < hc.size(); i++){
|
||||
if(std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-4){
|
||||
std::cout << i << " " << hc[i] << " " << rc[i] << std::endl;
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
}
|
||||
std::cout << "Pass!" << std::endl;
|
||||
}
|
||||
|
@@ -1,5 +1,7 @@
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include <numeric>
|
||||
#include "triton/driver/stream.h"
|
||||
#include "triton/driver/kernel.h"
|
||||
|
||||
@@ -15,74 +17,91 @@ public:
|
||||
};
|
||||
|
||||
|
||||
conv(int B, int NC, int H, int W, int R, int S, int NF,
|
||||
int upsample_h, int upsample_w,
|
||||
int pad_h, int pad_w,
|
||||
conv(int B, int NC,
|
||||
int D, int H, int W,
|
||||
int T, int R, int S, int NF,
|
||||
int upsample_d, int upsample_h, int upsample_w,
|
||||
int pad_d, int pad_h, int pad_w,
|
||||
type ty = FPROP)
|
||||
: B_(B), NC_(NC), D_(1), H_(H), W_(W), T_(1), R_(R), S_(S), NF_(NF),
|
||||
upsample_d_(1), upsample_h_(upsample_h), upsample_w_(upsample_w),
|
||||
: NB_(B), NC_(NC), AD_(D), AH_(H), AW_(W), BD_(T), BH_(R), BW_(S), NF_(NF),
|
||||
upsample_d_(upsample_d), upsample_h_(upsample_h), upsample_w_(upsample_w),
|
||||
stride_d_(1), stride_h_(1), stride_w_(1),
|
||||
pad_d_(0), pad_h_(pad_h), pad_w_(pad_w),
|
||||
pad_d_(pad_d), pad_h_(pad_h), pad_w_(pad_w),
|
||||
ty_(ty)
|
||||
{
|
||||
RD_ = (D_*upsample_d_ - T_ + 1 + 2*pad_d_ + stride_d_ - 1)/stride_d_;
|
||||
RH_ = (H_*upsample_h_ - R_ + 1 + 2*pad_h_ + stride_h_ - 1)/stride_h_;
|
||||
RW_ = (W_*upsample_w_ - S_ + 1 + 2*pad_w_ + stride_w_ - 1)/stride_w_;
|
||||
// memory strides for data
|
||||
stride_a_w_ = 1;
|
||||
stride_a_h_ = W_*stride_a_w_;
|
||||
stride_a_d_ = H_*stride_a_h_;
|
||||
stride_a_c_ = D_*stride_a_d_;
|
||||
stride_a_n_ = NC_*stride_a_c_;
|
||||
// memory stride for activations
|
||||
stride_c_q_ = 1;
|
||||
stride_c_p_ = RW_*stride_c_q_;
|
||||
stride_c_m_ = RH_*stride_c_p_;
|
||||
stride_c_k_ = RD_*stride_c_m_;
|
||||
stride_c_n_ = NF_*stride_c_k_;
|
||||
CD_ = (AD_*upsample_d_ - BD_ + 1 + 2*pad_d_ + stride_d_ - 1)/stride_d_;
|
||||
CH_ = (AH_*upsample_h_ - BH_ + 1 + 2*pad_h_ + stride_h_ - 1)/stride_h_;
|
||||
CW_ = (AW_*upsample_w_ - BW_ + 1 + 2*pad_w_ + stride_w_ - 1)/stride_w_;
|
||||
// shapes
|
||||
shapes_a_ = {NB_, NC_, AD_, AH_, AW_};
|
||||
shapes_b_ = {NC_, BD_, BH_, BW_, NF_};
|
||||
shapes_c_ = {NB_, NF_, CD_, CH_, CW_};
|
||||
// swap a and c for bprop
|
||||
if(ty_ == BPROP){
|
||||
std::swap(stride_a_n_, stride_c_n_);
|
||||
std::swap(stride_a_c_, stride_c_k_);
|
||||
std::swap(stride_a_h_, stride_c_p_);
|
||||
std::swap(stride_a_w_, stride_c_q_);
|
||||
std::swap(D_, RD_);
|
||||
std::swap(H_, RH_);
|
||||
std::swap(W_, RW_);
|
||||
std::swap(NF_, NC_);
|
||||
pad_d_ = (RD_ - D_ + T_ - 1) / 2;
|
||||
pad_h_ = (RH_ - H_ + R_ - 1) / 2;
|
||||
pad_w_ = (RW_ - W_ + S_ - 1) / 2;
|
||||
pad_d_ = (CD_ - AD_ + BD_ - 1) / 2;
|
||||
pad_h_ = (CH_ - AH_ + BH_ - 1) / 2;
|
||||
pad_w_ = (CW_ - AW_ + BW_ - 1) / 2;
|
||||
shapes_a_.swap(shapes_c_);
|
||||
}
|
||||
// swap b and c for wgrad
|
||||
if(ty_ == WGRAD){
|
||||
shapes_b_.swap(shapes_c_);
|
||||
}
|
||||
// leading dimensions
|
||||
auto set_ld = [](const std::vector<int32_t>& shapes,
|
||||
std::vector<int32_t>& ld) {
|
||||
size_t size = shapes.size();
|
||||
ld.resize(size);
|
||||
ld[4] = 1;
|
||||
ld[3] = shapes[4]*ld[4];
|
||||
ld[2] = shapes[3]*ld[3];
|
||||
ld[1] = shapes[2]*ld[2];
|
||||
ld[0] = shapes[1]*ld[1];
|
||||
};
|
||||
set_ld(shapes_a_, ld_a_);
|
||||
set_ld(shapes_b_, ld_b_);
|
||||
set_ld(shapes_c_, ld_c_);
|
||||
// equivalent matmul
|
||||
M_ = B_*RD_*RH_*RW_;
|
||||
N_ = NF_;
|
||||
K_ = NC_*T_*R_*S_;
|
||||
if(ty_ == WGRAD){
|
||||
M_ = shapes_c_[0]*shapes_c_[1]*shapes_c_[2]*shapes_c_[3];
|
||||
N_ = shapes_c_[4];
|
||||
K_ = shapes_b_[0]*shapes_b_[2]*shapes_b_[3]*shapes_b_[4];
|
||||
}
|
||||
else{
|
||||
M_ = shapes_c_[0]*shapes_c_[2]*shapes_c_[3]*shapes_c_[4];
|
||||
N_ = shapes_c_[1];
|
||||
K_ = shapes_b_[0]*shapes_b_[1]*shapes_b_[2]*shapes_b_[3];
|
||||
}
|
||||
// look-up table info
|
||||
Fs_ = T_*R_*S_;
|
||||
Fs_ = shapes_b_[1]*shapes_b_[2]*shapes_b_[3];
|
||||
TK_ = 8;
|
||||
Luts_ = (TK_ + Fs_ - 1) / Fs_ * Fs_;
|
||||
}
|
||||
|
||||
size_t a_size() {
|
||||
return B_*NC_*D_*H_*W_;
|
||||
return std::accumulate(shapes_a_.begin(), shapes_a_.end(),
|
||||
1, std::multiplies<int>());
|
||||
}
|
||||
|
||||
size_t b_size() {
|
||||
return NC_*NF_*T_*R_*S_;
|
||||
return std::accumulate(shapes_b_.begin(), shapes_b_.end(),
|
||||
1, std::multiplies<int>());
|
||||
}
|
||||
|
||||
size_t c_size() {
|
||||
return B_*NF_*RD_*RH_*RW_;
|
||||
return std::accumulate(shapes_c_.begin(), shapes_c_.end(),
|
||||
1, std::multiplies<int>());
|
||||
}
|
||||
|
||||
void build_deltas(std::vector<int>& deltas){
|
||||
if(ty_ == WGRAD)
|
||||
throw std::runtime_error("no look-up table necessary for wgrad");
|
||||
deltas.resize(Luts_ + upsample_d_*upsample_h_*upsample_w_*Luts_);
|
||||
auto unpack = [&](int32_t trs){
|
||||
int32_t tr = trs / S_;
|
||||
int32_t s = trs - tr*S_;
|
||||
int32_t t = tr / R_;
|
||||
int32_t r = tr - t*R_;
|
||||
int32_t tr = trs / BW_;
|
||||
int32_t s = trs - tr*BW_;
|
||||
int32_t t = tr / BH_;
|
||||
int32_t r = tr - t*BH_;
|
||||
return std::make_tuple(t, r, s);
|
||||
};
|
||||
for(size_t i = 0; i < Luts_; ++i)
|
||||
@@ -112,18 +131,20 @@ public:
|
||||
int32_t rdiff = (nextr + ph)/upsample_h_ - (r + ph)/upsample_h_;
|
||||
int32_t sdiff = (nexts + pw)/upsample_w_ - (s + pw)/upsample_w_;
|
||||
// delta pointers
|
||||
deltas_ptr[i] = cdiff*stride_a_c_ + sdiff*stride_a_w_ + rdiff*stride_a_h_ + tdiff*stride_a_d_;
|
||||
deltas_ptr[i] = cdiff*ld_a_[1] + tdiff*ld_a_[2] + rdiff*ld_a_[3] + sdiff*ld_a_[4];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void build_masks(std::vector<int>& masks){
|
||||
if(ty_ == WGRAD)
|
||||
throw std::runtime_error("no look-up table necessary for wgrad");
|
||||
masks.resize(Luts_ + (2*pad_h_+1)*(2*pad_w_+1)*(2*pad_d_+1)*Luts_);
|
||||
auto unpack = [&](int32_t trs){
|
||||
int32_t tr = trs / S_;
|
||||
int32_t s = trs - tr*S_;
|
||||
int32_t t = tr / R_;
|
||||
int32_t r = tr - t*R_;
|
||||
int32_t tr = trs / BW_;
|
||||
int32_t s = trs - tr*BW_;
|
||||
int32_t t = tr / BH_;
|
||||
int32_t r = tr - t*BH_;
|
||||
return std::make_tuple(t, r, s);
|
||||
};
|
||||
size_t Ms0 = Luts_;
|
||||
@@ -139,9 +160,9 @@ public:
|
||||
int32_t mask = 0x0;
|
||||
for(size_t j = 0; j < TK_; ++j){
|
||||
std::tie(t, r, s) = unpack((i + j) % Fs_);
|
||||
bool in_bounds_d = (t + pd) >= pad_d_ && (t + pd) < (T_ + pad_d_);
|
||||
bool in_bounds_h = (r + ph) >= pad_h_ && (r + ph) < (R_ + pad_h_);
|
||||
bool in_bounds_w = (s + pw) >= pad_w_ && (s + pw) < (S_ + pad_w_);
|
||||
bool in_bounds_d = (t + pd) >= pad_d_ && (t + pd) < (BD_ + pad_d_);
|
||||
bool in_bounds_h = (r + ph) >= pad_h_ && (r + ph) < (BH_ + pad_h_);
|
||||
bool in_bounds_w = (s + pw) >= pad_w_ && (s + pw) < (BW_ + pad_w_);
|
||||
mask |= (in_bounds_d && in_bounds_h && in_bounds_w) << j;
|
||||
}
|
||||
masks_ptr[i] = mask;
|
||||
@@ -168,46 +189,40 @@ public:
|
||||
kernel->setArg(3, M_);
|
||||
kernel->setArg(4, N_);
|
||||
kernel->setArg(5, K_);
|
||||
kernel->setArg(6, B_);
|
||||
kernel->setArg(7, H_);
|
||||
kernel->setArg(8, W_);
|
||||
kernel->setArg(9, NF_);
|
||||
kernel->setArg(10, RH_);
|
||||
kernel->setArg(11, RW_);
|
||||
kernel->setArg(12, NC_);
|
||||
kernel->setArg(13, R_);
|
||||
kernel->setArg(14, S_);
|
||||
kernel->setArg(15, stride_a_n_);
|
||||
kernel->setArg(16, stride_a_c_);
|
||||
kernel->setArg(17, stride_a_h_);
|
||||
kernel->setArg(18, stride_a_w_);
|
||||
kernel->setArg(19, stride_c_n_);
|
||||
kernel->setArg(20, stride_c_k_);
|
||||
kernel->setArg(21, stride_c_p_);
|
||||
kernel->setArg(22, stride_c_q_);
|
||||
kernel->setArg(23, pad_h_);
|
||||
kernel->setArg(24, pad_w_);
|
||||
kernel->setArg(6, AH_);
|
||||
kernel->setArg(7, AW_);
|
||||
kernel->setArg(8, BH_);
|
||||
kernel->setArg(9, BW_);
|
||||
kernel->setArg(10, CH_);
|
||||
kernel->setArg(11, CW_);
|
||||
kernel->setArg(12, ld_a_[0]);
|
||||
kernel->setArg(13, ld_a_[1]);
|
||||
kernel->setArg(14, ld_a_[2]);
|
||||
kernel->setArg(15, ld_a_[3]);
|
||||
kernel->setArg(16, ld_a_[4]);
|
||||
kernel->setArg(17, ld_b_[0]);
|
||||
kernel->setArg(18, ld_b_[1]);
|
||||
kernel->setArg(19, ld_b_[2]);
|
||||
kernel->setArg(20, ld_b_[3]);
|
||||
kernel->setArg(21, ld_b_[4]);
|
||||
kernel->setArg(22, ld_c_[0]);
|
||||
kernel->setArg(23, ld_c_[1]);
|
||||
kernel->setArg(24, ld_c_[2]);
|
||||
kernel->setArg(25, ld_c_[3]);
|
||||
kernel->setArg(26, ld_c_[4]);
|
||||
kernel->setArg(27, pad_h_);
|
||||
kernel->setArg(28, pad_w_);
|
||||
}
|
||||
|
||||
std::vector<unsigned> default_params() {
|
||||
// if(ty_ == FPROP)
|
||||
if(ty_ == FPROP || ty_ == BPROP)
|
||||
return {16, 2, 64, 32, 2, 64, 16, 8, 2, 2, 8, 1, 8, 4};
|
||||
// else
|
||||
// return {16, 2, 64, 16, 32, 16, 4, 2, 2, 4, 2, 8, 4, 2};
|
||||
else
|
||||
return {8, 2, 16, 8, 2, 16, 8, 2, 8, 8};
|
||||
}
|
||||
|
||||
|
||||
std::string src() {
|
||||
std::string bs0 = "TN", bs1 = "TK";
|
||||
std::string ldb0 = "*NF", ldb1 = "";
|
||||
std::string bcb0 = "[:, newaxis]", bcb1 = "[newaxis, :]";
|
||||
std::string b = "b";
|
||||
if(ty_ == BPROP){
|
||||
std::swap(bs0, bs1);
|
||||
std::swap(ldb0, ldb1);
|
||||
std::swap(bcb0, bcb1);
|
||||
b = "trans(b)";
|
||||
}
|
||||
std::string xprop() {
|
||||
std::string res =
|
||||
R"(
|
||||
const tunable int32 TM = {16, 32, 64};
|
||||
@@ -221,36 +236,37 @@ public:
|
||||
read_only restrict fp32 *b,
|
||||
fp32 *c,
|
||||
int32 M, int32 N, int32 K,
|
||||
int32 B, int32 H, int32 W,
|
||||
int32 NF, int32 RH, int32 RW,
|
||||
int32 NC, int32 R, int32 S,
|
||||
int32 lda_n, int32 lda_c, int32 lda_h, int32 lda_w,
|
||||
int32 ldc_n, int32 ldc_k, int32 ldc_p, int32 ldc_q,
|
||||
int32 AH, int32 AW,
|
||||
int32 BH, int32 BW,
|
||||
int32 CH, int32 CW,
|
||||
int32 lda_n, int32 lda_c, int32 lda_d, int32 lda_h, int32 lda_w,
|
||||
int32 ldb_c, int32 ldb_t, int32 ldb_r, int32 ldb_s, int32 ldb_k,
|
||||
int32 ldc_n, int32 ldc_k, int32 ldc_m, int32 ldc_p, int32 ldc_q,
|
||||
int32 pad_h, int32 pad_w){
|
||||
int32 rxa[TM] = get_global_range[TM](0);
|
||||
int32 rb0[TN] = get_global_range[TN](1);
|
||||
int32 rka[TK] = 0 ... TK;
|
||||
int32 rb1[TK] = 0 ... TK;
|
||||
fp32 C[TM, TN] = 0;
|
||||
int32 rabh[TM] = rxa / RW;
|
||||
int32 raw[TM] = rxa % RW - pad_w;
|
||||
int32 rab[TM] = rabh / RH;
|
||||
int32 rah[TM] = rabh % RH - pad_h;
|
||||
int32 rabh[TM] = rxa / CW;
|
||||
int32 raw[TM] = rxa % CW - pad_w;
|
||||
int32 rab[TM] = rabh / CH;
|
||||
int32 rah[TM] = rabh % CH - pad_h;
|
||||
int32 ra0[TM] = rab*lda_n + rah*lda_h + raw*lda_w;
|
||||
int32 racr[TK] = rka / S;
|
||||
int32 ras[TK] = rka % S;
|
||||
int32 rac[TK] = racr / R;
|
||||
int32 rar[TK] = racr % R;
|
||||
int32 racr[TK] = rka / BW;
|
||||
int32 ras[TK] = rka % BW;
|
||||
int32 rac[TK] = racr / BH;
|
||||
int32 rar[TK] = racr % BH;
|
||||
int32 ra1[TK] = rac*lda_c + rar*lda_h + ras*lda_w;
|
||||
fp32* pa[TM, TK] = a + ra1[newaxis, :] + ra0[:, newaxis];
|
||||
fp32* pb[TN, TK] = b + rb1[newaxis, :]*NF + rb0[:, newaxis];
|
||||
fp32* pb[TN, TK] = b + rb1[newaxis, :]*ldb_s + rb0[:, newaxis];
|
||||
__constant__ int32* pincd[TK] = delta + rka;
|
||||
__constant__ int32* pd[TK] = delta + R*S + rka;
|
||||
__constant__ int32* pd[TK] = delta + BH*BW + rka;
|
||||
int32 d[TK] = *pd;
|
||||
int32 incd[TK] = *pincd;
|
||||
int32 maskh[TM] = pad_h + min(rah, 0) + max(rah + R - H, 0);
|
||||
int32 maskw[TM] = pad_w + min(raw, 0) + max(raw + S - W, 0);
|
||||
__constant__ int32* pm[TM] = masks + R*S + maskw*R*S + maskh*R*S*(2*pad_w + 1);
|
||||
int32 maskh[TM] = pad_h + min(rah, 0) + max(rah + BH - AH, 0);
|
||||
int32 maskw[TM] = pad_w + min(raw, 0) + max(raw + BW - AW, 0);
|
||||
__constant__ int32* pm[TM] = masks + BH*BW + maskw*BH*BW + maskh*BH*BW*(2*pad_w + 1);
|
||||
__constant__ int32* pincm[TM] = delta;
|
||||
int32 incm[TM] = *pincm;
|
||||
int32 checka0[TM] = *pm;
|
||||
@@ -260,7 +276,7 @@ public:
|
||||
fp32 b[TN, TK] = *pb;
|
||||
for(int32 k = K; k > 0; k = k - TK){
|
||||
C = dot(a, trans(b), C);
|
||||
pb = pb + TK*NF;
|
||||
pb = pb + TK*ldb_s;
|
||||
pa = pa + d[newaxis, :];
|
||||
b = *pb;
|
||||
pd = pd + incd;
|
||||
@@ -276,8 +292,8 @@ public:
|
||||
}
|
||||
int32 rxc[TM] = get_global_range[TM](0);
|
||||
int32 rc1[TN] = get_global_range[TN](1);
|
||||
int32 rcn[TM] = rxc / (RH*RW);
|
||||
int32 rcpq[TM] = rxc % (RH*RW);
|
||||
int32 rcn[TM] = rxc / (CH*CW);
|
||||
int32 rcpq[TM] = rxc % (CH*CW);
|
||||
int32 rc0[TM] = rcn * ldc_n + rcpq;
|
||||
fp32* pc[TM, TN] = c + rc1[newaxis, :]*ldc_k + rc0[:, newaxis];
|
||||
int1 checkc0[TM] = rxc < M;
|
||||
@@ -288,62 +304,169 @@ public:
|
||||
return res;
|
||||
}
|
||||
|
||||
// C = A * B
|
||||
// where A is N,C,AH,AW
|
||||
// B is N,K,BH,BW
|
||||
// C is C,CH,CW,K
|
||||
std::string wgrad() {
|
||||
std::string res =
|
||||
R"(
|
||||
const tunable int32 TM = {16, 32, 64};
|
||||
const tunable int32 TN = {16, 32, 64};
|
||||
const tunable int32 TK = {8};
|
||||
|
||||
void conv(read_only restrict fp32 *a,
|
||||
read_only restrict fp32 *b,
|
||||
fp32 *c,
|
||||
int32 M, int32 N, int32 K,
|
||||
int32 AH, int32 AW,
|
||||
int32 CH, int32 CW,
|
||||
int32 BH, int32 BW,
|
||||
int32 lda_n, int32 lda_c, int32 lda_d, int32 lda_h, int32 lda_w,
|
||||
int32 ldb_n, int32 ldb_k, int32 ldb_m, int32 ldb_p, int32 ldb_q,
|
||||
int32 ldc_c, int32 ldc_t, int32 ldc_r, int32 ldc_s, int32 ldc_k,
|
||||
int32 pad_h, int32 pad_w){
|
||||
int32 rxa[TM] = get_global_range[TM](0);
|
||||
int32 ryb[TN] = get_global_range[TN](1);
|
||||
int32 rk[TK] = 0 ... TK;
|
||||
fp32 C[TM, TN] = 0;
|
||||
int32 racr[TM] = rxa / CW;
|
||||
int32 raw_base[TM] = rxa % CW - pad_w;
|
||||
int32 rac[TM] = racr / CH;
|
||||
int32 rah_base[TM] = racr % CH - pad_h;
|
||||
fp32* pa_base[TM, TK] = a + rac[:, newaxis]*lda_c;
|
||||
fp32* pb_base[TN, TK] = b + ryb[:, newaxis]*ldb_k;
|
||||
for(int32 k = K; k > 0; k = k - TK){
|
||||
int32 rknp[TK] = rk / BW;
|
||||
int32 rkq[TK] = rk % BW;
|
||||
int32 rkn[TK] = rknp / BH;
|
||||
int32 rkp[TK] = rknp % BH;
|
||||
int32 rah[TM, TK] = rah_base[:, newaxis] + rkp[newaxis, :];
|
||||
int32 raw[TM, TK] = raw_base[:, newaxis] + rkq[newaxis, :];
|
||||
int1 checka[TM, TK] = (rah >= 0) && (rah < AH) && (raw >= 0) && (raw < AW);
|
||||
fp32* pa[TM, TK] = pa_base + rah*lda_h + raw*lda_w + rkn*lda_n;
|
||||
fp32* pb[TN, TK] = pb_base + rkp*ldb_p + rkq*ldb_q + rkn*ldb_n;
|
||||
fp32 A[TM, TK] = checka ? *pa : 0;
|
||||
fp32 B[TN, TK] = *pb;
|
||||
C = dot(A, trans(B), C);
|
||||
rk = rk + TK;
|
||||
}
|
||||
int32 rxc[TM] = get_global_range[TM](0);
|
||||
int32 ryc[TN] = get_global_range[TN](1);
|
||||
int32 rccr[TM] = rxc / CW;
|
||||
int32 rcs[TM] = rxa % CW;
|
||||
int32 rcc[TM] = racr / CH;
|
||||
int32 rcr[TM] = racr % CH;
|
||||
int32 rc0[TM] = rcc*ldc_c + rcr*ldc_r + rcs*ldc_s;
|
||||
fp32* pc[TM, TN] = c + rc0[:, newaxis] + ryc[newaxis, :]*ldc_k;
|
||||
int1 checkc0[TM] = rxc < M;
|
||||
int1 checkc1[TN] = ryc < N;
|
||||
int1 checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];
|
||||
@checkc *pc = C;
|
||||
})";
|
||||
return res;
|
||||
}
|
||||
|
||||
std::string src() {
|
||||
if(ty_ == FPROP || ty_ == BPROP)
|
||||
return xprop();
|
||||
else
|
||||
return wgrad();
|
||||
}
|
||||
|
||||
template<class IN_DTYPE, class OUT_DTYPE>
|
||||
void cpu_xprop(OUT_DTYPE* C, IN_DTYPE* A, IN_DTYPE* B)
|
||||
{
|
||||
IN_DTYPE acc;
|
||||
for(int32_t n = 0; n < shapes_c_[0]; ++n)
|
||||
for(int32_t k = 0; k < shapes_c_[1] ; ++k)
|
||||
for(int32_t cd = 0 ; cd < shapes_c_[2]; ++cd)
|
||||
for(int32_t ch = 0 ; ch < shapes_c_[3]; ++ch)
|
||||
for(int32_t cw = 0; cw < shapes_c_[4]; ++cw)
|
||||
{
|
||||
acc = 0;
|
||||
int32_t d = cd*stride_d_ - pad_d_;
|
||||
int32_t h = ch*stride_h_ - pad_h_;
|
||||
int32_t w = cw*stride_w_ - pad_w_;
|
||||
for(int32_t c = 0; c < shapes_b_[0]; ++c)
|
||||
for(int32_t bd = 0; bd < shapes_b_[1]; ++bd)
|
||||
for(int32_t bh = 0; bh < shapes_b_[2]; ++bh)
|
||||
for(int32_t bw = 0; bw < shapes_b_[3]; ++bw){
|
||||
int32_t ad = d + bd;
|
||||
int32_t ah = h + bh;
|
||||
int32_t aw = w + bw;
|
||||
bool in_bounds = (ad >= 0 && ad < shapes_a_[2] &&
|
||||
ah >= 0 && ah < shapes_a_[3] &&
|
||||
aw >= 0 && aw < shapes_a_[4]);
|
||||
IN_DTYPE a = 0;
|
||||
if(in_bounds)
|
||||
a = A[n*ld_a_[0] + c*ld_a_[1] + ad*ld_a_[2] + ah*ld_a_[3] + aw*ld_a_[4]];
|
||||
IN_DTYPE b = B[c*ld_b_[0] + bd*ld_b_[1] + bh*ld_b_[2] + bw*ld_b_[3] + k*ld_b_[4]];
|
||||
acc = std::fma(a, b, acc);
|
||||
}
|
||||
C[n*ld_c_[0] + k*ld_c_[1] + cd*ld_c_[2] + ch*ld_c_[3] + cw*ld_c_[4]] = acc;
|
||||
}
|
||||
}
|
||||
|
||||
template<class IN_DTYPE, class OUT_DTYPE>
|
||||
void cpu_wgrad(OUT_DTYPE* C, IN_DTYPE* A, IN_DTYPE* B)
|
||||
{
|
||||
IN_DTYPE acc;
|
||||
for(int32_t c = 0 ; c < shapes_c_[0]; ++c)
|
||||
for(int32_t cd = 0; cd < shapes_c_[1]; ++cd)
|
||||
for(int32_t ch = 0; ch < shapes_c_[2]; ++ch)
|
||||
for(int32_t cw = 0; cw < shapes_c_[3]; ++cw)
|
||||
for(int32_t k = 0 ; k < shapes_c_[4]; ++k)
|
||||
{
|
||||
acc = 0;
|
||||
int32_t d = cd*stride_d_ - pad_d_;
|
||||
int32_t h = ch*stride_h_ - pad_h_;
|
||||
int32_t w = cw*stride_w_ - pad_w_;
|
||||
for(int32_t n = 0; n < shapes_b_[0]; ++n)
|
||||
for(int32_t bd = 0; bd < shapes_b_[2]; ++bd)
|
||||
for(int32_t bh = 0; bh < shapes_b_[3]; ++bh)
|
||||
for(int32_t bw = 0; bw < shapes_b_[4]; ++bw){
|
||||
int32_t ad = d + bd;
|
||||
int32_t ah = h + bh;
|
||||
int32_t aw = w + bw;
|
||||
bool in_bounds = (ad >= 0 && ad < shapes_a_[2] &&
|
||||
ah >= 0 && ah < shapes_a_[3] &&
|
||||
aw >= 0 && aw < shapes_a_[4]);
|
||||
IN_DTYPE a = 0;
|
||||
if(in_bounds)
|
||||
a = A[n*ld_a_[0] + c*ld_a_[1] + ad*ld_a_[2] + ah*ld_a_[3] + aw*ld_a_[4]];
|
||||
IN_DTYPE b = B[n*ld_b_[0] + k*ld_b_[1] + bd*ld_b_[2] + bh*ld_b_[3] + bw*ld_b_[4]];
|
||||
acc = std::fma(a, b, acc);
|
||||
}
|
||||
C[c*ld_c_[0] + cd*ld_c_[1] + ch*ld_c_[2] + cw*ld_c_[3] + k*ld_c_[4]] = acc;
|
||||
}
|
||||
}
|
||||
|
||||
template<class IN_DTYPE, class OUT_DTYPE>
|
||||
void cpu_ref(OUT_DTYPE* C, IN_DTYPE* A, IN_DTYPE* B)
|
||||
{
|
||||
auto idx = [&](int32_t x, int32_t y, int32_t z, int32_t w, int32_t u,
|
||||
int32_t /*s0*/, int32_t s1, int32_t s2, int32_t s3, int32_t s4)
|
||||
{ return u + w*s4 + z*s4*s3 + y*s4*s3*s2 + x*s4*s3*s2*s1; };
|
||||
|
||||
IN_DTYPE accs[1];
|
||||
float tmp[1];
|
||||
for(int32_t m = 0 ; m < RD_; ++m)
|
||||
for(int32_t p = 0 ; p < RH_; ++p)
|
||||
for(int32_t q = 0; q < RW_; ++q)
|
||||
for(int32_t n = 0; n < B_; ++n)
|
||||
for(int32_t k = 0; k < NF_ ; ++k)
|
||||
{
|
||||
for(int32_t i = 0; i < 1; ++i)
|
||||
accs[i] = 0;
|
||||
int32_t mm = m*stride_d_ - pad_d_;
|
||||
int32_t pp = p*stride_h_ - pad_h_;
|
||||
int32_t qq = q*stride_w_ - pad_w_;
|
||||
for(int32_t kk = 0; kk < 1; ++kk)
|
||||
for(int32_t c = 0; c < NC_; ++c)
|
||||
for(int32_t t = 0; t < T_; ++t)
|
||||
for(int32_t r = 0; r < R_; ++r)
|
||||
for(int32_t s = 0; s < S_; ++s){
|
||||
int32_t d = mm + t;
|
||||
int32_t h = pp + r;
|
||||
int32_t w = qq + s;
|
||||
bool in_bounds = (d >= 0 && h >= 0 && w >= 0 && d < D_ && h < H_ && w < W_);
|
||||
IN_DTYPE a = in_bounds?A[idx(n, c, d, h, w, B_, NC_, D_, H_, W_)]:0;
|
||||
IN_DTYPE b = B[idx(c, t, r, s, k*1 + kk, NC_, T_, R_, S_, NF_*1)];
|
||||
accs[kk] = std::fma(a, b, accs[kk]);
|
||||
}
|
||||
for(int32_t kk = 0; kk < 1; ++kk){
|
||||
tmp[kk] = accs[kk];
|
||||
}
|
||||
C[idx(n, k, m, p, q, B_, NF_, RD_, RH_, RW_)] = tmp[0];
|
||||
}
|
||||
if(ty_ == FPROP || ty_ == BPROP)
|
||||
cpu_xprop(C, A, B);
|
||||
else
|
||||
cpu_wgrad(C, A, B);
|
||||
}
|
||||
|
||||
private:
|
||||
// image size
|
||||
int32_t B_;
|
||||
int32_t NB_;
|
||||
int32_t NC_;
|
||||
int32_t D_;
|
||||
int32_t H_;
|
||||
int32_t W_;
|
||||
int32_t AD_;
|
||||
int32_t AH_;
|
||||
int32_t AW_;
|
||||
// filter size
|
||||
int32_t T_;
|
||||
int32_t R_;
|
||||
int32_t S_;
|
||||
int32_t BD_;
|
||||
int32_t BH_;
|
||||
int32_t BW_;
|
||||
int32_t NF_;
|
||||
// activation size
|
||||
int32_t RD_;
|
||||
int32_t RH_;
|
||||
int32_t RW_;
|
||||
int32_t CD_;
|
||||
int32_t CH_;
|
||||
int32_t CW_;
|
||||
// upsampling
|
||||
int32_t upsample_d_;
|
||||
int32_t upsample_h_;
|
||||
@@ -364,18 +487,15 @@ private:
|
||||
int32_t Fs_;
|
||||
int32_t TK_;
|
||||
int32_t Luts_;
|
||||
// memory strides for data
|
||||
int32_t stride_a_w_;
|
||||
int32_t stride_a_h_;
|
||||
int32_t stride_a_d_;
|
||||
int32_t stride_a_c_;
|
||||
int32_t stride_a_n_;
|
||||
// memory stride for activations
|
||||
int32_t stride_c_q_;
|
||||
int32_t stride_c_p_;
|
||||
int32_t stride_c_m_;
|
||||
int32_t stride_c_k_;
|
||||
int32_t stride_c_n_;
|
||||
// memory strides for A
|
||||
std::vector<int32_t> shapes_a_;
|
||||
std::vector<int32_t> ld_a_;
|
||||
// memory strides for B
|
||||
std::vector<int32_t> shapes_b_;
|
||||
std::vector<int32_t> ld_b_;
|
||||
// memory stride for C
|
||||
std::vector<int32_t> shapes_c_;
|
||||
std::vector<int32_t> ld_c_;
|
||||
// type
|
||||
type ty_;
|
||||
bool is_bprop_;
|
||||
|
@@ -812,7 +812,9 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
||||
std::swap(b_idx[0], b_idx[1]);
|
||||
Value *a = TA->get_value(a_idx);
|
||||
Value *b = TB->get_value(b_idx);
|
||||
// res = builder.CreateCall(f_mul_add, {ConstantFP::get(a->getType(), 1), ConstantFP::get(b->getType(), 1), res});
|
||||
res = builder.CreateCall(f_mul_add, {a, b, res});
|
||||
|
||||
}
|
||||
result->set_value(idx, res);
|
||||
});
|
||||
|
Reference in New Issue
Block a user