diff --git a/examples/cpp/dot.cpp b/examples/cpp/dot.cpp index 720c872f2..09483116e 100644 --- a/examples/cpp/dot.cpp +++ b/examples/cpp/dot.cpp @@ -8,8 +8,8 @@ int main() { - bool AT = false; - bool BT = true; + bool AT = true; + bool BT = false; typedef float T; std::string ty = "fp16"; size_t dt_nbytes = sizeof(T); @@ -37,7 +37,7 @@ int main() { stream->write(dc, true, 0, hc); stream->synchronize(); triton::dnn::gemm gemm(M, N, K, AT, BT, ty, ty, 4, 4); - gemm.enqueue(stream, {da, db, dc}, true); + gemm.enqueue(stream, {da, db, dc}, false); // stream->read(dc, true, 0, hc); // gemm.cpu_ref(rc, ha, hb); // for(size_t i = 0; i < M*N; i++) diff --git a/include/triton/codegen/alignment_info.h b/include/triton/codegen/alignment_info.h index b90263dbe..d2d72e176 100644 --- a/include/triton/codegen/alignment_info.h +++ b/include/triton/codegen/alignment_info.h @@ -14,12 +14,17 @@ namespace ir { namespace codegen{ class alignment_info { + struct cst_info { + unsigned num_cst; + unsigned value; + }; + private: // helpers bool is_first_axis_unit(ir::value *v); // populate maps - bool populate_is_constant(ir::value *v); + cst_info populate_is_constant(ir::value *v); unsigned populate_max_contiguous(ir::value *v); unsigned populate_starting_multiple(ir::value *v); @@ -29,7 +34,7 @@ public: unsigned get_max_contiguous(ir::value* v) const; private: - std::map is_constant_; + std::map is_constant_; std::map max_contiguous_; std::map starting_multiple_; }; diff --git a/include/triton/runtime/jit.h b/include/triton/runtime/jit.h index a88cb2ddf..d3088d73b 100644 --- a/include/triton/runtime/jit.h +++ b/include/triton/runtime/jit.h @@ -70,6 +70,7 @@ public: void target_independent(ir::module &module) { optimize_dot.run(module); optimize_trans.run(module); +// ir::print(module, std::cout); } void target_dependent(ir::module &module) { diff --git a/lib/codegen/alignment_info.cpp b/lib/codegen/alignment_info.cpp index 5b7564479..b7e0b3641 100644 --- a/lib/codegen/alignment_info.cpp +++ b/lib/codegen/alignment_info.cpp @@ -9,6 +9,18 @@ namespace triton { namespace codegen{ +inline int gcd(int a, int b) { + if (a == 0) + return b; + if (b == 0) + return a; + if (a == b) + return a; + if (a > b) + return gcd(a-b, b); + return gcd(a, b-a); +} + template inline T add_to_cache(ir::value *i, T value, std::map &map) { return map[i] = value; @@ -22,50 +34,69 @@ bool alignment_info::is_first_axis_unit(ir::value *x){ return true; } -bool alignment_info::populate_is_constant(ir::value *v) { +alignment_info::cst_info alignment_info::populate_is_constant(ir::value *v) { if(is_constant_.find(v) != is_constant_.end()) return is_constant_.at(v); // helper for the cache - auto cache = [this,v](bool value){ return add_to_cache(v, value, is_constant_); }; + auto cache = [this,v](cst_info value){ + return add_to_cache(v, value, is_constant_); } + ; // populate if(auto *x = dynamic_cast(v)){ ir::value *op = x->get_operand(0); - populate_is_constant(op); - if(is_first_axis_unit(op)) - return cache(true); + auto op_cst = populate_is_constant(op); + if(is_first_axis_unit(op)){ + unsigned num_cst = x->get_type()->get_tile_shapes()[0]->get_value(); + return cache({num_cst, op_cst.value}); + } } if(auto *x = dynamic_cast(v)) - return cache(true); + return cache({true, (unsigned)x->get_value()}); if(auto *x = dynamic_cast(v)){ - bool lhs = populate_is_constant(x->get_operand(0)); - bool rhs = populate_is_constant(x->get_operand(1)); - return cache(lhs && rhs); + ir::value* lhs_op = x->get_operand(0); + ir::value* rhs_op = x->get_operand(1); + cst_info lhs = populate_is_constant(lhs_op); + cst_info rhs = populate_is_constant(rhs_op); + if(lhs.num_cst==0 && rhs.value && x->is_int_div()){ + unsigned max_contiguous = populate_max_contiguous(lhs_op); + unsigned starting_multiple = populate_starting_multiple(lhs_op); + return cache({gcd(max_contiguous, rhs.value) - (starting_multiple % rhs.value), 0}); + } + return cache({std::min(lhs.num_cst, rhs.num_cst), 0}); + } + if(auto *x = dynamic_cast(v)){ + ir::value* lhs_op = x->get_operand(0); + ir::value* rhs_op = x->get_operand(1); + cst_info lhs = populate_is_constant(lhs_op); + cst_info rhs = populate_is_constant(rhs_op); + return cache({std::min(lhs.num_cst, rhs.num_cst), 0}); } if(auto *x = dynamic_cast(v)){ - bool value_true = populate_is_constant(x->get_value_true()); - bool value_false = populate_is_constant(x->get_value_false()); - return cache(value_true && value_false); + cst_info value_true = populate_is_constant(x->get_value_true()); + cst_info value_false = populate_is_constant(x->get_value_false()); + return cache({std::min(value_true.num_cst, value_false.num_cst), 0}); } if(v->get_type()->is_tile_ty()) - return cache(false); + return cache({0, 0}); if(auto *x = dynamic_cast(v)){ // put a conservative initial value in phi node to avoid infinite recursion - bool result = true; + unsigned result = 1; for(unsigned n = 0; n < x->get_num_incoming(); n++){ ir::value* inc = x->get_incoming_value(n); if(is_constant_.find(inc) != is_constant_.end()) - result = is_constant_.at(inc); + result = is_constant_.at(inc).num_cst; } - cache(result); + cache({result, 0}); // recurse for(unsigned n = 0; n < x->get_num_incoming(); n++){ ir::value* inc = x->get_incoming_value(n); - result = result && populate_is_constant(inc); + result = std::min(result, populate_is_constant(inc).num_cst); } - return cache(result); + return cache({result, 0}); } // scalars are always constant in the contiguous dimension - return cache(true); + // but value is not known at compile-time + return cache({1, 0}); } unsigned alignment_info::populate_max_contiguous(ir::value *v){ @@ -95,13 +126,21 @@ unsigned alignment_info::populate_max_contiguous(ir::value *v){ ir::value* rhs = x->get_operand(1); unsigned lhs_max_contiguous = populate_max_contiguous(lhs); unsigned rhs_max_contiguous = populate_max_contiguous(rhs); - bool lhs_has_cst = populate_is_constant(lhs); - bool rhs_has_cst = populate_is_constant(rhs); - if(x->is_int_add_sub()){ - if(lhs_has_cst) - return cache(rhs_max_contiguous); - if(rhs_has_cst) + cst_info lhs_cst_info = populate_is_constant(lhs); + cst_info rhs_cst_info = populate_is_constant(rhs); + if(x->is_int_rem() && rhs_cst_info.value > 0) + return cache(std::min(lhs_max_contiguous, rhs_cst_info.value)); + if(x->is_int_mult()){ + if(rhs_cst_info.value == 1) return cache(lhs_max_contiguous); + if(lhs_cst_info.value == 1) + return cache(rhs_max_contiguous); + } + if(x->is_int_add_sub()){ + if(lhs_cst_info.num_cst) + return cache(gcd(rhs_max_contiguous, lhs_cst_info.num_cst)); + if(rhs_cst_info.num_cst) + return cache(gcd(lhs_max_contiguous, rhs_cst_info.num_cst)); } } if(auto *x = dynamic_cast(v)){ @@ -114,11 +153,11 @@ unsigned alignment_info::populate_max_contiguous(ir::value *v){ ir::value* rhs = x->get_operand(1); unsigned lhs_max_contiguous = populate_max_contiguous(lhs); unsigned rhs_max_contiguous = populate_max_contiguous(rhs); - bool lhs_has_cst = populate_is_constant(lhs); - bool rhs_has_cst = populate_is_constant(rhs); - if(lhs_has_cst) + auto lhs_cst_info = populate_is_constant(lhs); + auto rhs_cst_info = populate_is_constant(rhs); + if(lhs_cst_info.num_cst) return cache(rhs_max_contiguous); - if(rhs_has_cst) + if(rhs_cst_info.num_cst) return cache(lhs_max_contiguous); } if(auto *x = dynamic_cast(v)){ @@ -140,22 +179,12 @@ unsigned alignment_info::populate_max_contiguous(ir::value *v){ return cache(1); } -inline int gcd(int a, int b) { - if (a == 0) - return b; - if (b == 0) - return a; - if (a == b) - return a; - if (a > b) - return gcd(a-b, b); - return gcd(a, b-a); -} - unsigned alignment_info::populate_starting_multiple(ir::value *v){ if(starting_multiple_.find(v) != starting_multiple_.end()) return starting_multiple_.at(v); - auto cache = [this,v](unsigned value){ return add_to_cache(v, value, starting_multiple_); }; + auto cache = [this,v](unsigned value){ + return add_to_cache(v, value, starting_multiple_); + }; // has metadata if(auto *x = dynamic_cast(v)){ unsigned multiple_of = x->get_metadata(ir::metadata::multiple_of); @@ -185,15 +214,16 @@ unsigned alignment_info::populate_starting_multiple(ir::value *v){ return cache(gcd(lhs, rhs)); if(x->is_int_div()) return cache(std::max(lhs / rhs, 1)); - if(x->is_int_rem()) - return cache(std::max(lhs % rhs, 1)); + if(x->is_int_rem() && rhs > 1) + return cache(gcd(lhs, rhs)); if(x->is_shl()) return cache(lhs << rhs); if(x->is_shr()) return cache(std::max(lhs >> rhs, 1)); } - if(auto *x = dynamic_cast(v)) + if(auto *x = dynamic_cast(v)){ return cache(x->get_value()); + } if(auto *x = dynamic_cast(v)){ return cache(x->get_first()->get_value()); } @@ -270,7 +300,6 @@ void alignment_info::run(ir::module &mod) { for(ir::basic_block *block: fn->blocks()) for(ir::instruction *i: block->get_inst_list()){ populate_max_contiguous(i); -// std::cout << i->get_name() << " " << max_contiguous_.at(i) << " " << starting_multiple_.at(i) << std::endl; } } diff --git a/lib/codegen/tune.cpp b/lib/codegen/tune.cpp index 2d104d8d6..47b3f05fa 100644 --- a/lib/codegen/tune.cpp +++ b/lib/codegen/tune.cpp @@ -233,10 +233,15 @@ void tune::run(ir::module &mod) { for(ir::instruction *i : block->get_inst_list()){ if(fragments_.find({i, 0}) != fragments_.end() && fragments_.at({i, 0}) != STRIDED_SCAN) continue; - if(dynamic_cast(i) && i->get_type()->is_tile_ty()){ - ir::type *ty = mod.get_builder().get_int32_ty(); - std::unique_ptr tmp(ir::metaparameter::create(ctx, ty, 4, 8)); - *params_.at(i).at("nts.d0") = *tmp; + if(auto *ld = dynamic_cast(i)) + if(i->get_type()->is_tile_ty()){ + ir::type *ptr_ty = ld->get_pointer_operand()->get_type()->get_scalar_ty(); + size_t addr_space = ptr_ty->get_pointer_address_space(); + if(addr_space < 4){ + ir::type *ty = mod.get_builder().get_int32_ty(); + std::unique_ptr tmp(ir::metaparameter::create(ctx, ty, 8, 8)); + *params_.at(i).at("nts.d0") = *tmp; + } } if(dynamic_cast(i) && i->get_type()->is_tile_ty()){ ir::type *ty = mod.get_builder().get_int32_ty(); diff --git a/lib/dnn/base.cpp b/lib/dnn/base.cpp index efca7bec3..73bb474b8 100644 --- a/lib/dnn/base.cpp +++ b/lib/dnn/base.cpp @@ -51,8 +51,7 @@ void base::enqueue(driver::stream *stream, std::vector args, b jit->add_module(name_.c_str(), src.c_str(), best.params); } else { -// jit->add_module(name_.c_str(), src.c_str(), jit->get_valid(name_.c_str(), src.c_str())); - jit->add_module(name_.c_str(), src.c_str(), {32, 128, 16, 128, 2, 2, 2, 2, 4, 4, 32, 8, 4, 1}); + jit->add_module(name_.c_str(), src.c_str(), jit->get_valid(name_.c_str(), src.c_str())); } triton::driver::kernel* kernel = jit->get_function(name_.c_str()); clone->init_impl(stream, (triton::driver::cu_module*)kernel->module()); diff --git a/lib/dnn/gemm.cpp b/lib/dnn/gemm.cpp index 82fdb431b..42c7793c2 100644 --- a/lib/dnn/gemm.cpp +++ b/lib/dnn/gemm.cpp @@ -113,8 +113,8 @@ void matmul(restrict read_only align(16) )" + a_ty_ + R"( *A, int32 bound, int32 *locks, int32 grid0, int32 grid1) { int32 ridx = get_range_id(0); int32 ridy = get_range_id(1); - int32 rxa[TM] = ridx*TM + (0 ... TM); - int32 ryb[TN] = ridy*TN + (0 ... TN); + int32 rxa[TM] = ridx * TM + (0 ... TM); + int32 ryb[TN] = ridy * TN + (0 ... TN); int32 rka[TK] = 0 ... TK; int32 rkb[TK] = 0 ... TK; fp32 c[TM, TN] = 0; diff --git a/lib/dnn/shift.cpp b/lib/dnn/shift.cpp index 47e283769..844c982e7 100644 --- a/lib/dnn/shift.cpp +++ b/lib/dnn/shift.cpp @@ -27,7 +27,7 @@ shift::shift(int B, int C, layout_(layout){ // std::cout << B_ << " " << C_ << " " << F_ << " " << stride_h_ << " " << stride_w_ << " " << a_ty_ << " " << b_ty_ << " " << ty_ << " " << layout_ << std::endl; // max number of channels - TK_ = (ty == FPROP && a_ty_ == "fp32") ? 8 : 16; + TK_ = (ty == FPROP && a_ty_ == "fp32") ? 8 : 32; MAX_C_ = 8192 + TK_; // activation sizes CD_ = AD_ / stride_d_; @@ -223,7 +223,7 @@ void shift::deinit_impl() { void shift::enqueue_impl(driver::stream *stream, driver::kernel *kernel, std::vector args, runtime::launch_information info) { - unsigned TM = info.global_range_size[0], TN = info.global_range_size[1]; + unsigned TM = info.globals.at("TM"), TN = info.globals.at("TN"); unsigned grid_0 = (M_ + TM - 1)/TM; unsigned grid_1 = (N_ + TN - 1)/TN; unsigned num_locks = grid_0 * grid_1; @@ -278,6 +278,8 @@ void shift::triton_c_src(std::ostream &os) const { std::string usea = AT_ ? "trans(a)" : "a"; std::string useb = BT_ ? "trans(b)" : "b"; std::string bca0 = "[newaxis, :]", bca1 = "[:, newaxis]"; + std::string stride_h = std::to_string(stride_h_); + std::string stride_w = std::to_string(stride_w_); if(AT_){ std::swap(AS0, AS1); std::swap(bca0, bca1); @@ -290,6 +292,11 @@ void shift::triton_c_src(std::ostream &os) const { std::string BS = BS0 + ", " + BS1; bool is_chwn = layout_ == CHWN; + std::string lda_b = is_chwn ? "1" : "lda_b"; + std::string ldb_b = is_chwn ? "1" : "ldb_b"; + std::string ldc_b = is_chwn ? "1" : "ldc_b"; + + auto compute_bhw = [&](std::string rx, std::string sz, std::string rkx){ std::string B = std::to_string(B_); std::string CW = std::to_string(ICW_); @@ -317,7 +324,7 @@ const tunable int32 TM = {16, 32, 64, 128}; const tunable int32 TN = {16, 32, 64, 128}; const tunable int32 TK = {)" + std::to_string(TK_) + "};"; if(op_ == WGRAD) - result += "const tunable int32 GZ = {1, 4, 16};"; + result += "const tunable int32 GZ = {1};"; else result += "const tunable int32 GZ = {1};"; @@ -329,30 +336,27 @@ void shift(restrict read_only align(16) )" + a_ty_ + R"( *A, )" + c_ty_ + R"( *C, int32 M, int32 N, int32 K, int32 stride_h, int32 stride_w, - multiple_of(4) int32 lda_b, multiple_of(4) int32 lda_w, multiple_of(4) int32 lda_h, multiple_of(4) int32 lda_c, - multiple_of(4) int32 ldb_b, multiple_of(4) int32 ldb_w, multiple_of(4) int32 ldb_h, multiple_of(4) int32 ldb_c, - multiple_of(4) int32 ldc_b, multiple_of(4) int32 ldc_w, multiple_of(4) int32 ldc_h, multiple_of(4) int32 ldc_c, + multiple_of(8) int32 lda_b, multiple_of(8) int32 lda_w, multiple_of(8) int32 lda_h, multiple_of(8) int32 lda_c, + multiple_of(8) int32 ldb_b, multiple_of(8) int32 ldb_w, multiple_of(8) int32 ldb_h, multiple_of(8) int32 ldb_c, + multiple_of(8) int32 ldc_b, multiple_of(8) int32 ldc_w, multiple_of(8) int32 ldc_h, multiple_of(8) int32 ldc_c, int32 NB, int32 AH, int32 AW, int32 BH, int32 BW, int32 CH, int32 CW, int32* locks, int32 grid0, int32 grid1, int32 grid2) { - int32 rxa[TM] = get_global_range[TM](0); - int32 ryb[TN] = get_global_range[TN](1); - int32 rz = get_global_range[1](2); + int32 ridx = get_range_id(0); + int32 ridy = get_range_id(1); + int32 rz = get_range_id(2); + int32 rxa[TM] = ridx*TM + (0 ... TM); + int32 ryb[TN] = ridy*TN + (0 ... TN); int32 rka[TK] = 0 ... TK; int32 rkb[TK] = 0 ... TK; fp32 acc[TM, TN] = 0; int32 pad_h = BH / 2; - int32 pad_w = BW / 2; - int32 div = K / grid2; - int32 rem = K % grid2; - K = select(rz < rem, div - 1, div); - int32 offk = select(rz < rem, rz*(div + 1), rz*div + rem);)"; + int32 pad_w = BW / 2;)"; if(op_ == WGRAD){ result += R"( - rka = rka + offk; - rkb = rkb + offk; + )"; } @@ -360,31 +364,26 @@ if(op_ == WGRAD){ if(op_ == FPROP){ result += compute_bhw("ra", "TM", "rxa") + R"( - raw = raw * stride_w; - rah = rah * stride_h; - int32 offxa[TM] = rab*lda_b + raw*lda_w + rah*lda_h; + raw = raw * )" + stride_w + R"(; + rah = rah * )" + stride_h + R"(; + int32 offxa[TM] = rab*)" + lda_b + R"( + raw*lda_w + rah*lda_h; int32 offa0[TM, TK] = offxa[:, newaxis]; __constant__ int32* pd[TK] = delta_a + rka; - multiple_of(4) int32 d[TK] = *pd; + multiple_of(8) int32 d[TK] = *pd; int32 offa1[TM, TK] = d[newaxis, :];)"; } if(op_ == BPROP){ result += compute_bhw("ra", "TM", "rxa") + R"( - int32 offxa[TM] = rab*lda_b + raw*lda_w + rah*lda_h; + int32 offxa[TM] = rab*)" + lda_b + R"( + raw*lda_w + rah*lda_h; int32 offa0[TM, TK] = offxa[:, newaxis]; int32 offa1[TM, TK] = rka[newaxis, :] * lda_c;)"; } -if(op_ == WGRAD && layout_ == CHWN){ - result += R"( - int32 offa0[TK, TM] = rxa[newaxis, :] * lda_c; - int32 offa1[TK, TM] = rka[:, newaxis];)"; -} -if(op_ == WGRAD && layout_ == NCHW){ +if(op_ == WGRAD){ result += compute_bhw("ra", "TK", "rka") + R"( int32 offa0[TK, TM] = rxa[newaxis, :] * lda_c; - int32 offxa[TK] = rab*lda_b + raw*lda_w + rah*lda_h; + int32 offxa[TK] = rab*)" + lda_b + R"( + raw*lda_w + rah*lda_h; int32 offa1[TK, TM] = offxa[:, newaxis];)"; } @@ -403,11 +402,11 @@ if(op_ == WGRAD){ result += compute_bhw("rb", "TK", "rkb") + R"( __constant__ int32* pd[TN] = delta_a + ryb; - int32 d[TN] = *pd; - int32 shift[TK, TN] = d[newaxis, :]; - rbw = rbw * stride_w; - rbh = rbh * stride_h; - int32 offkb[TK] = rbb*ldb_b + rbw*ldb_w + rbh*ldb_h; + multiple_of(8) int32 d[TN] = *pd; + multiple_of(8) int32 shift[TK, TN] = d[newaxis, :]; + rbw = rbw * )" + stride_w + R"(; + rbh = rbh * )" + stride_h + R"(; + int32 offkb[TK] = rbb*)" + ldb_b + R"( + rbw*ldb_w + rbh*ldb_h; int32 offb0[TK, TN] = ryb[newaxis, :] * ldb_c; int32 offb1[TK, TN] = offkb[:, newaxis] + shift;)"; } @@ -416,8 +415,8 @@ if(op_ == WGRAD){ result += R"( )" + a_ty_ + "* pa[" + AS + R"(] = A + offa0 + offa1; )" + b_ty_ + "* pb[" + BS + R"(] = B + offb0 + offb1; - int1 checka[)" + AS + "] = (rka < K + offk)" + bca0 + R"(; - int1 checkb[)" + BS + "] = (rkb < K + offk)" + bcb0 + R"(; + int1 checka[)" + AS + "] = (rka < K)" + bca0 + R"(; + int1 checkb[)" + BS + "] = (rkb < K)" + bcb0 + R"(; )" + a_ty_ + " a[" + AS + R"(] = checka ? *pa : 0; )" + b_ty_ + " b[" + BS + R"(] = checkb ? *pb : 0; for(int32 k = K; k > 0; k = k - TK){ @@ -436,15 +435,11 @@ if(op_ == BPROP){ result += R"( pa = pa + TK * lda_c;)"; } -if(op_ == WGRAD && layout_ == CHWN){ - result += R"( - pa = pa + TK;)"; -} -if(op_ == WGRAD && layout_ == NCHW){ +if(op_ == WGRAD){ result += R"( rka = rka + TK;)" + compute_bhw("ra", "TK", "rka") + R"( - offxa = rab*lda_b + raw*lda_w + rah*lda_h; + offxa = rab*)" + lda_b + R"( + raw*lda_w + rah*lda_h; pa = A + offa0 + offxa[:, newaxis];)"; } result += R"( @@ -455,9 +450,9 @@ if(op_ == WGRAD){ result += R"( rkb = rkb + TK;)" + compute_bhw("rb", "TK", "rkb") + R"( - rbw = rbw * stride_w; - rbh = rbh * stride_h; - offkb = rbb*ldb_b + rbw*ldb_w + rbh*ldb_h; + rbw = rbw * )" + stride_w + R"(; + rbh = rbh * )" + stride_h + R"(; + offkb = rbb*)" + ldb_b + R"( + rbw*ldb_w + rbh*ldb_h; pb = B + offb0 + offkb[:, newaxis] + shift;)"; } if(op_ == FPROP){ @@ -471,21 +466,21 @@ if(op_ == BPROP){ result += R"( @checkb b = *pb; } - int32 rxc[TM] = get_global_range[TM](0); - int32 ryc[TN] = get_global_range[TN](1);)"; + int32 rxc[TM] = ridx*TM + (0 ... TM); + int32 ryc[TN] = ridy*TN + (0 ... TN);)"; /* C offsets */ if(op_ == BPROP){ result += compute_bhw("rc", "TM", "rxc") + R"( - rcw = rcw * stride_w; - rch = rch * stride_h; - int32 offxc[TM] = rcb*ldc_b + rcw*ldc_w + rch*ldc_h;)"; + rcw = rcw * )" + stride_w + R"(; + rch = rch * )" + stride_h + R"(; + int32 offxc[TM] = rcb*)" + ldc_b + R"( + rcw*ldc_w + rch*ldc_h;)"; } if(op_ == FPROP){ result += compute_bhw("rc", "TM", "rxc") + R"( - int32 offxc[TM] = rcb*ldc_b + rcw*ldc_w + rch*ldc_h;)"; + int32 offxc[TM] = rcb*)" + ldc_b + R"( + rcw*ldc_w + rch*ldc_h;)"; } if(op_ == WGRAD){ result += R"( @@ -506,27 +501,7 @@ if(op_ == BPROP){ } else{ result += R"( - int1 has_lock = (GZ > 1) && (locks != 0); - if(has_lock){ - int32 ridx = get_range_id(0); - int32 ridy = get_range_id(1); - int32 *plock = locks + ridx + ridy*grid0; - int32 *pcount = plock + grid0*grid1; - while(__atomic_cas(plock, 0, 1) == 1); - int32 count = *pcount; - int32 countp1 = select(count == grid2 - 1, 0, count + 1); - if(count == 0) { - @checkc *pc = c; - } - else { - @checkc *pc = c + *pc; - } - *pcount = countp1; - *plock = 0; - } - else{ - @checkc *pc = c; - })"; + @checkc *pc = c;)"; } result += R"( })"; diff --git a/lib/ir/instructions.cpp b/lib/ir/instructions.cpp index e3ac042d1..063dbffc9 100644 --- a/lib/ir/instructions.cpp +++ b/lib/ir/instructions.cpp @@ -130,7 +130,7 @@ bool binary_operator::is_int_mult() const { } bool binary_operator::is_int_add_sub() const { - return op_ == llop::Add || llop::Sub; + return op_ == llop::Add || op_ == llop::Sub; }