2015-07-25 21:00:18 -07:00
# include <cstring>
# include <iostream>
2015-12-12 18:32:06 -05:00
# include "isaac/kernels/templates/elementwise_2d.h"
2015-04-29 15:50:57 -04:00
# include "isaac/symbolic/io.h"
2015-08-04 20:56:05 -07:00
# include "isaac/kernels/keywords.h"
2015-01-16 19:39:26 -05:00
2015-08-06 16:14:33 -07:00
# include "tools/arguments.hpp"
2015-08-06 12:05:12 -07:00
# include "tools/loop.hpp"
2015-08-06 16:14:33 -07:00
# include "tools/vector_types.hpp"
2015-08-06 12:05:12 -07:00
2015-04-29 15:50:57 -04:00
namespace isaac
2015-01-12 13:20:53 -05:00
{
2015-07-11 09:36:01 -04:00
namespace templates
{
2015-01-12 13:20:53 -05:00
2015-12-12 18:32:06 -05:00
elementwise_2d_parameters : : elementwise_2d_parameters ( unsigned int _simd_width ,
2015-01-12 13:20:53 -05:00
unsigned int _local_size_0 , unsigned int _local_size_1 ,
unsigned int _num_groups_0 , unsigned int _num_groups_1 ,
2015-01-17 10:48:02 -05:00
fetching_policy_type _fetching_policy ) : base : : parameters_type ( _simd_width , _local_size_0 , _local_size_1 , 1 ) , num_groups_0 ( _num_groups_0 ) , num_groups_1 ( _num_groups_1 ) , fetching_policy ( _fetching_policy ) { }
2015-01-12 13:20:53 -05:00
2015-12-19 02:55:24 -05:00
int elementwise_2d : : is_invalid_impl ( driver : : Device const & , expression_tree const & ) const
2015-01-12 13:20:53 -05:00
{
if ( p_ . simd_width > 1 )
return TEMPLATE_INVALID_SIMD_WIDTH ;
if ( p_ . fetching_policy = = FETCH_FROM_LOCAL )
return TEMPLATE_INVALID_FETCHING_POLICY_TYPE ;
return TEMPLATE_VALID ;
}
2015-12-19 02:55:24 -05:00
std : : string elementwise_2d : : generate_impl ( std : : string const & suffix , expression_tree const & expressions , driver : : Device const & device , mapping_type const & mappings ) const
2015-01-12 13:20:53 -05:00
{
kernel_generation_stream stream ;
2015-04-29 15:50:57 -04:00
std : : string _size_t = size_type ( device ) ;
2015-01-12 13:20:53 -05:00
std : : string init0 , upper_bound0 , inc0 , init1 , upper_bound1 , inc1 ;
2015-04-29 15:50:57 -04:00
std : : string data_type = append_width ( " #scalartype " , p_ . simd_width ) ;
2015-05-13 02:20:44 -04:00
driver : : backend_type backend = device . backend ( ) ;
2015-01-27 02:41:27 -05:00
2015-05-13 02:20:44 -04:00
switch ( backend )
{
2015-08-25 12:41:21 -04:00
case driver : : CUDA :
stream < < " #include \" helper_math.h \" " < < std : : endl ; break ;
case driver : : OPENCL :
stream < < " __attribute__((reqd_work_group_size( " < < p_ . local_size_0 < < " , " < < p_ . local_size_1 < < " ,1))) " < < std : : endl ; break ;
2015-05-13 02:20:44 -04:00
}
2015-12-12 18:32:06 -05:00
stream < < KernelPrefix ( backend ) < < " void elementwise_1d " < < suffix < < " ( " < < _size_t < < " M, " < < _size_t < < " N, " < < generate_arguments ( " #scalartype " , device , mappings , expressions ) < < " ) " < < std : : endl ;
2015-01-12 13:20:53 -05:00
stream < < " { " < < std : : endl ;
stream . inc_tab ( ) ;
2015-11-19 12:37:18 -05:00
process ( stream , PARENT_NODE_TYPE , { { " array1 " , " #scalartype #namereg = #pointer[#start]; " } ,
{ " array11 " , " #scalartype #namereg = #pointer[#start]; " } ,
{ " arrayn " , " #pointer += #start; " } ,
{ " array1n " , " #pointer += #start; " } ,
{ " arrayn1 " , " #pointer += #start; " } ,
{ " arraynn " , " #pointer += #start; " } }
2015-08-06 12:05:12 -07:00
, expressions , mappings ) ;
2015-01-12 13:20:53 -05:00
2015-05-13 02:20:44 -04:00
fetching_loop_info ( p_ . fetching_policy , " M " , stream , init0 , upper_bound0 , inc0 , GlobalIdx0 ( backend ) . get ( ) , GlobalSize0 ( backend ) . get ( ) , device ) ;
2015-04-29 15:50:57 -04:00
stream < < " for( " < < _size_t < < " i = " < < init0 < < " ; i < " < < upper_bound0 < < " ; i += " < < inc0 < < " ) " < < std : : endl ;
2015-01-12 13:20:53 -05:00
stream < < " { " < < std : : endl ;
stream . inc_tab ( ) ;
2015-05-13 02:20:44 -04:00
fetching_loop_info ( p_ . fetching_policy , " N " , stream , init1 , upper_bound1 , inc1 , GlobalIdx1 ( backend ) . get ( ) , GlobalSize1 ( backend ) . get ( ) , device ) ;
2015-04-29 15:50:57 -04:00
stream < < " for( " < < _size_t < < " j = " < < init1 < < " ; j < " < < upper_bound1 < < " ; j += " < < inc1 < < " ) " < < std : : endl ;
2015-01-12 13:20:53 -05:00
stream < < " { " < < std : : endl ;
stream . inc_tab ( ) ;
2015-11-19 12:37:18 -05:00
process ( stream , PARENT_NODE_TYPE , { { " arraynn " , data_type + " #namereg = $VALUE{i*#stride,j}; " } ,
{ " arrayn1 " , data_type + " #namereg = $VALUE{i*#stride}; " } ,
{ " arrayn " , data_type + " #namereg = $VALUE{i*#stride}; " } ,
{ " array1n " , data_type + " #namereg = $VALUE{j*#stride}; " } ,
2015-08-10 22:45:48 -07:00
{ " vdiag " , " #scalartype #namereg = ((i + ((#diag_offset<0)?#diag_offset:0))!=(j-((#diag_offset>0)?#diag_offset:0)))?0:$VALUE{min(i*#stride, j*#stride)}; " } ,
2015-09-30 15:31:41 -04:00
{ " repeat " , " #scalartype #namereg = $VALUE{(i%#sub0)*#stride, (j%#sub1)}; " } ,
2015-08-06 12:05:12 -07:00
{ " outer " , " #scalartype #namereg = ($LVALUE{i*#stride})*($RVALUE{j*#stride}); " } }
, expressions , mappings ) ;
2015-11-19 12:37:18 -05:00
stream < < evaluate ( PARENT_NODE_TYPE , { { " arraynn " , " #namereg " } ,
{ " array1n " , " #namereg " } ,
{ " arrayn1 " , " #namereg " } ,
{ " arrayn " , " #namereg " } ,
2015-08-06 12:05:12 -07:00
{ " vdiag " , " #namereg " } ,
{ " repeat " , " #namereg " } ,
2015-11-19 12:37:18 -05:00
{ " array1 " , " #namereg " } ,
{ " array11 " , " #namereg " } ,
2015-08-06 12:05:12 -07:00
{ " outer " , " #namereg " } ,
{ " cast " , CastPrefix ( backend , data_type ) . get ( ) } ,
{ " host_scalar " , p_ . simd_width = = 1 ? " #name " : InitPrefix ( backend , data_type ) . get ( ) + " (#name) " } }
2015-09-30 15:31:41 -04:00
, expressions , expressions . root ( ) , mappings ) < < " ; " < < std : : endl ;
2015-08-06 12:05:12 -07:00
2015-11-19 12:37:18 -05:00
process ( stream , LHS_NODE_TYPE , { { " arraynn " , " $VALUE{i*#stride,j} = #namereg; " } ,
{ " array1n " , " $VALUE{j*#stride} = #namereg; " } ,
{ " arrayn1 " , " $VALUE{i*#stride} = #namereg; " } ,
{ " arrayn " , " $VALUE{i*#stride} = #namereg; " } } , expressions , mappings ) ;
2015-01-12 13:20:53 -05:00
stream . dec_tab ( ) ;
stream < < " } " < < std : : endl ;
stream . dec_tab ( ) ;
stream < < " } " < < std : : endl ;
2015-06-30 17:55:57 -04:00
2015-01-12 13:20:53 -05:00
stream . dec_tab ( ) ;
stream < < " } " < < std : : endl ;
return stream . str ( ) ;
}
2015-12-12 18:32:06 -05:00
elementwise_2d : : elementwise_2d ( parameters_type const & parameters , binding_policy_t binding_policy ) :
base_impl < elementwise_2d , elementwise_2d_parameters > ( parameters , binding_policy ) { }
2015-01-12 13:20:53 -05:00
2015-12-12 18:32:06 -05:00
elementwise_2d : : elementwise_2d ( unsigned int simd , unsigned int ls1 , unsigned int ls2 ,
2015-01-12 13:20:53 -05:00
unsigned int ng1 , unsigned int ng2 , fetching_policy_type fetch ,
binding_policy_t bind ) :
2015-12-12 18:32:06 -05:00
base_impl < elementwise_2d , elementwise_2d_parameters > ( elementwise_2d_parameters ( simd , ls1 , ls2 , ng1 , ng2 , fetch ) , bind )
2015-01-12 13:20:53 -05:00
{ }
2015-12-19 02:55:24 -05:00
std : : vector < int_t > elementwise_2d : : input_sizes ( expression_tree const & expression ) const
2015-01-12 13:20:53 -05:00
{
2015-09-30 15:31:41 -04:00
std : : pair < int_t , int_t > size = matrix_size ( expression . tree ( ) , lhs_most ( expression . tree ( ) , expression . root ( ) ) ) ;
2015-08-06 12:05:12 -07:00
return { size . first , size . second } ;
2015-01-12 13:20:53 -05:00
}
2015-12-12 18:32:06 -05:00
void elementwise_2d : : enqueue ( driver : : CommandQueue & /*queue*/ , driver : : Program const & program , std : : string const & suffix , base & , execution_handler const & control )
2015-01-12 13:20:53 -05:00
{
2015-12-19 02:55:24 -05:00
expression_tree const & expressions = control . x ( ) ;
2015-12-12 18:32:06 -05:00
std : : string name = " elementwise_1d " ;
2015-08-05 11:42:08 -07:00
name + = suffix ;
driver : : Kernel kernel ( program , name . c_str ( ) ) ;
2015-04-29 15:50:57 -04:00
driver : : NDRange global ( p_ . local_size_0 * p_ . num_groups_0 , p_ . local_size_1 * p_ . num_groups_1 ) ;
driver : : NDRange local ( p_ . local_size_0 , p_ . local_size_1 ) ;
2015-01-12 13:20:53 -05:00
unsigned int current_arg = 0 ;
2015-02-01 22:28:49 -05:00
std : : vector < int_t > MN = input_sizes ( expressions ) ;
2015-04-29 15:50:57 -04:00
kernel . setSizeArg ( current_arg + + , MN [ 0 ] ) ;
kernel . setSizeArg ( current_arg + + , MN [ 1 ] ) ;
2015-08-06 16:14:33 -07:00
set_arguments ( expressions , kernel , current_arg , binding_policy_ ) ;
2015-02-01 23:56:05 -05:00
2015-09-30 15:31:41 -04:00
control . execution_options ( ) . enqueue ( program . context ( ) , kernel , global , local ) ;
2015-01-12 13:20:53 -05:00
}
}
2015-07-11 09:36:01 -04:00
}