Feature: Merged kernel-fusion branch

* Fuses multiple AXPY kernel
* Possibility to add thread-wise for loops in AXPY-like kernels
This commit is contained in:
Philippe Tillet
2015-09-30 15:31:41 -04:00
parent 149441b9e2
commit feeb1e9862
64 changed files with 10047 additions and 1119 deletions

View File

@@ -22,33 +22,27 @@ gemv_parameters::gemv_parameters(unsigned int _simd_width,
num_groups_0(_num_groups_0), num_groups_1(_num_groups_1), fetch_policy(_fetch_policy) { }
int gemv::is_invalid_impl(driver::Device const &, expressions_tuple const &) const
int gemv::is_invalid_impl(driver::Device const &, math_expression const &) const
{
if (p_.fetch_policy==FETCH_FROM_LOCAL)
return TEMPLATE_INVALID_FETCHING_POLICY_TYPE;
return TEMPLATE_VALID;
}
unsigned int gemv::lmem_usage(const expressions_tuple &) const
unsigned int gemv::lmem_usage(const math_expression&) const
{
return (p_.local_size_0+1)*p_.local_size_1;
}
std::string gemv::generate_impl(std::string const & suffix, expressions_tuple const & expressions, driver::Device const & device, std::vector<mapping_type> const & mappings) const
std::string gemv::generate_impl(std::string const & suffix, math_expression const & expression, driver::Device const & device, mapping_type const & mapping) const
{
using tools::to_string;
std::vector<mapped_gemv*> dots;
expressions_tuple::data_type::const_iterator sit;
std::vector<mapping_type>::const_iterator mit;
for (mit = mappings.begin(), sit = expressions.data().begin(); mit != mappings.end(); ++mit, ++sit)
{
array_expression const & first_expression = *expressions.data().front();
std::vector<size_t> idx = filter_nodes(&is_dot, first_expression, false);
for (auto & elem : idx)
dots.push_back((mapped_gemv*)(mit->at(mapping_key(elem, PARENT_NODE_TYPE)).get()));
}
std::vector<size_t> idx = filter_nodes(&is_dot, expression, expression.root(), false);
for (auto & elem : idx)
dots.push_back((mapped_gemv*)(mapping.at(mapping_key(elem, PARENT_NODE_TYPE)).get()));
kernel_generation_stream stream;
driver::backend_type backend = device.backend();
@@ -61,7 +55,7 @@ std::string gemv::generate_impl(std::string const & suffix, expressions_tuple co
std::string arguments = _size_t + " M, " + _size_t + " N, " ;
for (const auto & e : dots)
{
std::string numeric_type = to_string(lhs_most(e->array_expression().tree(), e->array_expression().root()).lhs.dtype);
std::string numeric_type = to_string(lhs_most(e->math_expression().tree(), e->math_expression().root()).lhs.dtype);
if (e->is_index_dot())
{
arguments += e->process(Global(backend).get() + " unsigned int* #name_temp, ");
@@ -80,14 +74,14 @@ std::string gemv::generate_impl(std::string const & suffix, expressions_tuple co
stream << " __attribute__((reqd_work_group_size(" << p_.local_size_0 << "," << p_.local_size_1 << ",1)))" << std::endl; break;
}
stream << KernelPrefix(backend) << " void " << name[0] << "(" << arguments << generate_arguments("#scalartype", device, mappings, expressions) << ")" << std::endl;
stream << KernelPrefix(backend) << " void " << name[0] << "(" << arguments << generate_arguments("#scalartype", device, mapping, expression) << ")" << std::endl;
stream << "{" << std::endl;
stream.inc_tab();
process(stream, PARENT_NODE_TYPE,
{{"array0", "#scalartype #namereg = #pointer[#start];"},
{"array1", "#pointer += #start;"},
{"array2", "#pointer += #start;"}}, expressions, mappings);
{"array2", "#pointer += #start;"}}, expression, mapping);
unsigned int local_size_0_ld = p_.local_size_0;
std::string local_size_0_ld_str = to_string(local_size_0_ld);
@@ -115,23 +109,23 @@ std::string gemv::generate_impl(std::string const & suffix, expressions_tuple co
element_wise_loop_1D(stream, p_.fetch_policy, (dot_type_==REDUCE_COLUMNS)?p_.simd_width:1, "c", "N", GlobalIdx0(backend).get(), GlobalSize0(backend).get(), device, [&](unsigned int row_simd_width)
{
std::set<std::string> already_fetched;
for (const auto & e : dots)
{
std::map<std::string, std::string> accessors;
if(dot_type_==REDUCE_COLUMNS)
{
std::string data_type = append_width("#scalartype",row_simd_width);
accessors["array2"] = data_type + " #namereg = " + vload(row_simd_width, "#scalartype", "c*#stride", "#pointer + r*#ld", backend,false)+";";
accessors["repeat"] = data_type + " #namereg = " + vload(row_simd_width, "#scalartype", "(c%#tuplearg0)*#stride", "#pointer + (r%#tuplearg1)*#stride ", backend,false)+";";
accessors["array2"] = data_type + " #namereg = " + vload(row_simd_width, "#scalartype", "c*#stride", "#pointer + r*#ld", "1", backend,false)+";";
accessors["repeat"] = data_type + " #namereg = " + vload(row_simd_width, "#scalartype", "(c%#sub0)*#stride", "#pointer + (r%#sub1)*#stride ", "1", backend,false)+";";
}
else
{
std::string data_type = append_width("#scalartype",col_simd_width);
accessors["array2"] = data_type + " #namereg = " + vload(col_simd_width, "#scalartype", "0", "#pointer + r*#stride + c*#ld", backend,false) + ";";
accessors["repeat"] = "#scalartype #namereg = $VALUE{(r%#tuplearg0)*#stride, (c%#tuplearg1)*#stride};";
accessors["array2"] = data_type + " #namereg = " + vload(col_simd_width, "#scalartype", "0", "#pointer + r*#stride + c*#ld", "1", backend,false) + ";";
accessors["repeat"] = "#scalartype #namereg = $VALUE{(r%#sub0)*#stride, (c%#sub1)*#stride};";
}
e->process_recursive(stream, PARENT_NODE_TYPE, accessors);
e->process_recursive(stream, PARENT_NODE_TYPE, accessors, already_fetched);
}
//Update accumulators
@@ -196,7 +190,7 @@ std::string gemv::generate_impl(std::string const & suffix, expressions_tuple co
if(col_simd_width > 1)
accessors["gemv"] = access_vector_type(accessors["gemv"], s);
accessors["array1"] = "#pointer[(r +" + to_string(s) + ")*#stride]";
evaluate(stream, PARENT_NODE_TYPE, accessors, expressions, mappings);
stream << evaluate(PARENT_NODE_TYPE, accessors, expression, expression.root(), mapping) << ";" << std::endl;
}
}
else
@@ -206,8 +200,8 @@ std::string gemv::generate_impl(std::string const & suffix, expressions_tuple co
if(col_simd_width > 1)
stream << "if(M - r > " << col_simd_width << "){" << std::endl;
if (e->is_index_dot())
stream << e->process(vstore(col_simd_width,"uint", "#name_buf_value[lidy*" + local_size_0_ld_str + "]", "0", "#name_temp_value + r + M*" + GroupIdx0(backend).get(),backend, false)) << ";" << std::endl;
stream << e->process(vstore(col_simd_width,"#scalartype", "#name_buf[lidy*" + local_size_0_ld_str + "]", "0", "#name_temp + r + M*" + GroupIdx0(backend).get(),backend, false)) << ";" << std::endl;
stream << e->process(vstore(col_simd_width,"uint", "#name_buf_value[lidy*" + local_size_0_ld_str + "]", "0", "#name_temp_value + r + M*" + GroupIdx0(backend).get(), "1", backend, false)) << ";" << std::endl;
stream << e->process(vstore(col_simd_width,"#scalartype", "#name_buf[lidy*" + local_size_0_ld_str + "]", "0", "#name_temp + r + M*" + GroupIdx0(backend).get(), "1", backend, false)) << ";" << std::endl;
if(col_simd_width > 1)
{
stream << "}" << std::endl;
@@ -233,7 +227,6 @@ std::string gemv::generate_impl(std::string const & suffix, expressions_tuple co
stream.dec_tab();
stream << "}" << std::endl;
// std::cout << stream.str() << std::endl;
if(p_.num_groups_0>1)
{
@@ -244,14 +237,14 @@ std::string gemv::generate_impl(std::string const & suffix, expressions_tuple co
if(backend==driver::OPENCL)
stream << " __attribute__((reqd_work_group_size(" << p_.local_size_0 << "," << p_.local_size_1 << ",1)))" << std::endl;
stream << KernelPrefix(backend) << " void " << name[1] << "(" << arguments << generate_arguments("#scalartype", device, mappings, expressions) << ")" << std::endl;
stream << KernelPrefix(backend) << " void " << name[1] << "(" << arguments << generate_arguments("#scalartype", device, mapping, expression) << ")" << std::endl;
stream << "{" << std::endl;
stream.inc_tab();
process(stream, PARENT_NODE_TYPE,
{{"array0", "#scalartype #namereg = #pointer[#start];"},
{"array1", "#pointer += #start;"},
{"array2", "#pointer += #start; "}}, expressions, mappings);
{"array2", "#pointer += #start; "}}, expression, mapping);
for (const auto & e : dots)
stream << e->process(Local(backend).get() + " #scalartype #name_buf[" + to_string(p_.local_size_1*local_size_0_ld) + "];") << std::endl;
@@ -316,7 +309,7 @@ std::string gemv::generate_impl(std::string const & suffix, expressions_tuple co
std::map<std::string, std::string> accessors;
accessors["gemv"] = "#name_buf[lidy*" + local_size_0_ld_str + "]";
accessors["array1"] = "#pointer[r*#stride]";
evaluate(stream, PARENT_NODE_TYPE, accessors, expressions, mappings);
stream << evaluate(PARENT_NODE_TYPE, accessors, expression, expression.root(), mapping) << ";" << std::endl;
stream.dec_tab();
stream << "}" << std::endl;
@@ -338,41 +331,37 @@ gemv::gemv(gemv::parameters_type const & parameters,
base_impl<gemv, gemv_parameters>(parameters, binding_policy),
dot_type_(rtype){ }
std::vector<int_t> gemv::input_sizes(expressions_tuple const & expressions) const
std::vector<int_t> gemv::input_sizes(math_expression const & expression) const
{
array_expression const & first_expression = *expressions.data().front();
std::vector<std::size_t> idx = filter_nodes(&is_dot, first_expression, false);
std::pair<int_t, int_t> MN = matrix_size(lhs_most(first_expression.tree(), idx[0]));
std::vector<std::size_t> idx = filter_nodes(&is_dot, expression, expression.root(), false);
std::pair<int_t, int_t> MN = matrix_size(expression.tree(), lhs_most(expression.tree(), idx[0]));
if(dot_type_==REDUCE_COLUMNS)
std::swap(MN.first,MN.second);
return {MN.first, MN.second};
}
void gemv::enqueue(driver::CommandQueue & queue, driver::Program const & program, std::string const & suffix, base & fallback, controller<expressions_tuple> const & controller)
void gemv::enqueue(driver::CommandQueue & queue, driver::Program const & program, std::string const & suffix, base & fallback, execution_handler const & control)
{
expressions_tuple const & expressions = controller.x();
driver::Context const & context = expressions.context();
math_expression const & expression = control.x();
driver::Context const & context = expression.context();
std::vector<int_t> MN = input_sizes(expressions);
std::vector<array_expression::node const *> dots;
for (const auto & e : expressions.data())
{
std::vector<size_t> dots_idx = filter_nodes(&is_dot, *e, false);
for (auto & r : dots_idx)
dots.push_back(&(e)->tree()[r]);
}
std::vector<int_t> MN = input_sizes(expression);
std::vector<math_expression::node const *> dots;
std::vector<size_t> dots_idx = filter_nodes(&is_dot, expression, expression.root(), false);
for (size_t idx : dots_idx)
dots.push_back(&expression.tree()[idx]);
//Fallback
if(p_.simd_width>1 && requires_fallback(expressions))
if(p_.simd_width>1 && requires_fallback(expression))
{
fallback.enqueue(queue, program, "fallback", fallback, controller);
fallback.enqueue(queue, program, "fallback", fallback, control);
return;
}
//Kernel
std::vector< driver::Buffer > tmp;
std::vector< driver::Buffer > tmpidx;
unsigned int dtype_size = size_of(lhs_most(expressions.data().front()->tree(), expressions.data().front()->root()).lhs.dtype);
unsigned int dtype_size = size_of(lhs_most(expression.tree(), expression.root()).lhs.dtype);
std::string name[2] = {"prod", "reduce"};
name[0] += suffix;
@@ -410,14 +399,14 @@ void gemv::enqueue(driver::CommandQueue & queue, driver::Program const & program
kernel.setArg(n_arg++, tmp[i]);
i++;
}
set_arguments(expressions, kernel, n_arg, binding_policy_);
set_arguments(expression, kernel, n_arg, binding_policy_);
}
//NDRange
driver::NDRange global[2] = { driver::NDRange(p_.local_size_0*p_.num_groups_0, p_.local_size_1*p_.num_groups_1), driver::NDRange(p_.local_size_0, p_.local_size_1*p_.num_groups_1) };
driver::NDRange local[2] = { driver::NDRange(p_.local_size_0, p_.local_size_1), driver::NDRange(p_.local_size_0, p_.local_size_1) };
for(unsigned int i = 0 ; i < nk ; ++i)
controller.execution_options().enqueue(program.context(), kernels[i], global[i], local[i]);
control.execution_options().enqueue(program.context(), kernels[i], global[i], local[i]);
}
gemv_n::gemv_n(gemv_parameters const & parameters,