This commit is contained in:
Philippe Tillet
2019-07-03 19:25:16 -07:00
parent 0d8faa5b1e
commit 1d88f0a36b
7 changed files with 194 additions and 61 deletions

View File

@@ -74,12 +74,12 @@ int main() {
// shift // shift
std::vector<unsigned> params = { std::vector<unsigned> params = {
16, 4, 64, 16, 4, 128, 2, 2, 1, 2, 4, 4, 16, 4 4, 2, 16, 4, 128, 2, 2, 1, 1, 8, 16, 8, 2
}; };
std::ostringstream oss; std::ostringstream oss;
shift.src(oss); shift.src(oss);
std::string src = oss.str(); std::string src = oss.str();
jit.autotune("shift", src.c_str(), benchmark); // jit.autotune("shift", src.c_str(), benchmark);
jit.add_module("shift", src.c_str(), params); jit.add_module("shift", src.c_str(), params);
triton::driver::kernel* kernel = jit.get_function("shift"); triton::driver::kernel* kernel = jit.get_function("shift");
triton::jit::launch_information info = jit.get_launch_info("shift"); triton::jit::launch_information info = jit.get_launch_info("shift");

View File

@@ -1,7 +1,9 @@
import os import os
import tensorflow as tf import tensorflow as tf
from tensorflow.python.framework import ops
import numpy as np import numpy as np
from time import time from time import time
data_files_path = tf.resource_loader.get_data_files_path() data_files_path = tf.resource_loader.get_data_files_path()
library_dir = os.path.dirname(os.path.realpath(__file__)) library_dir = os.path.dirname(os.path.realpath(__file__))
module = tf.load_op_library(os.path.join(library_dir, 'libtf_blocksparse.so')) module = tf.load_op_library(os.path.join(library_dir, 'libtf_blocksparse.so'))
@@ -42,23 +44,45 @@ def run_conv():
result = sess.run([c], feed_dict = {a: ha, result = sess.run([c], feed_dict = {a: ha,
b: hb})[0] b: hb})[0]
@ops.RegisterGradient('ShiftConv')
def blocksparse_matmul_grad(op, dy):
shift_h = op.get_attr('shift_h')
shift_w = op.get_attr('shift_w')
x = op.inputs[0]
w = op.inputs[1]
dx = module.shift_conv_dx(dy, w, shift_h=shift_h, shift_w=shift_w)
dw = module.shift_conv_dw(dy, x, shift_h=shift_h, shift_w=shift_w)
return (dx, dw)
def run_shift(): def run_shift():
B, C, H, W = 16, 32, 32, 32 B, C, H, W = 1, 16, 8, 8
R, S, F = 3, 3, 32 R, S, F = 3, 3, 16
a = tf.placeholder(tf.float32, shape=[C, H, W, B]) a = tf.placeholder(tf.float32, shape=[C, H, W, B])
b = tf.placeholder(tf.float32, shape=[C, F]) b = tf.placeholder(tf.float32, shape=[C, F])
shift_h = tf.zeros(C, tf.int32) #hshift_h = np.random.randint(-R//2, R//2 + 1, size=C, dtype=np.int32)
shift_w = tf.zeros(C, tf.int32) #hshift_w = np.random.randint(-S//2, R//2 + 1, size=C, dtype=np.int32)
hshift_h = np.zeros(C, np.int32) hshift_h = 0*np.ones(C, dtype=np.int32)
hshift_w = np.zeros(C, np.int32) hshift_w = 0*np.ones(C, dtype=np.int32)
c = module.shift_conv(a, b, shift_h=tf.make_tensor_proto(hshift_h), shift_w=tf.make_tensor_proto(hshift_w)) c = module.shift_conv(a, b, shift_h=tf.make_tensor_proto(hshift_h), shift_w=tf.make_tensor_proto(hshift_w))
# Reference # Reference
ha = np.random.rand(C, H, W, B) ha = np.ones((C, H, W, B), dtype=np.int32)
hb = np.random.rand(C, F) hb = np.ones((C, F), dtype=np.int32)
# Run
sess = tf.InteractiveSession() sess = tf.InteractiveSession()
grads = tf.test.compute_gradient([a, b], [(C, H, W, B), (C, F)], c, (C, H, W, B),
extra_feed_dict={a: ha, b: hb})
dx_t, dx_n = grads[0]
dw_t, dw_n = grads[1]
print(dw_t)
print(dw_n)
#print(np.max(dw_t - dw_n))
#print(np.max(dx_t - dx_n))
np.savetxt('theoretical.dat', dw_t, fmt='%4.2f')
np.savetxt('numerical.dat', dw_n, fmt='%4.2f')
# Run
sess.run(tf.global_variables_initializer()) sess.run(tf.global_variables_initializer())
result = sess.run([c], feed_dict = {a: ha, result = sess.run([c], feed_dict = {a: ha,
b: hb})[0] b: hb})[0]
#print(result)
run_shift() run_shift()

View File

@@ -19,6 +19,15 @@
using namespace tensorflow; using namespace tensorflow;
using GPUDevice = Eigen::GpuDevice; using GPUDevice = Eigen::GpuDevice;
typedef std::tuple<int32_t, int32_t, int32_t, int32_t, int32_t,
int32_t, int32_t, int32_t, int32_t,
int32_t*, int32_t*,
triton::dnn::shift::type, bool> shift_key_t;
static std::map<CUstream, std::unique_ptr<triton::driver::stream>> m_stream;
static std::map<shift_key_t, std::unique_ptr<triton::jit>> m_jit;
static std::map<shift_key_t, std::unique_ptr<triton::dnn::shift>> m_config;
template<triton::dnn::shift::type OP> template<triton::dnn::shift::type OP>
class ShiftConvOp : public OpKernel { class ShiftConvOp : public OpKernel {
public: public:
@@ -78,15 +87,27 @@ public:
// shapes // shapes
int64_t C, H, W, B, F; int64_t C, H, W, B, F;
FillShapes(context, C, H, W, B, F, tf_a, tf_b); FillShapes(context, C, H, W, B, F, tf_a, tf_b);
int64_t D = 1, T = 1;
bool has_bias = false;
// shift configuration // shift configuration
int32_t* shift_h_data = h_shift_h_.flat<int32_t>().data(); int32_t* shift_h_data = h_shift_h_.flat<int32_t>().data();
int32_t* shift_w_data = h_shift_w_.flat<int32_t>().data(); int32_t* shift_w_data = h_shift_w_.flat<int32_t>().data();
std::vector<int32_t> shift_h(shift_h_data, shift_h_data + C); std::vector<int32_t> shift_h(shift_h_data, shift_h_data + C);
std::vector<int32_t> shift_w(shift_w_data, shift_w_data + C); std::vector<int32_t> shift_w(shift_w_data, shift_w_data + C);
triton::dnn::shift shift(B, C, 1, H, W, 1, R_, S_, F, shift_h, shift_w, "fp32", "fp32", OP, false); shift_key_t key = {B, C, 1, H, W, 1, R_, S_, F, shift_h_data, shift_w_data, OP, has_bias};
// create configuration
triton::dnn::shift* shift;
if(m_config.find(key) == m_config.end())
shift = m_config.emplace(key, new triton::dnn::shift(
B, C, D, H, W, T, R_, S_, F,
shift_h, shift_w, "fp32", "fp32", OP, has_bias))
.first->second.get();
else
shift = m_config.at(key).get();
// shapes for c // shapes for c
std::vector<int64> c_shapes; std::vector<int64> c_shapes;
for(int32_t x: shift.c_shapes()) for(int32_t x: shift->c_shapes())
c_shapes.push_back(x); c_shapes.push_back(x);
TensorShape out_shapes(c_shapes); TensorShape out_shapes(c_shapes);
Tensor* tf_c = nullptr; Tensor* tf_c = nullptr;
@@ -94,38 +115,58 @@ public:
// return early if possible // return early if possible
if (out_shapes.num_elements() == 0) if (out_shapes.num_elements() == 0)
return; return;
// initialize default compute device
triton::jit jit(ctx);
// matrix multiplication parameters // matrix multiplication parameters
triton::driver::cu_buffer da(ctx, (CUdeviceptr)tf_a.flat<float>().data(), false); triton::driver::cu_buffer da(ctx, (CUdeviceptr)tf_a.flat<float>().data(), false);
triton::driver::cu_buffer db(ctx, (CUdeviceptr)tf_b.flat<float>().data(), false); triton::driver::cu_buffer db(ctx, (CUdeviceptr)tf_b.flat<float>().data(), false);
triton::driver::cu_buffer dc(ctx, (CUdeviceptr)tf_c->flat<float>().data(), false); triton::driver::cu_buffer dc(ctx, (CUdeviceptr)tf_c->flat<float>().data(), false);
// benchmark a given matrix multiplication kernel // get JIT
auto benchmark = [&](triton::driver::kernel* kernel, triton::jit* jit;
triton::jit::launch_information info) { bool autotune = false;
// launch info if(m_jit.find(key) == m_jit.end()) {
unsigned TM = info.global_range_size[0]; jit = m_jit.emplace(key, new triton::jit(ctx)).first->second.get();
unsigned TN = info.global_range_size[1]; std::ostringstream oss;
unsigned nthreads = info.num_threads; shift->src(oss);
shift.init(stream, (triton::driver::cu_module*)kernel->module()); std::string src = oss.str();
shift.enqueue(stream, kernel, &da, &db, &dc, TM, TN, nthreads); auto benchmark = [&](triton::driver::kernel* kernel,
stream->synchronize(); triton::jit::launch_information info) {
double ts = triton::tools::bench([&](){ shift.enqueue(stream, kernel, &da, &db, &dc, TM, TN, nthreads); }, // launch info
[&](){ stream->synchronize(); }, ctx->device()); unsigned TM = info.global_range_size[0];
return shift.get_nflops() / ts * 1e-3; unsigned TN = info.global_range_size[1];
}; unsigned nthreads = info.num_threads;
shift->init(stream, (triton::driver::cu_module*)kernel->module());
std::ostringstream oss; shift->enqueue(stream, kernel, &da, &db, &dc, TM, TN, nthreads);
shift.src(oss); stream->synchronize();
std::string src = oss.str(); double ts = triton::tools::bench([&](){ shift->enqueue(stream, kernel, &da, &db, &dc, TM, TN, nthreads); },
triton::jit::tune_res_t best = jit.autotune("shift", src.c_str(), benchmark); [&](){ stream->synchronize(); }, ctx->device());
return shift->get_nflops() / ts * 1e-3;
};
// auto-tune and save result
if(autotune) {
triton::jit::tune_res_t best = jit->autotune("shift", src.c_str(), benchmark);
jit->add_module("shift", src.c_str(), best.params);
}
else {
jit->add_module("shift", src.c_str(), jit->get_valid("shift", src.c_str()));
}
triton::driver::kernel* kernel = jit->get_function("shift");
shift->init(stream, (triton::driver::cu_module*)kernel->module());
}
else
jit = m_jit.at(key).get();
// Run
triton::driver::kernel* kernel = jit->get_function("shift");
triton::jit::launch_information info = jit->get_launch_info("shift");
// launch info
unsigned TM = info.global_range_size[0];
unsigned TN = info.global_range_size[1];
unsigned nthreads = info.num_threads;
// enqueue
shift->enqueue(stream, kernel, &da, &db, &dc, TM, TN, nthreads);
} }
private: private:
Tensor h_shift_h_; Tensor h_shift_h_;
Tensor h_shift_w_; Tensor h_shift_w_;
// triton::driver::buffer* d_shift_h_;
// triton::driver::buffer* d_shift_w_;
int R_; int R_;
int S_; int S_;
}; };
@@ -136,5 +177,21 @@ REGISTER_OP("ShiftConv")
.Input("b: float32") .Input("b: float32")
.Attr("shift_h: tensor") .Attr("shift_h: tensor")
.Attr("shift_w: tensor") .Attr("shift_w: tensor")
.Output("c: float32") .Output("c: float32");
;
REGISTER_KERNEL_BUILDER(Name("ShiftConvDx").Device(DEVICE_GPU), ShiftConvOp<triton::dnn::shift::BPROP>);
REGISTER_OP("ShiftConvDx")
.Input("a: float32")
.Input("b: float32")
.Attr("shift_h: tensor")
.Attr("shift_w: tensor")
.Output("c: float32");
REGISTER_KERNEL_BUILDER(Name("ShiftConvDw").Device(DEVICE_GPU), ShiftConvOp<triton::dnn::shift::WGRAD>);
REGISTER_OP("ShiftConvDw")
.Input("a: float32")
.Input("b: float32")
.Attr("shift_h: tensor")
.Attr("shift_w: tensor")
.Output("c: float32");

View File

@@ -103,6 +103,7 @@ private:
public: public:
jit(driver::context* context); jit(driver::context* context);
~jit(); ~jit();
std::vector<unsigned> get_valid(const char *name, const char *src);
tune_res_t autotune(const char* name, const char* src, benchmark_t benchmark); tune_res_t autotune(const char* name, const char* src, benchmark_t benchmark);
void add_module(ir::module &module, const std::vector<unsigned>& params = {}); void add_module(ir::module &module, const std::vector<unsigned>& params = {});
void add_module(const char* name, const char* src, const std::vector<unsigned>& params = {}); void add_module(const char* name, const char* src, const std::vector<unsigned>& params = {});

View File

@@ -70,16 +70,26 @@ shift::shift(int B, int C,
} }
void shift::build_deltas() { void shift::build_deltas() {
// compute offset
auto offset = [&](unsigned c) {
return c*ld_a_[0] + shift_h_[c]*ld_a_[1] + shift_w_[c]*ld_a_[2];
};
h_deltas_.resize(MAX_C_); h_deltas_.resize(MAX_C_);
// populate look-up table if(ty_ == FPROP){
for(unsigned c = 0; c < TK_; c++) // compute offset
h_deltas_[c] = offset(c); auto offset = [&](unsigned c) {
for(unsigned c = 0; c < C_; c++) return c*ld_a_[0] + shift_h_[c]*ld_a_[1] + shift_w_[c]*ld_a_[2];
h_deltas_[TK_ + c] = offset(c + TK_) - offset(c); };
// populate look-up table
for(unsigned c = 0; c < TK_; c++)
h_deltas_[c] = offset(c);
for(unsigned c = 0; c < C_; c++)
h_deltas_[TK_ + c] = offset(c + TK_) - offset(c);
}
if(ty_ == BPROP){
for(unsigned c = 0; c < C_; c++)
h_deltas_[c] = shift_h_[c]*ld_c_[1] + shift_w_[c]*ld_c_[2];
}
if(ty_ == WGRAD){
for(unsigned c = 0; c < C_; c++)
h_deltas_[c] = shift_h_[c]*ld_b_[1] + shift_w_[c]*ld_b_[2];
}
} }
size_t shift::a_size(){ size_t shift::a_size(){
@@ -102,7 +112,7 @@ std::vector<int32_t> shift::c_shapes(){
} }
size_t shift::get_nflops() { size_t shift::get_nflops() {
return 2. * M_ * N_ * K_; return 2.*M_*N_*K_;
} }
@@ -114,15 +124,13 @@ void shift::init(driver::stream *stream, driver::cu_module *module) {
void shift::enqueue(driver::stream *stream, driver::kernel *kernel, void shift::enqueue(driver::stream *stream, driver::kernel *kernel,
driver::buffer *a, driver::buffer *b, driver::buffer *c, driver::buffer *a, driver::buffer *b, driver::buffer *c,
size_t TM, size_t TN, size_t nthreads) { size_t TM, size_t TN, size_t nthreads) {
if(ty_ == WGRAD)
std::swap(a, b);
kernel->setArg(0, a); kernel->setArg(0, a);
kernel->setArg(1, b); kernel->setArg(1, b);
kernel->setArg(2, c); kernel->setArg(2, c);
kernel->setArg(3, M_); kernel->setArg(3, M_);
kernel->setArg(4, N_); kernel->setArg(4, N_);
kernel->setArg(5, K_); kernel->setArg(5, K_);
kernel->setArg(6, B_*AH_*AW_); kernel->setArg(6, M_);
kernel->setArg(7, N_); kernel->setArg(7, N_);
kernel->setArg(8, B_); kernel->setArg(8, B_);
kernel->setArg(9, AH_); kernel->setArg(9, AH_);
@@ -177,7 +185,7 @@ void shift(restrict read_only align(16) )" << a_ty_ << R"( *a,
restrict read_only align(16) )" << b_ty_ << R"( *b, restrict read_only align(16) )" << b_ty_ << R"( *b,
fp32 *c, fp32 *c,
int32 M, int32 N, int32 K, int32 M, int32 N, int32 K,
multiple_of(4) int32 lda, multiple_of(4) int32 ldb, int32 lda, int32 ldb,
int32 ABS, int32 AH, int32 AW, int32 AR, int32 AS) { int32 ABS, int32 AH, int32 AW, int32 AR, int32 AS) {
int32 rxa[TM] = get_global_range[TM](0); int32 rxa[TM] = get_global_range[TM](0);
int32 ryb[TN] = get_global_range[TN](1); int32 ryb[TN] = get_global_range[TN](1);
@@ -203,11 +211,13 @@ if(ty_ == FPROP){
} }
if(ty_ == WGRAD){ if(ty_ == WGRAD){
os << R"( os << R"(
int32 shift[TK, TN] = 0;)"; __constant__ int32* pd[TN] = delta + ryb;
int32 d[TN] = *pd;
int32 shift[TK, TN] = d[newaxis, :];)";
} }
os << R"( os << R"(
)" << a_ty_ << "* pa[" << AS << "] = a + rxa" << bca1 << " + " << rka << bca0 << lda0 << R"(; )" << a_ty_ << "* pa[" << AS << "] = a + rxa" << bca1 << lda1 << " + " << rka << bca0 << lda0 << R"(;
)" << b_ty_ << "* pb[" << BS << "] = b + ryb" << bcb1 << " + " << rkb << bcb0 << ldb0 << R"(; )" << b_ty_ << "* pb[" << BS << "] = b + ryb" << bcb1 << ldb1 << " + " << rkb << bcb0 << ldb0 << R"(;
)" << a_ty_ << " a[" << AS << R"(] = *pa; )" << a_ty_ << " a[" << AS << R"(] = *pa;
)" << b_ty_ << " b[" << BS << R"(] = *pb; )" << b_ty_ << " b[" << BS << R"(] = *pb;
for(int32 k = K; k > 0; k = k - TK){ for(int32 k = K; k > 0; k = k - TK){
@@ -239,7 +249,7 @@ if(ty_ == WGRAD){
int1 maskw[TK] = (rbw >= pad_w) && (rbw < (AW - pad_w)); int1 maskw[TK] = (rbw >= pad_w) && (rbw < (AW - pad_w));
int1 mask[TK, TN] = maskh[:, newaxis] && maskw[:, newaxis]; int1 mask[TK, TN] = maskh[:, newaxis] && maskw[:, newaxis];
int32 inc[TK, TN] = mask ? 0 : shift; int32 inc[TK, TN] = mask ? 0 : shift;
pb = pb + TK; pb = pb + TK)" << ldb0 << R"(;
)" << b_ty_ << R"(* pbb[TK, TN] = pb + inc; )" << b_ty_ << R"(* pbb[TK, TN] = pb + inc;
@checkb b = *pbb;)"; @checkb b = *pbb;)";
} }
@@ -259,14 +269,15 @@ else{
if(ty_ == BPROP){ if(ty_ == BPROP){
os << R"( os << R"(
int32 rcwhc[TM] = rxc / ABS; int32 rcwhc[TM] = rxc / ABS;
int32 rcw[TM] = rcwhc % AW; int32 rcw[TM] = (rcwhc % AW);
int32 rchc[TM] = rcwhc / AW; int32 rchc[TM] = rcwhc / AW;
int32 rch[TM] = rchc % AH; int32 rch[TM] = (rchc % AH);
int1 maskh[TM] = (rch >= pad_h) && (rch < (AH - pad_h)); int1 maskh[TM] = (rch >= pad_h) && (rch < (AH - pad_h));
int1 maskw[TM] = (rcw >= pad_w) && (rcw < (AW - pad_w)); int1 maskw[TM] = (rcw >= pad_w) && (rcw < (AW - pad_w));
int1 interior[TM, TN] = maskh[:, newaxis] && maskw[:, newaxis]; int1 interior[TM, TN] = maskh[:, newaxis] && maskw[:, newaxis];
fp32* shiftpc[TM, TN] = pc + 0; __constant__ int32* pd[TN] = delta + ryc;
pc = interior ? shiftpc : pc; fp32* shift_pc[TM, TN] = pc + (*pd)[newaxis, :];
pc = interior ? shift_pc : pc;
@checkc __atomic_add(pc, C); @checkc __atomic_add(pc, C);
)"; )";
} }

View File

@@ -255,7 +255,7 @@ std::string cu_module::compile_llvm_module(llvm::Module* module) {
cu_module::cu_module(driver::context * context, llvm::Module* ll_module): cu_module(context, compile_llvm_module(ll_module)) { } cu_module::cu_module(driver::context * context, llvm::Module* ll_module): cu_module(context, compile_llvm_module(ll_module)) { }
cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){ cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){
// std::cout << source << sd::endl; // std::cout << source << std::endl;
cu_context::context_switcher ctx_switch(*context); cu_context::context_switcher ctx_switch(*context);
// JIT compile source-code // JIT compile source-code
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER}; CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER};

View File

@@ -96,6 +96,46 @@ jit::jit(driver::context *context): driver_context_(context),
jit::~jit(){ } jit::~jit(){ }
std::vector<unsigned> jit::get_valid(const char *name, const char *src) {
// find metaparameters
auto ptt_module = make_triton_module(name, src);
ir::module &tt_module = *ptt_module;
// set parameters
passes_wrapper passes(target_.get());
passes.target_independent(tt_module);
passes.tune.run(tt_module);
auto mps = passes.tune.get_params(tt_module);
// create parameter ranges
std::vector<std::vector<unsigned>> ranges;
for(ir::metaparameter *mp: mps)
ranges.push_back(mp->get_space());
// iterate over parameters
std::vector<unsigned> result;
loop_nest<unsigned>(ranges, [&](const std::vector<unsigned> params){
if(!result.empty())
return;
std::map<ir::value*, std::vector<std::string>> errors;
unsigned i = 0;
for(ir::metaparameter *mp: mps)
mp->set_value(params[i++]);
passes.target_independent(tt_module);
passes.tune.init(tt_module);
passes.tune.check_constraints(errors);
// for(auto e: errors)
// for(auto x: e.second)
// std::cout << x << std::endl;
// std::cout << "-----" << std::endl;
if(!errors.empty())
return;
result = params;
});
if(result.empty())
throw std::runtime_error("couldn't find valid parameters");
return result;
}
jit::tune_res_t jit::autotune(const char *name, const char *src, benchmark_t benchmark) { jit::tune_res_t jit::autotune(const char *name, const char *src, benchmark_t benchmark) {
// find metaparameters // find metaparameters
auto ptt_module = make_triton_module(name, src); auto ptt_module = make_triton_module(name, src);