From 066ae338f1370ff72e6e394709054fa69a82dece Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Tue, 9 Jul 2019 14:08:51 -0700 Subject: [PATCH] [dnn/shift]: added stride to shift --- examples/python/tensorflow/run.py | 34 +++++++++---------- examples/python/tensorflow/shift.cpp | 30 +++++++++++++---- include/triton/dnn/shift.h | 5 +++ lib/dnn/shift.cpp | 49 ++++++++++++++++------------ 4 files changed, 74 insertions(+), 44 deletions(-) diff --git a/examples/python/tensorflow/run.py b/examples/python/tensorflow/run.py index a0f107ea4..1fcf68587 100644 --- a/examples/python/tensorflow/run.py +++ b/examples/python/tensorflow/run.py @@ -49,35 +49,35 @@ def run_conv(): def blocksparse_matmul_grad(op, dy): shift_h = op.get_attr('shift_h') shift_w = op.get_attr('shift_w') + stride_h = op.get_attr('stride_h') + stride_w = op.get_attr('stride_w') x = op.inputs[0] w = op.inputs[1] - dx = module.shift_conv_dx(dy, w, shift_h=shift_h, shift_w=shift_w) - dw = module.shift_conv_dw(dy, x, shift_h=shift_h, shift_w=shift_w) + dx = module.shift_conv_dx(dy, w, stride_h=stride_h, stride_w=stride_w, shift_h=shift_h, shift_w=shift_w) + dw = module.shift_conv_dw(dy, x, stride_h=stride_h, stride_w=stride_w, shift_h=shift_h, shift_w=shift_w) return (dx, dw) def run_shift(): - B, C, H, W = 16, 1024, 8, 8 - R, S, F = 3, 3, 1024 + B, C, H, W = 16, 16, 4, 4 + R, S, F = 3, 3, 4 + stride_h, stride_w = 2, 2 np.random.seed(2) a = tf.placeholder(tf.float32, shape=[C, H, W, B]) b = tf.placeholder(tf.float32, shape=[C, F]) hshift_h = np.random.randint(- (R//2), R//2 + 1, size=C, dtype=np.int32) hshift_w = np.random.randint(- (S//2), R//2 + 1, size=C, dtype=np.int32) - #hshift_h = np.ones(C, dtype=np.int32) - #hshift_w = np.ones(C, dtype=np.int32) - c = module.shift_conv(a, b, shift_h=tf.make_tensor_proto(hshift_h), shift_w=tf.make_tensor_proto(hshift_w)) - # Reference + c = module.shift_conv(a, b, stride_h=stride_h, stride_w=stride_w, shift_h=tf.make_tensor_proto(hshift_h), shift_w=tf.make_tensor_proto(hshift_w)) + # feed values ha = np.random.rand(C, H, W, B) hb = np.random.rand(C, F) - #ha = np.ones((C, H, W, B), dtype=np.int32) - #hb = np.ones((C, F), dtype=np.int32) sess = tf.InteractiveSession() - #grads = tf.test.compute_gradient([a, b], [(C, H, W, B), (C, F)], c, (F, H, W, B), - # extra_feed_dict = {a: ha, b: hb}) - #dw_t, dw_n = grads[1] - #dx_t, dx_n = grads[0] - #print(np.max(np.abs(dw_t - dw_n))) - #print(np.max(np.abs(dx_t - dx_n))) + # test + grads = tf.test.compute_gradient([a, b], [(C, H, W, B), (C, F)], c, (F, H//stride_h, W//stride_w, B), + extra_feed_dict = {a: ha, b: hb}) + dw_t, dw_n = grads[1] + dx_t, dx_n = grads[0] + print(np.max(np.abs(dw_t - dw_n))) + print(np.max(np.abs(dx_t - dx_n))) # Run sess.run(tf.global_variables_initializer()) result = sess.run([c], feed_dict = {a: ha, @@ -127,4 +127,4 @@ def run_batchnorm(): print(np.max(np.abs(dg_t - dg_n))) print(np.max(np.abs(db_t - db_n))) -run_batchnorm() +run_shift() diff --git a/examples/python/tensorflow/shift.cpp b/examples/python/tensorflow/shift.cpp index 6e9abec55..0ccd06d1f 100644 --- a/examples/python/tensorflow/shift.cpp +++ b/examples/python/tensorflow/shift.cpp @@ -34,6 +34,8 @@ public: explicit ShiftConvOp(OpKernelConstruction* context) : OpKernel(context) { context->GetAttr("shift_h", &h_shift_h_); context->GetAttr("shift_w", &h_shift_w_); + context->GetAttr("stride_h", &stride_h_); + context->GetAttr("stride_w", &stride_w_); R_ = 3; S_ = 3; } @@ -52,12 +54,12 @@ public: int64_t Hb = tf_b.dim_size(1); int64_t Wb = tf_b.dim_size(2); int64_t Bb = tf_b.dim_size(3); - OP_REQUIRES(context, Ha == Hb, tensorflow::errors::InvalidArgument("operands must have the same image height")); - OP_REQUIRES(context, Wa == Wb, tensorflow::errors::InvalidArgument("operands must have the same image width")); + OP_REQUIRES(context, Ha*stride_h_ == Hb, tensorflow::errors::InvalidArgument("operands must have the same image height")); + OP_REQUIRES(context, Wa*stride_w_ == Wb, tensorflow::errors::InvalidArgument("operands must have the same image width")); OP_REQUIRES(context, Ba == Bb, tensorflow::errors::InvalidArgument("operands must have the same batch size")); - H = Ha; - W = Wa; - B = Ba; + H = Hb; + W = Wb; + B = Bb; } else { // shapes for a @@ -65,6 +67,10 @@ public: H = tf_a.dim_size(1); W = tf_a.dim_size(2); B = tf_a.dim_size(3); + if(OP == triton::dnn::shift::BPROP){ + H *= stride_h_; + W *= stride_w_; + } // shapes for b int64_t Cb = tf_b.dim_size(0); F = tf_b.dim_size(1); @@ -104,7 +110,9 @@ public: if(m_config.find(key) == m_config.end()) shift = m_config.emplace(key, new triton::dnn::shift( B, C, D, H, W, T, R_, S_, F, - shift_h, shift_w, "fp32", "fp32", OP, has_bias)) + stride_h_, stride_w_, + shift_h, shift_w, + "fp32", "fp32", OP, has_bias)) .first->second.get(); else shift = m_config.at(key).get(); @@ -125,7 +133,7 @@ public: triton::driver::cu_buffer dc(ctx, (CUdeviceptr)tf_c->flat().data(), false); // get JIT triton::jit* jit; - bool autotune = true; + bool autotune = false; if(m_jit.find(key) == m_jit.end()) { jit = m_jit.emplace(key, new triton::jit(ctx)).first->second.get(); std::ostringstream oss; @@ -171,6 +179,8 @@ public: private: Tensor h_shift_h_; Tensor h_shift_w_; + int stride_h_; + int stride_w_; int R_; int S_; }; @@ -181,6 +191,8 @@ REGISTER_OP("ShiftConv") .Input("b: float32") .Attr("shift_h: tensor") .Attr("shift_w: tensor") + .Attr("stride_h: int") + .Attr("stride_w: int") .Output("c: float32"); REGISTER_KERNEL_BUILDER(Name("ShiftConvDx").Device(DEVICE_GPU), ShiftConvOp); @@ -189,6 +201,8 @@ REGISTER_OP("ShiftConvDx") .Input("b: float32") .Attr("shift_h: tensor") .Attr("shift_w: tensor") + .Attr("stride_h: int") + .Attr("stride_w: int") .Output("c: float32"); REGISTER_KERNEL_BUILDER(Name("ShiftConvDw").Device(DEVICE_GPU), ShiftConvOp); @@ -197,5 +211,7 @@ REGISTER_OP("ShiftConvDw") .Input("b: float32") .Attr("shift_h: tensor") .Attr("shift_w: tensor") + .Attr("stride_h: int") + .Attr("stride_w: int") .Output("c: float32"); diff --git a/include/triton/dnn/shift.h b/include/triton/dnn/shift.h index 3c4b53037..e9bd921df 100644 --- a/include/triton/dnn/shift.h +++ b/include/triton/dnn/shift.h @@ -52,6 +52,7 @@ public: shift(int B, int NC, int D, int H, int W, int T, int R, int S, int NF, + int stride_h, int stride_w, const std::vector &shift_h, const std::vector &shift_w, std::string a_ty = "fp32", std::string b_ty = "fp32", type ty = FPROP, bool bias = false); @@ -133,6 +134,10 @@ private: std::vector shapes_a_; std::vector shapes_b_; std::vector shapes_c_; + // strides + int32_t stride_d_; + int32_t stride_h_; + int32_t stride_w_; // memory strides std::vector ld_a_; std::vector ld_b_; diff --git a/lib/dnn/shift.cpp b/lib/dnn/shift.cpp index 66f3e0c35..99078e0cd 100644 --- a/lib/dnn/shift.cpp +++ b/lib/dnn/shift.cpp @@ -17,6 +17,7 @@ shift::shift(int B, int C, int D, int H, int W, int T, int R, int S, int F, + int stride_h, int stride_w, const std::vector& shift_h, const std::vector& shift_w, std::string a_ty, std::string b_ty, type ty, bool bias) @@ -24,6 +25,7 @@ shift::shift(int B, int C, AD_(D), AH_(H), AW_(W), BD_(T), BH_(R), BW_(S), F_(F), + stride_d_(1), stride_h_(stride_h), stride_w_(stride_w), shift_h_(shift_h), shift_w_(shift_w), a_ty_(a_ty), b_ty_(b_ty), ty_(ty), bias_(bias) { @@ -33,17 +35,21 @@ shift::shift(int B, int C, // transpose AT_ = false; BT_ = true; + // activation sizes + CD_ = AD_ / stride_d_; + CH_ = AH_ / stride_h_; + CW_ = AW_ / stride_w_; // equivalent matmul - M_ = B_*AH_*AW_; + M_ = B_*CH_*CW_; N_ = F_; K_ = C_; // shapes // input layout: C, H, W, B // filter layout: C, F // output layout: F, H, W, B - shapes_a_ = {C, H, W, B}; + shapes_a_ = {C, AH_, AW_, B}; shapes_b_ = {C, F}; - shapes_c_ = {F, H, W, B}; + shapes_c_ = {F, CH_, CW_, B}; if(ty_ == WGRAD){ shapes_b_.swap(shapes_c_); shapes_a_.swap(shapes_b_); @@ -51,14 +57,14 @@ shift::shift(int B, int C, BT_ = false; M_ = F_; N_ = C_; - K_ = B_*AH_*AW_; + K_ = B_*CH_*CW_; } if(ty_ == BPROP){ shapes_a_.swap(shapes_c_); AT_ = false; BT_ = false; K_ = F_; - M_ = B_*AH_*AW_; + M_ = B_*CH_*CW_; N_ = C_; } // memory strides @@ -133,13 +139,15 @@ void shift::enqueue(driver::stream *stream, driver::kernel *kernel, kernel->setArg(3, M_); kernel->setArg(4, N_); kernel->setArg(5, K_); - kernel->setArg(6, lda); - kernel->setArg(7, ldb); - kernel->setArg(8, B_); - kernel->setArg(9, AH_); - kernel->setArg(10, AW_); - kernel->setArg(11, BH_); - kernel->setArg(12, BW_); + kernel->setArg(6, stride_h_); + kernel->setArg(7, stride_w_); + kernel->setArg(8, lda); + kernel->setArg(9, ldb); + kernel->setArg(10, B_); + kernel->setArg(11, AH_); + kernel->setArg(12, AW_); + kernel->setArg(13, BH_); + kernel->setArg(14, BW_); std::array grid = {(M_ + TM - 1)/TM, (N_ + TN - 1)/TN, 1}; if(ty_ == BPROP) ((driver::cu_buffer*)c)->set_zero(stream, M_*N_*4); @@ -188,6 +196,7 @@ void shift(restrict read_only align(16) )" << a_ty_ << R"( *a, restrict read_only align(16) )" << b_ty_ << R"( *b, fp32 *c, int32 M, int32 N, int32 K, + int32 stride_h, int32 stride_w, int32 lda, int32 ldb, int32 ABS, int32 AH, int32 AW, int32 AR, int32 AS) { int32 rxa[TM] = get_global_range[TM](0); @@ -200,9 +209,9 @@ void shift(restrict read_only align(16) )" << a_ty_ << R"( *a, if(ty_ == FPROP){ os << R"( int32 rawhc[TM] = rxa / ABS; - int32 raw[TM] = rawhc % AW; + int32 raw[TM] = (rawhc % AW)*stride_w; int32 rahc[TM] = rawhc / AW; - int32 rah[TM] = rahc % AH; + int32 rah[TM] = (rahc % AH)*stride_h; __constant__ int32* pd[TK] = delta + rka; multiple_of(4) int32 d[TK] = *pd; int1 interiorh[TM] = (rah >= pad_h) && (rah < (AH - pad_h)); @@ -227,9 +236,9 @@ if(ty_ == WGRAD){ if(ty_ == WGRAD){ os << R"( int32 rbwhc[TK] = rkb / ABS; - int32 rbw[TK] = rbwhc % AW; + int32 rbw[TK] = (rbwhc % AW)*stride_w; int32 rbhc[TK] = rbwhc / AW; - int32 rbh[TK] = rbhc % AH; + int32 rbh[TK] = (rbhc % AH)*stride_h; int1 interiorh[TK] = (rbh >= pad_h) && (rbh < (AH - pad_h)); int1 interiorw[TK] = (rbw >= pad_w) && (rbw < (AW - pad_w)); int1 interior[TK, TN] = interiorh[:, newaxis] && interiorw[:, newaxis]; @@ -266,9 +275,9 @@ if(ty_ == WGRAD){ pb = pb + TK)" << ldb0 << R"(; rkb = rkb + TK; rbwhc = rkb / ABS; - rbw = rbwhc % AW; + rbw = (rbwhc % AW)*stride_w; rbhc = rbwhc / AW; - rbh = rbhc % AH; + rbh = (rbhc % AH)*stride_h; interiorh = (rbh >= pad_h) && (rbh < (AH - pad_h)); interiorw = (rbw >= pad_w) && (rbw < (AW - pad_w)); interior = interiorh[:, newaxis] && interiorw[:, newaxis]; @@ -292,9 +301,9 @@ else{ if(ty_ == BPROP){ os << R"( int32 rcwhc[TM] = rxc / ABS; - int32 rcw[TM] = rcwhc % AW; + int32 rcw[TM] = (rcwhc % AW)*stride_w; int32 rchc[TM] = rcwhc / AW; - int32 rch[TM] = rchc % AH; + int32 rch[TM] = (rchc % AH)*stride_h; int1 interiorh[TM] = (rch >= pad_h) && (rch < (AH - pad_h)); int1 interiorw[TM] = (rcw >= pad_w) && (rcw < (AW - pad_w)); int1 interior[TM, TN] = interiorh[:, newaxis] && interiorw[:, newaxis];