Kernels: fixed kernels fusion for DOT, GEMV

This commit is contained in:
Philippe Tillet
2015-12-05 19:14:09 -05:00
parent 7140f065c2
commit 004eebc038
4 changed files with 53 additions and 30 deletions

View File

@@ -52,18 +52,26 @@ std::string gemv::generate_impl(std::string const & suffix, math_expression cons
name[0] += suffix;
name[1] += suffix;
std::string arguments = _size_t + " M, " + _size_t + " N, " ;
for (const auto & e : dots)
auto unroll_tmp = [&]()
{
std::string numeric_type = to_string(lhs_most(e->math_expression().tree(), e->math_expression().root()).lhs.dtype);
if (e->is_index_dot())
{
arguments += e->process(Global(backend).get() + " unsigned int* #name_temp, ");
arguments += e->process(Global(backend).get() + " " + numeric_type + "* #name_temp_value,");
}
else
arguments += e->process(Global(backend).get() + " " + numeric_type + "* #name_temp, ");
}
unsigned int offset = 0;
for (const auto & e : dots)
{
numeric_type dtype = lhs_most(e->math_expression().tree(), e->math_expression().root()).lhs.dtype;
std::string sdtype = to_string(dtype);
if (e->is_index_dot())
{
stream << e->process("uint* #name_temp = (uint*)(tmp + " + tools::to_string(offset) + "*M);");
offset += 4*p_.num_groups_0;
stream << e->process(sdtype + "* #name_temp_value = (" + sdtype + "*)(tmp + " + tools::to_string(offset) + "*M);");
offset += size_of(dtype)*p_.num_groups_0;
}
else{
stream << e->process(sdtype + "* #name_temp = (" + sdtype + "*)(tmp + " + tools::to_string(offset) + "*M);");
offset += size_of(dtype)*p_.num_groups_0;
}
}
};
int col_simd_width = (dot_type_ == REDUCE_COLUMNS) ? 1 : p_.simd_width;
switch(backend)
@@ -74,10 +82,12 @@ std::string gemv::generate_impl(std::string const & suffix, math_expression cons
stream << " __attribute__((reqd_work_group_size(" << p_.local_size_0 << "," << p_.local_size_1 << ",1)))" << std::endl; break;
}
stream << KernelPrefix(backend) << " void " << name[0] << "(" << arguments << generate_arguments("#scalartype", device, mapping, expression) << ")" << std::endl;
stream << KernelPrefix(backend) << " void " << name[0] << "(" << _size_t << " M, " << _size_t << " N, char* tmp, " << generate_arguments("#scalartype", device, mapping, expression) << ")" << std::endl;
stream << "{" << std::endl;
stream.inc_tab();
unroll_tmp();
process(stream, PARENT_NODE_TYPE,
{{"array1", "#scalartype #namereg = #pointer[#start];"},
{"arrayn", "#pointer += #start;"},
@@ -239,10 +249,12 @@ std::string gemv::generate_impl(std::string const & suffix, math_expression cons
if(backend==driver::OPENCL)
stream << " __attribute__((reqd_work_group_size(" << p_.local_size_0 << "," << p_.local_size_1 << ",1)))" << std::endl;
stream << KernelPrefix(backend) << " void " << name[1] << "(" << arguments << generate_arguments("#scalartype", device, mapping, expression) << ")" << std::endl;
stream << KernelPrefix(backend) << " void " << name[1] << "(" << _size_t << " M, " << _size_t << " N, char* tmp, " << generate_arguments("#scalartype", device, mapping, expression) << ")" << std::endl;
stream << "{" << std::endl;
stream.inc_tab();
unroll_tmp();
process(stream, PARENT_NODE_TYPE,
{{"array1", "#scalartype #namereg = #pointer[#start];"},
{"arrayn", "#pointer += #start;"},