Kernels: fixed kernels fusion for DOT, GEMV
This commit is contained in:
@@ -70,23 +70,31 @@ std::string dot::generate_impl(std::string const & suffix, math_expression const
|
||||
driver::backend_type backend = device.backend();
|
||||
std::string _size_t = size_type(device);
|
||||
|
||||
std::string arguments = _size_t + " N, ";
|
||||
for (unsigned int k = 0; k < N; ++k)
|
||||
{
|
||||
std::string numeric_type = to_string(lhs_most(exprs[k]->math_expression().tree(), exprs[k]->math_expression().root()).lhs.dtype);
|
||||
if (exprs[k]->is_index_dot())
|
||||
{
|
||||
arguments += exprs[k]->process(Global(backend).get() + " unsigned int* #name_temp, ");
|
||||
arguments += exprs[k]->process(Global(backend).get() + " " + numeric_type + "* #name_temp_value, ");
|
||||
}
|
||||
else
|
||||
arguments += exprs[k]->process(Global(backend).get() + " " + numeric_type + "* #name_temp, ");
|
||||
}
|
||||
|
||||
std::string name[2] = {"prod", "reduce"};
|
||||
name[0] += suffix;
|
||||
name[1] += suffix;
|
||||
|
||||
auto unroll_tmp = [&]()
|
||||
{
|
||||
unsigned int offset = 0;
|
||||
for (unsigned int k = 0; k < N; ++k)
|
||||
{
|
||||
numeric_type dtype = lhs_most(exprs[k]->math_expression().tree(), exprs[k]->math_expression().root()).lhs.dtype;
|
||||
std::string sdtype = to_string(dtype);
|
||||
if (exprs[k]->is_index_dot())
|
||||
{
|
||||
stream << exprs[k]->process("uint* #name_temp = (uint*)(tmp + " + tools::to_string(offset) + ");");
|
||||
offset += 4*p_.num_groups;
|
||||
stream << exprs[k]->process(sdtype + "* #name_temp_value = (" + sdtype + "*)(tmp + " + tools::to_string(offset) + ");");
|
||||
offset += size_of(dtype)*p_.num_groups;
|
||||
}
|
||||
else{
|
||||
stream << exprs[k]->process(sdtype + "* #name_temp = (" + sdtype + "*)(tmp + " + tools::to_string(offset) + ");");
|
||||
offset += size_of(dtype)*p_.num_groups;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/* ------------------------
|
||||
* First Kernel
|
||||
* -----------------------*/
|
||||
@@ -98,10 +106,12 @@ std::string dot::generate_impl(std::string const & suffix, math_expression const
|
||||
stream << " __attribute__((reqd_work_group_size(" << p_.local_size_0 << ",1,1)))" << std::endl; break;
|
||||
}
|
||||
|
||||
stream << KernelPrefix(backend) << " void " << name[0] << "(" << arguments << generate_arguments("#scalartype", device, mapping, expressions) << ")" << std::endl;
|
||||
stream << KernelPrefix(backend) << " void " << name[0] << "(" << _size_t << " N, char* tmp," << generate_arguments("#scalartype", device, mapping, expressions) << ")" << std::endl;
|
||||
stream << "{" << std::endl;
|
||||
stream.inc_tab();
|
||||
|
||||
unroll_tmp();
|
||||
|
||||
stream << "unsigned int lid = " <<LocalIdx0(backend) << ";" << std::endl;
|
||||
stream << "unsigned int gid = " <<GlobalIdx0(backend) << ";" << std::endl;
|
||||
stream << "unsigned int gpid = " <<GroupIdx0(backend) << ";" << std::endl;
|
||||
@@ -206,10 +216,12 @@ std::string dot::generate_impl(std::string const & suffix, math_expression const
|
||||
|
||||
|
||||
|
||||
stream << KernelPrefix(backend) << " void " << name[1] << "(" << arguments << generate_arguments("#scalartype", device, mapping, expressions) << ")" << std::endl;
|
||||
stream << KernelPrefix(backend) << " void " << name[1] << "(" << _size_t << " N, char* tmp, " << generate_arguments("#scalartype", device, mapping, expressions) << ")" << std::endl;
|
||||
stream << "{" << std::endl;
|
||||
stream.inc_tab();
|
||||
|
||||
unroll_tmp();
|
||||
|
||||
stream << "unsigned int lid = " <<LocalIdx0(backend) << ";" << std::endl;
|
||||
stream << "unsigned int lsize = " <<LocalSize0(backend) << ";" << std::endl;
|
||||
|
||||
|
@@ -52,18 +52,26 @@ std::string gemv::generate_impl(std::string const & suffix, math_expression cons
|
||||
name[0] += suffix;
|
||||
name[1] += suffix;
|
||||
|
||||
std::string arguments = _size_t + " M, " + _size_t + " N, " ;
|
||||
auto unroll_tmp = [&]()
|
||||
{
|
||||
unsigned int offset = 0;
|
||||
for (const auto & e : dots)
|
||||
{
|
||||
std::string numeric_type = to_string(lhs_most(e->math_expression().tree(), e->math_expression().root()).lhs.dtype);
|
||||
numeric_type dtype = lhs_most(e->math_expression().tree(), e->math_expression().root()).lhs.dtype;
|
||||
std::string sdtype = to_string(dtype);
|
||||
if (e->is_index_dot())
|
||||
{
|
||||
arguments += e->process(Global(backend).get() + " unsigned int* #name_temp, ");
|
||||
arguments += e->process(Global(backend).get() + " " + numeric_type + "* #name_temp_value,");
|
||||
stream << e->process("uint* #name_temp = (uint*)(tmp + " + tools::to_string(offset) + "*M);");
|
||||
offset += 4*p_.num_groups_0;
|
||||
stream << e->process(sdtype + "* #name_temp_value = (" + sdtype + "*)(tmp + " + tools::to_string(offset) + "*M);");
|
||||
offset += size_of(dtype)*p_.num_groups_0;
|
||||
}
|
||||
else
|
||||
arguments += e->process(Global(backend).get() + " " + numeric_type + "* #name_temp, ");
|
||||
else{
|
||||
stream << e->process(sdtype + "* #name_temp = (" + sdtype + "*)(tmp + " + tools::to_string(offset) + "*M);");
|
||||
offset += size_of(dtype)*p_.num_groups_0;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
int col_simd_width = (dot_type_ == REDUCE_COLUMNS) ? 1 : p_.simd_width;
|
||||
switch(backend)
|
||||
@@ -74,10 +82,12 @@ std::string gemv::generate_impl(std::string const & suffix, math_expression cons
|
||||
stream << " __attribute__((reqd_work_group_size(" << p_.local_size_0 << "," << p_.local_size_1 << ",1)))" << std::endl; break;
|
||||
}
|
||||
|
||||
stream << KernelPrefix(backend) << " void " << name[0] << "(" << arguments << generate_arguments("#scalartype", device, mapping, expression) << ")" << std::endl;
|
||||
stream << KernelPrefix(backend) << " void " << name[0] << "(" << _size_t << " M, " << _size_t << " N, char* tmp, " << generate_arguments("#scalartype", device, mapping, expression) << ")" << std::endl;
|
||||
stream << "{" << std::endl;
|
||||
stream.inc_tab();
|
||||
|
||||
unroll_tmp();
|
||||
|
||||
process(stream, PARENT_NODE_TYPE,
|
||||
{{"array1", "#scalartype #namereg = #pointer[#start];"},
|
||||
{"arrayn", "#pointer += #start;"},
|
||||
@@ -239,10 +249,12 @@ std::string gemv::generate_impl(std::string const & suffix, math_expression cons
|
||||
if(backend==driver::OPENCL)
|
||||
stream << " __attribute__((reqd_work_group_size(" << p_.local_size_0 << "," << p_.local_size_1 << ",1)))" << std::endl;
|
||||
|
||||
stream << KernelPrefix(backend) << " void " << name[1] << "(" << arguments << generate_arguments("#scalartype", device, mapping, expression) << ")" << std::endl;
|
||||
stream << KernelPrefix(backend) << " void " << name[1] << "(" << _size_t << " M, " << _size_t << " N, char* tmp, " << generate_arguments("#scalartype", device, mapping, expression) << ")" << std::endl;
|
||||
stream << "{" << std::endl;
|
||||
stream.inc_tab();
|
||||
|
||||
unroll_tmp();
|
||||
|
||||
process(stream, PARENT_NODE_TYPE,
|
||||
{{"array1", "#scalartype #namereg = #pointer[#start];"},
|
||||
{"arrayn", "#pointer += #start;"},
|
||||
|
@@ -200,7 +200,6 @@ extern "C"
|
||||
const TYPE_CU *B, int ldb, TYPE_CU beta, TYPE_CU *C,\
|
||||
int ldc)\
|
||||
{\
|
||||
std::cout << transa << " " << transb << " " << m << " " << n << " " << k << std::endl;\
|
||||
if(k==1 && m>1 && n>1){\
|
||||
sc::array dA((sc::int_t)m, TYPE_ISAAC, sc::driver::Buffer((CUdeviceptr)A, false), 0, transa=='N'?1:lda);\
|
||||
sc::array dB((sc::int_t)n, TYPE_ISAAC, sc::driver::Buffer((CUdeviceptr)B, false), 0, transb=='T'?1:ldb);\
|
||||
|
@@ -63,7 +63,7 @@ void test_impl(T epsilon, simple_vector_base<T> & cx, simple_vector_base<T> & c
|
||||
RUN_TEST("s = x'.y", cs+=cx[i]*cy[i], 0, cs, ds = dot(x,y));
|
||||
RUN_TEST("s = exp(x'.y)", cs += cx[i]*cy[i], 0, std::exp(cs), ds = exp(dot(x,y)));
|
||||
RUN_TEST("s = 1 + x'.y", cs += cx[i]*cy[i], 0, 1 + cs, ds = 1 + dot(x,y));
|
||||
// RUN_TEST("s = x'.y + y'.y", cs+= cx[i]*cy[i] + cy[i]*cy[i], 0, cs, ds = dot(x,y) + dot(y,y));
|
||||
RUN_TEST("s = x'.y + y'.y", cs+= cx[i]*cy[i] + cy[i]*cy[i], 0, cs, ds = dot(x,y) + dot(y,y));
|
||||
RUN_TEST("s = max(x)", cs = std::max(cs, cx[i]), std::numeric_limits<T>::min(), cs, ds = max(x));
|
||||
RUN_TEST("s = min(x)", cs = std::min(cs, cx[i]), std::numeric_limits<T>::max(), cs, ds = min(x));
|
||||
}
|
||||
|
Reference in New Issue
Block a user