diff --git a/examples/python/tensorflow/CMakeLists.txt b/examples/python/tensorflow/CMakeLists.txt index bfd54f6a6..08ba828e3 100644 --- a/examples/python/tensorflow/CMakeLists.txt +++ b/examples/python/tensorflow/CMakeLists.txt @@ -5,7 +5,7 @@ if(${TensorFlow_FOUND}) include_directories("${CUDA_HOME}/include") link_directories(${TF_LIB}) add_definitions(-D_GLIBCXX_USE_CXX11_ABI=${TF_ABI}) - add_library(tf_blocksparse SHARED dot.cpp conv2d.cpp) + add_library(tf_blocksparse SHARED dot.cpp conv2d.cpp shift.cpp) target_link_libraries(tf_blocksparse tensorflow_framework triton) configure_file(${CMAKE_CURRENT_SOURCE_DIR}/run.py ${CMAKE_CURRENT_BINARY_DIR}/run.py diff --git a/examples/python/tensorflow/run.py b/examples/python/tensorflow/run.py index dca626b1a..63acfdf2a 100644 --- a/examples/python/tensorflow/run.py +++ b/examples/python/tensorflow/run.py @@ -28,13 +28,13 @@ def run_dot(): print("dif: %f" % np.max(dif)) def run_conv(): - BS, C, H, W = 16, 32, 32, 32 + B, C, H, W = 16, 32, 32, 32 R, S, NF = 3, 3, 32 - a = tf.placeholder(tf.float32, shape=[BS, C, H, W]) + a = tf.placeholder(tf.float32, shape=[B, C, H, W]) b = tf.placeholder(tf.float32, shape=[C, R, S, NF]) c = module.conv2d(a, b) # Reference - ha = np.random.rand(BS, C, H, W) + ha = np.random.rand(B, C, H, W) hb = np.random.rand(C, R, S, NF) # Run sess = tf.InteractiveSession() @@ -42,4 +42,23 @@ def run_conv(): result = sess.run([c], feed_dict = {a: ha, b: hb})[0] -run_conv() +def run_shift(): + B, C, H, W = 16, 32, 32, 32 + R, S, F = 3, 3, 32 + a = tf.placeholder(tf.float32, shape=[C, H, W, B]) + b = tf.placeholder(tf.float32, shape=[C, F]) + shift_h = tf.zeros(C, tf.int32) + shift_w = tf.zeros(C, tf.int32) + hshift_h = np.zeros(C, np.int32) + hshift_w = np.zeros(C, np.int32) + c = module.shift_conv(a, b, shift_h=tf.make_tensor_proto(hshift_h), shift_w=tf.make_tensor_proto(hshift_w)) + # Reference + ha = np.random.rand(C, H, W, B) + hb = np.random.rand(C, F) + # Run + sess = tf.InteractiveSession() + sess.run(tf.global_variables_initializer()) + result = sess.run([c], feed_dict = {a: ha, + b: hb})[0] + +run_shift() diff --git a/examples/python/tensorflow/shift.cpp b/examples/python/tensorflow/shift.cpp new file mode 100644 index 000000000..2fe366de6 --- /dev/null +++ b/examples/python/tensorflow/shift.cpp @@ -0,0 +1,111 @@ +#include + +#include "triton/driver/buffer.h" +#include "triton/driver/backend.h" +#include "triton/driver/stream.h" +#include "triton/runtime/jit.h" +#include "triton/tools/bench.hpp" +#include "triton/dnn/shift.h" + +#define EIGEN_USE_GPU +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/util/cuda_kernel_helper.h" +#include "tensorflow/core/util/padding.h" +#include "tensorflow/core/util/tensor_format.h" +#include "tensorflow/core/framework/common_shape_fns.h" + +using namespace tensorflow; +using GPUDevice = Eigen::GpuDevice; + +class ShiftConvOp : public OpKernel { +public: + explicit ShiftConvOp(OpKernelConstruction* context) : OpKernel(context) { + context->GetAttr("shift_h", &h_shift_h_); + context->GetAttr("shift_w", &h_shift_w_); + R_ = 3; + S_ = 3; + } + + void ComputeCommon(OpKernelContext* context){ + + } + + void Compute(OpKernelContext* context){ + // get device/stream + GPUDevice device = context->eigen_device(); + triton::driver::cu_stream sstream(device.stream(), false); + triton::driver::context* ctx = sstream.context(); + triton::driver::stream* stream = &sstream; + // get inputs + const Tensor& tf_a = context->input(0); + const Tensor& tf_b = context->input(1); + // shapes for a + int64_t Ca = tf_a.dim_size(0); + int64_t H = tf_a.dim_size(1); + int64_t W = tf_a.dim_size(2); + int64_t B = tf_a.dim_size(3); + // shapes for b + int64_t Cb = tf_b.dim_size(0); + int64_t F = tf_b.dim_size(1); + // checks + OP_REQUIRES(context, Ca == Cb, tensorflow::errors::InvalidArgument("operands must have the same number of channels")); + int64_t C = Ca; + // shapes for c + Tensor* tf_c = nullptr; + TensorShape out_shape({Ca, H, W, B}); + OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &tf_c)); + // return early if possible + if (out_shape.num_elements() == 0) + return; + // initialize default compute device + triton::jit jit(ctx); + // matrix multiplication parameters + triton::driver::cu_buffer da(ctx, (CUdeviceptr)tf_a.flat().data(), false); + triton::driver::cu_buffer db(ctx, (CUdeviceptr)tf_b.flat().data(), false); + triton::driver::cu_buffer dc(ctx, (CUdeviceptr)tf_c->flat().data(), false); + // shift configuration + int32_t* shift_h_data = h_shift_h_.flat().data(); + int32_t* shift_w_data = h_shift_w_.flat().data(); + std::vector shift_h(shift_h_data, shift_h_data + C); + std::vector shift_w(shift_w_data, shift_w_data + C); + triton::dnn::shift shift(B, C, 1, H, W, 1, R_, S_, F, shift_h, shift_w, "fp32", "fp32", triton::dnn::shift::FPROP, false); + // benchmark a given matrix multiplication kernel + auto benchmark = [&](triton::driver::kernel* kernel, + triton::jit::launch_information info) { + // launch info + unsigned TM = info.global_range_size[0]; + unsigned TN = info.global_range_size[1]; + unsigned nthreads = info.num_threads; + shift.init(stream, (triton::driver::cu_module*)kernel->module()); + shift.enqueue(stream, kernel, &da, &db, &dc, TM, TN, nthreads); + stream->synchronize(); + double ts = triton::tools::bench([&](){ shift.enqueue(stream, kernel, &da, &db, &dc, TM, TN, nthreads); }, + [&](){ stream->synchronize(); }, ctx->device()); + return shift.get_nflops() / ts * 1e-3; + }; + + std::ostringstream oss; + shift.src(oss); + std::string src = oss.str(); + triton::jit::tune_res_t best = jit.autotune("shift", src.c_str(), benchmark); + } + +private: + Tensor h_shift_h_; + Tensor h_shift_w_; +// triton::driver::buffer* d_shift_h_; +// triton::driver::buffer* d_shift_w_; + int R_; + int S_; +}; + +REGISTER_KERNEL_BUILDER(Name("ShiftConv").Device(DEVICE_GPU), ShiftConvOp); +REGISTER_OP("ShiftConv") + .Input("a: float32") + .Input("b: float32") + .Attr("shift_h: tensor") + .Attr("shift_w: tensor") + .Output("c: float32") +;