From 0d8faa5b1e093b7059d60a3d8bd4d9da9fb429db Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Tue, 2 Jul 2019 21:38:10 -0700 Subject: [PATCH] fixup --- examples/python/tensorflow/shift.cpp | 73 +++++++++++++++++++--------- 1 file changed, 51 insertions(+), 22 deletions(-) diff --git a/examples/python/tensorflow/shift.cpp b/examples/python/tensorflow/shift.cpp index 2fe366de6..812912704 100644 --- a/examples/python/tensorflow/shift.cpp +++ b/examples/python/tensorflow/shift.cpp @@ -19,6 +19,7 @@ using namespace tensorflow; using GPUDevice = Eigen::GpuDevice; +template class ShiftConvOp : public OpKernel { public: explicit ShiftConvOp(OpKernelConstruction* context) : OpKernel(context) { @@ -28,7 +29,40 @@ public: S_ = 3; } - void ComputeCommon(OpKernelContext* context){ + void FillShapes(OpKernelContext* context, + int64_t &C, int64_t &H, int64_t &W, int64_t &B, int64_t &F, + const Tensor& tf_a, const Tensor& tf_b) { + if(OP == triton::dnn::shift::WGRAD) { + // shapes for a + F = tf_a.dim_size(0); + int64_t Ha = tf_a.dim_size(1); + int64_t Wa = tf_a.dim_size(2); + int64_t Ba = tf_a.dim_size(3); + // shapes for b + C = tf_b.dim_size(0); + int64_t Hb = tf_b.dim_size(1); + int64_t Wb = tf_b.dim_size(2); + int64_t Bb = tf_b.dim_size(3); + OP_REQUIRES(context, Ha == Hb, tensorflow::errors::InvalidArgument("operands must have the same image height")); + OP_REQUIRES(context, Wa == Wb, tensorflow::errors::InvalidArgument("operands must have the same image width")); + OP_REQUIRES(context, Ba == Bb, tensorflow::errors::InvalidArgument("operands must have the same batch size")); + H = Ha; + W = Wa; + B = Ba; + } + else { + // shapes for a + int64_t Ca = tf_a.dim_size(0); + H = tf_a.dim_size(1); + W = tf_a.dim_size(2); + B = tf_a.dim_size(3); + // shapes for b + int64_t Cb = tf_b.dim_size(0); + F = tf_b.dim_size(1); + // checks + OP_REQUIRES(context, Ca == Cb, tensorflow::errors::InvalidArgument("operands must have the same number of channels")); + C = Ca; + } } @@ -41,23 +75,24 @@ public: // get inputs const Tensor& tf_a = context->input(0); const Tensor& tf_b = context->input(1); - // shapes for a - int64_t Ca = tf_a.dim_size(0); - int64_t H = tf_a.dim_size(1); - int64_t W = tf_a.dim_size(2); - int64_t B = tf_a.dim_size(3); - // shapes for b - int64_t Cb = tf_b.dim_size(0); - int64_t F = tf_b.dim_size(1); - // checks - OP_REQUIRES(context, Ca == Cb, tensorflow::errors::InvalidArgument("operands must have the same number of channels")); - int64_t C = Ca; + // shapes + int64_t C, H, W, B, F; + FillShapes(context, C, H, W, B, F, tf_a, tf_b); + // shift configuration + int32_t* shift_h_data = h_shift_h_.flat().data(); + int32_t* shift_w_data = h_shift_w_.flat().data(); + std::vector shift_h(shift_h_data, shift_h_data + C); + std::vector shift_w(shift_w_data, shift_w_data + C); + triton::dnn::shift shift(B, C, 1, H, W, 1, R_, S_, F, shift_h, shift_w, "fp32", "fp32", OP, false); // shapes for c + std::vector c_shapes; + for(int32_t x: shift.c_shapes()) + c_shapes.push_back(x); + TensorShape out_shapes(c_shapes); Tensor* tf_c = nullptr; - TensorShape out_shape({Ca, H, W, B}); - OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &tf_c)); + OP_REQUIRES_OK(context, context->allocate_output(0, out_shapes, &tf_c)); // return early if possible - if (out_shape.num_elements() == 0) + if (out_shapes.num_elements() == 0) return; // initialize default compute device triton::jit jit(ctx); @@ -65,12 +100,6 @@ public: triton::driver::cu_buffer da(ctx, (CUdeviceptr)tf_a.flat().data(), false); triton::driver::cu_buffer db(ctx, (CUdeviceptr)tf_b.flat().data(), false); triton::driver::cu_buffer dc(ctx, (CUdeviceptr)tf_c->flat().data(), false); - // shift configuration - int32_t* shift_h_data = h_shift_h_.flat().data(); - int32_t* shift_w_data = h_shift_w_.flat().data(); - std::vector shift_h(shift_h_data, shift_h_data + C); - std::vector shift_w(shift_w_data, shift_w_data + C); - triton::dnn::shift shift(B, C, 1, H, W, 1, R_, S_, F, shift_h, shift_w, "fp32", "fp32", triton::dnn::shift::FPROP, false); // benchmark a given matrix multiplication kernel auto benchmark = [&](triton::driver::kernel* kernel, triton::jit::launch_information info) { @@ -101,7 +130,7 @@ private: int S_; }; -REGISTER_KERNEL_BUILDER(Name("ShiftConv").Device(DEVICE_GPU), ShiftConvOp); +REGISTER_KERNEL_BUILDER(Name("ShiftConv").Device(DEVICE_GPU), ShiftConvOp); REGISTER_OP("ShiftConv") .Input("a: float32") .Input("b: float32")