From 1d88f0a36b3555c68bfbc0f14953876d0cc747c5 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Wed, 3 Jul 2019 19:25:16 -0700 Subject: [PATCH] stuff --- examples/cpp/shift.cpp | 4 +- examples/python/tensorflow/run.py | 42 +++++++--- examples/python/tensorflow/shift.cpp | 111 ++++++++++++++++++++------- include/triton/runtime/jit.h | 1 + lib/dnn/shift.cpp | 55 +++++++------ lib/driver/module.cpp | 2 +- lib/runtime/jit.cpp | 40 ++++++++++ 7 files changed, 194 insertions(+), 61 deletions(-) diff --git a/examples/cpp/shift.cpp b/examples/cpp/shift.cpp index b330a3a9c..3d7646d9e 100644 --- a/examples/cpp/shift.cpp +++ b/examples/cpp/shift.cpp @@ -74,12 +74,12 @@ int main() { // shift std::vector params = { - 16, 4, 64, 16, 4, 128, 2, 2, 1, 2, 4, 4, 16, 4 + 4, 2, 16, 4, 128, 2, 2, 1, 1, 8, 16, 8, 2 }; std::ostringstream oss; shift.src(oss); std::string src = oss.str(); - jit.autotune("shift", src.c_str(), benchmark); +// jit.autotune("shift", src.c_str(), benchmark); jit.add_module("shift", src.c_str(), params); triton::driver::kernel* kernel = jit.get_function("shift"); triton::jit::launch_information info = jit.get_launch_info("shift"); diff --git a/examples/python/tensorflow/run.py b/examples/python/tensorflow/run.py index 63acfdf2a..cd6365f52 100644 --- a/examples/python/tensorflow/run.py +++ b/examples/python/tensorflow/run.py @@ -1,7 +1,9 @@ import os import tensorflow as tf +from tensorflow.python.framework import ops import numpy as np from time import time + data_files_path = tf.resource_loader.get_data_files_path() library_dir = os.path.dirname(os.path.realpath(__file__)) module = tf.load_op_library(os.path.join(library_dir, 'libtf_blocksparse.so')) @@ -42,23 +44,45 @@ def run_conv(): result = sess.run([c], feed_dict = {a: ha, b: hb})[0] + +@ops.RegisterGradient('ShiftConv') +def blocksparse_matmul_grad(op, dy): + shift_h = op.get_attr('shift_h') + shift_w = op.get_attr('shift_w') + x = op.inputs[0] + w = op.inputs[1] + dx = module.shift_conv_dx(dy, w, shift_h=shift_h, shift_w=shift_w) + dw = module.shift_conv_dw(dy, x, shift_h=shift_h, shift_w=shift_w) + return (dx, dw) + def run_shift(): - B, C, H, W = 16, 32, 32, 32 - R, S, F = 3, 3, 32 + B, C, H, W = 1, 16, 8, 8 + R, S, F = 3, 3, 16 a = tf.placeholder(tf.float32, shape=[C, H, W, B]) b = tf.placeholder(tf.float32, shape=[C, F]) - shift_h = tf.zeros(C, tf.int32) - shift_w = tf.zeros(C, tf.int32) - hshift_h = np.zeros(C, np.int32) - hshift_w = np.zeros(C, np.int32) + #hshift_h = np.random.randint(-R//2, R//2 + 1, size=C, dtype=np.int32) + #hshift_w = np.random.randint(-S//2, R//2 + 1, size=C, dtype=np.int32) + hshift_h = 0*np.ones(C, dtype=np.int32) + hshift_w = 0*np.ones(C, dtype=np.int32) c = module.shift_conv(a, b, shift_h=tf.make_tensor_proto(hshift_h), shift_w=tf.make_tensor_proto(hshift_w)) # Reference - ha = np.random.rand(C, H, W, B) - hb = np.random.rand(C, F) - # Run + ha = np.ones((C, H, W, B), dtype=np.int32) + hb = np.ones((C, F), dtype=np.int32) sess = tf.InteractiveSession() + grads = tf.test.compute_gradient([a, b], [(C, H, W, B), (C, F)], c, (C, H, W, B), + extra_feed_dict={a: ha, b: hb}) + dx_t, dx_n = grads[0] + dw_t, dw_n = grads[1] + print(dw_t) + print(dw_n) + #print(np.max(dw_t - dw_n)) + #print(np.max(dx_t - dx_n)) + np.savetxt('theoretical.dat', dw_t, fmt='%4.2f') + np.savetxt('numerical.dat', dw_n, fmt='%4.2f') + # Run sess.run(tf.global_variables_initializer()) result = sess.run([c], feed_dict = {a: ha, b: hb})[0] + #print(result) run_shift() diff --git a/examples/python/tensorflow/shift.cpp b/examples/python/tensorflow/shift.cpp index 812912704..a049f869d 100644 --- a/examples/python/tensorflow/shift.cpp +++ b/examples/python/tensorflow/shift.cpp @@ -19,6 +19,15 @@ using namespace tensorflow; using GPUDevice = Eigen::GpuDevice; +typedef std::tuple shift_key_t; + +static std::map> m_stream; +static std::map> m_jit; +static std::map> m_config; + template class ShiftConvOp : public OpKernel { public: @@ -78,15 +87,27 @@ public: // shapes int64_t C, H, W, B, F; FillShapes(context, C, H, W, B, F, tf_a, tf_b); + int64_t D = 1, T = 1; + bool has_bias = false; // shift configuration int32_t* shift_h_data = h_shift_h_.flat().data(); int32_t* shift_w_data = h_shift_w_.flat().data(); std::vector shift_h(shift_h_data, shift_h_data + C); std::vector shift_w(shift_w_data, shift_w_data + C); - triton::dnn::shift shift(B, C, 1, H, W, 1, R_, S_, F, shift_h, shift_w, "fp32", "fp32", OP, false); + shift_key_t key = {B, C, 1, H, W, 1, R_, S_, F, shift_h_data, shift_w_data, OP, has_bias}; + // create configuration + triton::dnn::shift* shift; + if(m_config.find(key) == m_config.end()) + shift = m_config.emplace(key, new triton::dnn::shift( + B, C, D, H, W, T, R_, S_, F, + shift_h, shift_w, "fp32", "fp32", OP, has_bias)) + .first->second.get(); + else + shift = m_config.at(key).get(); + // shapes for c std::vector c_shapes; - for(int32_t x: shift.c_shapes()) + for(int32_t x: shift->c_shapes()) c_shapes.push_back(x); TensorShape out_shapes(c_shapes); Tensor* tf_c = nullptr; @@ -94,38 +115,58 @@ public: // return early if possible if (out_shapes.num_elements() == 0) return; - // initialize default compute device - triton::jit jit(ctx); // matrix multiplication parameters triton::driver::cu_buffer da(ctx, (CUdeviceptr)tf_a.flat().data(), false); triton::driver::cu_buffer db(ctx, (CUdeviceptr)tf_b.flat().data(), false); triton::driver::cu_buffer dc(ctx, (CUdeviceptr)tf_c->flat().data(), false); - // benchmark a given matrix multiplication kernel - auto benchmark = [&](triton::driver::kernel* kernel, - triton::jit::launch_information info) { - // launch info - unsigned TM = info.global_range_size[0]; - unsigned TN = info.global_range_size[1]; - unsigned nthreads = info.num_threads; - shift.init(stream, (triton::driver::cu_module*)kernel->module()); - shift.enqueue(stream, kernel, &da, &db, &dc, TM, TN, nthreads); - stream->synchronize(); - double ts = triton::tools::bench([&](){ shift.enqueue(stream, kernel, &da, &db, &dc, TM, TN, nthreads); }, - [&](){ stream->synchronize(); }, ctx->device()); - return shift.get_nflops() / ts * 1e-3; - }; - - std::ostringstream oss; - shift.src(oss); - std::string src = oss.str(); - triton::jit::tune_res_t best = jit.autotune("shift", src.c_str(), benchmark); + // get JIT + triton::jit* jit; + bool autotune = false; + if(m_jit.find(key) == m_jit.end()) { + jit = m_jit.emplace(key, new triton::jit(ctx)).first->second.get(); + std::ostringstream oss; + shift->src(oss); + std::string src = oss.str(); + auto benchmark = [&](triton::driver::kernel* kernel, + triton::jit::launch_information info) { + // launch info + unsigned TM = info.global_range_size[0]; + unsigned TN = info.global_range_size[1]; + unsigned nthreads = info.num_threads; + shift->init(stream, (triton::driver::cu_module*)kernel->module()); + shift->enqueue(stream, kernel, &da, &db, &dc, TM, TN, nthreads); + stream->synchronize(); + double ts = triton::tools::bench([&](){ shift->enqueue(stream, kernel, &da, &db, &dc, TM, TN, nthreads); }, + [&](){ stream->synchronize(); }, ctx->device()); + return shift->get_nflops() / ts * 1e-3; + }; + // auto-tune and save result + if(autotune) { + triton::jit::tune_res_t best = jit->autotune("shift", src.c_str(), benchmark); + jit->add_module("shift", src.c_str(), best.params); + } + else { + jit->add_module("shift", src.c_str(), jit->get_valid("shift", src.c_str())); + } + triton::driver::kernel* kernel = jit->get_function("shift"); + shift->init(stream, (triton::driver::cu_module*)kernel->module()); + } + else + jit = m_jit.at(key).get(); + // Run + triton::driver::kernel* kernel = jit->get_function("shift"); + triton::jit::launch_information info = jit->get_launch_info("shift"); + // launch info + unsigned TM = info.global_range_size[0]; + unsigned TN = info.global_range_size[1]; + unsigned nthreads = info.num_threads; + // enqueue + shift->enqueue(stream, kernel, &da, &db, &dc, TM, TN, nthreads); } private: Tensor h_shift_h_; Tensor h_shift_w_; -// triton::driver::buffer* d_shift_h_; -// triton::driver::buffer* d_shift_w_; int R_; int S_; }; @@ -136,5 +177,21 @@ REGISTER_OP("ShiftConv") .Input("b: float32") .Attr("shift_h: tensor") .Attr("shift_w: tensor") - .Output("c: float32") -; + .Output("c: float32"); + +REGISTER_KERNEL_BUILDER(Name("ShiftConvDx").Device(DEVICE_GPU), ShiftConvOp); +REGISTER_OP("ShiftConvDx") + .Input("a: float32") + .Input("b: float32") + .Attr("shift_h: tensor") + .Attr("shift_w: tensor") + .Output("c: float32"); + +REGISTER_KERNEL_BUILDER(Name("ShiftConvDw").Device(DEVICE_GPU), ShiftConvOp); +REGISTER_OP("ShiftConvDw") + .Input("a: float32") + .Input("b: float32") + .Attr("shift_h: tensor") + .Attr("shift_w: tensor") + .Output("c: float32"); + diff --git a/include/triton/runtime/jit.h b/include/triton/runtime/jit.h index 6bc377c95..ca5395893 100644 --- a/include/triton/runtime/jit.h +++ b/include/triton/runtime/jit.h @@ -103,6 +103,7 @@ private: public: jit(driver::context* context); ~jit(); + std::vector get_valid(const char *name, const char *src); tune_res_t autotune(const char* name, const char* src, benchmark_t benchmark); void add_module(ir::module &module, const std::vector& params = {}); void add_module(const char* name, const char* src, const std::vector& params = {}); diff --git a/lib/dnn/shift.cpp b/lib/dnn/shift.cpp index e54ac4bdb..1a640e91e 100644 --- a/lib/dnn/shift.cpp +++ b/lib/dnn/shift.cpp @@ -70,16 +70,26 @@ shift::shift(int B, int C, } void shift::build_deltas() { - // compute offset - auto offset = [&](unsigned c) { - return c*ld_a_[0] + shift_h_[c]*ld_a_[1] + shift_w_[c]*ld_a_[2]; - }; h_deltas_.resize(MAX_C_); - // populate look-up table - for(unsigned c = 0; c < TK_; c++) - h_deltas_[c] = offset(c); - for(unsigned c = 0; c < C_; c++) - h_deltas_[TK_ + c] = offset(c + TK_) - offset(c); + if(ty_ == FPROP){ + // compute offset + auto offset = [&](unsigned c) { + return c*ld_a_[0] + shift_h_[c]*ld_a_[1] + shift_w_[c]*ld_a_[2]; + }; + // populate look-up table + for(unsigned c = 0; c < TK_; c++) + h_deltas_[c] = offset(c); + for(unsigned c = 0; c < C_; c++) + h_deltas_[TK_ + c] = offset(c + TK_) - offset(c); + } + if(ty_ == BPROP){ + for(unsigned c = 0; c < C_; c++) + h_deltas_[c] = shift_h_[c]*ld_c_[1] + shift_w_[c]*ld_c_[2]; + } + if(ty_ == WGRAD){ + for(unsigned c = 0; c < C_; c++) + h_deltas_[c] = shift_h_[c]*ld_b_[1] + shift_w_[c]*ld_b_[2]; + } } size_t shift::a_size(){ @@ -102,7 +112,7 @@ std::vector shift::c_shapes(){ } size_t shift::get_nflops() { - return 2. * M_ * N_ * K_; + return 2.*M_*N_*K_; } @@ -114,15 +124,13 @@ void shift::init(driver::stream *stream, driver::cu_module *module) { void shift::enqueue(driver::stream *stream, driver::kernel *kernel, driver::buffer *a, driver::buffer *b, driver::buffer *c, size_t TM, size_t TN, size_t nthreads) { - if(ty_ == WGRAD) - std::swap(a, b); kernel->setArg(0, a); kernel->setArg(1, b); kernel->setArg(2, c); kernel->setArg(3, M_); kernel->setArg(4, N_); kernel->setArg(5, K_); - kernel->setArg(6, B_*AH_*AW_); + kernel->setArg(6, M_); kernel->setArg(7, N_); kernel->setArg(8, B_); kernel->setArg(9, AH_); @@ -177,7 +185,7 @@ void shift(restrict read_only align(16) )" << a_ty_ << R"( *a, restrict read_only align(16) )" << b_ty_ << R"( *b, fp32 *c, int32 M, int32 N, int32 K, - multiple_of(4) int32 lda, multiple_of(4) int32 ldb, + int32 lda, int32 ldb, int32 ABS, int32 AH, int32 AW, int32 AR, int32 AS) { int32 rxa[TM] = get_global_range[TM](0); int32 ryb[TN] = get_global_range[TN](1); @@ -203,11 +211,13 @@ if(ty_ == FPROP){ } if(ty_ == WGRAD){ os << R"( - int32 shift[TK, TN] = 0;)"; + __constant__ int32* pd[TN] = delta + ryb; + int32 d[TN] = *pd; + int32 shift[TK, TN] = d[newaxis, :];)"; } os << R"( - )" << a_ty_ << "* pa[" << AS << "] = a + rxa" << bca1 << " + " << rka << bca0 << lda0 << R"(; - )" << b_ty_ << "* pb[" << BS << "] = b + ryb" << bcb1 << " + " << rkb << bcb0 << ldb0 << R"(; + )" << a_ty_ << "* pa[" << AS << "] = a + rxa" << bca1 << lda1 << " + " << rka << bca0 << lda0 << R"(; + )" << b_ty_ << "* pb[" << BS << "] = b + ryb" << bcb1 << ldb1 << " + " << rkb << bcb0 << ldb0 << R"(; )" << a_ty_ << " a[" << AS << R"(] = *pa; )" << b_ty_ << " b[" << BS << R"(] = *pb; for(int32 k = K; k > 0; k = k - TK){ @@ -239,7 +249,7 @@ if(ty_ == WGRAD){ int1 maskw[TK] = (rbw >= pad_w) && (rbw < (AW - pad_w)); int1 mask[TK, TN] = maskh[:, newaxis] && maskw[:, newaxis]; int32 inc[TK, TN] = mask ? 0 : shift; - pb = pb + TK; + pb = pb + TK)" << ldb0 << R"(; )" << b_ty_ << R"(* pbb[TK, TN] = pb + inc; @checkb b = *pbb;)"; } @@ -259,14 +269,15 @@ else{ if(ty_ == BPROP){ os << R"( int32 rcwhc[TM] = rxc / ABS; - int32 rcw[TM] = rcwhc % AW; + int32 rcw[TM] = (rcwhc % AW); int32 rchc[TM] = rcwhc / AW; - int32 rch[TM] = rchc % AH; + int32 rch[TM] = (rchc % AH); int1 maskh[TM] = (rch >= pad_h) && (rch < (AH - pad_h)); int1 maskw[TM] = (rcw >= pad_w) && (rcw < (AW - pad_w)); int1 interior[TM, TN] = maskh[:, newaxis] && maskw[:, newaxis]; - fp32* shiftpc[TM, TN] = pc + 0; - pc = interior ? shiftpc : pc; + __constant__ int32* pd[TN] = delta + ryc; + fp32* shift_pc[TM, TN] = pc + (*pd)[newaxis, :]; + pc = interior ? shift_pc : pc; @checkc __atomic_add(pc, C); )"; } diff --git a/lib/driver/module.cpp b/lib/driver/module.cpp index f11118401..4ff863666 100755 --- a/lib/driver/module.cpp +++ b/lib/driver/module.cpp @@ -255,7 +255,7 @@ std::string cu_module::compile_llvm_module(llvm::Module* module) { cu_module::cu_module(driver::context * context, llvm::Module* ll_module): cu_module(context, compile_llvm_module(ll_module)) { } cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){ -// std::cout << source << sd::endl; +// std::cout << source << std::endl; cu_context::context_switcher ctx_switch(*context); // JIT compile source-code CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER}; diff --git a/lib/runtime/jit.cpp b/lib/runtime/jit.cpp index 85e51b22f..30547a19e 100644 --- a/lib/runtime/jit.cpp +++ b/lib/runtime/jit.cpp @@ -96,6 +96,46 @@ jit::jit(driver::context *context): driver_context_(context), jit::~jit(){ } +std::vector jit::get_valid(const char *name, const char *src) { + // find metaparameters + auto ptt_module = make_triton_module(name, src); + ir::module &tt_module = *ptt_module; + // set parameters + passes_wrapper passes(target_.get()); + passes.target_independent(tt_module); + passes.tune.run(tt_module); + auto mps = passes.tune.get_params(tt_module); + // create parameter ranges + std::vector> ranges; + for(ir::metaparameter *mp: mps) + ranges.push_back(mp->get_space()); + // iterate over parameters + std::vector result; + loop_nest(ranges, [&](const std::vector params){ + if(!result.empty()) + return; + std::map> errors; + unsigned i = 0; + for(ir::metaparameter *mp: mps) + mp->set_value(params[i++]); + passes.target_independent(tt_module); + passes.tune.init(tt_module); + passes.tune.check_constraints(errors); +// for(auto e: errors) +// for(auto x: e.second) +// std::cout << x << std::endl; +// std::cout << "-----" << std::endl; + if(!errors.empty()) + return; + result = params; + }); + if(result.empty()) + throw std::runtime_error("couldn't find valid parameters"); + return result; +} + + + jit::tune_res_t jit::autotune(const char *name, const char *src, benchmark_t benchmark) { // find metaparameters auto ptt_module = make_triton_module(name, src);