From aa8bcf6bde843ea2ce995678afa31a34b2fe15e0 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Mon, 15 Jul 2019 21:03:58 -0700 Subject: [PATCH] [dnn/shift] added split-k for shift-conv --- examples/cpp/shift.cpp | 10 +- examples/python/tensorflow/run.py | 8 +- include/triton/dnn/base.h | 2 + include/triton/dnn/shift.h | 4 + include/triton/tools/bench.hpp | 20 +- lib/dnn/base.cpp | 18 +- lib/dnn/shift.cpp | 298 ++++++++++++------------------ lib/lang/expression.cpp | 4 + lib/runtime/jit.cpp | 5 +- 9 files changed, 166 insertions(+), 203 deletions(-) diff --git a/examples/cpp/shift.cpp b/examples/cpp/shift.cpp index 33ded064e..482fad6b4 100644 --- a/examples/cpp/shift.cpp +++ b/examples/cpp/shift.cpp @@ -18,8 +18,8 @@ int main() { // initialization int32_t R = 3, S = 3; - int32_t B = 32, F = 128; - int32_t H = 28, W = 28; + int32_t B = 128, F = 128; + int32_t H = 16, W = 16; int32_t C = 128; // random shifts @@ -44,9 +44,9 @@ int main() { std::swap(b_size, c_size); std::swap(a_size, b_size); } - std::vector ha(B*C*H*W); - std::vector hb(C*F); - std::vector hc(B*F*H*W); + std::vector ha(a_size); + std::vector hb(b_size); + std::vector hc(c_size); std::vector rc(hc.size()); // device buffers triton::driver::buffer* dc = triton::driver::buffer::create(context, hc.size()*4); diff --git a/examples/python/tensorflow/run.py b/examples/python/tensorflow/run.py index 971ad2898..893fc5b10 100644 --- a/examples/python/tensorflow/run.py +++ b/examples/python/tensorflow/run.py @@ -58,9 +58,9 @@ def blocksparse_matmul_grad(op, dy): return (dx, dw) def run_shift(): - B, C, H, W = 1, 16, 4, 4 + B, C, H, W = 2, 16, 4, 4 R, S, F = 3, 3, 16 - stride_h, stride_w = 2, 2 + stride_h, stride_w = 1, 1 np.random.seed(2) a = tf.placeholder(tf.float16, shape=[B, C, H, W]) b = tf.placeholder(tf.float16, shape=[C, F]) @@ -82,8 +82,8 @@ def run_shift(): dx_t, dx_n = grads[0] #import sys #np.set_printoptions(threshold=sys.maxsize) - print(dx_t) - print(dx_n) + print(dw_t) + print(dw_n) print(np.max(np.abs(dw_t - dw_n))) print(np.max(np.abs(dx_t - dx_n))) # Run diff --git a/include/triton/dnn/base.h b/include/triton/dnn/base.h index e8ba1c47e..3045ffb49 100644 --- a/include/triton/dnn/base.h +++ b/include/triton/dnn/base.h @@ -43,6 +43,8 @@ protected: private: // initialize virtual void init_impl(driver::stream *, driver::cu_module *){ } + // deinitialize + virtual void deinit_impl(){ } // enqueue virtual void enqueue_impl(driver::stream *stream, driver::kernel *kernel, std::vector args, diff --git a/include/triton/dnn/shift.h b/include/triton/dnn/shift.h index 8f33aee66..fbff404ca 100644 --- a/include/triton/dnn/shift.h +++ b/include/triton/dnn/shift.h @@ -52,6 +52,7 @@ public: private: // initialize and enqueue void init_impl(driver::stream *stream, driver::cu_module *module); + void deinit_impl(); void enqueue_impl(driver::stream *stream, driver::kernel *kernel, std::vector args, triton::runtime::launch_information info); @@ -163,6 +164,9 @@ private: bool BT_; // layout layout_t layout_; + // locks + size_t max_locks_; + driver::buffer *locks_; }; } diff --git a/include/triton/tools/bench.hpp b/include/triton/tools/bench.hpp index f37c04371..3c584bb02 100644 --- a/include/triton/tools/bench.hpp +++ b/include/triton/tools/bench.hpp @@ -32,15 +32,17 @@ double bench(OP const & op, SYNC const & sync, const triton::driver::device * de double total_time = 0; op(); sync(); - float norm = 1; - // normalize clock if possible to get roughly constant result - if(auto cu_device = dynamic_cast(device)) - norm = (float)cu_device->current_sm_clock()/cu_device->max_sm_clock(); - tmr.start(); - op(); - sync(); - times.push_back(norm*tmr.get().count()); - total_time+=times.back(); +// while(total_time*1e-9 < 1e-3){ + float norm = 1; + // normalize clock if possible to get roughly constant result + if(auto cu_device = dynamic_cast(device)) + norm = (float)cu_device->current_sm_clock()/cu_device->max_sm_clock(); + tmr.start(); + op(); + sync(); + times.push_back(norm*tmr.get().count()); + total_time+=times.back(); +// } return *std::min_element(times.begin(), times.end()); } diff --git a/lib/dnn/base.cpp b/lib/dnn/base.cpp index f5e2af0b2..c4f5ace3e 100644 --- a/lib/dnn/base.cpp +++ b/lib/dnn/base.cpp @@ -29,19 +29,20 @@ void base::enqueue(driver::stream *stream, std::vector args, b rt::jit* jit; /* the current template has not already been compiled */ if(m_jit.find(this) == m_jit.end()) { - jit = m_jit.emplace(this->clone(), new rt::jit(ctx)).first->second.get(); + base* clone = this->clone(); + jit = m_jit.emplace(clone, new rt::jit(ctx)).first->second.get(); std::ostringstream oss; - triton_c_src(oss); + clone->triton_c_src(oss); std::string src = oss.str(); auto benchmark = [&](triton::driver::kernel* kernel, rt::launch_information info) { // launch info - unsigned nthreads = info.num_threads; - init_impl(stream, (triton::driver::cu_module*)kernel->module()); - enqueue_impl(stream, kernel, args, info); + clone->init_impl(stream, (triton::driver::cu_module*)kernel->module()); + clone->enqueue_impl(stream, kernel, args, info); stream->synchronize(); - double ts = triton::tools::bench([&](){ enqueue_impl(stream, kernel, args, info); }, + double ts = triton::tools::bench([&](){ clone->enqueue_impl(stream, kernel, args, info); }, [&](){ stream->synchronize(); }, ctx->device()); + clone->deinit_impl(); return num_flops() / ts * 1e-3; }; // auto-tune and save result @@ -53,7 +54,7 @@ void base::enqueue(driver::stream *stream, std::vector args, b jit->add_module(name_.c_str(), src.c_str(), jit->get_valid(name_.c_str(), src.c_str())); } triton::driver::kernel* kernel = jit->get_function(name_.c_str()); - init_impl(stream, (triton::driver::cu_module*)kernel->module()); + clone->init_impl(stream, (triton::driver::cu_module*)kernel->module()); } /* retrieved compiled template */ else @@ -63,7 +64,8 @@ void base::enqueue(driver::stream *stream, std::vector args, b driver::kernel* kernel = jit->get_function(name_.c_str()); rt::launch_information info = jit->get_launch_info(name_.c_str()); /* launch */ - enqueue_impl(stream, kernel, args, info); + auto it = m_jit.find(this); + it->first->enqueue_impl(stream, kernel, args, info); } } diff --git a/lib/dnn/shift.cpp b/lib/dnn/shift.cpp index cc6dccc4d..87aaf32cb 100644 --- a/lib/dnn/shift.cpp +++ b/lib/dnn/shift.cpp @@ -124,6 +124,9 @@ shift::shift(int B, int C, if(layout_ == NCHW) shapes_c_ = {B, C, AH_, AW_}; } + // locks + max_locks_ = (op_ == WGRAD) ? 8192 : 0; + locks_ = nullptr; } base* shift::clone() const { @@ -195,11 +198,30 @@ void shift::init_impl(driver::stream *stream, driver::cu_module *module) { build_delta_a(); triton::driver::buffer* delta_a = ((triton::driver::cu_module*)module)->symbol("delta_a"); stream->write(delta_a, false, 0, h_delta_a.size()*4, h_delta_a.data()); + // locks + if(locks_ == nullptr && max_locks_ > 0){ + std::vector hlocks(2*max_locks_, 0); + locks_ = triton::driver::buffer::create(stream->context(), 2*max_locks_*4); + stream->write(locks_, false, 0, hlocks); + } +} + +void shift::deinit_impl() { + if(locks_ != nullptr){ + delete locks_; + locks_ = nullptr; + } } void shift::enqueue_impl(driver::stream *stream, driver::kernel *kernel, std::vector args, runtime::launch_information info) { + unsigned TM = info.global_range_size[0], TN = info.global_range_size[1]; + unsigned grid_0 = (M_ + TM - 1)/TM; + unsigned grid_1 = (N_ + TN - 1)/TN; + unsigned num_locks = grid_0 * grid_1; + unsigned grid_2 = num_locks < max_locks_ ? info.globals.at("GZ") : 1; + std::array grid = {grid_0, grid_1, grid_2}; driver::buffer *a = args[0], *b = args[1], *c = args[2]; kernel->setArg(0, a); kernel->setArg(1, b); @@ -228,8 +250,9 @@ void shift::enqueue_impl(driver::stream *stream, driver::kernel *kernel, kernel->setArg(24, BW_); kernel->setArg(25, CH_); kernel->setArg(26, CW_); - unsigned TM = info.global_range_size[0], TN = info.global_range_size[1]; - std::array grid = {(M_ + TM - 1)/TM, (N_ + TN - 1)/TN, 1}; + kernel->setArg(27, (num_locks > max_locks_) ? nullptr : locks_); + kernel->setArg(28, (int32_t)grid[0]); + kernel->setArg(29, (int32_t)grid[1]); if(op_ == BPROP){ size_t c_nbytes = (c_ty_ == "fp16") ? 2 : 4; ((driver::cu_buffer*)c)->set_zero(stream, AH_*AW_*B_*C_*c_nbytes); @@ -256,12 +279,49 @@ void shift::triton_c_src(std::ostream &os) const { std::string BS = BS0 + ", " + BS1; bool is_chwn = layout_ == CHWN; + auto compute_bhw = [&](std::string rx, std::string sz, std::string rkx){ + if(is_chwn) { + return R"( + int32 )" + rx + "wh[" + sz + "] = " + rkx + R"( / NB; + int32 )" + rx + "b[" + sz + "] = " + rkx + R"( % NB; + int32 )" + rx + "w[" + sz + "] = " + rx + R"(wh % CW; + int32 )" + rx + "h[" + sz + "] = " + rx + R"(wh / CW;)"; + } + else { + return R"( + int32 )" + rx + "bh[" + sz + "] = " + rkx + R"( / CW; + int32 )" + rx + "w[" + sz + "] = " + rkx + R"( % CW; + int32 )" + rx + "h[" + sz + "] = " + rx + R"(bh % CH; + int32 )" + rx + "b[" + sz + "] = " + rx + R"(bh / CH;)"; + } + }; + + auto compute_interior = [&](std::string rx, std::string sz0, std::string sz1) { + std::string result; + if(shift_edge_h_) + result += "int1 interiorh[" + sz0 + "] = 1;\n "; + else + result += "int1 interiorh[" + sz0 + "] = (" + rx + "h >= pad_h) && (" + rx + "h < (AH - pad_h));\n "; + if(shift_edge_w_) + result += "int1 interiorw[" + sz0 + "] = 1;"; + else + result += "int1 interiorw[" + sz0 + "] = (" + rx + "w >= pad_w) && (" + rx + "w < (AW - pad_w));"; + result += R"( + int1 interior[)" + sz0 + ", " + sz1 + "] = interiorh[:, newaxis] && interiorw[:, newaxis];"; + return result; + }; + std::string result = R"( const tunable int32 TM = {16, 32, 64, 128}; const tunable int32 TN = {16, 32, 64, 128}; -const tunable int32 TK = {)" + std::to_string(TK_) + R"(}; +const tunable int32 TK = {)" + std::to_string(TK_) + "};"; +if(op_ == WGRAD) + result += "const tunable int32 GZ = {1, 4, 16};"; +else + result += "const tunable int32 GZ = {1};"; +result += R"( __constant__ int32* delta_a = alloc_const int32[)" + std::to_string(MAX_C_) + R"(]; void shift(restrict read_only align(16) )" + a_ty_ + R"( *A, @@ -275,32 +335,32 @@ void shift(restrict read_only align(16) )" + a_ty_ + R"( *A, int32 NB, int32 AH, int32 AW, int32 BH, int32 BW, - int32 CH, int32 CW) { + int32 CH, int32 CW, + int32* locks, int32 grid0, int32 grid1) { int32 rxa[TM] = get_global_range[TM](0); int32 ryb[TN] = get_global_range[TN](1); + int32 rz = get_global_range[1](2); int32 rka[TK] = 0 ... TK; int32 rkb[TK] = 0 ... TK; fp32 acc[TM, TN] = 0; int32 pad_h = BH / 2; - int32 pad_w = BW / 2;)"; + int32 pad_w = BW / 2; + int32 split = select(locks == 0, 1, GZ); + int32 div = K / split; + int32 rem = K % split; + K = select(rz < rem, div - 1, div); + int32 offk = select(rz < rem, rz*(div + 1), rz*div + rem);)"; +if(op_ == WGRAD){ + result += R"( + rka = rka + offk; + rkb = rkb + offk; + )"; +} /* A offsets */ if(op_ == FPROP){ - if(is_chwn){ - result += R"( - int32 rawh[TM] = rxa / NB; - int32 rab[TM] = rxa % NB; - int32 raw[TM] = rawh % CW; - int32 rah[TM] = rawh / CW;)"; - } - else{ - result += R"( - int32 rabh[TM] = rxa / CW; - int32 raw[TM] = rxa % CW; - int32 rah[TM] = rabh % CH; - int32 rab[TM] = rabh / CH;)"; - } - result += R"( + result += + compute_bhw("ra", "TM", "rxa") + R"( raw = raw * stride_w; rah = rah * stride_h; int32 offxa[TM] = rab*lda_b + raw*lda_w + rah*lda_h; @@ -309,35 +369,12 @@ if(op_ == FPROP){ int32 d[TK] = *pd; int32 offa_interior[TM, TK] = d[newaxis, :]; int32 offa_exterior[TM, TK] = rka[newaxis, :] * lda_c; - )"; - if(shift_edge_h_) - result += " int1 interiorh[TM] = 1;\n"; - else - result += " int1 interiorh[TM] = (rah >= pad_h) && (rah < (AH - pad_h));\n"; - if(shift_edge_w_) - result += " int1 interiorw[TM] = 1;"; - else - result += " int1 interiorw[TM] = (raw >= pad_w) && (raw < (AW - pad_w));"; - result += R"( - int1 interior[TM, TK] = interiorh[:, newaxis] && interiorw[:, newaxis]; + )" + compute_interior("ra", "TM", "TK") + R"( int32 offa1[TM, TK] = interior ? offa_interior : offa_exterior;)"; } if(op_ == BPROP){ - if(is_chwn){ - result += R"( - int32 rawh[TM] = rxa / NB; - int32 rab[TM] = rxa % NB; - int32 raw[TM] = rawh % CW; - int32 rah[TM] = rawh / CW;)"; - } - else{ - result += R"( - int32 rabh[TM] = rxa / CW; - int32 raw[TM] = rxa % CW; - int32 rah[TM] = rabh % CH; - int32 rab[TM] = rabh / CH;)"; - } - result += R"( + result += + compute_bhw("ra", "TM", "rxa") + R"( int32 offxa[TM] = rab*lda_b + raw*lda_w + rah*lda_h; int32 offa0[TM, TK] = offxa[:, newaxis]; int32 offa1[TM, TK] = rka[newaxis, :] * lda_c;)"; @@ -348,21 +385,8 @@ if(op_ == WGRAD && layout_ == CHWN){ int32 offa1[TK, TM] = rka[:, newaxis];)"; } if(op_ == WGRAD && layout_ == NCHW){ - if(is_chwn){ - result += R"( - int32 rawh[TK] = rka / NB; - int32 rab[TK] = rka % NB; - int32 raw[TK] = rawh % CW; - int32 rah[TK] = rawh / CW;)"; - } - else{ - result += R"( - int32 rabh[TK] = rka / CW; - int32 raw[TK] = rka % CW; - int32 rah[TK] = rabh % CH; - int32 rab[TK] = rabh / CH;)"; - } - result += R"( + result += + compute_bhw("ra", "TK", "rka") + R"( int32 offa0[TK, TM] = rxa[newaxis, :] * lda_c; int32 offxa[TK] = rab*lda_b + raw*lda_w + rah*lda_h; int32 offa1[TK, TM] = offxa[:, newaxis];)"; @@ -380,38 +404,15 @@ if(op_ == BPROP){ int32 offb1[TK, TN] = rkb[:, newaxis];)"; } if(op_ == WGRAD){ - if(is_chwn){ - result += R"( - int32 rbwh[TK] = rkb / NB; - int32 rbb[TK] = rkb % NB; - int32 rbw[TK] = rbwh % CW; - int32 rbh[TK] = rbwh / CW;)"; - } - else{ - result += R"( - int32 rbbh[TK] = rkb / CW; - int32 rbw[TK] = rkb % CW; - int32 rbh[TK] = rbbh % CH; - int32 rbb[TK] = rbbh / CH;)"; - } - result += R"( + result += + compute_bhw("rb", "TK", "rkb") + R"( __constant__ int32* pd[TN] = delta_a + ryb; int32 d[TN] = *pd; int32 shift[TK, TN] = d[newaxis, :]; rbw = rbw * stride_w; rbh = rbh * stride_h; int32 offkb[TK] = rbb*ldb_b + rbw*ldb_w + rbh*ldb_h; - )"; - if(shift_edge_h_) - result += " int1 interiorh[TK] = 1;\n"; - else - result += " int1 interiorh[TK] = (rbh >= pad_h) && (rbh < (AH - pad_h));\n"; - if(shift_edge_w_) - result += " int1 interiorw[TK] = 1;"; - else - result += " int1 interiorw[TK] = (rbw >= pad_w) && (rbw < (AW - pad_w));"; - result += R"( - int1 interior[TK, TN] = interiorh[:, newaxis] && interiorw[:, newaxis]; + )" + compute_interior("rb", "TK", "TN") + R"( int32 incb[TK, TN] = interior ? shift : 0; int32 offb0[TK, TN] = ryb[newaxis, :] * ldb_c; int32 offb1[TK, TN] = offkb[:, newaxis] + incb;)"; @@ -421,8 +422,8 @@ if(op_ == WGRAD){ result += R"( )" + a_ty_ + "* pa[" + AS + R"(] = A + offa0 + offa1; )" + b_ty_ + "* pb[" + BS + R"(] = B + offb0 + offb1; - int1 checka[)" + AS + "] = (rka < K)" + bca0 + R"(; - int1 checkb[)" + BS + "] = (rkb < K)" + bcb0 + R"(; + int1 checka[)" + AS + "] = (rka < K + offk)" + bca0 + R"(; + int1 checkb[)" + BS + "] = (rkb < K + offk)" + bcb0 + R"(; )" + a_ty_ + " a[" + AS + R"(] = checka ? *pa : 0; )" + b_ty_ + " b[" + BS + R"(] = checkb ? *pb : 0; for(int32 k = K; k > 0; k = k - TK){ @@ -450,22 +451,8 @@ if(op_ == WGRAD && layout_ == CHWN){ } if(op_ == WGRAD && layout_ == NCHW){ result += R"( - rka = rka + TK;)"; - if(is_chwn){ - result += R"( - int32 rawh[TK] = rka / NB; - int32 rab[TK] = rka % NB; - int32 raw[TK] = rawh % CW; - int32 rah[TK] = rawh / CW;)"; - } - else{ - result += R"( - int32 rabh[TK] = rka / CW; - int32 raw[TK] = rka % CW; - int32 rah[TK] = rabh % CH; - int32 rab[TK] = rabh / CH;)"; - } - result += R"( + rka = rka + TK;)" + + compute_bhw("ra", "TK", "rka") + R"( offxa = rab*lda_b + raw*lda_w + rah*lda_h; pa = A + offa0 + offxa[:, newaxis];)"; } @@ -475,36 +462,12 @@ if(op_ == WGRAD && layout_ == NCHW){ /* Increment B pointers */ if(op_ == WGRAD){ result += R"( - rkb = rkb + TK;)"; - if(is_chwn){ - result += R"( - int32 rbwh[TK] = rkb / NB; - int32 rbb[TK] = rkb % NB; - int32 rbw[TK] = rbwh % CW; - int32 rbh[TK] = rbwh / CW;)"; - } - else{ - result += R"( - int32 rbbh[TK] = rkb / CW; - int32 rbw[TK] = rkb % CW; - int32 rbh[TK] = rbbh % CH; - int32 rbb[TK] = rbbh / CH;)"; - } - result += R"( + rkb = rkb + TK;)" + + compute_bhw("rb", "TK", "rkb") + R"( rbw = rbw * stride_w; rbh = rbh * stride_h; offkb = rbb*ldb_b + rbw*ldb_w + rbh*ldb_h; - )"; - if(shift_edge_h_) - result += " interiorh = 1;\n"; - else - result += " interiorh = (rbh >= pad_h) && (rbh < (AH - pad_h));\n"; - if(shift_edge_w_) - result += " interiorw = 1;"; - else - result += " interiorw = (rbw >= pad_w) && (rbw < (AW - pad_w));"; - result += R"( - interior = interiorh[:, newaxis] && interiorw[:, newaxis]; + )" + compute_interior("rb", "TK", "TN") + R"( incb = interior ? shift : 0; pb = B + offb0 + offkb[:, newaxis] + incb;)"; } @@ -524,41 +487,15 @@ if(op_ == BPROP){ /* C offsets */ if(op_ == BPROP){ - if(is_chwn){ - result += R"( - int32 rcwh[TM] = rxc / NB; - int32 rcb[TM] = rxc % NB; - int32 rcw[TM] = rcwh % CW; - int32 rch[TM] = rcwh / CW;)"; - } - else{ - result += R"( - int32 rcbh[TM] = rxc / CW; - int32 rcw[TM] = rxc % CW; - int32 rch[TM] = rcbh % CH; - int32 rcb[TM] = rcbh / CH;)"; - } - result += R"( + result += + compute_bhw("rc", "TM", "rxc") + R"( rcw = rcw * stride_w; rch = rch * stride_h; int32 offxc[TM] = rcb*ldc_b + rcw*ldc_w + rch*ldc_h;)"; } if(op_ == FPROP){ -if(is_chwn){ - result += R"( - int32 rcwh[TM] = rxc / NB; - int32 rcb[TM] = rxc % NB; - int32 rcw[TM] = rcwh % CW; - int32 rch[TM] = rcwh / CW;)"; -} -else{ - result += R"( - int32 rcbh[TM] = rxc / CW; - int32 rcw[TM] = rxc % CW; - int32 rch[TM] = rcbh % CH; - int32 rcb[TM] = rcbh / CH;)"; -} - result += R"( + result += + compute_bhw("rc", "TM", "rxc") + R"( int32 offxc[TM] = rcb*ldc_b + rcw*ldc_w + rch*ldc_h;)"; } if(op_ == WGRAD){ @@ -572,17 +509,8 @@ if(op_ == WGRAD){ int1 checkc1[TN] = ryc < N; int1 checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];)"; if(op_ == BPROP){ - result += "\n"; - if(shift_edge_h_) - result += " int1 interiorh[TM] = 1;\n"; - else - result += " int1 interiorh[TM] = (rch >= pad_h) && (rch < (AH - pad_h));\n"; - if(shift_edge_w_) - result += " int1 interiorw[TM] = 1;"; - else - result += " int1 interiorw[TM] = (rcw >= pad_w) && (rcw < (AW - pad_w));"; - result += R"( - int1 interior[TM, TN] = interiorh[:, newaxis] && interiorw[:, newaxis]; + result += R"( + )" + compute_interior("rc", "TM", "TN") + R"( __constant__ int32* pd[TN] = delta_a + ryc; )" + c_ty_ + R"(* shift_pc[TM, TN] = pc + (*pd)[newaxis, :]; pc = interior ? shift_pc : pc; @@ -591,12 +519,32 @@ if(op_ == BPROP){ } else{ result += R"( - @checkc *pc = c;)"; + int1 has_lock = (GZ > 1) && (locks != 0); + if(has_lock){ + int32 ridx = get_range_id(0); + int32 ridy = get_range_id(1); + int32 *plock = locks + ridx + ridy*grid0; + while(__atomic_cas(plock, 0, 1)); + int32 *pcount = plock + grid0*grid1; + int32 count = *pcount; + int32 countp1 = select(count == split - 1, 0, count + 1); + if(count == 0) { + @checkc *pc = c; + *pcount = countp1; + } + else { + @checkc *pc = c + *pc; + *pcount = countp1; + } + *plock = 0; + } + else{ + @checkc *pc = c; + })"; } result += R"( })"; -// std::cout << result << std::endl; os << result; } diff --git a/lib/lang/expression.cpp b/lib/lang/expression.cpp index 85e98a771..1e0536801 100644 --- a/lib/lang/expression.cpp +++ b/lib/lang/expression.cpp @@ -73,10 +73,14 @@ ir::value *binary_expression::llvm_op(ir::module *mod, ir::builder &builder, ir: return builder.create_icmpSGE(lhs, rhs, name); if(op_ == GE && is_int && !is_signed) return builder.create_icmpUGE(lhs, rhs, name); + if(op_ == EQ && is_ptr) + return builder.create_icmpEQ(lhs, rhs, name); if(op_ == EQ && is_float) return builder.create_fcmpOEQ(lhs, rhs, name); if(op_ == EQ && is_int) return builder.create_icmpEQ(lhs, rhs, name); + if(op_ == NE && is_ptr) + return builder.create_icmpNE(lhs, rhs, name); if(op_ == NE && is_float) return builder.create_fcmpONE(lhs, rhs, name); if(op_ == NE && is_int) diff --git a/lib/runtime/jit.cpp b/lib/runtime/jit.cpp index b55680a21..15d33b029 100644 --- a/lib/runtime/jit.cpp +++ b/lib/runtime/jit.cpp @@ -43,7 +43,8 @@ void loop_nest(std::vector const & ranges, size_t current = 0; while(true){ //Execute function - pool.add_job([values, &f](){ f(values); }); +// pool.add_job([values, &f](){ f(values); }); + f(values); //Increment counters while(values[i]++ == ranges[i] - 1){ if(i == 0) @@ -169,7 +170,7 @@ jit::tune_res_t jit::autotune(const char *name, const char *src, benchmark_t ben ranges.push_back(mp->get_space()); // iterate over parameters tune_res_t best; - size_t nthreads = 1; + size_t nthreads = 4; std::mutex mutex; loop_nest(ranges, [&](const std::vector params){ std::map> errors;