diff --git a/include/isaac/kernels/keywords.h b/include/isaac/kernels/keywords.h index 7dc8fcb7f..262153277 100644 --- a/include/isaac/kernels/keywords.h +++ b/include/isaac/kernels/keywords.h @@ -22,10 +22,13 @@ static inline std::string size_type(driver::Device const & device) switch(device.backend()) { #ifdef ISAAC_WITH_CUDA - case driver::CUDA: return "int"; + case driver::CUDA: + return "int"; #endif - case driver::OPENCL: return device.address_bits()==32?"int":"long"; - default: throw; + case driver::OPENCL: + return "int"; + default: + throw; } } diff --git a/lib/driver/kernel.cpp b/lib/driver/kernel.cpp index 882b6f8ef..869d16457 100644 --- a/lib/driver/kernel.cpp +++ b/lib/driver/kernel.cpp @@ -83,18 +83,11 @@ void Kernel::setSizeArg(unsigned int index, size_t N) } #endif case OPENCL: - if(address_bits_==32){ - cl_int NN = N; - ocl::check(clSetKernelArg(h_.cl(), index, 4, &NN)); - } - else if(address_bits_==64) - { - cl_long NN = N; - ocl::check(clSetKernelArg(h_.cl(), index, 8, &NN)); - } - else - throw; + { + cl_int NN = N; + setArg(index, 4, &NN); break; + } default: throw; } diff --git a/lib/kernels/templates/gemv.cpp b/lib/kernels/templates/gemv.cpp index 1693627c8..d0817add1 100644 --- a/lib/kernels/templates/gemv.cpp +++ b/lib/kernels/templates/gemv.cpp @@ -96,19 +96,12 @@ std::string gemv::generate_impl(std::string const & suffix, expressions_tuple co for (const auto & e : dots) stream << e->process(Local(backend).get() + " #scalartype #name_buf[" + to_string(p_.local_size_1*local_size_0_ld) + "];") << std::endl; - stream << "" << _size_t << " lid0 = " << LocalIdx0(backend) << ";" << std::endl; - stream << "" << _size_t << " gid0 = " << GlobalIdx0(backend) << ";" << std::endl; - stream << "" << _size_t << " gpid0 = " << GroupIdx0(backend) << ";" << std::endl; - stream << "" << _size_t << " gsize0 = " << GlobalSize0(backend) << ";" << std::endl; + stream << "for(" << _size_t << " r = " << GlobalIdx1(backend) << "; r < (M +" << p_.local_size_1 - 1 << ")/" << p_.local_size_1 << "*" << p_.local_size_1 << "; r += " << GlobalSize1(backend) << ")" << std::endl; + stream << "{" << std::endl; - stream << "" << _size_t << " lid1 = " << LocalIdx1(backend) <<";" << std::endl; - stream << "" << _size_t << " gid1 = " << GlobalIdx1(backend) <<";" << std::endl; - stream << "" << _size_t << " gpid1 = " << GroupIdx1(backend) << ";" << std::endl; - stream << "" << _size_t << " gsize1 = " << GlobalSize1(backend) <<";" << std::endl; - - stream << "" << _size_t << " upper_bound_1 = ( M +" << p_.local_size_1 - 1 << ")/" << p_.local_size_1 << "*" << p_.local_size_1 << ";" << std::endl; - stream << "for(" << _size_t << " r = gid1; r < upper_bound_1; r += gsize1){" << std::endl; stream.inc_tab(); + stream << "" << _size_t << " lidx = " << LocalIdx0(backend) << ";" << std::endl; + stream << "" << _size_t << " lidy = " << LocalIdx1(backend) <<";" << std::endl; for (const auto & e : dots) stream << e->process("#scalartype #name_acc = " + neutral_element((e)->root_op(), backend, "#scalartype") + ";") << std::endl; @@ -117,7 +110,7 @@ std::string gemv::generate_impl(std::string const & suffix, expressions_tuple co stream << "{" << std::endl; stream.inc_tab(); - element_wise_loop_1D(stream, p_.fetch_policy, p_.simd_width, "c", "N", "gid0", "gsize0", device, [&](unsigned int simd_width) + element_wise_loop_1D(stream, p_.fetch_policy, p_.simd_width, "c", "N", GlobalIdx0(backend).get(), GlobalSize0(backend).get(), device, [&](unsigned int simd_width) { std::string data_type = append_width("#scalartype",simd_width); @@ -161,7 +154,7 @@ std::string gemv::generate_impl(std::string const & suffix, expressions_tuple co stream << "}" << std::endl; for (auto & expr : dots) - stream << expr->process("#name_buf[lid1*" + local_size_0_ld_str + "+ lid0] = #name_acc;") << std::endl; + stream << expr->process("#name_buf[lidy*" + local_size_0_ld_str + "+ lidx] = #name_acc;") << std::endl; stream << "#pragma unroll" << std::endl; stream << "for(" << _size_t << " stride = " << p_.local_size_0/2 << "; stride >0; stride /=2)" << std::endl; @@ -169,17 +162,17 @@ std::string gemv::generate_impl(std::string const & suffix, expressions_tuple co stream.inc_tab(); stream << LocalBarrier(backend) << ";" << std::endl; - stream << "if (lid0 < stride)" << std::endl; + stream << "if (lidx < stride)" << std::endl; stream << "{" << std::endl; stream.inc_tab(); for (auto & e : dots) if (e->is_index_dot()) - compute_index_dot(stream, e->process("#name_buf[lid1*" + local_size_0_ld_str + " + lid0]"), e->process("#name_buf[lid1*" + local_size_0_ld_str + " + lid0 + stride]") - , e->process("#name_buf_value[lid1*" + local_size_0_ld_str + " + lid0]"), e->process("#name_buf_value[lid1*" + local_size_0_ld_str + " + lid0 + stride]") + compute_index_dot(stream, e->process("#name_buf[lidy*" + local_size_0_ld_str + " + lidx]"), e->process("#name_buf[lidy*" + local_size_0_ld_str + " + lidx + stride]") + , e->process("#name_buf_value[lidy*" + local_size_0_ld_str + " + lidx]"), e->process("#name_buf_value[lidy*" + local_size_0_ld_str + " + lidx + stride]") , e->root_op()); else - compute_dot(stream,e->process("#name_buf[lid1*" + local_size_0_ld_str + " + lid0]"), e->process("#name_buf[lid1*" + local_size_0_ld_str + " + lid0 + stride]"), e->root_op()); + compute_dot(stream,e->process("#name_buf[lidy*" + local_size_0_ld_str + " + lidx]"), e->process("#name_buf[lidy*" + local_size_0_ld_str + " + lidx + stride]"), e->root_op()); stream.dec_tab(); stream << "}" << std::endl; @@ -188,13 +181,13 @@ std::string gemv::generate_impl(std::string const & suffix, expressions_tuple co stream << "}" << std::endl; - stream << "if (lid0 == 0 && r < M)"; + stream << "if (lidx == 0 && r < M)" << std::endl; stream << "{" << std::endl; stream.inc_tab(); if(p_.num_groups_0==1) { std::map accessors; - accessors["gemv"] = "#name_buf[lid1*" + local_size_0_ld_str + "]"; + accessors["gemv"] = "#name_buf[lidy*" + local_size_0_ld_str + "]"; accessors["array1"] = "#pointer[r*#stride]"; evaluate(stream, PARENT_NODE_TYPE, accessors, expressions, mappings); } @@ -203,8 +196,8 @@ std::string gemv::generate_impl(std::string const & suffix, expressions_tuple co for (mapped_dot const * e : dots) { if (e->is_index_dot()) - stream << e->process("#name_temp_value[r + M*gpid0] = #name_buf_value[lid1*" + local_size_0_ld_str + "];") << std::endl; - stream << e->process("#name_temp[r + M*gpid0] = #name_buf[lid1*" + local_size_0_ld_str + "];") << std::endl; + stream << e->process("#name_temp_value[r + M*" + GroupIdx0(backend).get() + "] = #name_buf_value[lidy*" + local_size_0_ld_str + "];") << std::endl; + stream << e->process("#name_temp[r + M*" + GroupIdx0(backend).get() + "] = #name_buf[lidy*" + local_size_0_ld_str + "];") << std::endl; } } stream.dec_tab(); @@ -238,18 +231,10 @@ std::string gemv::generate_impl(std::string const & suffix, expressions_tuple co for (const auto & e : dots) stream << e->process(Local(backend).get() + " #scalartype #name_buf[" + to_string(p_.local_size_1*local_size_0_ld) + "];") << std::endl; - stream << _size_t << " lid0 = " << LocalIdx0(backend) << ";" << std::endl; - stream << _size_t << " lsize0 = " << LocalSize0(backend) << ";" << std::endl; - - stream << _size_t << " lid1 = " << LocalIdx1(backend) <<";" << std::endl; - stream << _size_t << " gid1 = " << GlobalIdx1(backend) <<";" << std::endl; - stream << _size_t << " gsize1 = " << GlobalSize1(backend) <<";" << std::endl; - - - - stream << _size_t << " upper_bound_1 = ( M +" << p_.local_size_1 - 1 << ")/" << p_.local_size_1 << "*" << p_.local_size_1 << ";" << std::endl; - stream << "for(" << _size_t << " r = gid1; r < upper_bound_1; r += gsize1){" << std::endl; + stream << "for(" << _size_t << " r = " << GlobalIdx1(backend) << "; r < (M +" << p_.local_size_1 - 1 << ")/" << p_.local_size_1 << "*" << p_.local_size_1 << "; r += " << GlobalSize1(backend) << "){" << std::endl; stream.inc_tab(); + stream << _size_t << " lidx = " << LocalIdx0(backend) << ";" << std::endl; + stream << _size_t << " lidy = " << LocalIdx1(backend) <<";" << std::endl; for (const auto & e : dots) stream << e->process("#scalartype #name_acc = " + neutral_element((e)->root_op(), backend, "#scalartype") + ";") << std::endl; @@ -258,7 +243,7 @@ std::string gemv::generate_impl(std::string const & suffix, expressions_tuple co stream << "{" << std::endl; stream.inc_tab(); - stream << "for(" << _size_t << " c = lid0; c < " << p_.num_groups_0 << "; c += lsize0){" << std::endl; + stream << "for(" << _size_t << " c = lidx; c < " << p_.num_groups_0 << "; c += " << LocalSize0(backend) << "){" << std::endl; stream.inc_tab(); for (mapped_dot* e: dots) @@ -272,7 +257,7 @@ std::string gemv::generate_impl(std::string const & suffix, expressions_tuple co stream << "}" << std::endl; for (auto & expr : dots) - stream << expr->process("#name_buf[lid1*" + local_size_0_ld_str + "+ lid0] = #name_acc;") << std::endl; + stream << expr->process("#name_buf[lidy*" + local_size_0_ld_str + "+ lidx] = #name_acc;") << std::endl; stream << "#pragma unroll" << std::endl; stream << "for(" << _size_t << " stride = " << p_.local_size_0/2 << "; stride >0; stride /=2)" << std::endl; @@ -280,17 +265,17 @@ std::string gemv::generate_impl(std::string const & suffix, expressions_tuple co stream.inc_tab(); stream << LocalBarrier(backend) << ";" << std::endl; - stream << "if (lid0 < stride)" << std::endl; + stream << "if (lidx < stride)" << std::endl; stream << "{" << std::endl; stream.inc_tab(); for (auto & e : dots) if (e->is_index_dot()) - compute_index_dot(stream, e->process("#name_buf[lid1*" + local_size_0_ld_str + " + lid0]"), e->process("#name_buf[lid1*" + local_size_0_ld_str + " + lid0 + stride]") - , e->process("#name_buf_value[lid1*" + local_size_0_ld_str + " + lid0]"), e->process("#name_buf_value[lid1*" + local_size_0_ld_str + " + lid0 + stride]") + compute_index_dot(stream, e->process("#name_buf[lidy*" + local_size_0_ld_str + " + lidx]"), e->process("#name_buf[lidy*" + local_size_0_ld_str + " + lidx + stride]") + , e->process("#name_buf_value[lidy*" + local_size_0_ld_str + " + lidx]"), e->process("#name_buf_value[lidy*" + local_size_0_ld_str + " + lidx + stride]") , e->root_op()); else - compute_dot(stream,e->process("#name_buf[lid1*" + local_size_0_ld_str + " + lid0]"), e->process("#name_buf[lid1*" + local_size_0_ld_str + " + lid0 + stride]"), e->root_op()); + compute_dot(stream,e->process("#name_buf[lidy*" + local_size_0_ld_str + " + lidx]"), e->process("#name_buf[lidy*" + local_size_0_ld_str + " + lidx + stride]"), e->root_op()); stream.dec_tab(); stream << "}" << std::endl; @@ -299,12 +284,12 @@ std::string gemv::generate_impl(std::string const & suffix, expressions_tuple co stream << "}" << std::endl; - stream << "if (lid0 == 0 && r < M)"; + stream << "if (lidx == 0 && r < M)"; stream << "{" << std::endl; stream.inc_tab(); std::map accessors; - accessors["gemv"] = "#name_buf[lid1*" + local_size_0_ld_str + "]"; + accessors["gemv"] = "#name_buf[lidy*" + local_size_0_ld_str + "]"; accessors["array1"] = "#pointer[r*#stride]"; evaluate(stream, PARENT_NODE_TYPE, accessors, expressions, mappings);