Kernels: fixed kernels fusion for DOT, GEMV

This commit is contained in:
Philippe Tillet
2015-12-05 19:14:09 -05:00
parent 7140f065c2
commit 004eebc038
4 changed files with 53 additions and 30 deletions

View File

@@ -70,23 +70,31 @@ std::string dot::generate_impl(std::string const & suffix, math_expression const
driver::backend_type backend = device.backend(); driver::backend_type backend = device.backend();
std::string _size_t = size_type(device); std::string _size_t = size_type(device);
std::string arguments = _size_t + " N, ";
for (unsigned int k = 0; k < N; ++k)
{
std::string numeric_type = to_string(lhs_most(exprs[k]->math_expression().tree(), exprs[k]->math_expression().root()).lhs.dtype);
if (exprs[k]->is_index_dot())
{
arguments += exprs[k]->process(Global(backend).get() + " unsigned int* #name_temp, ");
arguments += exprs[k]->process(Global(backend).get() + " " + numeric_type + "* #name_temp_value, ");
}
else
arguments += exprs[k]->process(Global(backend).get() + " " + numeric_type + "* #name_temp, ");
}
std::string name[2] = {"prod", "reduce"}; std::string name[2] = {"prod", "reduce"};
name[0] += suffix; name[0] += suffix;
name[1] += suffix; name[1] += suffix;
auto unroll_tmp = [&]()
{
unsigned int offset = 0;
for (unsigned int k = 0; k < N; ++k)
{
numeric_type dtype = lhs_most(exprs[k]->math_expression().tree(), exprs[k]->math_expression().root()).lhs.dtype;
std::string sdtype = to_string(dtype);
if (exprs[k]->is_index_dot())
{
stream << exprs[k]->process("uint* #name_temp = (uint*)(tmp + " + tools::to_string(offset) + ");");
offset += 4*p_.num_groups;
stream << exprs[k]->process(sdtype + "* #name_temp_value = (" + sdtype + "*)(tmp + " + tools::to_string(offset) + ");");
offset += size_of(dtype)*p_.num_groups;
}
else{
stream << exprs[k]->process(sdtype + "* #name_temp = (" + sdtype + "*)(tmp + " + tools::to_string(offset) + ");");
offset += size_of(dtype)*p_.num_groups;
}
}
};
/* ------------------------ /* ------------------------
* First Kernel * First Kernel
* -----------------------*/ * -----------------------*/
@@ -98,10 +106,12 @@ std::string dot::generate_impl(std::string const & suffix, math_expression const
stream << " __attribute__((reqd_work_group_size(" << p_.local_size_0 << ",1,1)))" << std::endl; break; stream << " __attribute__((reqd_work_group_size(" << p_.local_size_0 << ",1,1)))" << std::endl; break;
} }
stream << KernelPrefix(backend) << " void " << name[0] << "(" << arguments << generate_arguments("#scalartype", device, mapping, expressions) << ")" << std::endl; stream << KernelPrefix(backend) << " void " << name[0] << "(" << _size_t << " N, char* tmp," << generate_arguments("#scalartype", device, mapping, expressions) << ")" << std::endl;
stream << "{" << std::endl; stream << "{" << std::endl;
stream.inc_tab(); stream.inc_tab();
unroll_tmp();
stream << "unsigned int lid = " <<LocalIdx0(backend) << ";" << std::endl; stream << "unsigned int lid = " <<LocalIdx0(backend) << ";" << std::endl;
stream << "unsigned int gid = " <<GlobalIdx0(backend) << ";" << std::endl; stream << "unsigned int gid = " <<GlobalIdx0(backend) << ";" << std::endl;
stream << "unsigned int gpid = " <<GroupIdx0(backend) << ";" << std::endl; stream << "unsigned int gpid = " <<GroupIdx0(backend) << ";" << std::endl;
@@ -206,10 +216,12 @@ std::string dot::generate_impl(std::string const & suffix, math_expression const
stream << KernelPrefix(backend) << " void " << name[1] << "(" << arguments << generate_arguments("#scalartype", device, mapping, expressions) << ")" << std::endl; stream << KernelPrefix(backend) << " void " << name[1] << "(" << _size_t << " N, char* tmp, " << generate_arguments("#scalartype", device, mapping, expressions) << ")" << std::endl;
stream << "{" << std::endl; stream << "{" << std::endl;
stream.inc_tab(); stream.inc_tab();
unroll_tmp();
stream << "unsigned int lid = " <<LocalIdx0(backend) << ";" << std::endl; stream << "unsigned int lid = " <<LocalIdx0(backend) << ";" << std::endl;
stream << "unsigned int lsize = " <<LocalSize0(backend) << ";" << std::endl; stream << "unsigned int lsize = " <<LocalSize0(backend) << ";" << std::endl;

View File

@@ -52,18 +52,26 @@ std::string gemv::generate_impl(std::string const & suffix, math_expression cons
name[0] += suffix; name[0] += suffix;
name[1] += suffix; name[1] += suffix;
std::string arguments = _size_t + " M, " + _size_t + " N, " ; auto unroll_tmp = [&]()
{
unsigned int offset = 0;
for (const auto & e : dots) for (const auto & e : dots)
{ {
std::string numeric_type = to_string(lhs_most(e->math_expression().tree(), e->math_expression().root()).lhs.dtype); numeric_type dtype = lhs_most(e->math_expression().tree(), e->math_expression().root()).lhs.dtype;
std::string sdtype = to_string(dtype);
if (e->is_index_dot()) if (e->is_index_dot())
{ {
arguments += e->process(Global(backend).get() + " unsigned int* #name_temp, "); stream << e->process("uint* #name_temp = (uint*)(tmp + " + tools::to_string(offset) + "*M);");
arguments += e->process(Global(backend).get() + " " + numeric_type + "* #name_temp_value,"); offset += 4*p_.num_groups_0;
stream << e->process(sdtype + "* #name_temp_value = (" + sdtype + "*)(tmp + " + tools::to_string(offset) + "*M);");
offset += size_of(dtype)*p_.num_groups_0;
} }
else else{
arguments += e->process(Global(backend).get() + " " + numeric_type + "* #name_temp, "); stream << e->process(sdtype + "* #name_temp = (" + sdtype + "*)(tmp + " + tools::to_string(offset) + "*M);");
offset += size_of(dtype)*p_.num_groups_0;
} }
}
};
int col_simd_width = (dot_type_ == REDUCE_COLUMNS) ? 1 : p_.simd_width; int col_simd_width = (dot_type_ == REDUCE_COLUMNS) ? 1 : p_.simd_width;
switch(backend) switch(backend)
@@ -74,10 +82,12 @@ std::string gemv::generate_impl(std::string const & suffix, math_expression cons
stream << " __attribute__((reqd_work_group_size(" << p_.local_size_0 << "," << p_.local_size_1 << ",1)))" << std::endl; break; stream << " __attribute__((reqd_work_group_size(" << p_.local_size_0 << "," << p_.local_size_1 << ",1)))" << std::endl; break;
} }
stream << KernelPrefix(backend) << " void " << name[0] << "(" << arguments << generate_arguments("#scalartype", device, mapping, expression) << ")" << std::endl; stream << KernelPrefix(backend) << " void " << name[0] << "(" << _size_t << " M, " << _size_t << " N, char* tmp, " << generate_arguments("#scalartype", device, mapping, expression) << ")" << std::endl;
stream << "{" << std::endl; stream << "{" << std::endl;
stream.inc_tab(); stream.inc_tab();
unroll_tmp();
process(stream, PARENT_NODE_TYPE, process(stream, PARENT_NODE_TYPE,
{{"array1", "#scalartype #namereg = #pointer[#start];"}, {{"array1", "#scalartype #namereg = #pointer[#start];"},
{"arrayn", "#pointer += #start;"}, {"arrayn", "#pointer += #start;"},
@@ -239,10 +249,12 @@ std::string gemv::generate_impl(std::string const & suffix, math_expression cons
if(backend==driver::OPENCL) if(backend==driver::OPENCL)
stream << " __attribute__((reqd_work_group_size(" << p_.local_size_0 << "," << p_.local_size_1 << ",1)))" << std::endl; stream << " __attribute__((reqd_work_group_size(" << p_.local_size_0 << "," << p_.local_size_1 << ",1)))" << std::endl;
stream << KernelPrefix(backend) << " void " << name[1] << "(" << arguments << generate_arguments("#scalartype", device, mapping, expression) << ")" << std::endl; stream << KernelPrefix(backend) << " void " << name[1] << "(" << _size_t << " M, " << _size_t << " N, char* tmp, " << generate_arguments("#scalartype", device, mapping, expression) << ")" << std::endl;
stream << "{" << std::endl; stream << "{" << std::endl;
stream.inc_tab(); stream.inc_tab();
unroll_tmp();
process(stream, PARENT_NODE_TYPE, process(stream, PARENT_NODE_TYPE,
{{"array1", "#scalartype #namereg = #pointer[#start];"}, {{"array1", "#scalartype #namereg = #pointer[#start];"},
{"arrayn", "#pointer += #start;"}, {"arrayn", "#pointer += #start;"},

View File

@@ -200,7 +200,6 @@ extern "C"
const TYPE_CU *B, int ldb, TYPE_CU beta, TYPE_CU *C,\ const TYPE_CU *B, int ldb, TYPE_CU beta, TYPE_CU *C,\
int ldc)\ int ldc)\
{\ {\
std::cout << transa << " " << transb << " " << m << " " << n << " " << k << std::endl;\
if(k==1 && m>1 && n>1){\ if(k==1 && m>1 && n>1){\
sc::array dA((sc::int_t)m, TYPE_ISAAC, sc::driver::Buffer((CUdeviceptr)A, false), 0, transa=='N'?1:lda);\ sc::array dA((sc::int_t)m, TYPE_ISAAC, sc::driver::Buffer((CUdeviceptr)A, false), 0, transa=='N'?1:lda);\
sc::array dB((sc::int_t)n, TYPE_ISAAC, sc::driver::Buffer((CUdeviceptr)B, false), 0, transb=='T'?1:ldb);\ sc::array dB((sc::int_t)n, TYPE_ISAAC, sc::driver::Buffer((CUdeviceptr)B, false), 0, transb=='T'?1:ldb);\

View File

@@ -63,7 +63,7 @@ void test_impl(T epsilon, simple_vector_base<T> & cx, simple_vector_base<T> & c
RUN_TEST("s = x'.y", cs+=cx[i]*cy[i], 0, cs, ds = dot(x,y)); RUN_TEST("s = x'.y", cs+=cx[i]*cy[i], 0, cs, ds = dot(x,y));
RUN_TEST("s = exp(x'.y)", cs += cx[i]*cy[i], 0, std::exp(cs), ds = exp(dot(x,y))); RUN_TEST("s = exp(x'.y)", cs += cx[i]*cy[i], 0, std::exp(cs), ds = exp(dot(x,y)));
RUN_TEST("s = 1 + x'.y", cs += cx[i]*cy[i], 0, 1 + cs, ds = 1 + dot(x,y)); RUN_TEST("s = 1 + x'.y", cs += cx[i]*cy[i], 0, 1 + cs, ds = 1 + dot(x,y));
// RUN_TEST("s = x'.y + y'.y", cs+= cx[i]*cy[i] + cy[i]*cy[i], 0, cs, ds = dot(x,y) + dot(y,y)); RUN_TEST("s = x'.y + y'.y", cs+= cx[i]*cy[i] + cy[i]*cy[i], 0, cs, ds = dot(x,y) + dot(y,y));
RUN_TEST("s = max(x)", cs = std::max(cs, cx[i]), std::numeric_limits<T>::min(), cs, ds = max(x)); RUN_TEST("s = max(x)", cs = std::max(cs, cx[i]), std::numeric_limits<T>::min(), cs, ds = max(x));
RUN_TEST("s = min(x)", cs = std::min(cs, cx[i]), std::numeric_limits<T>::max(), cs, ds = min(x)); RUN_TEST("s = min(x)", cs = std::min(cs, cx[i]), std::numeric_limits<T>::max(), cs, ds = min(x));
} }