This commit is contained in:
Philippe Tillet
2019-07-03 19:25:16 -07:00
parent 0d8faa5b1e
commit 1d88f0a36b
7 changed files with 194 additions and 61 deletions

View File

@@ -1,7 +1,9 @@
import os
import tensorflow as tf
from tensorflow.python.framework import ops
import numpy as np
from time import time
data_files_path = tf.resource_loader.get_data_files_path()
library_dir = os.path.dirname(os.path.realpath(__file__))
module = tf.load_op_library(os.path.join(library_dir, 'libtf_blocksparse.so'))
@@ -42,23 +44,45 @@ def run_conv():
result = sess.run([c], feed_dict = {a: ha,
b: hb})[0]
@ops.RegisterGradient('ShiftConv')
def blocksparse_matmul_grad(op, dy):
shift_h = op.get_attr('shift_h')
shift_w = op.get_attr('shift_w')
x = op.inputs[0]
w = op.inputs[1]
dx = module.shift_conv_dx(dy, w, shift_h=shift_h, shift_w=shift_w)
dw = module.shift_conv_dw(dy, x, shift_h=shift_h, shift_w=shift_w)
return (dx, dw)
def run_shift():
B, C, H, W = 16, 32, 32, 32
R, S, F = 3, 3, 32
B, C, H, W = 1, 16, 8, 8
R, S, F = 3, 3, 16
a = tf.placeholder(tf.float32, shape=[C, H, W, B])
b = tf.placeholder(tf.float32, shape=[C, F])
shift_h = tf.zeros(C, tf.int32)
shift_w = tf.zeros(C, tf.int32)
hshift_h = np.zeros(C, np.int32)
hshift_w = np.zeros(C, np.int32)
#hshift_h = np.random.randint(-R//2, R//2 + 1, size=C, dtype=np.int32)
#hshift_w = np.random.randint(-S//2, R//2 + 1, size=C, dtype=np.int32)
hshift_h = 0*np.ones(C, dtype=np.int32)
hshift_w = 0*np.ones(C, dtype=np.int32)
c = module.shift_conv(a, b, shift_h=tf.make_tensor_proto(hshift_h), shift_w=tf.make_tensor_proto(hshift_w))
# Reference
ha = np.random.rand(C, H, W, B)
hb = np.random.rand(C, F)
# Run
ha = np.ones((C, H, W, B), dtype=np.int32)
hb = np.ones((C, F), dtype=np.int32)
sess = tf.InteractiveSession()
grads = tf.test.compute_gradient([a, b], [(C, H, W, B), (C, F)], c, (C, H, W, B),
extra_feed_dict={a: ha, b: hb})
dx_t, dx_n = grads[0]
dw_t, dw_n = grads[1]
print(dw_t)
print(dw_n)
#print(np.max(dw_t - dw_n))
#print(np.max(dx_t - dx_n))
np.savetxt('theoretical.dat', dw_t, fmt='%4.2f')
np.savetxt('numerical.dat', dw_n, fmt='%4.2f')
# Run
sess.run(tf.global_variables_initializer())
result = sess.run([c], feed_dict = {a: ha,
b: hb})[0]
#print(result)
run_shift()

View File

@@ -19,6 +19,15 @@
using namespace tensorflow;
using GPUDevice = Eigen::GpuDevice;
typedef std::tuple<int32_t, int32_t, int32_t, int32_t, int32_t,
int32_t, int32_t, int32_t, int32_t,
int32_t*, int32_t*,
triton::dnn::shift::type, bool> shift_key_t;
static std::map<CUstream, std::unique_ptr<triton::driver::stream>> m_stream;
static std::map<shift_key_t, std::unique_ptr<triton::jit>> m_jit;
static std::map<shift_key_t, std::unique_ptr<triton::dnn::shift>> m_config;
template<triton::dnn::shift::type OP>
class ShiftConvOp : public OpKernel {
public:
@@ -78,15 +87,27 @@ public:
// shapes
int64_t C, H, W, B, F;
FillShapes(context, C, H, W, B, F, tf_a, tf_b);
int64_t D = 1, T = 1;
bool has_bias = false;
// shift configuration
int32_t* shift_h_data = h_shift_h_.flat<int32_t>().data();
int32_t* shift_w_data = h_shift_w_.flat<int32_t>().data();
std::vector<int32_t> shift_h(shift_h_data, shift_h_data + C);
std::vector<int32_t> shift_w(shift_w_data, shift_w_data + C);
triton::dnn::shift shift(B, C, 1, H, W, 1, R_, S_, F, shift_h, shift_w, "fp32", "fp32", OP, false);
shift_key_t key = {B, C, 1, H, W, 1, R_, S_, F, shift_h_data, shift_w_data, OP, has_bias};
// create configuration
triton::dnn::shift* shift;
if(m_config.find(key) == m_config.end())
shift = m_config.emplace(key, new triton::dnn::shift(
B, C, D, H, W, T, R_, S_, F,
shift_h, shift_w, "fp32", "fp32", OP, has_bias))
.first->second.get();
else
shift = m_config.at(key).get();
// shapes for c
std::vector<int64> c_shapes;
for(int32_t x: shift.c_shapes())
for(int32_t x: shift->c_shapes())
c_shapes.push_back(x);
TensorShape out_shapes(c_shapes);
Tensor* tf_c = nullptr;
@@ -94,38 +115,58 @@ public:
// return early if possible
if (out_shapes.num_elements() == 0)
return;
// initialize default compute device
triton::jit jit(ctx);
// matrix multiplication parameters
triton::driver::cu_buffer da(ctx, (CUdeviceptr)tf_a.flat<float>().data(), false);
triton::driver::cu_buffer db(ctx, (CUdeviceptr)tf_b.flat<float>().data(), false);
triton::driver::cu_buffer dc(ctx, (CUdeviceptr)tf_c->flat<float>().data(), false);
// benchmark a given matrix multiplication kernel
auto benchmark = [&](triton::driver::kernel* kernel,
triton::jit::launch_information info) {
// launch info
unsigned TM = info.global_range_size[0];
unsigned TN = info.global_range_size[1];
unsigned nthreads = info.num_threads;
shift.init(stream, (triton::driver::cu_module*)kernel->module());
shift.enqueue(stream, kernel, &da, &db, &dc, TM, TN, nthreads);
stream->synchronize();
double ts = triton::tools::bench([&](){ shift.enqueue(stream, kernel, &da, &db, &dc, TM, TN, nthreads); },
[&](){ stream->synchronize(); }, ctx->device());
return shift.get_nflops() / ts * 1e-3;
};
std::ostringstream oss;
shift.src(oss);
std::string src = oss.str();
triton::jit::tune_res_t best = jit.autotune("shift", src.c_str(), benchmark);
// get JIT
triton::jit* jit;
bool autotune = false;
if(m_jit.find(key) == m_jit.end()) {
jit = m_jit.emplace(key, new triton::jit(ctx)).first->second.get();
std::ostringstream oss;
shift->src(oss);
std::string src = oss.str();
auto benchmark = [&](triton::driver::kernel* kernel,
triton::jit::launch_information info) {
// launch info
unsigned TM = info.global_range_size[0];
unsigned TN = info.global_range_size[1];
unsigned nthreads = info.num_threads;
shift->init(stream, (triton::driver::cu_module*)kernel->module());
shift->enqueue(stream, kernel, &da, &db, &dc, TM, TN, nthreads);
stream->synchronize();
double ts = triton::tools::bench([&](){ shift->enqueue(stream, kernel, &da, &db, &dc, TM, TN, nthreads); },
[&](){ stream->synchronize(); }, ctx->device());
return shift->get_nflops() / ts * 1e-3;
};
// auto-tune and save result
if(autotune) {
triton::jit::tune_res_t best = jit->autotune("shift", src.c_str(), benchmark);
jit->add_module("shift", src.c_str(), best.params);
}
else {
jit->add_module("shift", src.c_str(), jit->get_valid("shift", src.c_str()));
}
triton::driver::kernel* kernel = jit->get_function("shift");
shift->init(stream, (triton::driver::cu_module*)kernel->module());
}
else
jit = m_jit.at(key).get();
// Run
triton::driver::kernel* kernel = jit->get_function("shift");
triton::jit::launch_information info = jit->get_launch_info("shift");
// launch info
unsigned TM = info.global_range_size[0];
unsigned TN = info.global_range_size[1];
unsigned nthreads = info.num_threads;
// enqueue
shift->enqueue(stream, kernel, &da, &db, &dc, TM, TN, nthreads);
}
private:
Tensor h_shift_h_;
Tensor h_shift_w_;
// triton::driver::buffer* d_shift_h_;
// triton::driver::buffer* d_shift_w_;
int R_;
int S_;
};
@@ -136,5 +177,21 @@ REGISTER_OP("ShiftConv")
.Input("b: float32")
.Attr("shift_h: tensor")
.Attr("shift_w: tensor")
.Output("c: float32")
;
.Output("c: float32");
REGISTER_KERNEL_BUILDER(Name("ShiftConvDx").Device(DEVICE_GPU), ShiftConvOp<triton::dnn::shift::BPROP>);
REGISTER_OP("ShiftConvDx")
.Input("a: float32")
.Input("b: float32")
.Attr("shift_h: tensor")
.Attr("shift_w: tensor")
.Output("c: float32");
REGISTER_KERNEL_BUILDER(Name("ShiftConvDw").Device(DEVICE_GPU), ShiftConvOp<triton::dnn::shift::WGRAD>);
REGISTER_OP("ShiftConvDw")
.Input("a: float32")
.Input("b: float32")
.Attr("shift_h: tensor")
.Attr("shift_w: tensor")
.Output("c: float32");