more cleaning of conv
This commit is contained in:
@@ -25,6 +25,7 @@ int main() {
|
||||
int32_t M = B*RD*RH*RW;
|
||||
int32_t N = NF;
|
||||
int32_t K = NC*T*R*S;
|
||||
// convolution configuration
|
||||
std::vector<float> hc(B*RH*RW*NF);
|
||||
std::vector<float> rc(B*RH*RW*NF);
|
||||
std::vector<float> ha(B*NC*H*W);
|
||||
@@ -57,8 +58,9 @@ int main() {
|
||||
int32_t stride_o_k = RD*stride_o_m;
|
||||
int32_t stride_o_n = NF*stride_o_k;
|
||||
// look-up table
|
||||
triton::dnn::conv configuration(B, NC, H, W, R, S, NF, 1, 1, 0, 0);
|
||||
std::vector<int> h_delta, h_masks;
|
||||
triton::dnn::conv::init_cst(stride_i_d, stride_i_h, stride_i_w, stride_i_c, pad_d, pad_h, pad_w, T, R, S, h_delta, h_masks);
|
||||
configuration.build_lut(h_delta, h_masks);
|
||||
// benchmark a given convolution kernel
|
||||
auto benchmark = [&](triton::driver::kernel* kernel,
|
||||
triton::jit::launch_information info) {
|
||||
|
@@ -12,81 +12,106 @@ public:
|
||||
WGRAD
|
||||
};
|
||||
|
||||
static void build_lut(int TK,
|
||||
int stride_d, int stride_h, int stride_w, int stride_c,
|
||||
int pad_d, int pad_h, int pad_w,
|
||||
int T, int R, int S,
|
||||
std::vector<int>& res, std::vector<int>& masks) {
|
||||
/* convolution parameters */
|
||||
int F = T * R * S;
|
||||
int Nlut = (TK + F - 1) / F * F;
|
||||
int upsample_w = 1;
|
||||
int upsample_h = 1;
|
||||
int upsample_d = 1;
|
||||
|
||||
conv(int B, int NC, int H, int W, int R, int S, int NF,
|
||||
int upsample_h, int upsample_w,
|
||||
int pad_h, int pad_w)
|
||||
: B_(B), NC_(NC), D_(1), H_(H), W_(W), T_(1), R_(R), S_(S), NF_(NF),
|
||||
upsample_d_(1), upsample_h_(upsample_h), upsample_w_(upsample_w),
|
||||
pad_d_(0), pad_h_(pad_h), pad_w_(pad_w)
|
||||
{
|
||||
RD_ = (D_*upsample_d_ - T_ + 1 + 2*pad_d_ + stride_d_ - 1)/stride_d_;
|
||||
RH_ = (H_*upsample_h_ - R_ + 1 + 2*pad_h_ + stride_h_ - 1)/stride_h_;
|
||||
RW_ = (W_*upsample_w_ - S_ + 1 + 2*pad_w_ + stride_w_ - 1)/stride_w_;
|
||||
M_ = B*RD_*RH_*RW_;
|
||||
N_ = NF;
|
||||
K_ = NC*T_*R_*S_;
|
||||
Fs_ = T_*R_*S_;
|
||||
TK_ = 8;
|
||||
Luts_ = (TK_ + Fs_ - 1) / Fs_ * Fs_;
|
||||
// memory strides for data
|
||||
stride_a_w_ = 1;
|
||||
stride_a_h_ = W_*stride_a_w_;
|
||||
stride_a_d_ = H_*stride_a_h_;
|
||||
stride_a_c_ = D_*stride_a_d_;
|
||||
stride_a_n_ = NC_*stride_a_c_;
|
||||
// memory stride for activations
|
||||
stride_c_q_ = 1;
|
||||
stride_c_p_ = RW_*stride_c_q_;
|
||||
stride_c_m_ = RH_*stride_c_p_;
|
||||
stride_c_k_ = RD_*stride_c_m_;
|
||||
stride_c_n_ = NF_*stride_c_k_;
|
||||
}
|
||||
|
||||
|
||||
void build_lut(std::vector<int>& delta, std::vector<int>& masks) {
|
||||
delta.resize(Luts_ + upsample_d_*upsample_h_*upsample_w_*Luts_);
|
||||
masks.resize(Luts_ + (2*pad_h_+1)*(2*pad_w_+1)*(2*pad_d_+1)*Luts_);
|
||||
|
||||
/* unpack index wrt filters */
|
||||
auto unpack = [&](int32_t trs){
|
||||
int32_t tr = trs / S;
|
||||
int32_t s = trs - tr*S;
|
||||
int32_t t = tr / R;
|
||||
int32_t r = tr - t*R;
|
||||
int32_t tr = trs / S_;
|
||||
int32_t s = trs - tr*S_;
|
||||
int32_t t = tr / R_;
|
||||
int32_t r = tr - t*R_;
|
||||
return std::make_tuple(t, r, s);
|
||||
};
|
||||
/* increments */
|
||||
for(size_t i = 0; i < Nlut; ++i)
|
||||
res[i] = (((i + TK) % Nlut) - i);
|
||||
for(size_t i = 0; i < Luts_; ++i)
|
||||
delta[i] = (((i + TK_) % Luts_) - i);
|
||||
/* deltas */
|
||||
size_t Ds0 = Nlut;
|
||||
size_t Ds1 = upsample_w;
|
||||
size_t Ds2 = upsample_h;
|
||||
size_t Ds3 = upsample_d;
|
||||
size_t Ds0 = Luts_;
|
||||
size_t Ds1 = upsample_w_;
|
||||
size_t Ds2 = upsample_h_;
|
||||
size_t Ds3 = upsample_d_;
|
||||
for(size_t pd = 0; pd < Ds3; ++pd)
|
||||
for(size_t ph = 0; ph < Ds2; ++ph)
|
||||
for(size_t pw = 0; pw < Ds1; ++pw){
|
||||
int32_t* deltas_ptr = &res[Nlut + pw*Ds0 + ph*Ds0*Ds1 + pd*Ds0*Ds1*Ds2];
|
||||
int32_t* deltas_ptr = &delta[Luts_ + pw*Ds0 + ph*Ds0*Ds1 + pd*Ds0*Ds1*Ds2];
|
||||
// cumulative increments
|
||||
for(size_t i = 0; i < Ds0; ++i){
|
||||
int32_t ctrs = i;
|
||||
int32_t c = ctrs / F;
|
||||
int32_t c = ctrs / Fs_;
|
||||
int32_t t, r, s;
|
||||
std::tie(t, r, s) = unpack(ctrs % F);
|
||||
std::tie(t, r, s) = unpack(ctrs % Fs_);
|
||||
// next indices
|
||||
int32_t nextctrs = ctrs + TK;
|
||||
int32_t nextc = nextctrs / F;
|
||||
int32_t nextctrs = ctrs + TK_;
|
||||
int32_t nextc = nextctrs / Fs_;
|
||||
int32_t nextt, nextr, nexts;
|
||||
std::tie(nextt, nextr, nexts) = unpack(nextctrs % F);
|
||||
std::tie(nextt, nextr, nexts) = unpack(nextctrs % Fs_);
|
||||
// diffs
|
||||
int32_t cdiff = nextc - c;
|
||||
int32_t tdiff = (nextt + pd)/upsample_d - (t + pd)/upsample_d;
|
||||
int32_t rdiff = (nextr + ph)/upsample_h - (r + ph)/upsample_h;
|
||||
int32_t sdiff = (nexts + pw)/upsample_w - (s + pw)/upsample_w;
|
||||
int32_t tdiff = (nextt + pd)/upsample_d_ - (t + pd)/upsample_d_;
|
||||
int32_t rdiff = (nextr + ph)/upsample_h_ - (r + ph)/upsample_h_;
|
||||
int32_t sdiff = (nexts + pw)/upsample_w_ - (s + pw)/upsample_w_;
|
||||
// delta pointers
|
||||
deltas_ptr[i] = cdiff*stride_c + sdiff*stride_w + rdiff*stride_h + tdiff*stride_d;
|
||||
deltas_ptr[i] = cdiff*stride_a_c_ + sdiff*stride_a_w_ + rdiff*stride_a_h_ + tdiff*stride_a_d_;
|
||||
}
|
||||
}
|
||||
|
||||
/* Masks */
|
||||
size_t Ms0 = Nlut;
|
||||
size_t Ms1 = 2*pad_w + 1;
|
||||
size_t Ms2 = 2*pad_h + 1;
|
||||
size_t Ms3 = 2*pad_d + 1;
|
||||
size_t Ms0 = Luts_;
|
||||
size_t Ms1 = 2*pad_w_ + 1;
|
||||
size_t Ms2 = 2*pad_h_ + 1;
|
||||
size_t Ms3 = 2*pad_d_ + 1;
|
||||
for(size_t pd = 0; pd < Ms3; ++pd)
|
||||
for(size_t ph = 0; ph < Ms2; ++ph)
|
||||
for(size_t pw = 0; pw < Ms1; ++pw){
|
||||
int32_t* masks_ptr = &masks[Nlut + pw*Ms0 + ph*Ms0*Ms1 + pd*Ms0*Ms1*Ms2];
|
||||
int32_t* masks_ptr = &masks[Luts_ + pw*Ms0 + ph*Ms0*Ms1 + pd*Ms0*Ms1*Ms2];
|
||||
for(size_t i = 0; i < Ms0; ++i){
|
||||
int32_t t, r, s;
|
||||
int32_t mask = 0x0;
|
||||
for(size_t j = 0; j < TK; ++j){
|
||||
std::tie(t, r, s) = unpack((i + j) % F);
|
||||
bool in_bounds_d = (t + pd) >= pad_d && (t + pd) < (T + pad_d);
|
||||
bool in_bounds_h = (r + ph) >= pad_h && (r + ph) < (R + pad_h);
|
||||
bool in_bounds_w = (s + pw) >= pad_w && (s + pw) < (S + pad_w);
|
||||
for(size_t j = 0; j < TK_; ++j){
|
||||
std::tie(t, r, s) = unpack((i + j) % Fs_);
|
||||
bool in_bounds_d = (t + pd) >= pad_d_ && (t + pd) < (T_ + pad_d_);
|
||||
bool in_bounds_h = (r + ph) >= pad_h_ && (r + ph) < (R_ + pad_h_);
|
||||
bool in_bounds_w = (s + pw) >= pad_w_ && (s + pw) < (S_ + pad_w_);
|
||||
mask |= (in_bounds_d && in_bounds_h && in_bounds_w) << j;
|
||||
}
|
||||
masks_ptr[i] = mask;
|
||||
}
|
||||
}
|
||||
for(size_t i = 0; i < Nlut; ++i)
|
||||
for(size_t i = 0; i < Luts_; ++i)
|
||||
masks[i] = 0x0;
|
||||
|
||||
}
|
||||
@@ -95,20 +120,6 @@ public:
|
||||
return {16, 2, 64, 32, 2, 64, 16, 8, 2, 2, 8, 1, 8, 4 };
|
||||
}
|
||||
|
||||
static void init_cst(int stride_d, int stride_h, int stride_w, int stride_c,
|
||||
int pad_d, int pad_h, int pad_w,
|
||||
int T, int R, int S,
|
||||
std::vector<int> &h_delta, std::vector<int> &h_masks) {
|
||||
int upsample_d = 1;
|
||||
int upsample_h = 1;
|
||||
int upsample_w = 1;
|
||||
int TK = 8;
|
||||
int F = T * R * S;
|
||||
int nlut = (TK + F - 1) / F * F;
|
||||
h_delta.resize(nlut + upsample_d*upsample_h*upsample_w*nlut);
|
||||
h_masks.resize(nlut + (2*pad_h+1)*(2*pad_w+1)*(2*pad_d+1)*nlut);
|
||||
build_lut(TK, stride_d, stride_h, stride_w, stride_c, pad_d, pad_h, pad_w, T, R, S, h_delta, h_masks);
|
||||
}
|
||||
|
||||
static std::string src(type ty = FPROP) {
|
||||
|
||||
@@ -191,6 +202,56 @@ public:
|
||||
})";
|
||||
return res;
|
||||
}
|
||||
|
||||
private:
|
||||
// image size
|
||||
int B_;
|
||||
int NC_;
|
||||
int D_;
|
||||
int H_;
|
||||
int W_;
|
||||
// filter size
|
||||
int T_;
|
||||
int R_;
|
||||
int S_;
|
||||
int NF_;
|
||||
// activation size
|
||||
int RD_;
|
||||
int RH_;
|
||||
int RW_;
|
||||
// upsampling
|
||||
int upsample_d_;
|
||||
int upsample_h_;
|
||||
int upsample_w_;
|
||||
// padding
|
||||
int pad_d_;
|
||||
int pad_h_;
|
||||
int pad_w_;
|
||||
// striding
|
||||
int stride_d_;
|
||||
int stride_h_;
|
||||
int stride_w_;
|
||||
// equivalent matmul
|
||||
int M_;
|
||||
int N_;
|
||||
int K_;
|
||||
// helpers
|
||||
int Fs_;
|
||||
int TK_;
|
||||
int Luts_;
|
||||
// memory strides for data
|
||||
int32_t stride_a_w_;
|
||||
int32_t stride_a_h_;
|
||||
int32_t stride_a_d_;
|
||||
int32_t stride_a_c_;
|
||||
int32_t stride_a_n_;
|
||||
// memory stride for activations
|
||||
int32_t stride_c_q_;
|
||||
int32_t stride_c_p_;
|
||||
int32_t stride_c_m_;
|
||||
int32_t stride_c_k_;
|
||||
int32_t stride_c_n_;
|
||||
|
||||
};
|
||||
|
||||
}
|
||||
|
Reference in New Issue
Block a user