Kernels: fixed kernels fusion for DOT, GEMV
This commit is contained in:
@@ -52,18 +52,26 @@ std::string gemv::generate_impl(std::string const & suffix, math_expression cons
|
||||
name[0] += suffix;
|
||||
name[1] += suffix;
|
||||
|
||||
std::string arguments = _size_t + " M, " + _size_t + " N, " ;
|
||||
for (const auto & e : dots)
|
||||
auto unroll_tmp = [&]()
|
||||
{
|
||||
std::string numeric_type = to_string(lhs_most(e->math_expression().tree(), e->math_expression().root()).lhs.dtype);
|
||||
if (e->is_index_dot())
|
||||
{
|
||||
arguments += e->process(Global(backend).get() + " unsigned int* #name_temp, ");
|
||||
arguments += e->process(Global(backend).get() + " " + numeric_type + "* #name_temp_value,");
|
||||
}
|
||||
else
|
||||
arguments += e->process(Global(backend).get() + " " + numeric_type + "* #name_temp, ");
|
||||
}
|
||||
unsigned int offset = 0;
|
||||
for (const auto & e : dots)
|
||||
{
|
||||
numeric_type dtype = lhs_most(e->math_expression().tree(), e->math_expression().root()).lhs.dtype;
|
||||
std::string sdtype = to_string(dtype);
|
||||
if (e->is_index_dot())
|
||||
{
|
||||
stream << e->process("uint* #name_temp = (uint*)(tmp + " + tools::to_string(offset) + "*M);");
|
||||
offset += 4*p_.num_groups_0;
|
||||
stream << e->process(sdtype + "* #name_temp_value = (" + sdtype + "*)(tmp + " + tools::to_string(offset) + "*M);");
|
||||
offset += size_of(dtype)*p_.num_groups_0;
|
||||
}
|
||||
else{
|
||||
stream << e->process(sdtype + "* #name_temp = (" + sdtype + "*)(tmp + " + tools::to_string(offset) + "*M);");
|
||||
offset += size_of(dtype)*p_.num_groups_0;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
int col_simd_width = (dot_type_ == REDUCE_COLUMNS) ? 1 : p_.simd_width;
|
||||
switch(backend)
|
||||
@@ -74,10 +82,12 @@ std::string gemv::generate_impl(std::string const & suffix, math_expression cons
|
||||
stream << " __attribute__((reqd_work_group_size(" << p_.local_size_0 << "," << p_.local_size_1 << ",1)))" << std::endl; break;
|
||||
}
|
||||
|
||||
stream << KernelPrefix(backend) << " void " << name[0] << "(" << arguments << generate_arguments("#scalartype", device, mapping, expression) << ")" << std::endl;
|
||||
stream << KernelPrefix(backend) << " void " << name[0] << "(" << _size_t << " M, " << _size_t << " N, char* tmp, " << generate_arguments("#scalartype", device, mapping, expression) << ")" << std::endl;
|
||||
stream << "{" << std::endl;
|
||||
stream.inc_tab();
|
||||
|
||||
unroll_tmp();
|
||||
|
||||
process(stream, PARENT_NODE_TYPE,
|
||||
{{"array1", "#scalartype #namereg = #pointer[#start];"},
|
||||
{"arrayn", "#pointer += #start;"},
|
||||
@@ -239,10 +249,12 @@ std::string gemv::generate_impl(std::string const & suffix, math_expression cons
|
||||
if(backend==driver::OPENCL)
|
||||
stream << " __attribute__((reqd_work_group_size(" << p_.local_size_0 << "," << p_.local_size_1 << ",1)))" << std::endl;
|
||||
|
||||
stream << KernelPrefix(backend) << " void " << name[1] << "(" << arguments << generate_arguments("#scalartype", device, mapping, expression) << ")" << std::endl;
|
||||
stream << KernelPrefix(backend) << " void " << name[1] << "(" << _size_t << " M, " << _size_t << " N, char* tmp, " << generate_arguments("#scalartype", device, mapping, expression) << ")" << std::endl;
|
||||
stream << "{" << std::endl;
|
||||
stream.inc_tab();
|
||||
|
||||
unroll_tmp();
|
||||
|
||||
process(stream, PARENT_NODE_TYPE,
|
||||
{{"array1", "#scalartype #namereg = #pointer[#start];"},
|
||||
{"arrayn", "#pointer += #start;"},
|
||||
|
Reference in New Issue
Block a user