diff --git a/examples/python/pytorch/run.py b/examples/python/pytorch/run.py index 488f547f8..59f70d6c5 100644 --- a/examples/python/pytorch/run.py +++ b/examples/python/pytorch/run.py @@ -78,14 +78,10 @@ class NetReference(nn.Module): self.fc2 = nn.Linear(500, 10) def forward(self, x): - #x = x.permute(1, 2, 3, 0).contiguous() x = self.conv1(x) - #x = x.permute(3, 0, 1, 2).contiguous() x = self.bn1(x) x = F.relu(x) - #x = x.permute(1, 2, 3, 0).contiguous() x = self.conv2(x) - #x = x.permute(3, 0, 1, 2).contiguous() x = self.bn2(x) x = F.relu(x) x = x.view(-1, 32*7*7) diff --git a/examples/python/pytorch/shift.cpp b/examples/python/pytorch/shift.cpp index 1da8f3fbd..d650ca9e6 100644 --- a/examples/python/pytorch/shift.cpp +++ b/examples/python/pytorch/shift.cpp @@ -9,12 +9,34 @@ #define CHECK_CONTIGUOUS(x) AT_CHECK(x.is_contiguous(), #x " must be contiguous") #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) +void extract_shapes(const torch::Tensor &x, + int64_t &C, int64_t &H, int64_t &W, int64_t &B, + triton::dnn::shift::layout_t layout) { + if(layout == triton::dnn::shift::CHWN){ + C = x.size(0); + H = x.size(1); + W = x.size(2); + B = x.size(3); + } + else if(layout == triton::dnn::shift::NCHW){ + B = x.size(0); + C = x.size(1); + H = x.size(2); + W = x.size(3); + } + else{ + throw std::runtime_error("unsupported layout"); + } +} + +static const triton::dnn::shift::layout_t layout = triton::dnn::shift::NCHW; + torch::Tensor shift_common( int32_t B, int32_t C, int32_t D, int32_t H, int32_t W, int32_t T, int32_t R, int32_t S, int32_t F, int32_t stride_h, int32_t stride_w, int32_t* shift_h, int32_t* shift_w, - triton::dnn::shift::type ty, + triton::dnn::shift::type ty, triton::dnn::shift::layout_t layout, torch::Tensor torcha, torch::Tensor torchb, torch::Tensor torchbias, bool autotune = false ) { @@ -28,7 +50,7 @@ torch::Tensor shift_common( triton::dnn::shift shift(B, C, D, H, W, T, R, S, F, stride_h, stride_w, shift_h, shift_w, "fp32", "fp32", - ty, has_bias); + ty, has_bias, layout); // Bind memory triton::driver::cu_buffer a(ctx, (CUdeviceptr)torcha.storage().data(), false); triton::driver::cu_buffer b(ctx, (CUdeviceptr)torchb.storage().data(), false); @@ -56,10 +78,8 @@ torch::Tensor shift_y( CHECK_INPUT(x); CHECK_INPUT(w); // shapes for a - int64_t Ca = x.size(0); - int64_t H = x.size(1); - int64_t W = x.size(2); - int64_t B = x.size(3); + int64_t Ca, H, W, B; + extract_shapes(x, Ca, H, W, B, layout); // shapes for b int64_t Cb = w.size(0); int64_t F = w.size(1); @@ -68,7 +88,7 @@ torch::Tensor shift_y( // run return shift_common(B, C, 1, H, W, 1, R, S, F, stride_h, stride_w, (int32_t*)shift_h.storage().data(), (int32_t*)shift_w.storage().data(), - triton::dnn::shift::FPROP, x, w, bias); + triton::dnn::shift::FPROP, layout, x, w, bias); } torch::Tensor shift_dx( @@ -81,10 +101,8 @@ torch::Tensor shift_dx( CHECK_INPUT(dy); CHECK_INPUT(w); // shapes for a - int64_t Ca = dy.size(0); - int64_t H = dy.size(1); - int64_t W = dy.size(2); - int64_t B = dy.size(3); + int64_t Ca, H, W, B; + extract_shapes(dy, Ca, H, W, B, layout); H *= stride_h; W *= stride_w; // shapes for b @@ -98,7 +116,7 @@ torch::Tensor shift_dx( // run return shift_common(B, C, 1, H, W, 1, R, S, F, stride_h, stride_w, (int32_t*)shift_h.storage().data(), (int32_t*)shift_w.storage().data(), - triton::dnn::shift::BPROP, dy, w, bias); + triton::dnn::shift::BPROP, layout, dy, w, bias); } torch::Tensor shift_dw( @@ -111,15 +129,11 @@ torch::Tensor shift_dw( CHECK_INPUT(dy); CHECK_INPUT(x); // shapes for a - int64_t F = dy.size(0); - int64_t Ha = dy.size(1); - int64_t Wa = dy.size(2); - int64_t Ba = dy.size(3); + int64_t F, Ha, Wa, Ba; + extract_shapes(dy, F, Ha, Wa, Ba, layout); // shapes for b - int64_t C = x.size(0); - int64_t Hb = x.size(1); - int64_t Wb = x.size(2); - int64_t Bb = x.size(3); + int64_t C, Hb, Wb, Bb; + extract_shapes(x, C, Hb, Wb, Bb, layout); // check AT_CHECK(Ha*stride_h == Hb, "operands must have the same image height"); AT_CHECK(Wa*stride_w == Wb, "operands must have the same image width"); @@ -130,7 +144,7 @@ torch::Tensor shift_dw( // run return shift_common(B, C, 1, H, W, 1, R, S, F, stride_h, stride_w, (int32_t*)shift_h.storage().data(), (int32_t*)shift_w.storage().data(), - triton::dnn::shift::WGRAD, dy, x, bias); + triton::dnn::shift::WGRAD, layout, dy, x, bias); } static auto registry = diff --git a/examples/python/tensorflow/run.py b/examples/python/tensorflow/run.py index 375c45227..9de71d8a4 100644 --- a/examples/python/tensorflow/run.py +++ b/examples/python/tensorflow/run.py @@ -62,7 +62,7 @@ def run_shift(): R, S, F = 3, 3, 32 stride_h, stride_w = 2, 2 np.random.seed(2) - a = tf.placeholder(tf.float32, shape=[C, H, W, B]) + a = tf.placeholder(tf.float32, shape=[B, C, H, W]) b = tf.placeholder(tf.float32, shape=[C, F]) hshift_h = np.random.randint(- (R//2), R//2 + 1, size=C, dtype=np.int32) hshift_w = np.random.randint(- (S//2), R//2 + 1, size=C, dtype=np.int32) @@ -70,13 +70,13 @@ def run_shift(): #hshift_w = np.zeros(C, dtype=np.int32) c = module.shift_conv(a, b, stride_h=stride_h, stride_w=stride_w, shift_h=tf.make_tensor_proto(hshift_h), shift_w=tf.make_tensor_proto(hshift_w)) # feed values - ha = np.random.rand(C, H, W, B) + ha = np.random.rand(B, C, H, W) hb = np.random.rand(C, F) - #ha = np.ones((C, H, W, B), dtype=np.float32) + #ha = np.ones((B, C, H, W), dtype=np.float32) #hb = np.ones((C, F), dtype=np.float32) sess = tf.InteractiveSession() # test - grads = tf.test.compute_gradient([a, b], [(C, H, W, B), (C, F)], c, (F, H//stride_h, W//stride_w, B), + grads = tf.test.compute_gradient([a, b], [(B, C, H, W), (C, F)], c, (B, F, H//stride_h, W//stride_w), extra_feed_dict = {a: ha, b: hb}) dw_t, dw_n = grads[1] dx_t, dx_n = grads[0] diff --git a/examples/python/tensorflow/shift.cpp b/examples/python/tensorflow/shift.cpp index bde4d1b5e..d9014795e 100644 --- a/examples/python/tensorflow/shift.cpp +++ b/examples/python/tensorflow/shift.cpp @@ -22,7 +22,7 @@ using GPUDevice = Eigen::GpuDevice; template class ShiftConvOp : public OpKernel { public: - explicit ShiftConvOp(OpKernelConstruction* context) : OpKernel(context) { + explicit ShiftConvOp(OpKernelConstruction* context) : OpKernel(context), layout_(triton::dnn::shift::NCHW) { context->GetAttr("shift_h", &h_shift_h_); context->GetAttr("shift_w", &h_shift_w_); context->GetAttr("stride_h", &stride_h_); @@ -31,20 +31,32 @@ public: S_ = 3; } + void ExtractShapes(const Tensor &x, int64_t &C, int64_t &H, int64_t &W, int64_t &B) { + if(layout_ == triton::dnn::shift::CHWN){ + C = x.dim_size(0); + H = x.dim_size(1); + W = x.dim_size(2); + B = x.dim_size(3); + } + else if(layout_ == triton::dnn::shift::NCHW){ + B = x.dim_size(0); + C = x.dim_size(1); + H = x.dim_size(2); + W = x.dim_size(3); + } + else{ + throw std::runtime_error("unsupported layout"); + } + } + void FillShapes(OpKernelContext* context, int64_t &C, int64_t &H, int64_t &W, int64_t &B, int64_t &F, const Tensor& tf_a, const Tensor& tf_b) { if(OP == triton::dnn::shift::WGRAD) { - // shapes for a - F = tf_a.dim_size(0); - int64_t Ha = tf_a.dim_size(1); - int64_t Wa = tf_a.dim_size(2); - int64_t Ba = tf_a.dim_size(3); - // shapes for b - C = tf_b.dim_size(0); - int64_t Hb = tf_b.dim_size(1); - int64_t Wb = tf_b.dim_size(2); - int64_t Bb = tf_b.dim_size(3); + int64_t Ha, Wa, Ba; + int64_t Hb, Wb, Bb; + ExtractShapes(tf_a, F, Ha, Wa, Ba); + ExtractShapes(tf_b, C, Hb, Wb, Bb); OP_REQUIRES(context, Ha*stride_h_ == Hb, tensorflow::errors::InvalidArgument("operands must have the same image height")); OP_REQUIRES(context, Wa*stride_w_ == Wb, tensorflow::errors::InvalidArgument("operands must have the same image width")); OP_REQUIRES(context, Ba == Bb, tensorflow::errors::InvalidArgument("operands must have the same batch size")); @@ -54,10 +66,8 @@ public: } else { // shapes for a - int64_t Ca = tf_a.dim_size(0); - H = tf_a.dim_size(1); - W = tf_a.dim_size(2); - B = tf_a.dim_size(3); + int64_t Ca; + ExtractShapes(tf_a, Ca, H, W, B); if(OP == triton::dnn::shift::BPROP){ H *= stride_h_; W *= stride_w_; @@ -96,7 +106,7 @@ public: triton::dnn::shift shift(B, C, D, H, W, T, R_, S_, F, stride_h_, stride_w_, shift_h_data, shift_w_data, - "fp32", "fp32", OP, has_bias); + "fp32", "fp32", OP, has_bias, layout_); // shapes for c std::vector c_shapes; @@ -122,6 +132,7 @@ private: int stride_w_; int R_; int S_; + triton::dnn::shift::layout_t layout_; }; REGISTER_KERNEL_BUILDER(Name("ShiftConv").Device(DEVICE_GPU), ShiftConvOp); diff --git a/include/triton/dnn/shift.h b/include/triton/dnn/shift.h index b85ffe299..1731508d0 100644 --- a/include/triton/dnn/shift.h +++ b/include/triton/dnn/shift.h @@ -65,7 +65,7 @@ public: int stride_h, int stride_w, const int32_t* shift_h, const int32_t* shift_w, std::string a_ty = "fp32", std::string b_ty = "fp32", - type ty = FPROP, bool bias = false); + type ty = FPROP, bool bias = false, layout_t layout = CHWN); // look-up table void build_delta_a(); diff --git a/lib/dnn/shift.cpp b/lib/dnn/shift.cpp index 0bdcb49e2..da3b5877d 100644 --- a/lib/dnn/shift.cpp +++ b/lib/dnn/shift.cpp @@ -13,7 +13,8 @@ shift::shift(int B, int C, int stride_h, int stride_w, const int32_t *shift_h, const int32_t *shift_w, std::string a_ty, std::string b_ty, - type ty, bool bias) + type ty, bool bias, + layout_t layout) : base("shift"), B_(B), C_(C), AD_(D), AH_(H), AW_(W), @@ -23,7 +24,7 @@ shift::shift(int B, int C, shift_h_(shift_h), shift_w_(shift_w), a_ty_(a_ty), b_ty_(b_ty), ty_(ty), bias_(bias), - layout_(CHWN){ + layout_(layout){ // max number of channels TK_ = 16; MAX_C_ = 8192 + TK_;