Kernels: fixed kernels fusion for DOT, GEMV

This commit is contained in:
Philippe Tillet
2015-12-05 19:14:09 -05:00
parent 7140f065c2
commit 004eebc038
4 changed files with 53 additions and 30 deletions

View File

@@ -70,23 +70,31 @@ std::string dot::generate_impl(std::string const & suffix, math_expression const
driver::backend_type backend = device.backend();
std::string _size_t = size_type(device);
std::string arguments = _size_t + " N, ";
for (unsigned int k = 0; k < N; ++k)
{
std::string numeric_type = to_string(lhs_most(exprs[k]->math_expression().tree(), exprs[k]->math_expression().root()).lhs.dtype);
if (exprs[k]->is_index_dot())
{
arguments += exprs[k]->process(Global(backend).get() + " unsigned int* #name_temp, ");
arguments += exprs[k]->process(Global(backend).get() + " " + numeric_type + "* #name_temp_value, ");
}
else
arguments += exprs[k]->process(Global(backend).get() + " " + numeric_type + "* #name_temp, ");
}
std::string name[2] = {"prod", "reduce"};
name[0] += suffix;
name[1] += suffix;
auto unroll_tmp = [&]()
{
unsigned int offset = 0;
for (unsigned int k = 0; k < N; ++k)
{
numeric_type dtype = lhs_most(exprs[k]->math_expression().tree(), exprs[k]->math_expression().root()).lhs.dtype;
std::string sdtype = to_string(dtype);
if (exprs[k]->is_index_dot())
{
stream << exprs[k]->process("uint* #name_temp = (uint*)(tmp + " + tools::to_string(offset) + ");");
offset += 4*p_.num_groups;
stream << exprs[k]->process(sdtype + "* #name_temp_value = (" + sdtype + "*)(tmp + " + tools::to_string(offset) + ");");
offset += size_of(dtype)*p_.num_groups;
}
else{
stream << exprs[k]->process(sdtype + "* #name_temp = (" + sdtype + "*)(tmp + " + tools::to_string(offset) + ");");
offset += size_of(dtype)*p_.num_groups;
}
}
};
/* ------------------------
* First Kernel
* -----------------------*/
@@ -98,10 +106,12 @@ std::string dot::generate_impl(std::string const & suffix, math_expression const
stream << " __attribute__((reqd_work_group_size(" << p_.local_size_0 << ",1,1)))" << std::endl; break;
}
stream << KernelPrefix(backend) << " void " << name[0] << "(" << arguments << generate_arguments("#scalartype", device, mapping, expressions) << ")" << std::endl;
stream << KernelPrefix(backend) << " void " << name[0] << "(" << _size_t << " N, char* tmp," << generate_arguments("#scalartype", device, mapping, expressions) << ")" << std::endl;
stream << "{" << std::endl;
stream.inc_tab();
unroll_tmp();
stream << "unsigned int lid = " <<LocalIdx0(backend) << ";" << std::endl;
stream << "unsigned int gid = " <<GlobalIdx0(backend) << ";" << std::endl;
stream << "unsigned int gpid = " <<GroupIdx0(backend) << ";" << std::endl;
@@ -206,10 +216,12 @@ std::string dot::generate_impl(std::string const & suffix, math_expression const
stream << KernelPrefix(backend) << " void " << name[1] << "(" << arguments << generate_arguments("#scalartype", device, mapping, expressions) << ")" << std::endl;
stream << KernelPrefix(backend) << " void " << name[1] << "(" << _size_t << " N, char* tmp, " << generate_arguments("#scalartype", device, mapping, expressions) << ")" << std::endl;
stream << "{" << std::endl;
stream.inc_tab();
unroll_tmp();
stream << "unsigned int lid = " <<LocalIdx0(backend) << ";" << std::endl;
stream << "unsigned int lsize = " <<LocalSize0(backend) << ";" << std::endl;

View File

@@ -52,18 +52,26 @@ std::string gemv::generate_impl(std::string const & suffix, math_expression cons
name[0] += suffix;
name[1] += suffix;
std::string arguments = _size_t + " M, " + _size_t + " N, " ;
auto unroll_tmp = [&]()
{
unsigned int offset = 0;
for (const auto & e : dots)
{
std::string numeric_type = to_string(lhs_most(e->math_expression().tree(), e->math_expression().root()).lhs.dtype);
numeric_type dtype = lhs_most(e->math_expression().tree(), e->math_expression().root()).lhs.dtype;
std::string sdtype = to_string(dtype);
if (e->is_index_dot())
{
arguments += e->process(Global(backend).get() + " unsigned int* #name_temp, ");
arguments += e->process(Global(backend).get() + " " + numeric_type + "* #name_temp_value,");
stream << e->process("uint* #name_temp = (uint*)(tmp + " + tools::to_string(offset) + "*M);");
offset += 4*p_.num_groups_0;
stream << e->process(sdtype + "* #name_temp_value = (" + sdtype + "*)(tmp + " + tools::to_string(offset) + "*M);");
offset += size_of(dtype)*p_.num_groups_0;
}
else
arguments += e->process(Global(backend).get() + " " + numeric_type + "* #name_temp, ");
else{
stream << e->process(sdtype + "* #name_temp = (" + sdtype + "*)(tmp + " + tools::to_string(offset) + "*M);");
offset += size_of(dtype)*p_.num_groups_0;
}
}
};
int col_simd_width = (dot_type_ == REDUCE_COLUMNS) ? 1 : p_.simd_width;
switch(backend)
@@ -74,10 +82,12 @@ std::string gemv::generate_impl(std::string const & suffix, math_expression cons
stream << " __attribute__((reqd_work_group_size(" << p_.local_size_0 << "," << p_.local_size_1 << ",1)))" << std::endl; break;
}
stream << KernelPrefix(backend) << " void " << name[0] << "(" << arguments << generate_arguments("#scalartype", device, mapping, expression) << ")" << std::endl;
stream << KernelPrefix(backend) << " void " << name[0] << "(" << _size_t << " M, " << _size_t << " N, char* tmp, " << generate_arguments("#scalartype", device, mapping, expression) << ")" << std::endl;
stream << "{" << std::endl;
stream.inc_tab();
unroll_tmp();
process(stream, PARENT_NODE_TYPE,
{{"array1", "#scalartype #namereg = #pointer[#start];"},
{"arrayn", "#pointer += #start;"},
@@ -239,10 +249,12 @@ std::string gemv::generate_impl(std::string const & suffix, math_expression cons
if(backend==driver::OPENCL)
stream << " __attribute__((reqd_work_group_size(" << p_.local_size_0 << "," << p_.local_size_1 << ",1)))" << std::endl;
stream << KernelPrefix(backend) << " void " << name[1] << "(" << arguments << generate_arguments("#scalartype", device, mapping, expression) << ")" << std::endl;
stream << KernelPrefix(backend) << " void " << name[1] << "(" << _size_t << " M, " << _size_t << " N, char* tmp, " << generate_arguments("#scalartype", device, mapping, expression) << ")" << std::endl;
stream << "{" << std::endl;
stream.inc_tab();
unroll_tmp();
process(stream, PARENT_NODE_TYPE,
{{"array1", "#scalartype #namereg = #pointer[#start];"},
{"arrayn", "#pointer += #start;"},

View File

@@ -200,7 +200,6 @@ extern "C"
const TYPE_CU *B, int ldb, TYPE_CU beta, TYPE_CU *C,\
int ldc)\
{\
std::cout << transa << " " << transb << " " << m << " " << n << " " << k << std::endl;\
if(k==1 && m>1 && n>1){\
sc::array dA((sc::int_t)m, TYPE_ISAAC, sc::driver::Buffer((CUdeviceptr)A, false), 0, transa=='N'?1:lda);\
sc::array dB((sc::int_t)n, TYPE_ISAAC, sc::driver::Buffer((CUdeviceptr)B, false), 0, transb=='T'?1:ldb);\

View File

@@ -63,7 +63,7 @@ void test_impl(T epsilon, simple_vector_base<T> & cx, simple_vector_base<T> & c
RUN_TEST("s = x'.y", cs+=cx[i]*cy[i], 0, cs, ds = dot(x,y));
RUN_TEST("s = exp(x'.y)", cs += cx[i]*cy[i], 0, std::exp(cs), ds = exp(dot(x,y)));
RUN_TEST("s = 1 + x'.y", cs += cx[i]*cy[i], 0, 1 + cs, ds = 1 + dot(x,y));
// RUN_TEST("s = x'.y + y'.y", cs+= cx[i]*cy[i] + cy[i]*cy[i], 0, cs, ds = dot(x,y) + dot(y,y));
RUN_TEST("s = x'.y + y'.y", cs+= cx[i]*cy[i] + cy[i]*cy[i], 0, cs, ds = dot(x,y) + dot(y,y));
RUN_TEST("s = max(x)", cs = std::max(cs, cx[i]), std::numeric_limits<T>::min(), cs, ds = max(x));
RUN_TEST("s = min(x)", cs = std::min(cs, cx[i]), std::numeric_limits<T>::max(), cs, ds = min(x));
}