Kernels: Now SizeType is always "int". Right now I don't expect data-structure to have more than 2**31 entries. Improves performance on a number of routines.

This commit is contained in:
Philippe Tillet
2015-08-11 11:50:49 -07:00
parent f06c85c97c
commit b5cc1f7ddc
3 changed files with 35 additions and 54 deletions

View File

@@ -96,19 +96,12 @@ std::string gemv::generate_impl(std::string const & suffix, expressions_tuple co
for (const auto & e : dots)
stream << e->process(Local(backend).get() + " #scalartype #name_buf[" + to_string(p_.local_size_1*local_size_0_ld) + "];") << std::endl;
stream << "" << _size_t << " lid0 = " << LocalIdx0(backend) << ";" << std::endl;
stream << "" << _size_t << " gid0 = " << GlobalIdx0(backend) << ";" << std::endl;
stream << "" << _size_t << " gpid0 = " << GroupIdx0(backend) << ";" << std::endl;
stream << "" << _size_t << " gsize0 = " << GlobalSize0(backend) << ";" << std::endl;
stream << "for(" << _size_t << " r = " << GlobalIdx1(backend) << "; r < (M +" << p_.local_size_1 - 1 << ")/" << p_.local_size_1 << "*" << p_.local_size_1 << "; r += " << GlobalSize1(backend) << ")" << std::endl;
stream << "{" << std::endl;
stream << "" << _size_t << " lid1 = " << LocalIdx1(backend) <<";" << std::endl;
stream << "" << _size_t << " gid1 = " << GlobalIdx1(backend) <<";" << std::endl;
stream << "" << _size_t << " gpid1 = " << GroupIdx1(backend) << ";" << std::endl;
stream << "" << _size_t << " gsize1 = " << GlobalSize1(backend) <<";" << std::endl;
stream << "" << _size_t << " upper_bound_1 = ( M +" << p_.local_size_1 - 1 << ")/" << p_.local_size_1 << "*" << p_.local_size_1 << ";" << std::endl;
stream << "for(" << _size_t << " r = gid1; r < upper_bound_1; r += gsize1){" << std::endl;
stream.inc_tab();
stream << "" << _size_t << " lidx = " << LocalIdx0(backend) << ";" << std::endl;
stream << "" << _size_t << " lidy = " << LocalIdx1(backend) <<";" << std::endl;
for (const auto & e : dots)
stream << e->process("#scalartype #name_acc = " + neutral_element((e)->root_op(), backend, "#scalartype") + ";") << std::endl;
@@ -117,7 +110,7 @@ std::string gemv::generate_impl(std::string const & suffix, expressions_tuple co
stream << "{" << std::endl;
stream.inc_tab();
element_wise_loop_1D(stream, p_.fetch_policy, p_.simd_width, "c", "N", "gid0", "gsize0", device, [&](unsigned int simd_width)
element_wise_loop_1D(stream, p_.fetch_policy, p_.simd_width, "c", "N", GlobalIdx0(backend).get(), GlobalSize0(backend).get(), device, [&](unsigned int simd_width)
{
std::string data_type = append_width("#scalartype",simd_width);
@@ -161,7 +154,7 @@ std::string gemv::generate_impl(std::string const & suffix, expressions_tuple co
stream << "}" << std::endl;
for (auto & expr : dots)
stream << expr->process("#name_buf[lid1*" + local_size_0_ld_str + "+ lid0] = #name_acc;") << std::endl;
stream << expr->process("#name_buf[lidy*" + local_size_0_ld_str + "+ lidx] = #name_acc;") << std::endl;
stream << "#pragma unroll" << std::endl;
stream << "for(" << _size_t << " stride = " << p_.local_size_0/2 << "; stride >0; stride /=2)" << std::endl;
@@ -169,17 +162,17 @@ std::string gemv::generate_impl(std::string const & suffix, expressions_tuple co
stream.inc_tab();
stream << LocalBarrier(backend) << ";" << std::endl;
stream << "if (lid0 < stride)" << std::endl;
stream << "if (lidx < stride)" << std::endl;
stream << "{" << std::endl;
stream.inc_tab();
for (auto & e : dots)
if (e->is_index_dot())
compute_index_dot(stream, e->process("#name_buf[lid1*" + local_size_0_ld_str + " + lid0]"), e->process("#name_buf[lid1*" + local_size_0_ld_str + " + lid0 + stride]")
, e->process("#name_buf_value[lid1*" + local_size_0_ld_str + " + lid0]"), e->process("#name_buf_value[lid1*" + local_size_0_ld_str + " + lid0 + stride]")
compute_index_dot(stream, e->process("#name_buf[lidy*" + local_size_0_ld_str + " + lidx]"), e->process("#name_buf[lidy*" + local_size_0_ld_str + " + lidx + stride]")
, e->process("#name_buf_value[lidy*" + local_size_0_ld_str + " + lidx]"), e->process("#name_buf_value[lidy*" + local_size_0_ld_str + " + lidx + stride]")
, e->root_op());
else
compute_dot(stream,e->process("#name_buf[lid1*" + local_size_0_ld_str + " + lid0]"), e->process("#name_buf[lid1*" + local_size_0_ld_str + " + lid0 + stride]"), e->root_op());
compute_dot(stream,e->process("#name_buf[lidy*" + local_size_0_ld_str + " + lidx]"), e->process("#name_buf[lidy*" + local_size_0_ld_str + " + lidx + stride]"), e->root_op());
stream.dec_tab();
stream << "}" << std::endl;
@@ -188,13 +181,13 @@ std::string gemv::generate_impl(std::string const & suffix, expressions_tuple co
stream << "}" << std::endl;
stream << "if (lid0 == 0 && r < M)";
stream << "if (lidx == 0 && r < M)" << std::endl;
stream << "{" << std::endl;
stream.inc_tab();
if(p_.num_groups_0==1)
{
std::map<std::string, std::string> accessors;
accessors["gemv"] = "#name_buf[lid1*" + local_size_0_ld_str + "]";
accessors["gemv"] = "#name_buf[lidy*" + local_size_0_ld_str + "]";
accessors["array1"] = "#pointer[r*#stride]";
evaluate(stream, PARENT_NODE_TYPE, accessors, expressions, mappings);
}
@@ -203,8 +196,8 @@ std::string gemv::generate_impl(std::string const & suffix, expressions_tuple co
for (mapped_dot const * e : dots)
{
if (e->is_index_dot())
stream << e->process("#name_temp_value[r + M*gpid0] = #name_buf_value[lid1*" + local_size_0_ld_str + "];") << std::endl;
stream << e->process("#name_temp[r + M*gpid0] = #name_buf[lid1*" + local_size_0_ld_str + "];") << std::endl;
stream << e->process("#name_temp_value[r + M*" + GroupIdx0(backend).get() + "] = #name_buf_value[lidy*" + local_size_0_ld_str + "];") << std::endl;
stream << e->process("#name_temp[r + M*" + GroupIdx0(backend).get() + "] = #name_buf[lidy*" + local_size_0_ld_str + "];") << std::endl;
}
}
stream.dec_tab();
@@ -238,18 +231,10 @@ std::string gemv::generate_impl(std::string const & suffix, expressions_tuple co
for (const auto & e : dots)
stream << e->process(Local(backend).get() + " #scalartype #name_buf[" + to_string(p_.local_size_1*local_size_0_ld) + "];") << std::endl;
stream << _size_t << " lid0 = " << LocalIdx0(backend) << ";" << std::endl;
stream << _size_t << " lsize0 = " << LocalSize0(backend) << ";" << std::endl;
stream << _size_t << " lid1 = " << LocalIdx1(backend) <<";" << std::endl;
stream << _size_t << " gid1 = " << GlobalIdx1(backend) <<";" << std::endl;
stream << _size_t << " gsize1 = " << GlobalSize1(backend) <<";" << std::endl;
stream << _size_t << " upper_bound_1 = ( M +" << p_.local_size_1 - 1 << ")/" << p_.local_size_1 << "*" << p_.local_size_1 << ";" << std::endl;
stream << "for(" << _size_t << " r = gid1; r < upper_bound_1; r += gsize1){" << std::endl;
stream << "for(" << _size_t << " r = " << GlobalIdx1(backend) << "; r < (M +" << p_.local_size_1 - 1 << ")/" << p_.local_size_1 << "*" << p_.local_size_1 << "; r += " << GlobalSize1(backend) << "){" << std::endl;
stream.inc_tab();
stream << _size_t << " lidx = " << LocalIdx0(backend) << ";" << std::endl;
stream << _size_t << " lidy = " << LocalIdx1(backend) <<";" << std::endl;
for (const auto & e : dots)
stream << e->process("#scalartype #name_acc = " + neutral_element((e)->root_op(), backend, "#scalartype") + ";") << std::endl;
@@ -258,7 +243,7 @@ std::string gemv::generate_impl(std::string const & suffix, expressions_tuple co
stream << "{" << std::endl;
stream.inc_tab();
stream << "for(" << _size_t << " c = lid0; c < " << p_.num_groups_0 << "; c += lsize0){" << std::endl;
stream << "for(" << _size_t << " c = lidx; c < " << p_.num_groups_0 << "; c += " << LocalSize0(backend) << "){" << std::endl;
stream.inc_tab();
for (mapped_dot* e: dots)
@@ -272,7 +257,7 @@ std::string gemv::generate_impl(std::string const & suffix, expressions_tuple co
stream << "}" << std::endl;
for (auto & expr : dots)
stream << expr->process("#name_buf[lid1*" + local_size_0_ld_str + "+ lid0] = #name_acc;") << std::endl;
stream << expr->process("#name_buf[lidy*" + local_size_0_ld_str + "+ lidx] = #name_acc;") << std::endl;
stream << "#pragma unroll" << std::endl;
stream << "for(" << _size_t << " stride = " << p_.local_size_0/2 << "; stride >0; stride /=2)" << std::endl;
@@ -280,17 +265,17 @@ std::string gemv::generate_impl(std::string const & suffix, expressions_tuple co
stream.inc_tab();
stream << LocalBarrier(backend) << ";" << std::endl;
stream << "if (lid0 < stride)" << std::endl;
stream << "if (lidx < stride)" << std::endl;
stream << "{" << std::endl;
stream.inc_tab();
for (auto & e : dots)
if (e->is_index_dot())
compute_index_dot(stream, e->process("#name_buf[lid1*" + local_size_0_ld_str + " + lid0]"), e->process("#name_buf[lid1*" + local_size_0_ld_str + " + lid0 + stride]")
, e->process("#name_buf_value[lid1*" + local_size_0_ld_str + " + lid0]"), e->process("#name_buf_value[lid1*" + local_size_0_ld_str + " + lid0 + stride]")
compute_index_dot(stream, e->process("#name_buf[lidy*" + local_size_0_ld_str + " + lidx]"), e->process("#name_buf[lidy*" + local_size_0_ld_str + " + lidx + stride]")
, e->process("#name_buf_value[lidy*" + local_size_0_ld_str + " + lidx]"), e->process("#name_buf_value[lidy*" + local_size_0_ld_str + " + lidx + stride]")
, e->root_op());
else
compute_dot(stream,e->process("#name_buf[lid1*" + local_size_0_ld_str + " + lid0]"), e->process("#name_buf[lid1*" + local_size_0_ld_str + " + lid0 + stride]"), e->root_op());
compute_dot(stream,e->process("#name_buf[lidy*" + local_size_0_ld_str + " + lidx]"), e->process("#name_buf[lidy*" + local_size_0_ld_str + " + lidx + stride]"), e->root_op());
stream.dec_tab();
stream << "}" << std::endl;
@@ -299,12 +284,12 @@ std::string gemv::generate_impl(std::string const & suffix, expressions_tuple co
stream << "}" << std::endl;
stream << "if (lid0 == 0 && r < M)";
stream << "if (lidx == 0 && r < M)";
stream << "{" << std::endl;
stream.inc_tab();
std::map<std::string, std::string> accessors;
accessors["gemv"] = "#name_buf[lid1*" + local_size_0_ld_str + "]";
accessors["gemv"] = "#name_buf[lidy*" + local_size_0_ld_str + "]";
accessors["array1"] = "#pointer[r*#stride]";
evaluate(stream, PARENT_NODE_TYPE, accessors, expressions, mappings);