diff --git a/examples/cpp/conv.cpp b/examples/cpp/conv.cpp index 4d55babe7..df5b22803 100644 --- a/examples/cpp/conv.cpp +++ b/examples/cpp/conv.cpp @@ -16,24 +16,19 @@ int main() { int32_t D = 1, H = 24, W = 240; int32_t NC = 32, T = 1, R = 3, S = 3; int32_t pad_d = 0, pad_h = 1, pad_w = 1; - int32_t stride_d = 1, stride_h = 1, stride_w = 1; - int32_t upsample_d = 1, upsample_h = 1, upsample_w = 1; - int32_t RD = (D*upsample_d - T + 1 + 2*pad_d + stride_d - 1)/stride_d; - int32_t RH = (H*upsample_h - R + 1 + 2*pad_h + stride_h - 1)/stride_h; - int32_t RW = (W*upsample_w - S + 1 + 2*pad_w + stride_w - 1)/stride_w; triton::dnn::conv configuration(B, NC, H, W, R, S, NF, 1, 1, pad_h, pad_w, ty); // convolution configuration - std::vector hc(B*RH*RW*NF); - std::vector rc(B*RH*RW*NF); - std::vector ha(B*NC*H*W); - std::vector hb(NC*R*S*NF); + std::vector hc(configuration.c_size()); + std::vector rc(configuration.c_size()); + std::vector ha(configuration.a_size()); + std::vector hb(configuration.b_size()); srand(0); for(size_t i = 0; i < ha.size(); i++) ha[i] = (float)rand()/RAND_MAX; for(size_t i = 0; i < hb.size(); i++) hb[i] = (float)rand()/RAND_MAX; for(size_t i = 0; i < hc.size(); i++) - hc[i] = (float)rand()/RAND_MAX; + hc[i] = 0; rc = hc; triton::driver::buffer* dc = triton::driver::buffer::create(context, hc.size()*4); triton::driver::buffer* da = triton::driver::buffer::create(context, ha.size()*4); @@ -74,6 +69,7 @@ int main() { std::cout << "Performance: " << benchmark(kernel, info) << " TFLOPS " << std::endl; stream->read(dc, true, 0, hc); configuration.cpu_ref(rc.data(), ha.data(), hb.data()); +// std::cout << c[0] << std::endl; for(size_t i = 0; i < hc.size(); i++) if(std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-4){ std::cout << i << " " << hc[i] << " " << rc[i] << std::endl; diff --git a/include/triton/dnn/conv.h b/include/triton/dnn/conv.h index 11e222f3f..85bb1e038 100644 --- a/include/triton/dnn/conv.h +++ b/include/triton/dnn/conv.h @@ -28,9 +28,6 @@ public: RD_ = (D_*upsample_d_ - T_ + 1 + 2*pad_d_ + stride_d_ - 1)/stride_d_; RH_ = (H_*upsample_h_ - R_ + 1 + 2*pad_h_ + stride_h_ - 1)/stride_h_; RW_ = (W_*upsample_w_ - S_ + 1 + 2*pad_w_ + stride_w_ - 1)/stride_w_; - M_ = B*RD_*RH_*RW_; - N_ = NF; - K_ = NC*T_*R_*S_; // memory strides for data stride_a_w_ = 1; stride_a_h_ = W_*stride_a_w_; @@ -52,16 +49,33 @@ public: std::swap(D_, RD_); std::swap(H_, RH_); std::swap(W_, RW_); + std::swap(NF_, NC_); pad_d_ = (RD_ - D_ + T_ - 1) / 2; pad_h_ = (RH_ - H_ + R_ - 1) / 2; pad_w_ = (RW_ - W_ + S_ - 1) / 2; } + // equivalent matmul + M_ = B_*RD_*RH_*RW_; + N_ = NF_; + K_ = NC_*T_*R_*S_; // look-up table info Fs_ = T_*R_*S_; TK_ = 8; Luts_ = (TK_ + Fs_ - 1) / Fs_ * Fs_; } + size_t a_size() { + return B_*NC_*D_*H_*W_; + } + + size_t b_size() { + return NC_*NF_*T_*R_*S_; + } + + size_t c_size() { + return B_*NF_*RD_*RH_*RW_; + } + void build_deltas(std::vector& deltas){ deltas.resize(Luts_ + upsample_d_*upsample_h_*upsample_w_*Luts_); auto unpack = [&](int32_t trs){ @@ -148,9 +162,6 @@ public: void set_arg(driver::kernel *kernel, driver::buffer *a, driver::buffer *b, driver::buffer *c) { - - if(ty_ == BPROP) - std::swap(a, c); kernel->setArg(0, a); kernel->setArg(1, b); kernel->setArg(2, c); @@ -179,10 +190,10 @@ public: } std::vector default_params() { - if(ty_ == FPROP) +// if(ty_ == FPROP) return {16, 2, 64, 32, 2, 64, 16, 8, 2, 2, 8, 1, 8, 4}; - else - return {16, 2, 64, 16, 32, 16, 4, 2, 2, 4, 2, 8, 4, 2}; +// else +// return {16, 2, 64, 16, 32, 16, 4, 2, 2, 4, 2, 8, 4, 2}; } @@ -232,7 +243,7 @@ public: int32 rar[TK] = racr % R; int32 ra1[TK] = rac*lda_c + rar*lda_h + ras*lda_w; fp32* pa[TM, TK] = a + ra1[newaxis, :] + ra0[:, newaxis]; - fp32* pb[)" + bs0 + ", " + bs1 + R"(] = b + rb1)" + bcb1 + ldb0 + " + rb0" + bcb0 + ldb1 + R"(; + fp32* pb[TN, TK] = b + rb1[newaxis, :]*NF + rb0[:, newaxis]; __constant__ int32* pincd[TK] = delta + rka; __constant__ int32* pd[TK] = delta + R*S + rka; int32 d[TK] = *pd; @@ -246,10 +257,10 @@ public: int32 checka1[TK] = 1 << rka; int1 checka[TM, TK] = (checka0[:, newaxis] & checka1[newaxis, :]) > 0; fp32 a[TM, TK] = checka ? *pa : 0; - fp32 b[)" + bs0 + ", " + bs1 + R"(] = *pb; + fp32 b[TN, TK] = *pb; for(int32 k = K; k > 0; k = k - TK){ C = dot(a, trans(b), C); - pb = pb + TK)" + ldb0 + R"(; + pb = pb + TK*NF; pa = pa + d[newaxis, :]; b = *pb; pd = pd + incd; @@ -284,10 +295,6 @@ public: int32_t /*s0*/, int32_t s1, int32_t s2, int32_t s3, int32_t s4) { return u + w*s4 + z*s4*s3 + y*s4*s3*s2 + x*s4*s3*s2*s1; }; - if(ty_==BPROP){ - std::swap(A, C); - } - std::cout << A[0] << std::endl; IN_DTYPE accs[1]; float tmp[1]; for(int32_t m = 0 ; m < RD_; ++m) @@ -311,11 +318,7 @@ public: int32_t w = qq + s; bool in_bounds = (d >= 0 && h >= 0 && w >= 0 && d < D_ && h < H_ && w < W_); IN_DTYPE a = in_bounds?A[idx(n, c, d, h, w, B_, NC_, D_, H_, W_)]:0; - IN_DTYPE b; - if(ty_==FPROP) - b = B[idx(c, t, r, s, k*1 + kk, NC_, T_, R_, S_, NF_*1)]; - else - b = B[idx(c, t, s, r, k*1 + kk, NC_, T_, R_, S_, NF_*1)]; + IN_DTYPE b = B[idx(c, t, r, s, k*1 + kk, NC_, T_, R_, S_, NF_*1)]; accs[kk] = std::fma(a, b, accs[kk]); } for(int32_t kk = 0; kk < 1; ++kk){