2015-07-25 21:00:18 -07:00
# include <cstring>
2015-01-12 13:20:53 -05:00
# include <iostream>
2015-08-04 20:56:05 -07:00
# include "isaac/kernels/stream.h"
# include "isaac/kernels/keywords.h"
2015-12-12 18:32:06 -05:00
# include "isaac/kernels/templates/reduce_2d.h"
2015-08-06 12:05:12 -07:00
2015-08-06 16:14:33 -07:00
# include "tools/arguments.hpp"
2015-08-06 12:05:12 -07:00
# include "tools/loop.hpp"
2015-08-06 16:14:33 -07:00
# include "tools/reductions.hpp"
# include "tools/vector_types.hpp"
2015-08-06 19:34:26 -07:00
# include <string>
2015-04-29 15:50:57 -04:00
namespace isaac
2015-01-12 13:20:53 -05:00
{
2015-07-11 09:36:01 -04:00
namespace templates
{
2015-01-12 13:20:53 -05:00
2015-12-12 18:32:06 -05:00
reduce_2d_parameters : : reduce_2d_parameters ( unsigned int _simd_width ,
2015-01-12 13:20:53 -05:00
unsigned int _local_size_0 , unsigned int _local_size_1 ,
2015-02-10 23:01:16 -05:00
unsigned int _num_groups_0 , unsigned int _num_groups_1 , fetching_policy_type _fetch_policy ) : base : : parameters_type ( _simd_width , _local_size_0 , _local_size_1 , 1 ) ,
num_groups_0 ( _num_groups_0 ) , num_groups_1 ( _num_groups_1 ) , fetch_policy ( _fetch_policy ) { }
2015-01-12 13:20:53 -05:00
2015-12-12 18:32:06 -05:00
int reduce_2d : : is_invalid_impl ( driver : : Device const & , math_expression const & ) const
2015-01-12 13:20:53 -05:00
{
if ( p_ . fetch_policy = = FETCH_FROM_LOCAL )
return TEMPLATE_INVALID_FETCHING_POLICY_TYPE ;
return TEMPLATE_VALID ;
}
2015-12-12 18:32:06 -05:00
unsigned int reduce_2d : : lmem_usage ( const math_expression & ) const
2015-01-12 13:20:53 -05:00
{
2015-06-28 17:53:16 -07:00
return ( p_ . local_size_0 + 1 ) * p_ . local_size_1 ;
2015-01-12 13:20:53 -05:00
}
2015-12-12 18:32:06 -05:00
std : : string reduce_2d : : generate_impl ( std : : string const & suffix , math_expression const & expression , driver : : Device const & device , mapping_type const & mapping ) const
2015-01-12 13:20:53 -05:00
{
2015-08-06 20:20:08 -07:00
using tools : : to_string ;
2015-02-10 23:01:16 -05:00
2015-06-30 17:55:57 -04:00
2015-12-12 18:32:06 -05:00
std : : vector < mapped_reduce_2d * > reduce_1ds ;
std : : vector < size_t > idx = filter_nodes ( & is_reduce_1d , expression , expression . root ( ) , false ) ;
2015-09-30 15:31:41 -04:00
for ( auto & elem : idx )
2015-12-12 18:32:06 -05:00
reduce_1ds . push_back ( ( mapped_reduce_2d * ) ( mapping . at ( mapping_key ( elem , PARENT_NODE_TYPE ) ) . get ( ) ) ) ;
2015-01-12 13:20:53 -05:00
kernel_generation_stream stream ;
2015-04-29 15:50:57 -04:00
driver : : backend_type backend = device . backend ( ) ;
std : : string _size_t = size_type ( device ) ;
2015-01-12 13:20:53 -05:00
2015-08-05 11:42:08 -07:00
std : : string name [ 2 ] = { " prod " , " reduce " } ;
name [ 0 ] + = suffix ;
name [ 1 ] + = suffix ;
2015-01-27 02:41:27 -05:00
2015-12-05 19:14:09 -05:00
auto unroll_tmp = [ & ] ( )
2015-02-10 16:33:38 -05:00
{
2015-12-05 19:14:09 -05:00
unsigned int offset = 0 ;
2015-12-12 18:32:06 -05:00
for ( const auto & e : reduce_1ds )
2015-12-05 19:14:09 -05:00
{
numeric_type dtype = lhs_most ( e - > math_expression ( ) . tree ( ) , e - > math_expression ( ) . root ( ) ) . lhs . dtype ;
std : : string sdtype = to_string ( dtype ) ;
2015-12-12 18:32:06 -05:00
if ( e - > is_index_reduction ( ) )
2015-12-05 19:14:09 -05:00
{
stream < < e - > process ( " uint* #name_temp = (uint*)(tmp + " + tools : : to_string ( offset ) + " *M); " ) ;
offset + = 4 * p_ . num_groups_0 ;
stream < < e - > process ( sdtype + " * #name_temp_value = ( " + sdtype + " *)(tmp + " + tools : : to_string ( offset ) + " *M); " ) ;
offset + = size_of ( dtype ) * p_ . num_groups_0 ;
}
else {
stream < < e - > process ( sdtype + " * #name_temp = ( " + sdtype + " *)(tmp + " + tools : : to_string ( offset ) + " *M); " ) ;
offset + = size_of ( dtype ) * p_ . num_groups_0 ;
}
}
} ;
2015-02-10 16:33:38 -05:00
2015-12-12 18:32:06 -05:00
int col_simd_width = ( reduce_1d_type_ = = REDUCE_COLUMNS ) ? 1 : p_ . simd_width ;
2015-05-13 02:20:44 -04:00
switch ( backend )
{
2015-08-25 12:41:21 -04:00
case driver : : CUDA :
stream < < " #include \" helper_math.h \" " < < std : : endl ; break ;
case driver : : OPENCL :
stream < < " __attribute__((reqd_work_group_size( " < < p_ . local_size_0 < < " , " < < p_ . local_size_1 < < " ,1))) " < < std : : endl ; break ;
2015-05-13 02:20:44 -04:00
}
2015-04-29 15:50:57 -04:00
2015-12-12 13:30:16 -05:00
stream < < KernelPrefix ( backend ) < < " void " < < name [ 0 ] < < " ( " < < _size_t < < " M, " < < _size_t < < " N, " < < Global ( backend ) < < " char* tmp, " < < generate_arguments ( " #scalartype " , device , mapping , expression ) < < " ) " < < std : : endl ;
2015-01-12 13:20:53 -05:00
stream < < " { " < < std : : endl ;
stream . inc_tab ( ) ;
2015-12-05 19:14:09 -05:00
unroll_tmp ( ) ;
2015-01-12 13:20:53 -05:00
process ( stream , PARENT_NODE_TYPE ,
2015-11-19 12:37:18 -05:00
{ { " array1 " , " #scalartype #namereg = #pointer[#start]; " } ,
{ " arrayn " , " #pointer += #start; " } ,
{ " arraynn " , " #pointer += #start; " } } , expression , mapping ) ;
2015-01-12 13:20:53 -05:00
2015-06-28 17:53:16 -07:00
unsigned int local_size_0_ld = p_ . local_size_0 ;
2015-02-10 23:01:16 -05:00
std : : string local_size_0_ld_str = to_string ( local_size_0_ld ) ;
2015-12-12 18:32:06 -05:00
for ( const auto & e : reduce_1ds )
2015-08-13 10:06:18 -07:00
stream < < e - > process ( Local ( backend ) . get ( ) + " " + append_width ( " #scalartype " , col_simd_width ) + " #name_buf[ " + to_string ( p_ . local_size_1 * local_size_0_ld ) + " ]; " ) < < std : : endl ;
2015-04-29 15:50:57 -04:00
2015-08-13 10:06:18 -07:00
stream < < " for( " < < _size_t < < " r = " < < GlobalIdx1 ( backend ) < < " * " < < col_simd_width < < " ; r < (M + " < < p_ . local_size_1 - 1 < < " )/ " < < p_ . local_size_1 < < " * " < < p_ . local_size_1 * col_simd_width < < " ; r += " < < GlobalSize1 ( backend ) < < " * " < < col_simd_width < < " ) " < < std : : endl ;
2015-08-11 11:50:49 -07:00
stream < < " { " < < std : : endl ;
2015-01-12 13:20:53 -05:00
stream . inc_tab ( ) ;
2015-08-11 11:50:49 -07:00
stream < < " " < < _size_t < < " lidx = " < < LocalIdx0 ( backend ) < < " ; " < < std : : endl ;
stream < < " " < < _size_t < < " lidy = " < < LocalIdx1 ( backend ) < < " ; " < < std : : endl ;
2015-01-12 13:20:53 -05:00
2015-12-12 18:32:06 -05:00
for ( const auto & e : reduce_1ds ) {
2015-08-13 10:06:18 -07:00
std : : string data_type = append_width ( " #scalartype " , col_simd_width ) ;
2015-08-30 02:23:40 -04:00
stream < < e - > process ( data_type + " #name_acc = " + InitPrefix ( backend , data_type ) . get ( ) + " ( " + neutral_element ( ( e ) - > root_op ( ) , backend , " #scalartype " ) + " ); " ) < < std : : endl ;
2015-08-13 10:06:18 -07:00
}
2015-01-12 13:20:53 -05:00
stream < < " if (r < M) " < < std : : endl ;
stream < < " { " < < std : : endl ;
stream . inc_tab ( ) ;
2015-12-12 18:32:06 -05:00
element_wise_loop_1D ( stream , p_ . fetch_policy , ( reduce_1d_type_ = = REDUCE_COLUMNS ) ? p_ . simd_width : 1 , " c " , " N " , GlobalIdx0 ( backend ) . get ( ) , GlobalSize0 ( backend ) . get ( ) , device , [ & ] ( unsigned int row_simd_width )
2015-01-12 13:20:53 -05:00
{
2015-09-30 15:31:41 -04:00
std : : set < std : : string > already_fetched ;
2015-12-12 18:32:06 -05:00
for ( const auto & e : reduce_1ds )
2015-01-12 13:20:53 -05:00
{
2015-02-10 16:33:38 -05:00
std : : map < std : : string , std : : string > accessors ;
2015-12-12 18:32:06 -05:00
if ( reduce_1d_type_ = = REDUCE_COLUMNS )
2015-01-12 13:20:53 -05:00
{
2015-08-13 10:06:18 -07:00
std : : string data_type = append_width ( " #scalartype " , row_simd_width ) ;
2015-11-19 12:37:18 -05:00
accessors [ " arraynn " ] = data_type + " #namereg = " + vload ( row_simd_width , " #scalartype " , " c*#stride " , " #pointer + r*#ld " , " 1 " , backend , false ) + " ; " ;
2015-09-30 15:31:41 -04:00
accessors [ " repeat " ] = data_type + " #namereg = " + vload ( row_simd_width , " #scalartype " , " (c%#sub0)*#stride " , " #pointer + (r%#sub1)*#stride " , " 1 " , backend , false ) + " ; " ;
2015-02-10 16:33:38 -05:00
}
else
{
2015-08-13 10:06:18 -07:00
std : : string data_type = append_width ( " #scalartype " , col_simd_width ) ;
2015-11-19 12:37:18 -05:00
accessors [ " arraynn " ] = data_type + " #namereg = " + vload ( col_simd_width , " #scalartype " , " 0 " , " #pointer + r*#stride + c*#ld " , " 1 " , backend , false ) + " ; " ;
2015-09-30 15:31:41 -04:00
accessors [ " repeat " ] = " #scalartype #namereg = $VALUE{(r%#sub0)*#stride, (c%#sub1)*#stride}; " ;
2015-01-12 13:20:53 -05:00
}
2015-09-30 15:31:41 -04:00
e - > process_recursive ( stream , PARENT_NODE_TYPE , accessors , already_fetched ) ;
2015-02-10 16:33:38 -05:00
}
2015-01-12 13:20:53 -05:00
2015-02-10 16:33:38 -05:00
//Update accumulators
2015-08-13 10:06:18 -07:00
std : : vector < std : : string > str ( row_simd_width ) ;
if ( row_simd_width = = 1 )
2015-02-10 16:33:38 -05:00
str [ 0 ] = " #namereg " ;
else
2015-08-13 10:06:18 -07:00
for ( unsigned int a = 0 ; a < row_simd_width ; + + a )
2015-04-29 15:50:57 -04:00
str [ a ] = access_vector_type ( " #namereg " , a ) ;
2015-01-12 13:20:53 -05:00
2015-12-12 18:32:06 -05:00
for ( auto & elem : reduce_1ds )
2015-08-13 10:06:18 -07:00
for ( unsigned int a = 0 ; a < row_simd_width ; + + a )
2015-01-12 13:20:53 -05:00
{
2015-11-19 12:37:18 -05:00
std : : string value = elem - > evaluate_recursive ( LHS_NODE_TYPE , { { " arraynn " , str [ a ] } , { " repeat " , str [ a ] } , { " array1 " , " #namereg " } } ) ;
2015-12-12 18:32:06 -05:00
if ( elem - > is_index_reduction ( ) )
compute_index_reduce_1d ( stream , elem - > process ( " #name_acc " ) , " c* " + to_string ( row_simd_width ) + to_string ( a ) , elem - > process ( " #name_acc_value " ) , value , elem - > root_op ( ) ) ;
2015-02-10 16:33:38 -05:00
else
2015-12-12 18:32:06 -05:00
compute_reduce_1d ( stream , elem - > process ( " #name_acc " ) , value , elem - > root_op ( ) ) ;
2015-01-12 13:20:53 -05:00
}
2015-02-10 16:33:38 -05:00
} ) ;
stream . dec_tab ( ) ;
stream < < " } " < < std : : endl ;
2015-12-12 18:32:06 -05:00
for ( auto & expr : reduce_1ds )
2015-08-11 11:50:49 -07:00
stream < < expr - > process ( " #name_buf[lidy* " + local_size_0_ld_str + " + lidx] = #name_acc; " ) < < std : : endl ;
2015-02-10 16:33:38 -05:00
stream < < " #pragma unroll " < < std : : endl ;
2015-04-29 15:50:57 -04:00
stream < < " for( " < < _size_t < < " stride = " < < p_ . local_size_0 / 2 < < " ; stride >0; stride /=2) " < < std : : endl ;
2015-02-10 16:33:38 -05:00
stream < < " { " < < std : : endl ;
stream . inc_tab ( ) ;
2015-04-29 15:50:57 -04:00
stream < < LocalBarrier ( backend ) < < " ; " < < std : : endl ;
2015-08-11 11:50:49 -07:00
stream < < " if (lidx < stride) " < < std : : endl ;
2015-02-10 16:33:38 -05:00
stream < < " { " < < std : : endl ;
stream . inc_tab ( ) ;
2015-12-12 18:32:06 -05:00
for ( auto & e : reduce_1ds )
if ( e - > is_index_reduction ( ) )
compute_index_reduce_1d ( stream , e - > process ( " #name_buf[lidy* " + local_size_0_ld_str + " + lidx] " ) , e - > process ( " #name_buf[lidy* " + local_size_0_ld_str + " + lidx + stride] " )
2015-08-11 11:50:49 -07:00
, e - > process ( " #name_buf_value[lidy* " + local_size_0_ld_str + " + lidx] " ) , e - > process ( " #name_buf_value[lidy* " + local_size_0_ld_str + " + lidx + stride] " )
2015-02-10 16:33:38 -05:00
, e - > root_op ( ) ) ;
else
2015-12-12 18:32:06 -05:00
compute_reduce_1d ( stream , e - > process ( " #name_buf[lidy* " + local_size_0_ld_str + " + lidx] " ) , e - > process ( " #name_buf[lidy* " + local_size_0_ld_str + " + lidx + stride] " ) , e - > root_op ( ) ) ;
2015-02-10 16:33:38 -05:00
stream . dec_tab ( ) ;
stream < < " } " < < std : : endl ;
stream . dec_tab ( ) ;
stream < < " } " < < std : : endl ;
2015-08-11 11:50:49 -07:00
stream < < " if (lidx == 0 && r < M) " < < std : : endl ;
2015-02-10 16:33:38 -05:00
stream < < " { " < < std : : endl ;
stream . inc_tab ( ) ;
2015-02-10 23:01:16 -05:00
if ( p_ . num_groups_0 = = 1 )
{
std : : map < std : : string , std : : string > accessors ;
2015-08-13 10:06:18 -07:00
for ( int s = 0 ; s < col_simd_width ; + + s )
{
2015-12-12 18:32:06 -05:00
accessors [ " reduce_2d " ] = " #name_buf[lidy* " + local_size_0_ld_str + " ] " ;
2015-08-13 10:06:18 -07:00
if ( col_simd_width > 1 )
2015-12-12 18:32:06 -05:00
accessors [ " reduce_2d " ] = access_vector_type ( accessors [ " reduce_2d " ] , s ) ;
2015-11-19 12:37:18 -05:00
accessors [ " arrayn " ] = " #pointer[(r + " + to_string ( s ) + " )*#stride] " ;
2015-11-25 18:42:25 -05:00
accessors [ " array1n " ] = " #pointer[(r + " + to_string ( s ) + " )*#stride] " ;
accessors [ " arrayn1 " ] = " #pointer[(r + " + to_string ( s ) + " )*#stride] " ;
2015-09-30 15:31:41 -04:00
stream < < evaluate ( PARENT_NODE_TYPE , accessors , expression , expression . root ( ) , mapping ) < < " ; " < < std : : endl ;
2015-08-13 10:06:18 -07:00
}
2015-02-10 23:01:16 -05:00
}
else
{
2015-12-12 18:32:06 -05:00
for ( mapped_reduce const * e : reduce_1ds )
2015-02-10 16:33:38 -05:00
{
2015-08-13 10:06:18 -07:00
if ( col_simd_width > 1 )
stream < < " if(M - r > " < < col_simd_width < < " ){ " < < std : : endl ;
2015-12-12 18:32:06 -05:00
if ( e - > is_index_reduction ( ) )
2015-09-30 15:31:41 -04:00
stream < < e - > process ( vstore ( col_simd_width , " uint " , " #name_buf_value[lidy* " + local_size_0_ld_str + " ] " , " 0 " , " #name_temp_value + r + M* " + GroupIdx0 ( backend ) . get ( ) , " 1 " , backend , false ) ) < < " ; " < < std : : endl ;
stream < < e - > process ( vstore ( col_simd_width , " #scalartype " , " #name_buf[lidy* " + local_size_0_ld_str + " ] " , " 0 " , " #name_temp + r + M* " + GroupIdx0 ( backend ) . get ( ) , " 1 " , backend , false ) ) < < " ; " < < std : : endl ;
2015-08-13 10:06:18 -07:00
if ( col_simd_width > 1 )
{
stream < < " } " < < std : : endl ;
stream < < " else{ " < < std : : endl ;
stream . inc_tab ( ) ;
for ( int s = 0 ; s < col_simd_width ; + + s ) {
2015-12-12 18:32:06 -05:00
if ( e - > is_index_reduction ( ) )
2015-08-13 10:06:18 -07:00
stream < < " if(r + " < < s < < " < M) " < < e - > process ( " #name_temp_value[r + " + to_string ( s ) + " + M* " + GroupIdx0 ( backend ) . get ( ) + " ] = " + access_vector_type ( " #name_buf_value[lidy* " + local_size_0_ld_str + " ] " , s ) ) < < " ; " < < std : : endl ;
stream < < " if(r + " < < s < < " < M) " < < e - > process ( " #name_temp[r + " + to_string ( s ) + " + M* " + GroupIdx0 ( backend ) . get ( ) + " ] = " + access_vector_type ( " #name_buf[lidy* " + local_size_0_ld_str + " ] " , s ) ) < < " ; " < < std : : endl ;
}
stream . dec_tab ( ) ;
stream < < " } " < < std : : endl ;
}
2015-01-12 13:20:53 -05:00
}
2015-02-10 23:01:16 -05:00
}
2015-02-10 16:33:38 -05:00
stream . dec_tab ( ) ;
stream < < " } " < < std : : endl ;
stream . dec_tab ( ) ;
stream < < " } " < < std : : endl ;
stream . dec_tab ( ) ;
stream < < " } " < < std : : endl ;
2015-08-13 10:06:18 -07:00
2015-02-10 23:01:16 -05:00
if ( p_ . num_groups_0 > 1 )
{
2015-02-10 16:33:38 -05:00
/////////////////////////////////////////
////////////// Kernel 2
////////////////////////////////////////
2015-04-29 15:50:57 -04:00
if ( backend = = driver : : OPENCL )
stream < < " __attribute__((reqd_work_group_size( " < < p_ . local_size_0 < < " , " < < p_ . local_size_1 < < " ,1))) " < < std : : endl ;
2015-12-12 13:30:16 -05:00
stream < < KernelPrefix ( backend ) < < " void " < < name [ 1 ] < < " ( " < < _size_t < < " M, " < < _size_t < < " N , " < < Global ( backend ) < < " char* tmp, " < < generate_arguments ( " #scalartype " , device , mapping , expression ) < < " ) " < < std : : endl ;
2015-02-10 16:33:38 -05:00
stream < < " { " < < std : : endl ;
stream . inc_tab ( ) ;
2015-12-05 19:14:09 -05:00
unroll_tmp ( ) ;
2015-02-10 16:33:38 -05:00
process ( stream , PARENT_NODE_TYPE ,
2015-11-19 12:37:18 -05:00
{ { " array1 " , " #scalartype #namereg = #pointer[#start]; " } ,
{ " arrayn " , " #pointer += #start; " } ,
2015-11-25 18:42:25 -05:00
{ " array1n " , " #pointer += #start; " } ,
{ " arrayn1 " , " #pointer += #start; " } ,
2015-11-19 12:37:18 -05:00
{ " arraynn " , " #pointer += #start; " } } , expression , mapping ) ;
2015-02-10 16:33:38 -05:00
2015-12-12 18:32:06 -05:00
for ( const auto & e : reduce_1ds )
2015-04-29 15:50:57 -04:00
stream < < e - > process ( Local ( backend ) . get ( ) + " #scalartype #name_buf[ " + to_string ( p_ . local_size_1 * local_size_0_ld ) + " ]; " ) < < std : : endl ;
2015-08-11 11:50:49 -07:00
stream < < " for( " < < _size_t < < " r = " < < GlobalIdx1 ( backend ) < < " ; r < (M + " < < p_ . local_size_1 - 1 < < " )/ " < < p_ . local_size_1 < < " * " < < p_ . local_size_1 < < " ; r += " < < GlobalSize1 ( backend ) < < " ){ " < < std : : endl ;
2015-02-10 16:33:38 -05:00
stream . inc_tab ( ) ;
2015-08-11 11:50:49 -07:00
stream < < _size_t < < " lidx = " < < LocalIdx0 ( backend ) < < " ; " < < std : : endl ;
stream < < _size_t < < " lidy = " < < LocalIdx1 ( backend ) < < " ; " < < std : : endl ;
2015-02-10 16:33:38 -05:00
2015-12-12 18:32:06 -05:00
for ( const auto & e : reduce_1ds )
2015-05-13 02:20:44 -04:00
stream < < e - > process ( " #scalartype #name_acc = " + neutral_element ( ( e ) - > root_op ( ) , backend , " #scalartype " ) + " ; " ) < < std : : endl ;
2015-02-10 16:33:38 -05:00
stream < < " if (r < M) " < < std : : endl ;
stream < < " { " < < std : : endl ;
stream . inc_tab ( ) ;
2015-08-11 11:50:49 -07:00
stream < < " for( " < < _size_t < < " c = lidx; c < " < < p_ . num_groups_0 < < " ; c += " < < LocalSize0 ( backend ) < < " ){ " < < std : : endl ;
2015-02-10 16:33:38 -05:00
stream . inc_tab ( ) ;
2015-12-12 18:32:06 -05:00
for ( mapped_reduce * e : reduce_1ds )
compute_reduce_1d ( stream , e - > process ( " #name_acc " ) , e - > process ( " #name_temp[r + M*c] " ) , e - > root_op ( ) ) ;
2015-02-10 16:33:38 -05:00
stream . dec_tab ( ) ;
stream < < " } " < < std : : endl ;
2015-01-12 13:20:53 -05:00
stream . dec_tab ( ) ;
stream < < " } " < < std : : endl ;
2015-12-12 18:32:06 -05:00
for ( auto & expr : reduce_1ds )
2015-08-11 11:50:49 -07:00
stream < < expr - > process ( " #name_buf[lidy* " + local_size_0_ld_str + " + lidx] = #name_acc; " ) < < std : : endl ;
2015-01-12 13:20:53 -05:00
stream < < " #pragma unroll " < < std : : endl ;
2015-04-29 15:50:57 -04:00
stream < < " for( " < < _size_t < < " stride = " < < p_ . local_size_0 / 2 < < " ; stride >0; stride /=2) " < < std : : endl ;
2015-01-12 13:20:53 -05:00
stream < < " { " < < std : : endl ;
stream . inc_tab ( ) ;
2015-04-29 15:50:57 -04:00
stream < < LocalBarrier ( backend ) < < " ; " < < std : : endl ;
2015-08-11 11:50:49 -07:00
stream < < " if (lidx < stride) " < < std : : endl ;
2015-01-12 13:20:53 -05:00
stream < < " { " < < std : : endl ;
stream . inc_tab ( ) ;
2015-12-12 18:32:06 -05:00
for ( auto & e : reduce_1ds )
if ( e - > is_index_reduction ( ) )
compute_index_reduce_1d ( stream , e - > process ( " #name_buf[lidy* " + local_size_0_ld_str + " + lidx] " ) , e - > process ( " #name_buf[lidy* " + local_size_0_ld_str + " + lidx + stride] " )
2015-08-11 11:50:49 -07:00
, e - > process ( " #name_buf_value[lidy* " + local_size_0_ld_str + " + lidx] " ) , e - > process ( " #name_buf_value[lidy* " + local_size_0_ld_str + " + lidx + stride] " )
2015-02-10 16:33:38 -05:00
, e - > root_op ( ) ) ;
2015-01-12 13:20:53 -05:00
else
2015-12-12 18:32:06 -05:00
compute_reduce_1d ( stream , e - > process ( " #name_buf[lidy* " + local_size_0_ld_str + " + lidx] " ) , e - > process ( " #name_buf[lidy* " + local_size_0_ld_str + " + lidx + stride] " ) , e - > root_op ( ) ) ;
2015-01-12 13:20:53 -05:00
stream . dec_tab ( ) ;
stream < < " } " < < std : : endl ;
stream . dec_tab ( ) ;
stream < < " } " < < std : : endl ;
2015-08-11 11:50:49 -07:00
stream < < " if (lidx == 0 && r < M) " ;
2015-01-12 13:20:53 -05:00
stream < < " { " < < std : : endl ;
stream . inc_tab ( ) ;
2015-02-10 16:33:38 -05:00
2015-01-12 13:20:53 -05:00
std : : map < std : : string , std : : string > accessors ;
2015-12-12 18:32:06 -05:00
accessors [ " reduce_2d " ] = " #name_buf[lidy* " + local_size_0_ld_str + " ] " ;
2015-11-19 12:37:18 -05:00
accessors [ " arrayn " ] = " #pointer[r*#stride] " ;
2015-11-25 18:42:25 -05:00
accessors [ " array1n " ] = " #pointer[r*#stride] " ;
accessors [ " arrayn1 " ] = " #pointer[r*#stride] " ;
2015-09-30 15:31:41 -04:00
stream < < evaluate ( PARENT_NODE_TYPE , accessors , expression , expression . root ( ) , mapping ) < < " ; " < < std : : endl ;
2015-02-10 16:33:38 -05:00
2015-01-12 13:20:53 -05:00
stream . dec_tab ( ) ;
stream < < " } " < < std : : endl ;
stream . dec_tab ( ) ;
stream < < " } " < < std : : endl ;
stream . dec_tab ( ) ;
stream < < " } " < < std : : endl ;
2015-02-10 23:01:16 -05:00
}
2015-02-10 16:33:38 -05:00
2015-11-19 12:37:18 -05:00
// std::cout << stream.str() << std::endl;
2015-01-12 13:20:53 -05:00
return stream . str ( ) ;
}
2015-12-12 18:32:06 -05:00
reduce_2d : : reduce_2d ( reduce_2d : : parameters_type const & parameters ,
reduce_2d : : reduce_1d_type rtype ,
2015-01-12 13:20:53 -05:00
binding_policy_t binding_policy ) :
2015-12-12 18:32:06 -05:00
base_impl < reduce_2d , reduce_2d_parameters > ( parameters , binding_policy ) ,
reduce_1d_type_ ( rtype ) { }
2015-01-12 13:20:53 -05:00
2015-12-12 18:32:06 -05:00
std : : vector < int_t > reduce_2d : : input_sizes ( math_expression const & expression ) const
2015-01-12 13:20:53 -05:00
{
2015-12-12 18:32:06 -05:00
std : : vector < std : : size_t > idx = filter_nodes ( & is_reduce_1d , expression , expression . root ( ) , false ) ;
2015-09-30 15:31:41 -04:00
std : : pair < int_t , int_t > MN = matrix_size ( expression . tree ( ) , lhs_most ( expression . tree ( ) , idx [ 0 ] ) ) ;
2015-12-12 18:32:06 -05:00
if ( reduce_1d_type_ = = REDUCE_COLUMNS )
2015-01-12 13:20:53 -05:00
std : : swap ( MN . first , MN . second ) ;
2015-08-06 12:05:12 -07:00
return { MN . first , MN . second } ;
2015-01-12 13:20:53 -05:00
}
2015-12-12 18:32:06 -05:00
void reduce_2d : : enqueue ( driver : : CommandQueue & queue , driver : : Program const & program , std : : string const & suffix , base & fallback , execution_handler const & control )
2015-01-12 13:20:53 -05:00
{
2015-09-30 15:31:41 -04:00
math_expression const & expression = control . x ( ) ;
2015-02-10 16:33:38 -05:00
2015-09-30 15:31:41 -04:00
std : : vector < int_t > MN = input_sizes ( expression ) ;
2015-12-12 18:32:06 -05:00
std : : vector < math_expression : : node const * > reduce_1ds ;
std : : vector < size_t > reduce_1ds_idx = filter_nodes ( & is_reduce_1d , expression , expression . root ( ) , false ) ;
for ( size_t idx : reduce_1ds_idx )
reduce_1ds . push_back ( & expression . tree ( ) [ idx ] ) ;
2015-02-10 16:33:38 -05:00
2015-04-29 15:50:57 -04:00
//Fallback
2015-09-30 15:31:41 -04:00
if ( p_ . simd_width > 1 & & requires_fallback ( expression ) )
2015-04-29 15:50:57 -04:00
{
2015-09-30 15:31:41 -04:00
fallback . enqueue ( queue , program , " fallback " , fallback , control ) ;
2015-04-29 15:50:57 -04:00
return ;
}
2015-01-12 13:20:53 -05:00
2015-04-29 15:50:57 -04:00
//Kernel
2015-08-05 11:42:08 -07:00
std : : string name [ 2 ] = { " prod " , " reduce " } ;
name [ 0 ] + = suffix ;
name [ 1 ] + = suffix ;
2015-02-10 23:01:16 -05:00
unsigned int nk = ( p_ . num_groups_0 = = 1 ) ? 1 : 2 ;
2015-04-29 15:50:57 -04:00
std : : vector < driver : : Kernel > kernels ;
2015-02-10 23:01:16 -05:00
for ( unsigned int k = 0 ; k < nk ; + + k )
2015-08-05 11:42:08 -07:00
kernels . push_back ( driver : : Kernel ( program , name [ k ] . c_str ( ) ) ) ;
2015-02-10 23:01:16 -05:00
for ( unsigned int k = 0 ; k < nk ; + + k )
2015-02-10 16:33:38 -05:00
{
2015-04-29 15:50:57 -04:00
driver : : Kernel & kernel = kernels [ k ] ;
2015-02-10 16:33:38 -05:00
unsigned int n_arg = 0 ;
int_t M = MN [ 0 ] ;
int_t N = MN [ 1 ] ;
2015-04-29 15:50:57 -04:00
kernel . setSizeArg ( n_arg + + , M ) ;
kernel . setSizeArg ( n_arg + + , N ) ;
2015-11-27 18:43:46 -05:00
kernel . setArg ( n_arg + + , driver : : backend : : workspaces : : get ( queue ) ) ; //Temporary buffers
2015-09-30 15:31:41 -04:00
set_arguments ( expression , kernel , n_arg , binding_policy_ ) ;
2015-02-10 16:33:38 -05:00
}
2015-01-12 13:20:53 -05:00
2015-02-10 23:01:16 -05:00
//NDRange
2015-04-29 15:50:57 -04:00
driver : : NDRange global [ 2 ] = { driver : : NDRange ( p_ . local_size_0 * p_ . num_groups_0 , p_ . local_size_1 * p_ . num_groups_1 ) , driver : : NDRange ( p_ . local_size_0 , p_ . local_size_1 * p_ . num_groups_1 ) } ;
driver : : NDRange local [ 2 ] = { driver : : NDRange ( p_ . local_size_0 , p_ . local_size_1 ) , driver : : NDRange ( p_ . local_size_0 , p_ . local_size_1 ) } ;
2015-02-10 23:01:16 -05:00
for ( unsigned int i = 0 ; i < nk ; + + i )
2015-09-30 15:31:41 -04:00
control . execution_options ( ) . enqueue ( program . context ( ) , kernels [ i ] , global [ i ] , local [ i ] ) ;
2015-01-12 13:20:53 -05:00
}
2015-12-16 16:34:36 -05:00
reduce_2d_rows : : reduce_2d_rows ( reduce_2d_parameters const & parameters , binding_policy_t binding_policy ) : reduce_2d ( parameters , REDUCE_ROWS , binding_policy ) { }
2015-01-12 13:20:53 -05:00
2015-12-16 16:34:36 -05:00
reduce_2d_rows : : reduce_2d_rows ( unsigned int simd , unsigned int ls1 , unsigned int ls2 , unsigned int ng1 , unsigned int ng2 ,
2015-12-12 18:32:06 -05:00
fetching_policy_type fetch , binding_policy_t bind ) : reduce_2d ( reduce_2d_parameters ( simd , ls1 , ls2 , ng1 , ng2 , fetch ) , REDUCE_ROWS , bind ) { }
2015-01-12 13:20:53 -05:00
2015-12-16 16:34:36 -05:00
reduce_2d_cols : : reduce_2d_cols ( reduce_2d : : parameters_type const & parameters , binding_policy_t binding_policy ) : reduce_2d ( parameters , REDUCE_COLUMNS , binding_policy ) { }
2015-01-12 13:20:53 -05:00
2015-12-16 16:34:36 -05:00
reduce_2d_cols : : reduce_2d_cols ( unsigned int simd , unsigned int ls1 , unsigned int ls2 , unsigned int ng1 , unsigned int ng2 ,
2015-12-12 18:32:06 -05:00
fetching_policy_type fetch , binding_policy_t bind ) : reduce_2d ( reduce_2d_parameters ( simd , ls1 , ls2 , ng1 , ng2 , fetch ) , REDUCE_COLUMNS , bind ) { }
2015-01-12 13:20:53 -05:00
}
2015-07-11 09:36:01 -04:00
}