[dnn/conv] some minor fixes

This commit is contained in:
Philippe Tillet
2019-05-08 10:09:30 -04:00
parent 615569287e
commit 54f888a270
2 changed files with 247 additions and 167 deletions

View File

@@ -9,22 +9,19 @@
int main() {
// initialize default compute device
auto context = triton::driver::backend::contexts::get_default();
// initialize just-in-time compiler
triton::jit jit(context);
triton::dnn::conv::type ty = triton::dnn::conv::BPROP;
// initialization
int32_t B = 4, NF = 32;
int32_t D = 1, H = 24, W = 240;
int32_t NC = 64, T = 1, R = 3, S = 3;
int32_t NC = 32, T = 1, R = 3, S = 3;
int32_t pad_d = 0, pad_h = 1, pad_w = 1;
int32_t stride_d = 1, stride_h = 1, stride_w = 1;
int32_t upsample_d = 1, upsample_h = 1, upsample_w = 1;
int32_t RD = (D*upsample_d - T + 1 + 2*pad_d + stride_d - 1)/stride_d;
int32_t RH = (H*upsample_h - R + 1 + 2*pad_h + stride_h - 1)/stride_h;
int32_t RW = (W*upsample_w - S + 1 + 2*pad_w + stride_w - 1)/stride_w;
// equivalent matmul dimensions
int32_t M = B*RD*RH*RW;
int32_t N = NF;
int32_t K = NC*T*R*S;
triton::dnn::conv configuration(B, NC, H, W, R, S, NF, 1, 1, pad_h, pad_w, ty);
// convolution configuration
std::vector<float> hc(B*RH*RW*NF);
std::vector<float> rc(B*RH*RW*NF);
@@ -36,7 +33,8 @@ int main() {
for(size_t i = 0; i < hb.size(); i++)
hb[i] = (float)rand()/RAND_MAX;
for(size_t i = 0; i < hc.size(); i++)
hc[i] = 0;
hc[i] = (float)rand()/RAND_MAX;
rc = hc;
triton::driver::buffer* dc = triton::driver::buffer::create(context, hc.size()*4);
triton::driver::buffer* da = triton::driver::buffer::create(context, ha.size()*4);
triton::driver::buffer* db = triton::driver::buffer::create(context, hb.size()*4);
@@ -45,80 +43,38 @@ int main() {
stream->write(db, true, 0, hb);
stream->write(dc, true, 0, hc);
stream->synchronize();
// memory strides for data
int32_t stride_i_w = 1;
int32_t stride_i_h = W*stride_i_w;
int32_t stride_i_d = H*stride_i_h;
int32_t stride_i_c = D*stride_i_d;
int32_t stride_i_n = NC*stride_i_c;
// memory stride for activations
int32_t stride_o_q = 1;
int32_t stride_o_p = RW*stride_o_q;
int32_t stride_o_m = RH*stride_o_p;
int32_t stride_o_k = RD*stride_o_m;
int32_t stride_o_n = NF*stride_o_k;
// look-up table
triton::dnn::conv configuration(B, NC, H, W, R, S, NF, 1, 1, 0, 0);
std::vector<int> h_delta, h_masks;
configuration.build_lut(h_delta, h_masks);
configuration.build_deltas(h_delta);
configuration.build_masks(h_masks);
// benchmark a given convolution kernel
auto benchmark = [&](triton::driver::kernel* kernel,
triton::jit::launch_information info) {
// launch info
unsigned TM = info.global_range_size[0];
unsigned TN = info.global_range_size[1];
// initialize constant memory
unsigned nthreads = info.num_threads;
std::array<size_t, 3> grid = configuration.get_grid(TM, TN);
triton::driver::buffer* delta = jit.get_buffer("delta");
triton::driver::buffer* masks = jit.get_buffer("masks");
stream->write(delta, false, 0, h_delta.size()*4, h_delta.data());
stream->write(masks, false, 0, h_masks.size()*4, h_masks.data());
stream->synchronize();
// launch info
unsigned nthreads = info.num_threads;
std::array<size_t, 3> grid = {(M + TM - 1)/TM, (N + TN - 1)/TN, 1};
// set arguments
kernel->setArg(0, da);
kernel->setArg(1, db);
kernel->setArg(2, dc);
kernel->setArg(3, M);
kernel->setArg(4, N);
kernel->setArg(5, K);
kernel->setArg(6, B);
kernel->setArg(7, H);
kernel->setArg(8, W);
kernel->setArg(9, NF);
kernel->setArg(10, RH);
kernel->setArg(11, RW);
kernel->setArg(12, NC);
kernel->setArg(13, R);
kernel->setArg(14, S);
kernel->setArg(15, stride_i_n);
kernel->setArg(16, stride_i_c);
kernel->setArg(17, stride_i_h);
kernel->setArg(18, stride_i_w);
kernel->setArg(19, stride_o_n);
kernel->setArg(20, stride_o_k);
kernel->setArg(21, stride_o_p);
kernel->setArg(22, stride_o_q);
kernel->setArg(23, pad_h);
kernel->setArg(24, pad_w);
// dry run
configuration.set_arg(kernel, da, db, dc);
stream->enqueue(kernel, grid, {nthreads, 1, 1});
stream->synchronize();
// benchmark
double ts = bench([&](){stream->enqueue(kernel, grid, {nthreads, 1, 1});},
[&](){ stream->synchronize(); }, *context->device());
return 2.*M*N*K / ts * 1e-3;
return configuration.get_nflops() / ts * 1e-3;
};
std::string src = triton::dnn::conv::src();
std::string src = configuration.src();
// jit.autotune("conv", src.c_str(), benchmark);
jit.add_module("conv", src.c_str(), triton::dnn::conv::default_params());
jit.add_module("conv", src.c_str(), configuration.default_params());
triton::driver::kernel* kernel = jit.get_function("conv");
triton::jit::launch_information info = jit.get_launch_info("conv");
std::cout << "Performance: " << benchmark(kernel, info) << " TFLOPS " << std::endl;
stream->read(dc, true, 0, hc);
cpp_conv_nchw(NC, B, NF, D, H, W, T, R, S, pad_d, pad_h, pad_w, stride_d, stride_h, stride_w, RD, RH, RW, rc, ha, hb);
for(size_t i = 0; i < M*N; i++)
configuration.cpu_ref(rc.data(), ha.data(), hb.data());
for(size_t i = 0; i < hc.size(); i++)
if(std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-4){
std::cout << i << " " << hc[i] << " " << rc[i] << std::endl;
exit(EXIT_FAILURE);

View File

@@ -1,5 +1,7 @@
#include <string>
#include <vector>
#include "triton/driver/stream.h"
#include "triton/driver/kernel.h"
namespace triton{
namespace dnn{
@@ -15,10 +17,13 @@ public:
conv(int B, int NC, int H, int W, int R, int S, int NF,
int upsample_h, int upsample_w,
int pad_h, int pad_w)
int pad_h, int pad_w,
type ty = FPROP)
: B_(B), NC_(NC), D_(1), H_(H), W_(W), T_(1), R_(R), S_(S), NF_(NF),
upsample_d_(1), upsample_h_(upsample_h), upsample_w_(upsample_w),
pad_d_(0), pad_h_(pad_h), pad_w_(pad_w)
stride_d_(1), stride_h_(1), stride_w_(1),
pad_d_(0), pad_h_(pad_h), pad_w_(pad_w),
ty_(ty)
{
RD_ = (D_*upsample_d_ - T_ + 1 + 2*pad_d_ + stride_d_ - 1)/stride_d_;
RH_ = (H_*upsample_h_ - R_ + 1 + 2*pad_h_ + stride_h_ - 1)/stride_h_;
@@ -26,9 +31,6 @@ public:
M_ = B*RD_*RH_*RW_;
N_ = NF;
K_ = NC*T_*R_*S_;
Fs_ = T_*R_*S_;
TK_ = 8;
Luts_ = (TK_ + Fs_ - 1) / Fs_ * Fs_;
// memory strides for data
stride_a_w_ = 1;
stride_a_h_ = W_*stride_a_w_;
@@ -41,88 +43,160 @@ public:
stride_c_m_ = RH_*stride_c_p_;
stride_c_k_ = RD_*stride_c_m_;
stride_c_n_ = NF_*stride_c_k_;
// swap a and c for bprop
if(ty_ == BPROP){
std::swap(stride_a_n_, stride_c_n_);
std::swap(stride_a_c_, stride_c_k_);
std::swap(stride_a_h_, stride_c_p_);
std::swap(stride_a_w_, stride_c_q_);
std::swap(D_, RD_);
std::swap(H_, RH_);
std::swap(W_, RW_);
pad_d_ = (RD_ - D_ + T_ - 1) / 2;
pad_h_ = (RH_ - H_ + R_ - 1) / 2;
pad_w_ = (RW_ - W_ + S_ - 1) / 2;
}
// look-up table info
Fs_ = T_*R_*S_;
TK_ = 8;
Luts_ = (TK_ + Fs_ - 1) / Fs_ * Fs_;
}
void build_lut(std::vector<int>& delta, std::vector<int>& masks) {
delta.resize(Luts_ + upsample_d_*upsample_h_*upsample_w_*Luts_);
masks.resize(Luts_ + (2*pad_h_+1)*(2*pad_w_+1)*(2*pad_d_+1)*Luts_);
/* unpack index wrt filters */
auto unpack = [&](int32_t trs){
int32_t tr = trs / S_;
int32_t s = trs - tr*S_;
int32_t t = tr / R_;
int32_t r = tr - t*R_;
return std::make_tuple(t, r, s);
};
/* increments */
for(size_t i = 0; i < Luts_; ++i)
delta[i] = (((i + TK_) % Luts_) - i);
/* deltas */
size_t Ds0 = Luts_;
size_t Ds1 = upsample_w_;
size_t Ds2 = upsample_h_;
size_t Ds3 = upsample_d_;
for(size_t pd = 0; pd < Ds3; ++pd)
for(size_t ph = 0; ph < Ds2; ++ph)
for(size_t pw = 0; pw < Ds1; ++pw){
int32_t* deltas_ptr = &delta[Luts_ + pw*Ds0 + ph*Ds0*Ds1 + pd*Ds0*Ds1*Ds2];
// cumulative increments
for(size_t i = 0; i < Ds0; ++i){
int32_t ctrs = i;
int32_t c = ctrs / Fs_;
int32_t t, r, s;
std::tie(t, r, s) = unpack(ctrs % Fs_);
// next indices
int32_t nextctrs = ctrs + TK_;
int32_t nextc = nextctrs / Fs_;
int32_t nextt, nextr, nexts;
std::tie(nextt, nextr, nexts) = unpack(nextctrs % Fs_);
// diffs
int32_t cdiff = nextc - c;
int32_t tdiff = (nextt + pd)/upsample_d_ - (t + pd)/upsample_d_;
int32_t rdiff = (nextr + ph)/upsample_h_ - (r + ph)/upsample_h_;
int32_t sdiff = (nexts + pw)/upsample_w_ - (s + pw)/upsample_w_;
// delta pointers
deltas_ptr[i] = cdiff*stride_a_c_ + sdiff*stride_a_w_ + rdiff*stride_a_h_ + tdiff*stride_a_d_;
}
void build_deltas(std::vector<int>& deltas){
deltas.resize(Luts_ + upsample_d_*upsample_h_*upsample_w_*Luts_);
auto unpack = [&](int32_t trs){
int32_t tr = trs / S_;
int32_t s = trs - tr*S_;
int32_t t = tr / R_;
int32_t r = tr - t*R_;
return std::make_tuple(t, r, s);
};
for(size_t i = 0; i < Luts_; ++i)
deltas[i] = (((i + TK_) % Luts_) - i);
size_t Ds0 = Luts_;
size_t Ds1 = upsample_w_;
size_t Ds2 = upsample_h_;
size_t Ds3 = upsample_d_;
for(size_t pd = 0; pd < Ds3; ++pd)
for(size_t ph = 0; ph < Ds2; ++ph)
for(size_t pw = 0; pw < Ds1; ++pw){
int32_t* deltas_ptr = &deltas[Luts_ + pw*Ds0 + ph*Ds0*Ds1 + pd*Ds0*Ds1*Ds2];
// cumulative increments
for(size_t i = 0; i < Ds0; ++i){
int32_t ctrs = i;
int32_t c = ctrs / Fs_;
int32_t t, r, s;
std::tie(t, r, s) = unpack(ctrs % Fs_);
// next indices
int32_t nextctrs = ctrs + TK_;
int32_t nextc = nextctrs / Fs_;
int32_t nextt, nextr, nexts;
std::tie(nextt, nextr, nexts) = unpack(nextctrs % Fs_);
// diffs
int32_t cdiff = nextc - c;
int32_t tdiff = (nextt + pd)/upsample_d_ - (t + pd)/upsample_d_;
int32_t rdiff = (nextr + ph)/upsample_h_ - (r + ph)/upsample_h_;
int32_t sdiff = (nexts + pw)/upsample_w_ - (s + pw)/upsample_w_;
// delta pointers
deltas_ptr[i] = cdiff*stride_a_c_ + sdiff*stride_a_w_ + rdiff*stride_a_h_ + tdiff*stride_a_d_;
}
}
}
/* Masks */
size_t Ms0 = Luts_;
size_t Ms1 = 2*pad_w_ + 1;
size_t Ms2 = 2*pad_h_ + 1;
size_t Ms3 = 2*pad_d_ + 1;
for(size_t pd = 0; pd < Ms3; ++pd)
for(size_t ph = 0; ph < Ms2; ++ph)
for(size_t pw = 0; pw < Ms1; ++pw){
int32_t* masks_ptr = &masks[Luts_ + pw*Ms0 + ph*Ms0*Ms1 + pd*Ms0*Ms1*Ms2];
for(size_t i = 0; i < Ms0; ++i){
int32_t t, r, s;
int32_t mask = 0x0;
for(size_t j = 0; j < TK_; ++j){
std::tie(t, r, s) = unpack((i + j) % Fs_);
bool in_bounds_d = (t + pd) >= pad_d_ && (t + pd) < (T_ + pad_d_);
bool in_bounds_h = (r + ph) >= pad_h_ && (r + ph) < (R_ + pad_h_);
bool in_bounds_w = (s + pw) >= pad_w_ && (s + pw) < (S_ + pad_w_);
mask |= (in_bounds_d && in_bounds_h && in_bounds_w) << j;
}
masks_ptr[i] = mask;
}
void build_masks(std::vector<int>& masks){
masks.resize(Luts_ + (2*pad_h_+1)*(2*pad_w_+1)*(2*pad_d_+1)*Luts_);
auto unpack = [&](int32_t trs){
int32_t tr = trs / S_;
int32_t s = trs - tr*S_;
int32_t t = tr / R_;
int32_t r = tr - t*R_;
return std::make_tuple(t, r, s);
};
size_t Ms0 = Luts_;
size_t Ms1 = 2*pad_w_ + 1;
size_t Ms2 = 2*pad_h_ + 1;
size_t Ms3 = 2*pad_d_ + 1;
for(size_t pd = 0; pd < Ms3; ++pd)
for(size_t ph = 0; ph < Ms2; ++ph)
for(size_t pw = 0; pw < Ms1; ++pw){
int32_t* masks_ptr = &masks[Luts_ + pw*Ms0 + ph*Ms0*Ms1 + pd*Ms0*Ms1*Ms2];
for(size_t i = 0; i < Ms0; ++i){
int32_t t, r, s;
int32_t mask = 0x0;
for(size_t j = 0; j < TK_; ++j){
std::tie(t, r, s) = unpack((i + j) % Fs_);
bool in_bounds_d = (t + pd) >= pad_d_ && (t + pd) < (T_ + pad_d_);
bool in_bounds_h = (r + ph) >= pad_h_ && (r + ph) < (R_ + pad_h_);
bool in_bounds_w = (s + pw) >= pad_w_ && (s + pw) < (S_ + pad_w_);
mask |= (in_bounds_d && in_bounds_h && in_bounds_w) << j;
}
masks_ptr[i] = mask;
}
for(size_t i = 0; i < Luts_; ++i)
masks[i] = 0x0;
}
for(size_t i = 0; i < Luts_; ++i)
masks[i] = 0x0;
}
static std::vector<unsigned> default_params() {
return {16, 2, 64, 32, 2, 64, 16, 8, 2, 2, 8, 1, 8, 4 };
std::array<size_t, 3> get_grid(size_t TM, size_t TN){
return {(M_ + TM - 1)/TM, (N_ + TN - 1)/TN, 1};
}
size_t get_nflops(){
return 2.*M_*N_*K_;
}
void set_arg(driver::kernel *kernel,
driver::buffer *a, driver::buffer *b, driver::buffer *c)
{
if(ty_ == BPROP)
std::swap(a, c);
kernel->setArg(0, a);
kernel->setArg(1, b);
kernel->setArg(2, c);
kernel->setArg(3, M_);
kernel->setArg(4, N_);
kernel->setArg(5, K_);
kernel->setArg(6, B_);
kernel->setArg(7, H_);
kernel->setArg(8, W_);
kernel->setArg(9, NF_);
kernel->setArg(10, RH_);
kernel->setArg(11, RW_);
kernel->setArg(12, NC_);
kernel->setArg(13, R_);
kernel->setArg(14, S_);
kernel->setArg(15, stride_a_n_);
kernel->setArg(16, stride_a_c_);
kernel->setArg(17, stride_a_h_);
kernel->setArg(18, stride_a_w_);
kernel->setArg(19, stride_c_n_);
kernel->setArg(20, stride_c_k_);
kernel->setArg(21, stride_c_p_);
kernel->setArg(22, stride_c_q_);
kernel->setArg(23, pad_h_);
kernel->setArg(24, pad_w_);
}
std::vector<unsigned> default_params() {
if(ty_ == FPROP)
return {16, 2, 64, 32, 2, 64, 16, 8, 2, 2, 8, 1, 8, 4};
else
return {16, 2, 64, 16, 32, 16, 4, 2, 2, 4, 2, 8, 4, 2};
}
static std::string src(type ty = FPROP) {
std::string src() {
std::string bs0 = "TN", bs1 = "TK";
std::string ldb0 = "*NF", ldb1 = "";
std::string bcb0 = "[:, newaxis]", bcb1 = "[newaxis, :]";
std::string b = "b";
if(ty_ == BPROP){
std::swap(bs0, bs1);
std::swap(ldb0, ldb1);
std::swap(bcb0, bcb1);
b = "trans(b)";
}
std::string res =
R"(
const tunable int32 TM = {16, 32, 64};
@@ -158,7 +232,7 @@ public:
int32 rar[TK] = racr % R;
int32 ra1[TK] = rac*lda_c + rar*lda_h + ras*lda_w;
fp32* pa[TM, TK] = a + ra1[newaxis, :] + ra0[:, newaxis];
fp32* pb[TN, TK] = b + rb1[newaxis, :]*NF + rb0[:, newaxis];
fp32* pb[)" + bs0 + ", " + bs1 + R"(] = b + rb1)" + bcb1 + ldb0 + " + rb0" + bcb0 + ldb1 + R"(;
__constant__ int32* pincd[TK] = delta + rka;
__constant__ int32* pd[TK] = delta + R*S + rka;
int32 d[TK] = *pd;
@@ -172,10 +246,10 @@ public:
int32 checka1[TK] = 1 << rka;
int1 checka[TM, TK] = (checka0[:, newaxis] & checka1[newaxis, :]) > 0;
fp32 a[TM, TK] = checka ? *pa : 0;
fp32 b[TN, TK] = *pb;
fp32 b[)" + bs0 + ", " + bs1 + R"(] = *pb;
for(int32 k = K; k > 0; k = k - TK){
C = dot(a, trans(b), C);
pb = pb + TK*NF;
pb = pb + TK)" + ldb0 + R"(;
pa = pa + d[newaxis, :];
b = *pb;
pd = pd + incd;
@@ -203,42 +277,90 @@ public:
return res;
}
template<class IN_DTYPE, class OUT_DTYPE>
void cpu_ref(OUT_DTYPE* C, IN_DTYPE* A, IN_DTYPE* B)
{
auto idx = [&](int32_t x, int32_t y, int32_t z, int32_t w, int32_t u,
int32_t /*s0*/, int32_t s1, int32_t s2, int32_t s3, int32_t s4)
{ return u + w*s4 + z*s4*s3 + y*s4*s3*s2 + x*s4*s3*s2*s1; };
if(ty_==BPROP){
std::swap(A, C);
}
std::cout << A[0] << std::endl;
IN_DTYPE accs[1];
float tmp[1];
for(int32_t m = 0 ; m < RD_; ++m)
for(int32_t p = 0 ; p < RH_; ++p)
for(int32_t q = 0; q < RW_; ++q)
for(int32_t n = 0; n < B_; ++n)
for(int32_t k = 0; k < NF_ ; ++k)
{
for(int32_t i = 0; i < 1; ++i)
accs[i] = 0;
int32_t mm = m*stride_d_ - pad_d_;
int32_t pp = p*stride_h_ - pad_h_;
int32_t qq = q*stride_w_ - pad_w_;
for(int32_t kk = 0; kk < 1; ++kk)
for(int32_t c = 0; c < NC_; ++c)
for(int32_t t = 0; t < T_; ++t)
for(int32_t r = 0; r < R_; ++r)
for(int32_t s = 0; s < S_; ++s){
int32_t d = mm + t;
int32_t h = pp + r;
int32_t w = qq + s;
bool in_bounds = (d >= 0 && h >= 0 && w >= 0 && d < D_ && h < H_ && w < W_);
IN_DTYPE a = in_bounds?A[idx(n, c, d, h, w, B_, NC_, D_, H_, W_)]:0;
IN_DTYPE b;
if(ty_==FPROP)
b = B[idx(c, t, r, s, k*1 + kk, NC_, T_, R_, S_, NF_*1)];
else
b = B[idx(c, t, s, r, k*1 + kk, NC_, T_, R_, S_, NF_*1)];
accs[kk] = std::fma(a, b, accs[kk]);
}
for(int32_t kk = 0; kk < 1; ++kk){
tmp[kk] = accs[kk];
}
C[idx(n, k, m, p, q, B_, NF_, RD_, RH_, RW_)] = tmp[0];
}
}
private:
// image size
int B_;
int NC_;
int D_;
int H_;
int W_;
int32_t B_;
int32_t NC_;
int32_t D_;
int32_t H_;
int32_t W_;
// filter size
int T_;
int R_;
int S_;
int NF_;
int32_t T_;
int32_t R_;
int32_t S_;
int32_t NF_;
// activation size
int RD_;
int RH_;
int RW_;
int32_t RD_;
int32_t RH_;
int32_t RW_;
// upsampling
int upsample_d_;
int upsample_h_;
int upsample_w_;
int32_t upsample_d_;
int32_t upsample_h_;
int32_t upsample_w_;
// padding
int pad_d_;
int pad_h_;
int pad_w_;
int32_t pad_d_;
int32_t pad_h_;
int32_t pad_w_;
// striding
int stride_d_;
int stride_h_;
int stride_w_;
int32_t stride_d_;
int32_t stride_h_;
int32_t stride_w_;
// equivalent matmul
int M_;
int N_;
int K_;
int32_t M_;
int32_t N_;
int32_t K_;
// helpers
int Fs_;
int TK_;
int Luts_;
int32_t Fs_;
int32_t TK_;
int32_t Luts_;
// memory strides for data
int32_t stride_a_w_;
int32_t stride_a_h_;
@@ -251,7 +373,9 @@ private:
int32_t stride_c_m_;
int32_t stride_c_k_;
int32_t stride_c_n_;
// type
type ty_;
bool is_bprop_;
};
}