2015-07-25 21:00:18 -07:00
# include <cstring>
2015-01-12 13:20:53 -05:00
# include <iostream>
2015-08-04 20:56:05 -07:00
# include "isaac/kernels/stream.h"
# include "isaac/kernels/keywords.h"
# include "isaac/kernels/templates/gemv.h"
2015-08-06 12:05:12 -07:00
2015-08-06 16:14:33 -07:00
# include "tools/arguments.hpp"
2015-08-06 12:05:12 -07:00
# include "tools/loop.hpp"
2015-08-06 16:14:33 -07:00
# include "tools/reductions.hpp"
# include "tools/vector_types.hpp"
2015-08-06 19:34:26 -07:00
# include <string>
2015-04-29 15:50:57 -04:00
namespace isaac
2015-01-12 13:20:53 -05:00
{
2015-07-11 09:36:01 -04:00
namespace templates
{
2015-01-12 13:20:53 -05:00
2015-07-11 09:36:01 -04:00
gemv_parameters : : gemv_parameters ( unsigned int _simd_width ,
2015-01-12 13:20:53 -05:00
unsigned int _local_size_0 , unsigned int _local_size_1 ,
2015-02-10 23:01:16 -05:00
unsigned int _num_groups_0 , unsigned int _num_groups_1 , fetching_policy_type _fetch_policy ) : base : : parameters_type ( _simd_width , _local_size_0 , _local_size_1 , 1 ) ,
num_groups_0 ( _num_groups_0 ) , num_groups_1 ( _num_groups_1 ) , fetch_policy ( _fetch_policy ) { }
2015-01-12 13:20:53 -05:00
2015-07-11 09:36:01 -04:00
int gemv : : is_invalid_impl ( driver : : Device const & , expressions_tuple const & ) const
2015-01-12 13:20:53 -05:00
{
if ( p_ . fetch_policy = = FETCH_FROM_LOCAL )
return TEMPLATE_INVALID_FETCHING_POLICY_TYPE ;
return TEMPLATE_VALID ;
}
2015-08-05 09:24:10 -07:00
unsigned int gemv : : lmem_usage ( const expressions_tuple & ) const
2015-01-12 13:20:53 -05:00
{
2015-06-28 17:53:16 -07:00
return ( p_ . local_size_0 + 1 ) * p_ . local_size_1 ;
2015-01-12 13:20:53 -05:00
}
2015-08-05 11:42:08 -07:00
std : : string gemv : : generate_impl ( std : : string const & suffix , expressions_tuple const & expressions , driver : : Device const & device , std : : vector < mapping_type > const & mappings ) const
2015-01-12 13:20:53 -05:00
{
2015-08-06 20:20:08 -07:00
using tools : : to_string ;
2015-02-10 23:01:16 -05:00
2015-06-30 17:55:57 -04:00
2015-07-11 09:36:01 -04:00
std : : vector < mapped_gemv * > dots ;
2015-04-29 15:50:57 -04:00
expressions_tuple : : data_type : : const_iterator sit ;
std : : vector < mapping_type > : : const_iterator mit ;
for ( mit = mappings . begin ( ) , sit = expressions . data ( ) . begin ( ) ; mit ! = mappings . end ( ) ; + + mit , + + sit )
{
array_expression const & first_expression = * expressions . data ( ) . front ( ) ;
2015-07-11 09:36:01 -04:00
std : : vector < size_t > idx = filter_nodes ( & is_dot , first_expression , false ) ;
2015-04-29 15:50:57 -04:00
for ( auto & elem : idx )
2015-07-11 09:36:01 -04:00
dots . push_back ( ( mapped_gemv * ) ( mit - > at ( mapping_key ( elem , PARENT_NODE_TYPE ) ) . get ( ) ) ) ;
2015-04-29 15:50:57 -04:00
}
2015-01-12 13:20:53 -05:00
kernel_generation_stream stream ;
2015-04-29 15:50:57 -04:00
driver : : backend_type backend = device . backend ( ) ;
std : : string _size_t = size_type ( device ) ;
2015-01-12 13:20:53 -05:00
2015-08-05 11:42:08 -07:00
std : : string name [ 2 ] = { " prod " , " reduce " } ;
name [ 0 ] + = suffix ;
name [ 1 ] + = suffix ;
2015-01-27 02:41:27 -05:00
2015-04-29 15:50:57 -04:00
std : : string arguments = _size_t + " M, " + _size_t + " N, " ;
2015-07-11 09:36:01 -04:00
for ( const auto & e : dots )
2015-02-10 16:33:38 -05:00
{
2015-08-06 12:05:12 -07:00
std : : string numeric_type = to_string ( lhs_most ( e - > array_expression ( ) . tree ( ) , e - > array_expression ( ) . root ( ) ) . lhs . dtype ) ;
2015-07-11 09:36:01 -04:00
if ( e - > is_index_dot ( ) )
2015-02-10 16:33:38 -05:00
{
2015-04-29 15:50:57 -04:00
arguments + = e - > process ( Global ( backend ) . get ( ) + " unsigned int* #name_temp, " ) ;
2015-08-06 19:34:26 -07:00
arguments + = e - > process ( Global ( backend ) . get ( ) + " " + numeric_type + " * #name_temp_value, " ) ;
2015-02-10 16:33:38 -05:00
}
else
2015-08-06 19:34:26 -07:00
arguments + = e - > process ( Global ( backend ) . get ( ) + " " + numeric_type + " * #name_temp, " ) ;
2015-02-10 16:33:38 -05:00
}
2015-08-13 10:06:18 -07:00
int col_simd_width = ( dot_type_ = = REDUCE_COLUMNS ) ? 1 : p_ . simd_width ;
2015-05-13 02:20:44 -04:00
switch ( backend )
{
2015-08-25 12:41:21 -04:00
case driver : : CUDA :
stream < < " #include \" helper_math.h \" " < < std : : endl ; break ;
case driver : : OPENCL :
stream < < " __attribute__((reqd_work_group_size( " < < p_ . local_size_0 < < " , " < < p_ . local_size_1 < < " ,1))) " < < std : : endl ; break ;
2015-05-13 02:20:44 -04:00
}
2015-04-29 15:50:57 -04:00
stream < < KernelPrefix ( backend ) < < " void " < < name [ 0 ] < < " ( " < < arguments < < generate_arguments ( " #scalartype " , device , mappings , expressions ) < < " ) " < < std : : endl ;
2015-01-12 13:20:53 -05:00
stream < < " { " < < std : : endl ;
stream . inc_tab ( ) ;
process ( stream , PARENT_NODE_TYPE ,
2015-02-10 16:33:38 -05:00
{ { " array0 " , " #scalartype #namereg = #pointer[#start]; " } ,
{ " array1 " , " #pointer += #start; " } ,
2015-08-10 22:45:48 -07:00
{ " array2 " , " #pointer += #start; " } } , expressions , mappings ) ;
2015-01-12 13:20:53 -05:00
2015-06-28 17:53:16 -07:00
unsigned int local_size_0_ld = p_ . local_size_0 ;
2015-02-10 23:01:16 -05:00
std : : string local_size_0_ld_str = to_string ( local_size_0_ld ) ;
2015-07-11 09:36:01 -04:00
for ( const auto & e : dots )
2015-08-13 10:06:18 -07:00
stream < < e - > process ( Local ( backend ) . get ( ) + " " + append_width ( " #scalartype " , col_simd_width ) + " #name_buf[ " + to_string ( p_ . local_size_1 * local_size_0_ld ) + " ]; " ) < < std : : endl ;
2015-04-29 15:50:57 -04:00
2015-08-13 10:06:18 -07:00
stream < < " for( " < < _size_t < < " r = " < < GlobalIdx1 ( backend ) < < " * " < < col_simd_width < < " ; r < (M + " < < p_ . local_size_1 - 1 < < " )/ " < < p_ . local_size_1 < < " * " < < p_ . local_size_1 * col_simd_width < < " ; r += " < < GlobalSize1 ( backend ) < < " * " < < col_simd_width < < " ) " < < std : : endl ;
2015-08-11 11:50:49 -07:00
stream < < " { " < < std : : endl ;
2015-01-12 13:20:53 -05:00
stream . inc_tab ( ) ;
2015-08-11 11:50:49 -07:00
stream < < " " < < _size_t < < " lidx = " < < LocalIdx0 ( backend ) < < " ; " < < std : : endl ;
stream < < " " < < _size_t < < " lidy = " < < LocalIdx1 ( backend ) < < " ; " < < std : : endl ;
2015-01-12 13:20:53 -05:00
2015-08-13 10:06:18 -07:00
for ( const auto & e : dots ) {
std : : string data_type = append_width ( " #scalartype " , col_simd_width ) ;
stream < < e - > process ( data_type + " #name_acc = " + neutral_element ( ( e ) - > root_op ( ) , backend , " #scalartype " ) + " ; " ) < < std : : endl ;
}
2015-01-12 13:20:53 -05:00
stream < < " if (r < M) " < < std : : endl ;
stream < < " { " < < std : : endl ;
stream . inc_tab ( ) ;
2015-08-13 10:06:18 -07:00
element_wise_loop_1D ( stream , p_ . fetch_policy , ( dot_type_ = = REDUCE_COLUMNS ) ? p_ . simd_width : 1 , " c " , " N " , GlobalIdx0 ( backend ) . get ( ) , GlobalSize0 ( backend ) . get ( ) , device , [ & ] ( unsigned int row_simd_width )
2015-01-12 13:20:53 -05:00
{
2015-06-30 17:55:57 -04:00
2015-07-11 09:36:01 -04:00
for ( const auto & e : dots )
2015-01-12 13:20:53 -05:00
{
2015-02-10 16:33:38 -05:00
std : : map < std : : string , std : : string > accessors ;
2015-07-11 09:36:01 -04:00
if ( dot_type_ = = REDUCE_COLUMNS )
2015-01-12 13:20:53 -05:00
{
2015-08-13 10:06:18 -07:00
std : : string data_type = append_width ( " #scalartype " , row_simd_width ) ;
accessors [ " array2 " ] = data_type + " #namereg = " + vload ( row_simd_width , " #scalartype " , " c*#stride " , " #pointer + r*#ld " , backend ) + " ; " ;
accessors [ " repeat " ] = data_type + " #namereg = " + vload ( row_simd_width , " #scalartype " , " (c%#tuplearg0)*#stride " , " #pointer + (r%#tuplearg1)*#stride " , backend ) + " ; " ;
2015-02-10 16:33:38 -05:00
}
else
{
2015-08-13 10:06:18 -07:00
std : : string data_type = append_width ( " #scalartype " , col_simd_width ) ;
accessors [ " array2 " ] = data_type + " #namereg = " + vload ( col_simd_width , " #scalartype " , " 0 " , " #pointer + r*#stride + c*#ld " , backend ) + " ; " ;
2015-02-10 16:33:38 -05:00
accessors [ " repeat " ] = " #scalartype #namereg = $VALUE{(r%#tuplearg0)*#stride, (c%#tuplearg1)*#stride}; " ;
2015-01-12 13:20:53 -05:00
}
2015-02-10 16:33:38 -05:00
e - > process_recursive ( stream , PARENT_NODE_TYPE , accessors ) ;
}
2015-01-12 13:20:53 -05:00
2015-02-10 16:33:38 -05:00
//Update accumulators
2015-08-13 10:06:18 -07:00
std : : vector < std : : string > str ( row_simd_width ) ;
if ( row_simd_width = = 1 )
2015-02-10 16:33:38 -05:00
str [ 0 ] = " #namereg " ;
else
2015-08-13 10:06:18 -07:00
for ( unsigned int a = 0 ; a < row_simd_width ; + + a )
2015-04-29 15:50:57 -04:00
str [ a ] = access_vector_type ( " #namereg " , a ) ;
2015-01-12 13:20:53 -05:00
2015-07-11 09:36:01 -04:00
for ( auto & elem : dots )
2015-08-13 10:06:18 -07:00
for ( unsigned int a = 0 ; a < row_simd_width ; + + a )
2015-01-12 13:20:53 -05:00
{
2015-02-10 16:33:38 -05:00
std : : string value = elem - > evaluate_recursive ( LHS_NODE_TYPE , { { " array2 " , str [ a ] } , { " repeat " , str [ a ] } , { " array0 " , " #namereg " } } ) ;
2015-07-11 09:36:01 -04:00
if ( elem - > is_index_dot ( ) )
2015-08-13 10:06:18 -07:00
compute_index_dot ( stream , elem - > process ( " #name_acc " ) , " c* " + to_string ( row_simd_width ) + to_string ( a ) , elem - > process ( " #name_acc_value " ) , value , elem - > root_op ( ) ) ;
2015-02-10 16:33:38 -05:00
else
2015-07-11 09:36:01 -04:00
compute_dot ( stream , elem - > process ( " #name_acc " ) , value , elem - > root_op ( ) ) ;
2015-01-12 13:20:53 -05:00
}
2015-02-10 16:33:38 -05:00
} ) ;
stream . dec_tab ( ) ;
stream < < " } " < < std : : endl ;
2015-07-11 09:36:01 -04:00
for ( auto & expr : dots )
2015-08-11 11:50:49 -07:00
stream < < expr - > process ( " #name_buf[lidy* " + local_size_0_ld_str + " + lidx] = #name_acc; " ) < < std : : endl ;
2015-02-10 16:33:38 -05:00
stream < < " #pragma unroll " < < std : : endl ;
2015-04-29 15:50:57 -04:00
stream < < " for( " < < _size_t < < " stride = " < < p_ . local_size_0 / 2 < < " ; stride >0; stride /=2) " < < std : : endl ;
2015-02-10 16:33:38 -05:00
stream < < " { " < < std : : endl ;
stream . inc_tab ( ) ;
2015-04-29 15:50:57 -04:00
stream < < LocalBarrier ( backend ) < < " ; " < < std : : endl ;
2015-08-11 11:50:49 -07:00
stream < < " if (lidx < stride) " < < std : : endl ;
2015-02-10 16:33:38 -05:00
stream < < " { " < < std : : endl ;
stream . inc_tab ( ) ;
2015-07-11 09:36:01 -04:00
for ( auto & e : dots )
if ( e - > is_index_dot ( ) )
2015-08-11 11:50:49 -07:00
compute_index_dot ( stream , e - > process ( " #name_buf[lidy* " + local_size_0_ld_str + " + lidx] " ) , e - > process ( " #name_buf[lidy* " + local_size_0_ld_str + " + lidx + stride] " )
, e - > process ( " #name_buf_value[lidy* " + local_size_0_ld_str + " + lidx] " ) , e - > process ( " #name_buf_value[lidy* " + local_size_0_ld_str + " + lidx + stride] " )
2015-02-10 16:33:38 -05:00
, e - > root_op ( ) ) ;
else
2015-08-11 11:50:49 -07:00
compute_dot ( stream , e - > process ( " #name_buf[lidy* " + local_size_0_ld_str + " + lidx] " ) , e - > process ( " #name_buf[lidy* " + local_size_0_ld_str + " + lidx + stride] " ) , e - > root_op ( ) ) ;
2015-02-10 16:33:38 -05:00
stream . dec_tab ( ) ;
stream < < " } " < < std : : endl ;
stream . dec_tab ( ) ;
stream < < " } " < < std : : endl ;
2015-08-11 11:50:49 -07:00
stream < < " if (lidx == 0 && r < M) " < < std : : endl ;
2015-02-10 16:33:38 -05:00
stream < < " { " < < std : : endl ;
stream . inc_tab ( ) ;
2015-02-10 23:01:16 -05:00
if ( p_ . num_groups_0 = = 1 )
{
std : : map < std : : string , std : : string > accessors ;
2015-08-13 10:06:18 -07:00
for ( int s = 0 ; s < col_simd_width ; + + s )
{
accessors [ " gemv " ] = " #name_buf[lidy* " + local_size_0_ld_str + " ] " ;
if ( col_simd_width > 1 )
accessors [ " gemv " ] = access_vector_type ( accessors [ " gemv " ] , s ) ;
accessors [ " array1 " ] = " #pointer[(r + " + to_string ( s ) + " )*#stride] " ;
evaluate ( stream , PARENT_NODE_TYPE , accessors , expressions , mappings ) ;
}
2015-02-10 23:01:16 -05:00
}
else
{
2015-07-11 09:36:01 -04:00
for ( mapped_dot const * e : dots )
2015-02-10 16:33:38 -05:00
{
2015-08-13 10:06:18 -07:00
if ( col_simd_width > 1 )
stream < < " if(M - r > " < < col_simd_width < < " ){ " < < std : : endl ;
2015-07-11 09:36:01 -04:00
if ( e - > is_index_dot ( ) )
2015-08-13 10:06:18 -07:00
stream < < e - > process ( vstore ( col_simd_width , " uint " , " #name_buf_value[lidy* " + local_size_0_ld_str + " ] " , " 0 " , " #name_temp_value + r + M* " + GroupIdx0 ( backend ) . get ( ) , backend ) ) < < " ; " < < std : : endl ;
stream < < e - > process ( vstore ( col_simd_width , " #scalartype " , " #name_buf[lidy* " + local_size_0_ld_str + " ] " , " 0 " , " #name_temp + r + M* " + GroupIdx0 ( backend ) . get ( ) , backend ) ) < < " ; " < < std : : endl ;
if ( col_simd_width > 1 )
{
stream < < " } " < < std : : endl ;
stream < < " else{ " < < std : : endl ;
stream . inc_tab ( ) ;
for ( int s = 0 ; s < col_simd_width ; + + s ) {
if ( e - > is_index_dot ( ) )
stream < < " if(r + " < < s < < " < M) " < < e - > process ( " #name_temp_value[r + " + to_string ( s ) + " + M* " + GroupIdx0 ( backend ) . get ( ) + " ] = " + access_vector_type ( " #name_buf_value[lidy* " + local_size_0_ld_str + " ] " , s ) ) < < " ; " < < std : : endl ;
stream < < " if(r + " < < s < < " < M) " < < e - > process ( " #name_temp[r + " + to_string ( s ) + " + M* " + GroupIdx0 ( backend ) . get ( ) + " ] = " + access_vector_type ( " #name_buf[lidy* " + local_size_0_ld_str + " ] " , s ) ) < < " ; " < < std : : endl ;
}
stream . dec_tab ( ) ;
stream < < " } " < < std : : endl ;
}
2015-01-12 13:20:53 -05:00
}
2015-02-10 23:01:16 -05:00
}
2015-02-10 16:33:38 -05:00
stream . dec_tab ( ) ;
stream < < " } " < < std : : endl ;
stream . dec_tab ( ) ;
stream < < " } " < < std : : endl ;
stream . dec_tab ( ) ;
stream < < " } " < < std : : endl ;
2015-08-13 10:06:18 -07:00
// std::cout << stream.str() << std::endl;
2015-02-10 23:01:16 -05:00
if ( p_ . num_groups_0 > 1 )
{
2015-02-10 16:33:38 -05:00
/////////////////////////////////////////
////////////// Kernel 2
////////////////////////////////////////
2015-04-29 15:50:57 -04:00
if ( backend = = driver : : OPENCL )
stream < < " __attribute__((reqd_work_group_size( " < < p_ . local_size_0 < < " , " < < p_ . local_size_1 < < " ,1))) " < < std : : endl ;
stream < < KernelPrefix ( backend ) < < " void " < < name [ 1 ] < < " ( " < < arguments < < generate_arguments ( " #scalartype " , device , mappings , expressions ) < < " ) " < < std : : endl ;
2015-02-10 16:33:38 -05:00
stream < < " { " < < std : : endl ;
stream . inc_tab ( ) ;
process ( stream , PARENT_NODE_TYPE ,
{ { " array0 " , " #scalartype #namereg = #pointer[#start]; " } ,
{ " array1 " , " #pointer += #start; " } ,
2015-08-10 22:45:48 -07:00
{ " array2 " , " #pointer += #start; " } } , expressions , mappings ) ;
2015-02-10 16:33:38 -05:00
2015-07-11 09:36:01 -04:00
for ( const auto & e : dots )
2015-04-29 15:50:57 -04:00
stream < < e - > process ( Local ( backend ) . get ( ) + " #scalartype #name_buf[ " + to_string ( p_ . local_size_1 * local_size_0_ld ) + " ]; " ) < < std : : endl ;
2015-08-11 11:50:49 -07:00
stream < < " for( " < < _size_t < < " r = " < < GlobalIdx1 ( backend ) < < " ; r < (M + " < < p_ . local_size_1 - 1 < < " )/ " < < p_ . local_size_1 < < " * " < < p_ . local_size_1 < < " ; r += " < < GlobalSize1 ( backend ) < < " ){ " < < std : : endl ;
2015-02-10 16:33:38 -05:00
stream . inc_tab ( ) ;
2015-08-11 11:50:49 -07:00
stream < < _size_t < < " lidx = " < < LocalIdx0 ( backend ) < < " ; " < < std : : endl ;
stream < < _size_t < < " lidy = " < < LocalIdx1 ( backend ) < < " ; " < < std : : endl ;
2015-02-10 16:33:38 -05:00
2015-07-11 09:36:01 -04:00
for ( const auto & e : dots )
2015-05-13 02:20:44 -04:00
stream < < e - > process ( " #scalartype #name_acc = " + neutral_element ( ( e ) - > root_op ( ) , backend , " #scalartype " ) + " ; " ) < < std : : endl ;
2015-02-10 16:33:38 -05:00
stream < < " if (r < M) " < < std : : endl ;
stream < < " { " < < std : : endl ;
stream . inc_tab ( ) ;
2015-08-11 11:50:49 -07:00
stream < < " for( " < < _size_t < < " c = lidx; c < " < < p_ . num_groups_0 < < " ; c += " < < LocalSize0 ( backend ) < < " ){ " < < std : : endl ;
2015-02-10 16:33:38 -05:00
stream . inc_tab ( ) ;
2015-07-11 09:36:01 -04:00
for ( mapped_dot * e : dots )
compute_dot ( stream , e - > process ( " #name_acc " ) , e - > process ( " #name_temp[r + M*c] " ) , e - > root_op ( ) ) ;
2015-02-10 16:33:38 -05:00
stream . dec_tab ( ) ;
stream < < " } " < < std : : endl ;
2015-01-12 13:20:53 -05:00
stream . dec_tab ( ) ;
stream < < " } " < < std : : endl ;
2015-07-11 09:36:01 -04:00
for ( auto & expr : dots )
2015-08-11 11:50:49 -07:00
stream < < expr - > process ( " #name_buf[lidy* " + local_size_0_ld_str + " + lidx] = #name_acc; " ) < < std : : endl ;
2015-01-12 13:20:53 -05:00
stream < < " #pragma unroll " < < std : : endl ;
2015-04-29 15:50:57 -04:00
stream < < " for( " < < _size_t < < " stride = " < < p_ . local_size_0 / 2 < < " ; stride >0; stride /=2) " < < std : : endl ;
2015-01-12 13:20:53 -05:00
stream < < " { " < < std : : endl ;
stream . inc_tab ( ) ;
2015-04-29 15:50:57 -04:00
stream < < LocalBarrier ( backend ) < < " ; " < < std : : endl ;
2015-08-11 11:50:49 -07:00
stream < < " if (lidx < stride) " < < std : : endl ;
2015-01-12 13:20:53 -05:00
stream < < " { " < < std : : endl ;
stream . inc_tab ( ) ;
2015-07-11 09:36:01 -04:00
for ( auto & e : dots )
if ( e - > is_index_dot ( ) )
2015-08-11 11:50:49 -07:00
compute_index_dot ( stream , e - > process ( " #name_buf[lidy* " + local_size_0_ld_str + " + lidx] " ) , e - > process ( " #name_buf[lidy* " + local_size_0_ld_str + " + lidx + stride] " )
, e - > process ( " #name_buf_value[lidy* " + local_size_0_ld_str + " + lidx] " ) , e - > process ( " #name_buf_value[lidy* " + local_size_0_ld_str + " + lidx + stride] " )
2015-02-10 16:33:38 -05:00
, e - > root_op ( ) ) ;
2015-01-12 13:20:53 -05:00
else
2015-08-11 11:50:49 -07:00
compute_dot ( stream , e - > process ( " #name_buf[lidy* " + local_size_0_ld_str + " + lidx] " ) , e - > process ( " #name_buf[lidy* " + local_size_0_ld_str + " + lidx + stride] " ) , e - > root_op ( ) ) ;
2015-01-12 13:20:53 -05:00
stream . dec_tab ( ) ;
stream < < " } " < < std : : endl ;
stream . dec_tab ( ) ;
stream < < " } " < < std : : endl ;
2015-08-11 11:50:49 -07:00
stream < < " if (lidx == 0 && r < M) " ;
2015-01-12 13:20:53 -05:00
stream < < " { " < < std : : endl ;
stream . inc_tab ( ) ;
2015-02-10 16:33:38 -05:00
2015-01-12 13:20:53 -05:00
std : : map < std : : string , std : : string > accessors ;
2015-08-11 11:50:49 -07:00
accessors [ " gemv " ] = " #name_buf[lidy* " + local_size_0_ld_str + " ] " ;
2015-01-17 15:47:52 -05:00
accessors [ " array1 " ] = " #pointer[r*#stride] " ;
2015-02-01 22:28:49 -05:00
evaluate ( stream , PARENT_NODE_TYPE , accessors , expressions , mappings ) ;
2015-02-10 16:33:38 -05:00
2015-01-12 13:20:53 -05:00
stream . dec_tab ( ) ;
stream < < " } " < < std : : endl ;
stream . dec_tab ( ) ;
stream < < " } " < < std : : endl ;
stream . dec_tab ( ) ;
stream < < " } " < < std : : endl ;
2015-02-10 23:01:16 -05:00
}
2015-02-10 16:33:38 -05:00
2015-01-12 13:20:53 -05:00
return stream . str ( ) ;
}
2015-07-11 09:36:01 -04:00
gemv : : gemv ( gemv : : parameters_type const & parameters ,
gemv : : dot_type rtype ,
2015-01-12 13:20:53 -05:00
binding_policy_t binding_policy ) :
2015-07-11 09:36:01 -04:00
base_impl < gemv , gemv_parameters > ( parameters , binding_policy ) ,
dot_type_ ( rtype ) { }
2015-01-12 13:20:53 -05:00
2015-07-11 09:36:01 -04:00
std : : vector < int_t > gemv : : input_sizes ( expressions_tuple const & expressions ) const
2015-01-12 13:20:53 -05:00
{
2015-02-01 22:28:49 -05:00
array_expression const & first_expression = * expressions . data ( ) . front ( ) ;
2015-07-11 09:36:01 -04:00
std : : vector < std : : size_t > idx = filter_nodes ( & is_dot , first_expression , false ) ;
2015-01-16 07:31:39 -05:00
std : : pair < int_t , int_t > MN = matrix_size ( lhs_most ( first_expression . tree ( ) , idx [ 0 ] ) ) ;
2015-07-11 09:36:01 -04:00
if ( dot_type_ = = REDUCE_COLUMNS )
2015-01-12 13:20:53 -05:00
std : : swap ( MN . first , MN . second ) ;
2015-08-06 12:05:12 -07:00
return { MN . first , MN . second } ;
2015-01-12 13:20:53 -05:00
}
2015-08-05 11:42:08 -07:00
void gemv : : enqueue ( driver : : CommandQueue & queue , driver : : Program const & program , std : : string const & suffix , base & fallback , controller < expressions_tuple > const & controller )
2015-01-12 13:20:53 -05:00
{
2015-02-05 04:42:57 -05:00
expressions_tuple const & expressions = controller . x ( ) ;
2015-04-29 15:50:57 -04:00
driver : : Context const & context = expressions . context ( ) ;
2015-02-10 16:33:38 -05:00
2015-02-01 22:28:49 -05:00
std : : vector < int_t > MN = input_sizes ( expressions ) ;
2015-07-11 09:36:01 -04:00
std : : vector < array_expression : : node const * > dots ;
2015-02-10 16:33:38 -05:00
for ( const auto & e : expressions . data ( ) )
{
2015-07-11 09:36:01 -04:00
std : : vector < size_t > dots_idx = filter_nodes ( & is_dot , * e , false ) ;
for ( auto & r : dots_idx )
dots . push_back ( & ( e ) - > tree ( ) [ r ] ) ;
2015-02-10 16:33:38 -05:00
}
2015-04-29 15:50:57 -04:00
//Fallback
2015-08-13 10:06:18 -07:00
if ( p_ . simd_width > 1 & & requires_fallback ( expressions ) )
2015-04-29 15:50:57 -04:00
{
fallback . enqueue ( queue , program , " fallback " , fallback , controller ) ;
return ;
}
2015-01-12 13:20:53 -05:00
2015-04-29 15:50:57 -04:00
//Kernel
std : : vector < driver : : Buffer > tmp ;
std : : vector < driver : : Buffer > tmpidx ;
2015-02-10 16:33:38 -05:00
unsigned int dtype_size = size_of ( lhs_most ( expressions . data ( ) . front ( ) - > tree ( ) , expressions . data ( ) . front ( ) - > root ( ) ) . lhs . dtype ) ;
2015-02-10 23:01:16 -05:00
2015-08-05 11:42:08 -07:00
std : : string name [ 2 ] = { " prod " , " reduce " } ;
name [ 0 ] + = suffix ;
name [ 1 ] + = suffix ;
2015-02-10 23:01:16 -05:00
unsigned int nk = ( p_ . num_groups_0 = = 1 ) ? 1 : 2 ;
2015-04-29 15:50:57 -04:00
std : : vector < driver : : Kernel > kernels ;
2015-02-10 23:01:16 -05:00
for ( unsigned int k = 0 ; k < nk ; + + k )
2015-08-05 11:42:08 -07:00
kernels . push_back ( driver : : Kernel ( program , name [ k ] . c_str ( ) ) ) ;
2015-02-10 23:01:16 -05:00
for ( unsigned int k = 0 ; k < nk ; + + k )
2015-02-10 16:33:38 -05:00
{
2015-04-29 15:50:57 -04:00
driver : : Kernel & kernel = kernels [ k ] ;
2015-02-10 16:33:38 -05:00
unsigned int n_arg = 0 ;
int_t M = MN [ 0 ] ;
int_t N = MN [ 1 ] ;
2015-04-29 15:50:57 -04:00
kernel . setSizeArg ( n_arg + + , M ) ;
kernel . setSizeArg ( n_arg + + , N ) ;
2015-02-10 16:33:38 -05:00
//Temporary buffers
unsigned int i = 0 ;
unsigned int j = 0 ;
2015-07-11 09:36:01 -04:00
for ( auto const & r : dots )
2015-02-10 16:33:38 -05:00
{
2015-07-11 09:36:01 -04:00
if ( is_index_dot ( r - > op ) )
2015-02-10 16:33:38 -05:00
{
if ( tmpidx . size ( ) < = j )
2015-04-29 15:50:57 -04:00
tmpidx . push_back ( driver : : Buffer ( context , p_ . num_groups_0 * M * 4 ) ) ;
2015-02-10 23:01:16 -05:00
kernel . setArg ( n_arg + + , tmpidx [ j ] ) ;
2015-02-10 16:33:38 -05:00
j + + ;
}
if ( tmp . size ( ) < = i )
2015-04-29 15:50:57 -04:00
tmp . push_back ( driver : : Buffer ( context , p_ . num_groups_0 * M * dtype_size ) ) ;
2015-02-10 23:01:16 -05:00
kernel . setArg ( n_arg + + , tmp [ i ] ) ;
2015-02-10 16:33:38 -05:00
i + + ;
}
2015-08-06 16:14:33 -07:00
set_arguments ( expressions , kernel , n_arg , binding_policy_ ) ;
2015-02-10 16:33:38 -05:00
}
2015-01-12 13:20:53 -05:00
2015-02-10 23:01:16 -05:00
//NDRange
2015-04-29 15:50:57 -04:00
driver : : NDRange global [ 2 ] = { driver : : NDRange ( p_ . local_size_0 * p_ . num_groups_0 , p_ . local_size_1 * p_ . num_groups_1 ) , driver : : NDRange ( p_ . local_size_0 , p_ . local_size_1 * p_ . num_groups_1 ) } ;
driver : : NDRange local [ 2 ] = { driver : : NDRange ( p_ . local_size_0 , p_ . local_size_1 ) , driver : : NDRange ( p_ . local_size_0 , p_ . local_size_1 ) } ;
2015-02-10 23:01:16 -05:00
for ( unsigned int i = 0 ; i < nk ; + + i )
2015-06-23 09:38:34 -07:00
controller . execution_options ( ) . enqueue ( program . context ( ) , kernels [ i ] , global [ i ] , local [ i ] ) ;
2015-01-12 13:20:53 -05:00
}
2015-07-11 09:36:01 -04:00
gemv_n : : gemv_n ( gemv_parameters const & parameters ,
2015-01-12 13:20:53 -05:00
binding_policy_t binding_policy ) :
2015-07-11 09:36:01 -04:00
gemv ( parameters , REDUCE_ROWS , binding_policy ) { }
2015-01-12 13:20:53 -05:00
2015-07-11 09:36:01 -04:00
gemv_n : : gemv_n ( unsigned int simd , unsigned int ls1 , unsigned int ls2 ,
2015-02-10 23:01:16 -05:00
unsigned int ng1 , unsigned int ng2 , fetching_policy_type fetch , binding_policy_t bind ) :
2015-07-11 09:36:01 -04:00
gemv ( gemv_parameters ( simd , ls1 , ls2 , ng1 , ng2 , fetch ) , REDUCE_ROWS , bind )
2015-01-12 13:20:53 -05:00
{ }
2015-07-11 09:36:01 -04:00
gemv_t : : gemv_t ( gemv : : parameters_type const & parameters ,
2015-01-12 13:20:53 -05:00
binding_policy_t binding_policy ) :
2015-07-11 09:36:01 -04:00
gemv ( parameters , REDUCE_COLUMNS , binding_policy ) { }
2015-01-12 13:20:53 -05:00
2015-07-11 09:36:01 -04:00
gemv_t : : gemv_t ( unsigned int simd , unsigned int ls1 , unsigned int ls2 ,
2015-02-10 23:01:16 -05:00
unsigned int ng1 , unsigned int ng2 , fetching_policy_type fetch , binding_policy_t bind ) :
2015-07-11 09:36:01 -04:00
gemv ( gemv_parameters ( simd , ls1 , ls2 , ng1 , ng2 , fetch ) , REDUCE_COLUMNS , bind )
2015-01-12 13:20:53 -05:00
{ }
}
2015-07-11 09:36:01 -04:00
}