2015-07-20 18:02:56 -07:00
# define NOMINMAX
2015-01-12 13:20:53 -05:00
# include <cassert>
2015-07-20 18:02:56 -07:00
# include <algorithm>
# include <stdexcept>
2015-01-12 13:20:53 -05:00
2015-04-29 15:50:57 -04:00
# include "isaac/array.h"
2015-09-30 15:31:41 -04:00
# include "isaac/tuple.h"
2015-04-29 15:50:57 -04:00
# include "isaac/exception/unknown_datatype.h"
2015-08-12 00:46:51 -07:00
# include "isaac/profiles/profiles.h"
2015-04-29 15:50:57 -04:00
# include "isaac/symbolic/execute.h"
2015-09-30 15:31:41 -04:00
# include "isaac/symbolic/io.h"
2015-01-28 17:08:39 -05:00
2015-04-29 15:50:57 -04:00
namespace isaac
2015-01-12 13:20:53 -05:00
{
/*--- Constructors ---*/
//1D Constructors
2015-11-19 12:37:18 -05:00
int_t array_base : : dsize ( )
{
return std : : max ( ( int_t ) 1 , shape_ . prod ( ) * size_of ( dtype_ ) ) ;
}
array_base : : array_base ( int_t shape0 , numeric_type dtype , driver : : Context const & context ) :
dtype_ ( dtype ) , shape_ { shape0 } , start_ ( 0 ) , stride_ ( 1 ) ,
context_ ( context ) , data_ ( context_ , dsize ( ) ) ,
2015-10-01 17:23:26 -04:00
T ( isaac : : trans ( * this ) )
2015-02-04 22:06:15 -05:00
{ }
2015-01-12 13:20:53 -05:00
2015-11-19 12:37:18 -05:00
array_base : : array_base ( int_t shape0 , numeric_type dtype , driver : : Buffer data , int_t start , int_t inc ) :
dtype_ ( dtype ) , shape_ { shape0 } , start_ ( start ) , stride_ ( inc ) , context_ ( data . context ( ) ) , data_ ( data ) ,
2015-10-01 17:23:26 -04:00
T ( isaac : : trans ( * this ) )
2015-06-25 08:12:16 -07:00
{ }
2015-01-21 20:08:52 -05:00
template < class DT >
2015-11-19 12:37:18 -05:00
array_base : : array_base ( std : : vector < DT > const & x , driver : : Context const & context ) :
dtype_ ( to_numeric_type < DT > : : value ) , shape_ { ( int_t ) x . size ( ) } , start_ ( 0 ) , stride_ ( 1 ) ,
context_ ( context ) , data_ ( context , dsize ( ) ) ,
2015-10-01 17:23:26 -04:00
T ( isaac : : trans ( * this ) )
2015-01-12 13:20:53 -05:00
{ * this = x ; }
2015-11-19 12:37:18 -05:00
array_base : : array_base ( array_base & v , slice const & s0 ) :
dtype_ ( v . dtype_ ) , shape_ { s0 . size ( v . shape_ [ 0 ] ) } , start_ ( v . start_ + v . stride_ [ 0 ] * s0 . start ) , stride_ ( v . stride_ [ 0 ] * s0 . stride ) , context_ ( v . context ( ) ) , data_ ( v . data_ ) ,
2015-10-01 17:23:26 -04:00
T ( isaac : : trans ( * this ) )
2015-01-17 15:47:52 -05:00
{ }
2015-01-12 13:20:53 -05:00
2015-11-19 12:37:18 -05:00
# define INSTANTIATE(T) template ISAACAPI array_base::array_base(std::vector<T> const &, driver::Context const &)
2015-05-04 21:23:05 -04:00
INSTANTIATE ( char ) ;
INSTANTIATE ( unsigned char ) ;
INSTANTIATE ( short ) ;
INSTANTIATE ( unsigned short ) ;
INSTANTIATE ( int ) ;
INSTANTIATE ( unsigned int ) ;
INSTANTIATE ( long ) ;
INSTANTIATE ( unsigned long ) ;
2015-05-04 23:54:43 -04:00
INSTANTIATE ( long long ) ;
INSTANTIATE ( unsigned long long ) ;
2015-05-04 21:23:05 -04:00
INSTANTIATE ( float ) ;
INSTANTIATE ( double ) ;
2015-01-12 13:20:53 -05:00
# undef INSTANTIATE
// 2D
2015-11-19 12:37:18 -05:00
array_base : : array_base ( int_t shape0 , int_t shape1 , numeric_type dtype , driver : : Context const & context ) :
dtype_ ( dtype ) , shape_ { shape0 , shape1 } , start_ ( 0 ) , stride_ ( 1 , shape0 ) ,
context_ ( context ) , data_ ( context_ , dsize ( ) ) ,
2015-10-01 17:23:26 -04:00
T ( isaac : : trans ( * this ) )
2015-01-12 13:20:53 -05:00
{ }
2015-11-19 12:37:18 -05:00
array_base : : array_base ( int_t shape0 , int_t shape1 , numeric_type dtype , driver : : Buffer data , int_t start , int_t ld ) :
dtype_ ( dtype ) , shape_ { shape0 , shape1 } , start_ ( start ) , stride_ ( 1 , ld ) , context_ ( data . context ( ) ) , data_ ( data ) ,
2015-10-01 17:23:26 -04:00
T ( isaac : : trans ( * this ) )
2015-06-25 08:12:16 -07:00
{ }
2015-11-19 12:37:18 -05:00
array_base : : array_base ( array_base & M , slice const & s0 , slice const & s1 ) :
dtype_ ( M . dtype_ ) , shape_ { s0 . size ( M . shape_ [ 0 ] ) , s1 . size ( M . shape_ [ 1 ] ) } ,
start_ ( M . start_ + M . stride_ [ 0 ] * s0 . start + s1 . start * M . stride_ [ 1 ] ) ,
stride_ ( M . stride_ [ 0 ] * s0 . stride , M . stride_ [ 1 ] * s1 . stride ) ,
2015-10-01 17:23:26 -04:00
context_ ( M . data_ . context ( ) ) , data_ ( M . data_ ) ,
T ( isaac : : trans ( * this ) )
2015-01-12 13:20:53 -05:00
{ }
2015-06-28 17:53:16 -07:00
2015-01-21 20:08:52 -05:00
template < typename DT >
2015-11-19 12:37:18 -05:00
array_base : : array_base ( int_t shape0 , int_t shape1 , std : : vector < DT > const & data , driver : : Context const & context )
2015-02-04 22:06:15 -05:00
: dtype_ ( to_numeric_type < DT > : : value ) ,
2015-11-19 12:37:18 -05:00
shape_ { shape0 , shape1 } , start_ ( 0 ) , stride_ ( 1 , shape0 ) ,
context_ ( context ) , data_ ( context_ , dsize ( ) ) ,
2015-10-01 17:23:26 -04:00
T ( isaac : : trans ( * this ) )
2015-01-12 13:20:53 -05:00
{
2015-04-29 15:50:57 -04:00
isaac : : copy ( data , * this ) ;
2015-01-12 13:20:53 -05:00
}
2015-06-25 08:12:16 -07:00
// 3D
2015-11-19 12:37:18 -05:00
array_base : : array_base ( int_t shape0 , int_t shape1 , int_t shape2 , numeric_type dtype , driver : : Context const & context ) :
dtype_ ( dtype ) , shape_ { shape0 , shape1 , shape2 } , start_ ( 0 ) , stride_ ( 1 , shape0 ) ,
context_ ( context ) , data_ ( context_ , dsize ( ) ) ,
2015-10-01 17:23:26 -04:00
T ( isaac : : trans ( * this ) )
2015-06-25 08:12:16 -07:00
{ }
2015-11-19 12:37:18 -05:00
# define INSTANTIATE(T) template ISAACAPI array_base::array_base(int_t, int_t, std::vector<T> const &, driver::Context const &)
2015-05-04 21:23:05 -04:00
INSTANTIATE ( char ) ;
INSTANTIATE ( unsigned char ) ;
INSTANTIATE ( short ) ;
INSTANTIATE ( unsigned short ) ;
INSTANTIATE ( int ) ;
INSTANTIATE ( unsigned int ) ;
INSTANTIATE ( long ) ;
INSTANTIATE ( unsigned long ) ;
2015-05-04 23:54:43 -04:00
INSTANTIATE ( long long ) ;
INSTANTIATE ( unsigned long long ) ;
2015-05-04 21:23:05 -04:00
INSTANTIATE ( float ) ;
INSTANTIATE ( double ) ;
2015-01-12 13:20:53 -05:00
# undef INSTANTIATE
2015-11-19 12:37:18 -05:00
array_base : : array_base ( numeric_type dtype , shape_t const & shape , int_t start , shape_t const & stride , driver : : Context const & context ) :
dtype_ ( dtype ) , shape_ ( shape ) , start_ ( start ) , stride_ ( stride ) , context_ ( context ) , data_ ( context_ , dsize ( ) ) ,
2015-10-01 17:23:26 -04:00
T ( isaac : : trans ( * this ) )
2015-11-19 12:37:18 -05:00
{ }
2015-09-30 15:31:41 -04:00
2015-11-19 12:37:18 -05:00
array_base : : array_base ( numeric_type dtype , shape_t const & shape , driver : : Context const & context ) : array_base ( dtype , shape , 0 , { 1 , shape [ 0 ] } , context )
{ }
array_base : : array_base ( execution_handler const & other ) :
2015-02-06 22:11:03 -05:00
dtype_ ( other . x ( ) . dtype ( ) ) ,
2015-11-19 12:37:18 -05:00
shape_ ( other . x ( ) . shape ( ) ) , start_ ( 0 ) , stride_ ( 1 , shape_ [ 0 ] ) ,
context_ ( other . x ( ) . context ( ) ) , data_ ( context_ , dsize ( ) ) ,
2015-10-01 17:23:26 -04:00
T ( isaac : : trans ( * this ) )
2015-01-18 14:52:45 -05:00
{
2015-02-04 22:06:15 -05:00
* this = other ;
2015-01-18 14:52:45 -05:00
}
2015-11-19 12:37:18 -05:00
//Destructor
array_base : : ~ array_base ( )
{ }
2015-01-12 13:20:53 -05:00
/*--- Getters ---*/
2015-11-19 12:37:18 -05:00
numeric_type array_base : : dtype ( ) const
{ return dtype_ ; }
2015-01-12 13:20:53 -05:00
2015-11-19 12:37:18 -05:00
shape_t const & array_base : : shape ( ) const
2015-02-04 22:06:15 -05:00
{ return shape_ ; }
2015-01-12 13:20:53 -05:00
2015-11-19 12:37:18 -05:00
int_t array_base : : dim ( ) const
{ return ( int_t ) shape_ . size ( ) ; }
2015-01-12 13:20:53 -05:00
2015-11-19 12:37:18 -05:00
int_t array_base : : start ( ) const
2015-02-04 22:06:15 -05:00
{ return start_ ; }
2015-01-12 13:20:53 -05:00
2015-11-19 12:37:18 -05:00
shape_t const & array_base : : stride ( ) const
2015-02-04 22:06:15 -05:00
{ return stride_ ; }
2015-01-12 13:20:53 -05:00
2015-11-19 12:37:18 -05:00
driver : : Context const & array_base : : context ( ) const
2015-01-12 13:20:53 -05:00
{ return context_ ; }
2015-11-19 12:37:18 -05:00
driver : : Buffer const & array_base : : data ( ) const
2015-01-12 13:20:53 -05:00
{ return data_ ; }
2015-11-19 12:37:18 -05:00
driver : : Buffer & array_base : : data ( )
2015-04-29 15:50:57 -04:00
{ return data_ ; }
2015-01-12 13:20:53 -05:00
/*--- Assignment Operators ----*/
//---------------------------------------
2015-09-30 15:31:41 -04:00
2015-11-19 12:37:18 -05:00
array_base & array_base : : operator = ( array_base const & rhs )
2015-09-30 15:31:41 -04:00
{
2015-11-19 12:37:18 -05:00
if ( shape_ . min ( ) = = 0 ) return * this ;
2015-09-30 15:31:41 -04:00
assert ( dtype_ = = rhs . dtype ( ) ) ;
math_expression expression ( * this , rhs , op_element ( OPERATOR_BINARY_TYPE_FAMILY , OPERATOR_ASSIGN_TYPE ) , context_ , dtype_ , shape_ ) ;
execute ( execution_handler ( expression ) ) ;
return * this ;
}
2015-01-12 13:20:53 -05:00
2015-11-19 12:37:18 -05:00
array_base & array_base : : operator = ( value_scalar const & rhs )
2015-09-30 15:31:41 -04:00
{
2015-11-19 12:37:18 -05:00
if ( shape_ . min ( ) = = 0 ) return * this ;
2015-09-30 15:31:41 -04:00
assert ( dtype_ = = rhs . dtype ( ) ) ;
math_expression expression ( * this , rhs , op_element ( OPERATOR_BINARY_TYPE_FAMILY , OPERATOR_ASSIGN_TYPE ) , context_ , dtype_ , shape_ ) ;
execute ( execution_handler ( expression ) ) ;
return * this ;
}
2015-02-05 04:42:57 -05:00
2015-09-30 15:31:41 -04:00
2015-11-19 12:37:18 -05:00
array_base & array_base : : operator = ( execution_handler const & c )
2015-01-12 13:20:53 -05:00
{
2015-11-19 12:37:18 -05:00
if ( shape_ . min ( ) = = 0 ) return * this ;
2015-02-05 04:42:57 -05:00
assert ( dtype_ = = c . x ( ) . dtype ( ) ) ;
2015-09-30 15:31:41 -04:00
math_expression expression ( * this , c . x ( ) , op_element ( OPERATOR_BINARY_TYPE_FAMILY , OPERATOR_ASSIGN_TYPE ) , context_ , dtype_ , shape_ ) ;
execute ( execution_handler ( expression , c . execution_options ( ) , c . dispatcher_options ( ) , c . compilation_options ( ) ) ) ;
2015-01-19 21:29:47 -05:00
return * this ;
2015-01-12 13:20:53 -05:00
}
2015-11-19 12:37:18 -05:00
array_base & array_base : : operator = ( math_expression const & rhs )
2015-09-30 15:31:41 -04:00
{
return * this = execution_handler ( rhs ) ;
}
2015-07-28 15:13:43 -07:00
2015-01-21 20:08:52 -05:00
template < class DT >
2015-11-19 12:37:18 -05:00
array_base & array_base : : operator = ( std : : vector < DT > const & rhs )
2015-01-12 13:20:53 -05:00
{
2015-11-19 12:37:18 -05:00
assert ( dim ( ) < = 1 ) ;
2015-04-29 15:50:57 -04:00
isaac : : copy ( rhs , * this ) ;
2015-01-12 13:20:53 -05:00
return * this ;
}
2015-11-19 12:37:18 -05:00
# define INSTANTIATE(TYPE) template ISAACAPI array_base& array_base::operator=<TYPE>(std::vector<TYPE> const &)
2015-01-12 13:20:53 -05:00
2015-05-04 21:23:05 -04:00
INSTANTIATE ( char ) ;
INSTANTIATE ( unsigned char ) ;
INSTANTIATE ( short ) ;
INSTANTIATE ( unsigned short ) ;
INSTANTIATE ( int ) ;
INSTANTIATE ( unsigned int ) ;
INSTANTIATE ( long ) ;
INSTANTIATE ( unsigned long ) ;
2015-05-04 23:54:43 -04:00
INSTANTIATE ( long long ) ;
INSTANTIATE ( unsigned long long ) ;
2015-05-04 21:23:05 -04:00
INSTANTIATE ( float ) ;
INSTANTIATE ( double ) ;
2015-01-12 13:20:53 -05:00
# undef INSTANTIATE
2015-01-19 21:29:47 -05:00
2015-06-27 11:44:50 -04:00
2015-01-12 13:20:53 -05:00
2015-01-19 21:29:47 -05:00
2015-11-19 12:37:18 -05:00
math_expression array_base : : operator - ( )
2015-09-30 15:31:41 -04:00
{ return math_expression ( * this , invalid_node ( ) , op_element ( OPERATOR_UNARY_TYPE_FAMILY , OPERATOR_SUB_TYPE ) , context_ , dtype_ , shape_ ) ; }
2015-01-19 21:29:47 -05:00
2015-11-19 12:37:18 -05:00
math_expression array_base : : operator ! ( )
2015-09-30 15:31:41 -04:00
{ return math_expression ( * this , invalid_node ( ) , op_element ( OPERATOR_UNARY_TYPE_FAMILY , OPERATOR_NEGATE_TYPE ) , context_ , INT_TYPE , shape_ ) ; }
2015-01-29 15:19:40 -05:00
2015-01-12 13:20:53 -05:00
//
2015-11-19 12:37:18 -05:00
array_base & array_base : : operator + = ( value_scalar const & rhs )
2015-09-30 15:31:41 -04:00
{ return * this = math_expression ( * this , rhs , op_element ( OPERATOR_BINARY_TYPE_FAMILY , OPERATOR_ADD_TYPE ) , context_ , dtype_ , shape_ ) ; }
2015-01-19 21:29:47 -05:00
2015-11-19 12:37:18 -05:00
array_base & array_base : : operator + = ( array_base const & rhs )
2015-09-30 15:31:41 -04:00
{ return * this = math_expression ( * this , rhs , op_element ( OPERATOR_BINARY_TYPE_FAMILY , OPERATOR_ADD_TYPE ) , context_ , dtype_ , shape_ ) ; }
2015-01-19 21:29:47 -05:00
2015-11-19 12:37:18 -05:00
array_base & array_base : : operator + = ( math_expression const & rhs )
2015-09-30 15:31:41 -04:00
{ return * this = math_expression ( * this , rhs , op_element ( OPERATOR_BINARY_TYPE_FAMILY , OPERATOR_ADD_TYPE ) , rhs . context ( ) , dtype_ , shape_ ) ; }
2015-01-19 21:29:47 -05:00
//----
2015-11-19 12:37:18 -05:00
array_base & array_base : : operator - = ( value_scalar const & rhs )
2015-09-30 15:31:41 -04:00
{ return * this = math_expression ( * this , rhs , op_element ( OPERATOR_BINARY_TYPE_FAMILY , OPERATOR_SUB_TYPE ) , context_ , dtype_ , shape_ ) ; }
2015-01-19 21:29:47 -05:00
2015-11-19 12:37:18 -05:00
array_base & array_base : : operator - = ( array_base const & rhs )
2015-09-30 15:31:41 -04:00
{ return * this = math_expression ( * this , rhs , op_element ( OPERATOR_BINARY_TYPE_FAMILY , OPERATOR_SUB_TYPE ) , context_ , dtype_ , shape_ ) ; }
2015-01-19 21:29:47 -05:00
2015-11-19 12:37:18 -05:00
array_base & array_base : : operator - = ( math_expression const & rhs )
2015-09-30 15:31:41 -04:00
{ return * this = math_expression ( * this , rhs , op_element ( OPERATOR_BINARY_TYPE_FAMILY , OPERATOR_SUB_TYPE ) , rhs . context ( ) , dtype_ , shape_ ) ; }
2015-01-19 21:29:47 -05:00
//----
2015-11-19 12:37:18 -05:00
array_base & array_base : : operator * = ( value_scalar const & rhs )
2015-09-30 15:31:41 -04:00
{ return * this = math_expression ( * this , rhs , op_element ( OPERATOR_BINARY_TYPE_FAMILY , OPERATOR_MULT_TYPE ) , context_ , dtype_ , shape_ ) ; }
2015-01-12 13:20:53 -05:00
2015-11-19 12:37:18 -05:00
array_base & array_base : : operator * = ( array_base const & rhs )
2015-09-30 15:31:41 -04:00
{ return * this = math_expression ( * this , rhs , op_element ( OPERATOR_BINARY_TYPE_FAMILY , OPERATOR_MULT_TYPE ) , context_ , dtype_ , shape_ ) ; }
2015-01-12 13:20:53 -05:00
2015-11-19 12:37:18 -05:00
array_base & array_base : : operator * = ( math_expression const & rhs )
2015-09-30 15:31:41 -04:00
{ return * this = math_expression ( * this , rhs , op_element ( OPERATOR_BINARY_TYPE_FAMILY , OPERATOR_MULT_TYPE ) , rhs . context ( ) , dtype_ , shape_ ) ; }
2015-01-19 21:29:47 -05:00
//----
2015-11-19 12:37:18 -05:00
array_base & array_base : : operator / = ( value_scalar const & rhs )
2015-09-30 15:31:41 -04:00
{ return * this = math_expression ( * this , rhs , op_element ( OPERATOR_BINARY_TYPE_FAMILY , OPERATOR_DIV_TYPE ) , context_ , dtype_ , shape_ ) ; }
2015-01-12 13:20:53 -05:00
2015-11-19 12:37:18 -05:00
array_base & array_base : : operator / = ( array_base const & rhs )
2015-09-30 15:31:41 -04:00
{ return * this = math_expression ( * this , rhs , op_element ( OPERATOR_BINARY_TYPE_FAMILY , OPERATOR_DIV_TYPE ) , context_ , dtype_ , shape_ ) ; }
2015-01-12 13:20:53 -05:00
2015-11-19 12:37:18 -05:00
array_base & array_base : : operator / = ( math_expression const & rhs )
2015-09-30 15:31:41 -04:00
{ return * this = math_expression ( * this , rhs , op_element ( OPERATOR_BINARY_TYPE_FAMILY , OPERATOR_DIV_TYPE ) , rhs . context ( ) , dtype_ , shape_ ) ; }
2015-01-12 13:20:53 -05:00
/*--- Indexing operators -----*/
//---------------------------------------
2015-11-19 12:37:18 -05:00
math_expression array_base : : operator [ ] ( for_idx_t idx ) const
2015-09-30 15:31:41 -04:00
{
2015-11-19 12:37:18 -05:00
return math_expression ( * this , idx , op_element ( OPERATOR_BINARY_TYPE_FAMILY , OPERATOR_ACCESS_INDEX_TYPE ) , context_ , dtype_ , { 1 } ) ;
2015-09-30 15:31:41 -04:00
}
2015-11-19 12:37:18 -05:00
scalar array_base : : operator [ ] ( int_t idx )
2015-01-13 01:17:27 -05:00
{
2015-11-19 12:37:18 -05:00
assert ( dim ( ) < = 1 ) ;
return scalar ( dtype_ , data_ , start_ + idx ) ;
2015-01-13 01:17:27 -05:00
}
2015-11-19 12:37:18 -05:00
const scalar array_base : : operator [ ] ( int_t idx ) const
2015-01-19 21:29:47 -05:00
{
2015-11-19 12:37:18 -05:00
assert ( dim ( ) < = 1 ) ;
return scalar ( dtype_ , data_ , start_ + idx ) ;
2015-01-19 21:29:47 -05:00
}
2015-11-19 12:37:18 -05:00
view array_base : : operator [ ] ( slice const & e1 )
2015-01-12 13:20:53 -05:00
{
2015-11-19 12:37:18 -05:00
assert ( dim ( ) < = 1 ) ;
2015-10-03 18:48:20 -04:00
return view ( * this , e1 ) ;
2015-01-12 13:20:53 -05:00
}
2015-11-19 12:37:18 -05:00
view array_base : : operator ( ) ( int_t i , int_t j )
{
assert ( dim ( ) = = 2 & & " Too many indices in array " ) ;
return view ( 1 , dtype_ , data_ , start_ + i * stride_ [ 0 ] + j * stride_ [ 1 ] , 1 ) ;
}
view array_base : : operator ( ) ( int_t i , slice const & sj )
{
assert ( dim ( ) = = 2 & & " Too many indices in array " ) ;
return view ( sj . size ( shape_ [ 1 ] ) , dtype_ , data_ , start_ + i * stride_ [ 0 ] + sj . start * stride_ [ 1 ] , sj . stride * stride_ [ 1 ] ) ;
}
view array_base : : operator ( ) ( slice const & si , int_t j )
{
assert ( dim ( ) = = 2 & & " Too many indices in array " ) ;
return view ( si . size ( shape_ [ 0 ] ) , dtype_ , data_ , start_ + si . start * stride_ [ 0 ] + j * stride_ [ 1 ] , si . stride ) ;
}
view array_base : : operator ( ) ( slice const & si , slice const & sj )
{
assert ( dim ( ) = = 2 & & " Too many indices in array " ) ;
return view ( * this , si , sj ) ;
}
//---------------------------------------
/*--- array ---*/
array : : array ( math_expression const & proxy ) : array_base ( execution_handler ( proxy ) ) { }
array : : array ( array_base const & other ) : array_base ( other . dtype ( ) , other . shape ( ) , other . context ( ) )
{ * this = other ; }
array : : array ( array const & other ) : array ( ( array_base const & ) other )
{ }
2015-10-06 16:34:47 -04:00
2015-10-03 18:48:20 -04:00
//---------------------------------------
/*--- View ---*/
2015-11-19 12:37:18 -05:00
view : : view ( array & data ) : array_base ( data ) { }
view : : view ( array_base & data , slice const & s1 ) : array_base ( data , s1 ) { }
view : : view ( array_base & data , slice const & s1 , slice const & s2 ) : array_base ( data , s1 , s2 ) { }
view : : view ( int_t size1 , numeric_type dtype , driver : : Buffer data , int_t start , int_t inc ) : array_base ( size1 , dtype , data , start , inc ) { }
2015-10-03 18:48:20 -04:00
2015-01-12 13:20:53 -05:00
2015-05-01 21:39:29 -04:00
//---------------------------------------
2015-01-12 13:20:53 -05:00
/*--- Scalar ---*/
namespace detail
{
template < class T >
2015-07-31 00:41:03 -07:00
void copy ( driver : : Context const & context , driver : : Buffer const & data , T value )
2015-01-12 13:20:53 -05:00
{
2015-08-03 16:05:57 -07:00
driver : : backend : : queues : : get ( context , 0 ) . write ( data , CL_TRUE , 0 , sizeof ( T ) , ( void * ) & value ) ;
2015-01-12 13:20:53 -05:00
}
}
2015-11-19 12:37:18 -05:00
scalar : : scalar ( numeric_type dtype , const driver : : Buffer & data , int_t offset ) : array_base ( 1 , dtype , data , offset , 1 )
2015-01-13 01:17:27 -05:00
{ }
2015-11-19 12:37:18 -05:00
scalar : : scalar ( value_scalar value , driver : : Context const & context ) : array_base ( 1 , value . dtype ( ) , context )
2015-01-12 13:20:53 -05:00
{
2015-02-04 22:06:15 -05:00
switch ( dtype_ )
2015-01-12 13:20:53 -05:00
{
2015-05-04 21:23:05 -04:00
case CHAR_TYPE : detail : : copy ( context_ , data_ , ( char ) value ) ; break ;
case UCHAR_TYPE : detail : : copy ( context_ , data_ , ( unsigned char ) value ) ; break ;
case SHORT_TYPE : detail : : copy ( context_ , data_ , ( short ) value ) ; break ;
case USHORT_TYPE : detail : : copy ( context_ , data_ , ( unsigned short ) value ) ; break ;
case INT_TYPE : detail : : copy ( context_ , data_ , ( int ) value ) ; break ;
case UINT_TYPE : detail : : copy ( context_ , data_ , ( unsigned int ) value ) ; break ;
case LONG_TYPE : detail : : copy ( context_ , data_ , ( long ) value ) ; break ;
case ULONG_TYPE : detail : : copy ( context_ , data_ , ( unsigned long ) value ) ; break ;
case FLOAT_TYPE : detail : : copy ( context_ , data_ , ( float ) value ) ; break ;
case DOUBLE_TYPE : detail : : copy ( context_ , data_ , ( double ) value ) ; break ;
2015-02-04 22:06:15 -05:00
default : throw unknown_datatype ( dtype_ ) ;
2015-01-12 13:20:53 -05:00
}
}
2015-11-19 12:37:18 -05:00
scalar : : scalar ( numeric_type dtype , driver : : Context const & context ) : array_base ( 1 , dtype , context )
2015-01-19 21:29:47 -05:00
{ }
2015-01-12 13:20:53 -05:00
2015-11-19 12:37:18 -05:00
scalar : : scalar ( math_expression const & proxy ) : array_base ( proxy ) { }
2015-01-12 13:20:53 -05:00
2015-05-04 21:23:05 -04:00
void scalar : : inject ( values_holder & v ) const
{
int_t dtsize = size_of ( dtype_ ) ;
# define HANDLE_CASE(DTYPE, VAL) \
case DTYPE : \
2015-11-19 12:37:18 -05:00
driver : : backend : : queues : : get ( context_ , 0 ) . read ( data_ , CL_TRUE , start_ * dtsize , dtsize , ( void * ) & v . VAL ) ; break ; \
2015-05-04 21:23:05 -04:00
switch ( dtype_ )
{
HANDLE_CASE ( CHAR_TYPE , int8 ) ;
HANDLE_CASE ( UCHAR_TYPE , uint8 ) ;
HANDLE_CASE ( SHORT_TYPE , int16 ) ;
HANDLE_CASE ( USHORT_TYPE , uint16 ) ;
HANDLE_CASE ( INT_TYPE , int32 ) ;
HANDLE_CASE ( UINT_TYPE , uint32 ) ;
HANDLE_CASE ( LONG_TYPE , int64 ) ;
HANDLE_CASE ( ULONG_TYPE , uint64 ) ;
HANDLE_CASE ( FLOAT_TYPE , float32 ) ;
HANDLE_CASE ( DOUBLE_TYPE , float64 ) ;
default : throw unknown_datatype ( dtype_ ) ;
}
# undef HANDLE_CASE
}
2015-07-21 17:18:50 -04:00
template < class TYPE >
TYPE scalar : : cast ( ) const
2015-01-12 13:20:53 -05:00
{
2015-01-13 01:17:27 -05:00
values_holder v ;
2015-05-04 21:23:05 -04:00
inject ( v ) ;
2015-07-21 17:18:50 -04:00
# define HANDLE_CASE(DTYPE, VAL) case DTYPE: return static_cast<TYPE>(v.VAL)
2015-01-13 01:17:27 -05:00
2015-02-04 22:06:15 -05:00
switch ( dtype_ )
2015-01-13 01:17:27 -05:00
{
HANDLE_CASE ( CHAR_TYPE , int8 ) ;
HANDLE_CASE ( UCHAR_TYPE , uint8 ) ;
HANDLE_CASE ( SHORT_TYPE , int16 ) ;
HANDLE_CASE ( USHORT_TYPE , uint16 ) ;
HANDLE_CASE ( INT_TYPE , int32 ) ;
HANDLE_CASE ( UINT_TYPE , uint32 ) ;
HANDLE_CASE ( LONG_TYPE , int64 ) ;
HANDLE_CASE ( ULONG_TYPE , uint64 ) ;
HANDLE_CASE ( FLOAT_TYPE , float32 ) ;
HANDLE_CASE ( DOUBLE_TYPE , float64 ) ;
2015-02-04 22:06:15 -05:00
default : throw unknown_datatype ( dtype_ ) ;
2015-01-13 01:17:27 -05:00
}
# undef HANDLE_CASE
}
scalar & scalar : : operator = ( value_scalar const & s )
{
2015-08-03 16:05:57 -07:00
driver : : CommandQueue & queue = driver : : backend : : queues : : get ( context_ , 0 ) ;
2015-02-04 22:06:15 -05:00
int_t dtsize = size_of ( dtype_ ) ;
2015-01-13 01:17:27 -05:00
# define HANDLE_CASE(TYPE, CLTYPE) case TYPE:\
{ \
2015-01-16 07:31:39 -05:00
CLTYPE v = s ; \
2015-11-19 12:37:18 -05:00
queue . write ( data_ , CL_TRUE , start_ * dtsize , dtsize , ( void * ) & v ) ; \
2015-01-13 14:44:19 -05:00
return * this ; \
2015-01-13 01:17:27 -05:00
}
2015-02-04 22:06:15 -05:00
switch ( dtype_ )
2015-01-13 01:17:27 -05:00
{
2015-05-04 21:23:05 -04:00
HANDLE_CASE ( CHAR_TYPE , char )
HANDLE_CASE ( UCHAR_TYPE , unsigned char )
HANDLE_CASE ( SHORT_TYPE , short )
HANDLE_CASE ( USHORT_TYPE , unsigned short )
HANDLE_CASE ( INT_TYPE , int )
HANDLE_CASE ( UINT_TYPE , unsigned int )
HANDLE_CASE ( LONG_TYPE , long )
HANDLE_CASE ( ULONG_TYPE , unsigned long )
HANDLE_CASE ( FLOAT_TYPE , float )
HANDLE_CASE ( DOUBLE_TYPE , double )
2015-02-04 22:06:15 -05:00
default : throw unknown_datatype ( dtype_ ) ;
2015-01-13 01:17:27 -05:00
}
2015-01-12 13:20:53 -05:00
}
2015-01-13 01:17:27 -05:00
# define INSTANTIATE(type) scalar::operator type() const { return cast<type>(); }
2015-05-04 21:23:05 -04:00
INSTANTIATE ( char )
INSTANTIATE ( unsigned char )
INSTANTIATE ( short )
INSTANTIATE ( unsigned short )
INSTANTIATE ( int )
INSTANTIATE ( unsigned int )
INSTANTIATE ( long )
INSTANTIATE ( unsigned long )
2015-05-04 23:54:43 -04:00
INSTANTIATE ( long long )
INSTANTIATE ( unsigned long long )
2015-05-04 21:23:05 -04:00
INSTANTIATE ( float )
INSTANTIATE ( double )
2015-01-13 01:17:27 -05:00
# undef INSTANTIATE
2015-01-12 13:20:53 -05:00
2015-01-16 07:31:39 -05:00
std : : ostream & operator < < ( std : : ostream & os , scalar const & s )
2015-01-12 13:20:53 -05:00
{
switch ( s . dtype ( ) )
{
2015-07-21 17:18:50 -04:00
// case BOOL_TYPE: return os << static_cast<bool>(s);
2015-05-04 21:23:05 -04:00
case CHAR_TYPE : return os < < static_cast < char > ( s ) ;
case UCHAR_TYPE : return os < < static_cast < unsigned char > ( s ) ;
case SHORT_TYPE : return os < < static_cast < short > ( s ) ;
case USHORT_TYPE : return os < < static_cast < unsigned short > ( s ) ;
case INT_TYPE : return os < < static_cast < int > ( s ) ;
case UINT_TYPE : return os < < static_cast < unsigned int > ( s ) ;
case LONG_TYPE : return os < < static_cast < long > ( s ) ;
case ULONG_TYPE : return os < < static_cast < unsigned long > ( s ) ;
2015-07-21 17:18:50 -04:00
// case HALF_TYPE: return os << static_cast<half>(s);
2015-05-04 21:23:05 -04:00
case FLOAT_TYPE : return os < < static_cast < float > ( s ) ;
case DOUBLE_TYPE : return os < < static_cast < double > ( s ) ;
2015-01-19 21:29:47 -05:00
default : throw unknown_datatype ( s . dtype ( ) ) ;
2015-01-12 13:20:53 -05:00
}
}
/*--- Binary Operators ----*/
//-----------------------------------
2015-11-19 12:37:18 -05:00
shape_t broadcast ( shape_t const & a , shape_t const & b )
2015-01-16 15:24:24 -05:00
{
2015-11-19 12:37:18 -05:00
std : : vector < int_t > aa = a , bb = b , result ;
size_t as = aa . size ( ) , bs = bb . size ( ) ;
if ( as < bs )
aa . insert ( aa . begin ( ) , bs - as , 1 ) ;
else
bb . insert ( bb . begin ( ) , as - bs , 1 ) ;
for ( size_t i = 0 ; i < std : : max ( as , bs ) ; + + i ) {
assert ( ( aa [ i ] = = bb [ i ] | | aa [ i ] = = 1 | | bb [ i ] = = 1 ) & & " Cannot broadcast " ) ;
result . push_back ( std : : max ( aa [ i ] , bb [ i ] ) ) ;
}
return shape_t ( result ) ;
2015-01-12 13:20:53 -05:00
}
2015-01-29 01:00:50 -05:00
# define DEFINE_ELEMENT_BINARY_OPERATOR(OP, OPNAME, DTYPE) \
2015-11-19 12:37:18 -05:00
math_expression OPNAME ( array_base const & x , math_expression const & y ) \
{ return math_expression ( x , y , op_element ( OPERATOR_BINARY_TYPE_FAMILY , OP ) , x . context ( ) , DTYPE , broadcast ( x . shape ( ) , y . shape ( ) ) ) ; } \
2015-01-12 13:20:53 -05:00
\
2015-11-19 12:37:18 -05:00
math_expression OPNAME ( array_base const & x , array_base const & y ) \
{ return math_expression ( x , y , op_element ( OPERATOR_BINARY_TYPE_FAMILY , OP ) , x . context ( ) , DTYPE , broadcast ( x . shape ( ) , y . shape ( ) ) ) ; } \
2015-09-30 15:31:41 -04:00
\
2015-11-19 12:37:18 -05:00
math_expression OPNAME ( array_base const & x , value_scalar const & y ) \
2015-09-30 15:31:41 -04:00
{ return math_expression ( x , y , op_element ( OPERATOR_BINARY_TYPE_FAMILY , OP ) , x . context ( ) , DTYPE , x . shape ( ) ) ; } \
\
2015-11-19 12:37:18 -05:00
math_expression OPNAME ( array_base const & x , for_idx_t const & y ) \
2015-09-30 15:31:41 -04:00
{ return math_expression ( x , y , op_element ( OPERATOR_BINARY_TYPE_FAMILY , OP ) , x . context ( ) , DTYPE , x . shape ( ) ) ; } \
\
2015-01-12 13:20:53 -05:00
\
2015-09-30 15:31:41 -04:00
math_expression OPNAME ( math_expression const & x , math_expression const & y ) \
2015-11-19 12:37:18 -05:00
{ return math_expression ( x , y , op_element ( OPERATOR_BINARY_TYPE_FAMILY , OP ) , x . context ( ) , DTYPE , broadcast ( x . shape ( ) , y . shape ( ) ) ) ; } \
2015-09-30 15:31:41 -04:00
\
2015-11-19 12:37:18 -05:00
math_expression OPNAME ( math_expression const & x , array_base const & y ) \
{ return math_expression ( x , y , op_element ( OPERATOR_BINARY_TYPE_FAMILY , OP ) , x . context ( ) , DTYPE , broadcast ( x . shape ( ) , y . shape ( ) ) ) ; } \
2015-09-30 15:31:41 -04:00
\
math_expression OPNAME ( math_expression const & x , value_scalar const & y ) \
{ return math_expression ( x , y , op_element ( OPERATOR_BINARY_TYPE_FAMILY , OP ) , x . context ( ) , DTYPE , x . shape ( ) ) ; } \
\
math_expression OPNAME ( math_expression const & x , for_idx_t const & y ) \
{ return math_expression ( x , y , op_element ( OPERATOR_BINARY_TYPE_FAMILY , OP ) , x . context ( ) , DTYPE , x . shape ( ) ) ; } \
\
\
math_expression OPNAME ( value_scalar const & y , math_expression const & x ) \
{ return math_expression ( y , x , op_element ( OPERATOR_BINARY_TYPE_FAMILY , OP ) , x . context ( ) , DTYPE , x . shape ( ) ) ; } \
\
2015-11-19 12:37:18 -05:00
math_expression OPNAME ( value_scalar const & y , array_base const & x ) \
2015-09-30 15:31:41 -04:00
{ return math_expression ( y , x , op_element ( OPERATOR_BINARY_TYPE_FAMILY , OP ) , x . context ( ) , DTYPE , x . shape ( ) ) ; } \
\
math_expression OPNAME ( value_scalar const & x , for_idx_t const & y ) \
{ return math_expression ( x , y , op_element ( OPERATOR_BINARY_TYPE_FAMILY , OP ) , DTYPE ) ; } \
2015-01-12 13:20:53 -05:00
\
\
2015-09-30 15:31:41 -04:00
math_expression OPNAME ( for_idx_t const & y , math_expression const & x ) \
{ return math_expression ( y , x , op_element ( OPERATOR_BINARY_TYPE_FAMILY , OP ) , x . context ( ) , DTYPE , x . shape ( ) ) ; } \
2015-01-12 13:20:53 -05:00
\
2015-09-30 15:31:41 -04:00
math_expression OPNAME ( for_idx_t const & y , value_scalar const & x ) \
{ return math_expression ( y , x , op_element ( OPERATOR_BINARY_TYPE_FAMILY , OP ) , DTYPE ) ; } \
2015-01-12 13:20:53 -05:00
\
2015-11-19 12:37:18 -05:00
math_expression OPNAME ( for_idx_t const & y , array_base const & x ) \
2015-09-30 15:31:41 -04:00
{ return math_expression ( y , x , op_element ( OPERATOR_BINARY_TYPE_FAMILY , OP ) , x . context ( ) , DTYPE , x . shape ( ) ) ; } \
\
math_expression OPNAME ( for_idx_t const & y , for_idx_t const & x ) \
{ return math_expression ( y , x , op_element ( OPERATOR_BINARY_TYPE_FAMILY , OP ) ) ; }
2015-01-29 01:00:50 -05:00
DEFINE_ELEMENT_BINARY_OPERATOR ( OPERATOR_ADD_TYPE , operator + , x . dtype ( ) )
DEFINE_ELEMENT_BINARY_OPERATOR ( OPERATOR_SUB_TYPE , operator - , x . dtype ( ) )
DEFINE_ELEMENT_BINARY_OPERATOR ( OPERATOR_MULT_TYPE , operator * , x . dtype ( ) )
DEFINE_ELEMENT_BINARY_OPERATOR ( OPERATOR_DIV_TYPE , operator / , x . dtype ( ) )
DEFINE_ELEMENT_BINARY_OPERATOR ( OPERATOR_ELEMENT_MAX_TYPE , maximum , x . dtype ( ) )
DEFINE_ELEMENT_BINARY_OPERATOR ( OPERATOR_ELEMENT_MIN_TYPE , minimum , x . dtype ( ) )
DEFINE_ELEMENT_BINARY_OPERATOR ( OPERATOR_ELEMENT_POW_TYPE , pow , x . dtype ( ) )
2015-06-28 17:53:16 -07:00
DEFINE_ELEMENT_BINARY_OPERATOR ( OPERATOR_ASSIGN_TYPE , assign , x . dtype ( ) )
2015-01-12 13:20:53 -05:00
2015-01-29 01:00:50 -05:00
DEFINE_ELEMENT_BINARY_OPERATOR ( OPERATOR_ELEMENT_GREATER_TYPE , operator > , INT_TYPE )
DEFINE_ELEMENT_BINARY_OPERATOR ( OPERATOR_ELEMENT_GEQ_TYPE , operator > = , INT_TYPE )
DEFINE_ELEMENT_BINARY_OPERATOR ( OPERATOR_ELEMENT_LESS_TYPE , operator < , INT_TYPE )
DEFINE_ELEMENT_BINARY_OPERATOR ( OPERATOR_ELEMENT_LEQ_TYPE , operator < = , INT_TYPE )
DEFINE_ELEMENT_BINARY_OPERATOR ( OPERATOR_ELEMENT_EQ_TYPE , operator = = , INT_TYPE )
DEFINE_ELEMENT_BINARY_OPERATOR ( OPERATOR_ELEMENT_NEQ_TYPE , operator ! = , INT_TYPE )
2015-01-12 13:20:53 -05:00
2015-06-30 17:55:57 -04:00
# define DEFINE_OUTER(LTYPE, RTYPE) \
2015-09-30 15:31:41 -04:00
math_expression outer ( LTYPE const & x , RTYPE const & y ) \
2015-06-30 17:55:57 -04:00
{ \
2015-11-19 12:37:18 -05:00
assert ( x . dim ( ) < = 1 & & y . dim ( ) < = 1 ) ; \
if ( x . dim ( ) < 1 | | y . dim ( ) < 1 ) \
return x * y ; \
return math_expression ( x , y , op_element ( OPERATOR_BINARY_TYPE_FAMILY , OPERATOR_OUTER_PROD_TYPE ) , x . context ( ) , x . dtype ( ) , { x . shape ( ) . max ( ) , y . shape ( ) . max ( ) } ) ; \
2015-06-30 17:55:57 -04:00
} \
2015-01-12 13:20:53 -05:00
2015-11-19 12:37:18 -05:00
DEFINE_OUTER ( array_base , array_base )
DEFINE_OUTER ( math_expression , array_base )
DEFINE_OUTER ( array_base , math_expression )
2015-09-30 15:31:41 -04:00
DEFINE_OUTER ( math_expression , math_expression )
2015-01-12 13:20:53 -05:00
# undef DEFINE_ELEMENT_BINARY_OPERATOR
2015-09-30 15:31:41 -04:00
# define DEFINE_ROT(LTYPE, RTYPE, CTYPE, STYPE)\
math_expression rot ( LTYPE const & x , RTYPE const & y , CTYPE const & c , STYPE const & s ) \
{ return fuse ( assign ( x , c * x + s * y ) , assign ( y , c * y - s * x ) ) ; }
2015-11-19 12:37:18 -05:00
DEFINE_ROT ( array_base , array_base , scalar , scalar )
DEFINE_ROT ( math_expression , array_base , scalar , scalar )
DEFINE_ROT ( array_base , math_expression , scalar , scalar )
2015-09-30 15:31:41 -04:00
DEFINE_ROT ( math_expression , math_expression , scalar , scalar )
2015-11-19 12:37:18 -05:00
DEFINE_ROT ( array_base , array_base , value_scalar , value_scalar )
DEFINE_ROT ( math_expression , array_base , value_scalar , value_scalar )
DEFINE_ROT ( array_base , math_expression , value_scalar , value_scalar )
2015-09-30 15:31:41 -04:00
DEFINE_ROT ( math_expression , math_expression , value_scalar , value_scalar )
2015-11-19 12:37:18 -05:00
DEFINE_ROT ( array_base , array_base , math_expression , math_expression )
DEFINE_ROT ( math_expression , array_base , math_expression , math_expression )
DEFINE_ROT ( array_base , math_expression , math_expression , math_expression )
2015-09-30 15:31:41 -04:00
DEFINE_ROT ( math_expression , math_expression , math_expression , math_expression )
2015-01-12 13:20:53 -05:00
//---------------------------------------
/*--- Math Operators----*/
//---------------------------------------
# define DEFINE_ELEMENT_UNARY_OPERATOR(OP, OPNAME) \
2015-11-19 12:37:18 -05:00
math_expression OPNAME ( array_base const & x ) \
2015-09-30 15:31:41 -04:00
{ return math_expression ( x , invalid_node ( ) , op_element ( OPERATOR_UNARY_TYPE_FAMILY , OP ) , x . context ( ) , x . dtype ( ) , x . shape ( ) ) ; } \
2015-01-12 13:20:53 -05:00
\
2015-09-30 15:31:41 -04:00
math_expression OPNAME ( math_expression const & x ) \
{ return math_expression ( x , invalid_node ( ) , op_element ( OPERATOR_UNARY_TYPE_FAMILY , OP ) , x . context ( ) , x . dtype ( ) , x . shape ( ) ) ; }
2015-01-12 13:20:53 -05:00
DEFINE_ELEMENT_UNARY_OPERATOR ( ( x . dtype ( ) = = FLOAT_TYPE | | x . dtype ( ) = = DOUBLE_TYPE ) ? OPERATOR_FABS_TYPE : OPERATOR_ABS_TYPE , abs )
DEFINE_ELEMENT_UNARY_OPERATOR ( OPERATOR_ACOS_TYPE , acos )
DEFINE_ELEMENT_UNARY_OPERATOR ( OPERATOR_ASIN_TYPE , asin )
DEFINE_ELEMENT_UNARY_OPERATOR ( OPERATOR_ATAN_TYPE , atan )
DEFINE_ELEMENT_UNARY_OPERATOR ( OPERATOR_CEIL_TYPE , ceil )
DEFINE_ELEMENT_UNARY_OPERATOR ( OPERATOR_COS_TYPE , cos )
DEFINE_ELEMENT_UNARY_OPERATOR ( OPERATOR_COSH_TYPE , cosh )
DEFINE_ELEMENT_UNARY_OPERATOR ( OPERATOR_EXP_TYPE , exp )
DEFINE_ELEMENT_UNARY_OPERATOR ( OPERATOR_FLOOR_TYPE , floor )
DEFINE_ELEMENT_UNARY_OPERATOR ( OPERATOR_LOG_TYPE , log )
DEFINE_ELEMENT_UNARY_OPERATOR ( OPERATOR_LOG10_TYPE , log10 )
DEFINE_ELEMENT_UNARY_OPERATOR ( OPERATOR_SIN_TYPE , sin )
DEFINE_ELEMENT_UNARY_OPERATOR ( OPERATOR_SINH_TYPE , sinh )
DEFINE_ELEMENT_UNARY_OPERATOR ( OPERATOR_SQRT_TYPE , sqrt )
DEFINE_ELEMENT_UNARY_OPERATOR ( OPERATOR_TAN_TYPE , tan )
DEFINE_ELEMENT_UNARY_OPERATOR ( OPERATOR_TANH_TYPE , tanh )
# undef DEFINE_ELEMENT_UNARY_OPERATOR
//---------------------------------------
2015-01-16 19:39:26 -05:00
2015-01-12 13:20:53 -05:00
///*--- Misc----*/
////---------------------------------------
2015-01-29 01:00:50 -05:00
inline operation_node_type casted ( numeric_type dtype )
2015-01-16 19:39:26 -05:00
{
2015-01-29 01:00:50 -05:00
switch ( dtype )
{
2015-01-29 15:19:40 -05:00
// case BOOL_TYPE: return OPERATOR_CAST_BOOL_TYPE;
2015-01-29 01:00:50 -05:00
case CHAR_TYPE : return OPERATOR_CAST_CHAR_TYPE ;
case UCHAR_TYPE : return OPERATOR_CAST_UCHAR_TYPE ;
case SHORT_TYPE : return OPERATOR_CAST_SHORT_TYPE ;
case USHORT_TYPE : return OPERATOR_CAST_USHORT_TYPE ;
case INT_TYPE : return OPERATOR_CAST_INT_TYPE ;
case UINT_TYPE : return OPERATOR_CAST_UINT_TYPE ;
case LONG_TYPE : return OPERATOR_CAST_LONG_TYPE ;
case ULONG_TYPE : return OPERATOR_CAST_ULONG_TYPE ;
2015-01-29 15:19:40 -05:00
// case FLOAT_TYPE: return OPERATOR_CAST_HALF_TYPE;
2015-01-29 01:00:50 -05:00
case FLOAT_TYPE : return OPERATOR_CAST_FLOAT_TYPE ;
case DOUBLE_TYPE : return OPERATOR_CAST_DOUBLE_TYPE ;
default : throw unknown_datatype ( dtype ) ;
}
2015-01-16 19:39:26 -05:00
}
2015-11-19 12:37:18 -05:00
math_expression cast ( array_base const & x , numeric_type dtype )
2015-09-30 15:31:41 -04:00
{ return math_expression ( x , invalid_node ( ) , op_element ( OPERATOR_UNARY_TYPE_FAMILY , casted ( dtype ) ) , x . context ( ) , dtype , x . shape ( ) ) ; }
2015-01-29 01:00:50 -05:00
2015-09-30 15:31:41 -04:00
math_expression cast ( math_expression const & x , numeric_type dtype )
{ return math_expression ( x , invalid_node ( ) , op_element ( OPERATOR_UNARY_TYPE_FAMILY , casted ( dtype ) ) , x . context ( ) , dtype , x . shape ( ) ) ; }
2015-01-29 01:00:50 -05:00
2015-09-30 15:31:41 -04:00
isaac : : math_expression eye ( int_t M , int_t N , isaac : : numeric_type dtype , driver : : Context const & ctx )
2015-11-19 12:37:18 -05:00
{ return math_expression ( value_scalar ( 1 ) , value_scalar ( 0 ) , op_element ( OPERATOR_UNARY_TYPE_FAMILY , OPERATOR_VDIAG_TYPE ) , ctx , dtype , { M , N } ) ; }
2015-01-29 01:00:50 -05:00
2015-11-19 12:37:18 -05:00
array diag ( array_base & x , int offset )
2015-10-05 14:35:46 -04:00
{
2015-11-19 12:37:18 -05:00
assert ( x . dim ( ) = = 2 & & " Input must be 2-d " ) ;
2015-10-07 00:50:49 -04:00
int_t offi = - ( offset < 0 ) * offset , offj = ( offset > 0 ) * offset ;
int_t size = std : : min ( x . shape ( ) [ 0 ] - offi , x . shape ( ) [ 1 ] - offj ) ;
2015-11-19 12:37:18 -05:00
int_t start = offi + x . stride ( ) [ 1 ] * offj ;
return array ( size , x . dtype ( ) , x . data ( ) , start , x . stride ( ) [ 1 ] + 1 ) ;
2015-10-05 14:35:46 -04:00
}
2015-10-04 17:05:06 -04:00
2015-09-30 15:31:41 -04:00
isaac : : math_expression zeros ( int_t M , int_t N , isaac : : numeric_type dtype , driver : : Context const & ctx )
2015-11-19 12:37:18 -05:00
{ return math_expression ( value_scalar ( 0 , dtype ) , invalid_node ( ) , op_element ( OPERATOR_UNARY_TYPE_FAMILY , OPERATOR_ADD_TYPE ) , ctx , dtype , { M , N } ) ; }
2015-01-20 11:17:42 -05:00
2015-11-19 12:37:18 -05:00
inline shape_t flip ( shape_t const & shape )
{
shape_t res = shape ;
for ( size_t i = 0 ; i < shape . size ( ) ; + + i )
res [ i ] = shape [ ( i + 1 ) % shape . size ( ) ] ;
return res ;
}
2015-01-12 13:20:53 -05:00
2015-11-19 12:37:18 -05:00
//inline size4 prod(size4 const & shape1, size4 const & shape2)
//{ return size4(shape1[0]*shape2[0], shape1[1]*shape2[1]);}
2015-01-12 13:20:53 -05:00
2015-11-19 12:37:18 -05:00
math_expression trans ( array_base const & x ) \
2015-09-30 15:31:41 -04:00
{ return math_expression ( x , invalid_node ( ) , op_element ( OPERATOR_UNARY_TYPE_FAMILY , OPERATOR_TRANS_TYPE ) , x . context ( ) , x . dtype ( ) , flip ( x . shape ( ) ) ) ; } \
2015-01-12 13:20:53 -05:00
\
2015-09-30 15:31:41 -04:00
math_expression trans ( math_expression const & x ) \
{ return math_expression ( x , invalid_node ( ) , op_element ( OPERATOR_UNARY_TYPE_FAMILY , OPERATOR_TRANS_TYPE ) , x . context ( ) , x . dtype ( ) , flip ( x . shape ( ) ) ) ; }
2015-01-12 13:20:53 -05:00
2015-11-19 12:37:18 -05:00
math_expression repmat ( array_base const & A , int_t const & rep1 , int_t const & rep2 )
2015-01-12 13:20:53 -05:00
{
2015-09-30 15:31:41 -04:00
int_t sub1 = A . shape ( ) [ 0 ] ;
2015-11-19 12:37:18 -05:00
int_t sub2 = A . dim ( ) = = 2 ? A . shape ( ) [ 1 ] : 1 ;
return math_expression ( A , make_tuple ( A . context ( ) , rep1 , rep2 , sub1 , sub2 ) , op_element ( OPERATOR_BINARY_TYPE_FAMILY , OPERATOR_REPEAT_TYPE ) , A . context ( ) , A . dtype ( ) , { rep1 * sub1 , rep2 * sub2 } ) ;
2015-01-12 13:20:53 -05:00
}
2015-09-30 15:31:41 -04:00
math_expression repmat ( math_expression const & A , int_t const & rep1 , int_t const & rep2 )
2015-01-12 13:20:53 -05:00
{
2015-09-30 15:31:41 -04:00
int_t sub1 = A . shape ( ) [ 0 ] ;
2015-11-19 12:37:18 -05:00
int_t sub2 = A . dim ( ) = = 2 ? A . shape ( ) [ 1 ] : 1 ;
return math_expression ( A , make_tuple ( A . context ( ) , rep1 , rep2 , sub1 , sub2 ) , op_element ( OPERATOR_BINARY_TYPE_FAMILY , OPERATOR_REPEAT_TYPE ) , A . context ( ) , A . dtype ( ) , { rep1 * sub1 , rep2 * sub2 } ) ;
2015-01-12 13:20:53 -05:00
}
2015-09-30 15:31:41 -04:00
# define DEFINE_ACCESS_ROW(TYPEA, TYPEB) \
math_expression row ( TYPEA const & x , TYPEB const & i ) \
2015-11-19 12:37:18 -05:00
{ return math_expression ( x , i , op_element ( OPERATOR_UNARY_TYPE_FAMILY , OPERATOR_MATRIX_ROW_TYPE ) , x . context ( ) , x . dtype ( ) , { x . shape ( ) [ 1 ] } ) ; }
2015-09-30 15:31:41 -04:00
2015-11-19 12:37:18 -05:00
DEFINE_ACCESS_ROW ( array_base , value_scalar )
DEFINE_ACCESS_ROW ( array_base , for_idx_t )
DEFINE_ACCESS_ROW ( array_base , math_expression )
2015-09-30 15:31:41 -04:00
DEFINE_ACCESS_ROW ( math_expression , value_scalar )
DEFINE_ACCESS_ROW ( math_expression , for_idx_t )
DEFINE_ACCESS_ROW ( math_expression , math_expression )
# define DEFINE_ACCESS_COL(TYPEA, TYPEB) \
math_expression col ( TYPEA const & x , TYPEB const & i ) \
2015-11-19 12:37:18 -05:00
{ return math_expression ( x , i , op_element ( OPERATOR_UNARY_TYPE_FAMILY , OPERATOR_MATRIX_COLUMN_TYPE ) , x . context ( ) , x . dtype ( ) , { x . shape ( ) [ 0 ] } ) ; }
2015-09-30 15:31:41 -04:00
2015-11-19 12:37:18 -05:00
DEFINE_ACCESS_COL ( array_base , value_scalar )
DEFINE_ACCESS_COL ( array_base , for_idx_t )
DEFINE_ACCESS_COL ( array_base , math_expression )
2015-09-30 15:31:41 -04:00
DEFINE_ACCESS_COL ( math_expression , value_scalar )
DEFINE_ACCESS_COL ( math_expression , for_idx_t )
DEFINE_ACCESS_COL ( math_expression , math_expression )
2015-01-12 13:20:53 -05:00
////---------------------------------------
///*--- Reductions ---*/
////---------------------------------------
2015-11-19 12:37:18 -05:00
# define DEFINE_REDUCTION(OP, OPNAME)\
math_expression OPNAME ( array_base const & x , int_t axis ) \
2015-01-12 13:20:53 -05:00
{ \
2015-11-19 12:37:18 -05:00
if ( axis < - 1 | | axis > x . dim ( ) ) \
2015-01-28 17:08:39 -05:00
throw std : : out_of_range ( " The axis entry is out of bounds " ) ; \
else if ( axis = = - 1 ) \
2015-11-19 12:37:18 -05:00
return math_expression ( x , invalid_node ( ) , op_element ( OPERATOR_VECTOR_DOT_TYPE_FAMILY , OP ) , x . context ( ) , x . dtype ( ) , { 1 } ) ; \
2015-01-12 13:20:53 -05:00
else if ( axis = = 0 ) \
2015-11-19 12:37:18 -05:00
return math_expression ( x , invalid_node ( ) , op_element ( OPERATOR_COLUMNS_DOT_TYPE_FAMILY , OP ) , x . context ( ) , x . dtype ( ) , { x . shape ( ) [ 1 ] } ) ; \
2015-05-01 15:57:03 -04:00
else \
2015-11-19 12:37:18 -05:00
return math_expression ( x , invalid_node ( ) , op_element ( OPERATOR_ROWS_DOT_TYPE_FAMILY , OP ) , x . context ( ) , x . dtype ( ) , { x . shape ( ) [ 0 ] } ) ; \
2015-01-12 13:20:53 -05:00
} \
\
2015-09-30 15:31:41 -04:00
math_expression OPNAME ( math_expression const & x , int_t axis ) \
2015-01-12 13:20:53 -05:00
{ \
2015-11-19 12:37:18 -05:00
if ( axis < - 1 | | axis > x . dim ( ) ) \
2015-01-28 17:08:39 -05:00
throw std : : out_of_range ( " The axis entry is out of bounds " ) ; \
2015-01-12 13:20:53 -05:00
if ( axis = = - 1 ) \
2015-11-19 12:37:18 -05:00
return math_expression ( x , invalid_node ( ) , op_element ( OPERATOR_VECTOR_DOT_TYPE_FAMILY , OP ) , x . context ( ) , x . dtype ( ) , { 1 } ) ; \
2015-01-12 13:20:53 -05:00
else if ( axis = = 0 ) \
2015-11-19 12:37:18 -05:00
return math_expression ( x , invalid_node ( ) , op_element ( OPERATOR_COLUMNS_DOT_TYPE_FAMILY , OP ) , x . context ( ) , x . dtype ( ) , { x . shape ( ) [ 1 ] } ) ; \
2015-05-01 15:57:03 -04:00
else \
2015-11-19 12:37:18 -05:00
return math_expression ( x , invalid_node ( ) , op_element ( OPERATOR_ROWS_DOT_TYPE_FAMILY , OP ) , x . context ( ) , x . dtype ( ) , { x . shape ( ) [ 0 ] } ) ; \
2015-01-12 13:20:53 -05:00
}
2015-11-19 12:37:18 -05:00
DEFINE_REDUCTION ( OPERATOR_ADD_TYPE , sum )
DEFINE_REDUCTION ( OPERATOR_ELEMENT_ARGMAX_TYPE , argmax )
DEFINE_REDUCTION ( OPERATOR_ELEMENT_MAX_TYPE , max )
DEFINE_REDUCTION ( OPERATOR_ELEMENT_MIN_TYPE , min )
DEFINE_REDUCTION ( OPERATOR_ELEMENT_ARGMIN_TYPE , argmin )
2015-01-12 13:20:53 -05:00
2015-11-19 12:37:18 -05:00
# undef DEFINE_REDUCTION
2015-01-12 13:20:53 -05:00
namespace detail
{
2015-11-19 12:37:18 -05:00
math_expression matmatprod ( array_base const & A , array_base const & B )
2015-01-12 13:20:53 -05:00
{
2015-11-19 12:37:18 -05:00
shape_t shape { A . shape ( ) [ 0 ] , B . shape ( ) [ 1 ] } ;
2015-09-30 15:31:41 -04:00
return math_expression ( A , B , op_element ( OPERATOR_GEMM_TYPE_FAMILY , OPERATOR_GEMM_NN_TYPE ) , A . context ( ) , A . dtype ( ) , shape ) ;
2015-01-12 13:20:53 -05:00
}
2015-11-19 12:37:18 -05:00
math_expression matmatprod ( math_expression const & A , array_base const & B )
2015-01-12 13:20:53 -05:00
{
2015-07-11 09:36:01 -04:00
operation_node_type type = OPERATOR_GEMM_NN_TYPE ;
2015-11-19 12:37:18 -05:00
shape_t shape { A . shape ( ) [ 0 ] , B . shape ( ) [ 1 ] } ;
2015-01-12 13:20:53 -05:00
2015-09-30 15:31:41 -04:00
math_expression : : node & A_root = const_cast < math_expression : : node & > ( A . tree ( ) [ A . root ( ) ] ) ;
2015-01-12 13:20:53 -05:00
bool A_trans = A_root . op . type = = OPERATOR_TRANS_TYPE ;
if ( A_trans ) {
2015-07-11 09:36:01 -04:00
type = OPERATOR_GEMM_TN_TYPE ;
2015-01-12 13:20:53 -05:00
}
2015-09-30 15:31:41 -04:00
math_expression res ( A , B , op_element ( OPERATOR_GEMM_TYPE_FAMILY , type ) , A . context ( ) , A . dtype ( ) , shape ) ;
math_expression : : node & res_root = const_cast < math_expression : : node & > ( res . tree ( ) [ res . root ( ) ] ) ;
2015-01-12 13:20:53 -05:00
if ( A_trans ) res_root . lhs = A_root . lhs ;
return res ;
}
2015-11-19 12:37:18 -05:00
math_expression matmatprod ( array_base const & A , math_expression const & B )
2015-01-12 13:20:53 -05:00
{
2015-07-11 09:36:01 -04:00
operation_node_type type = OPERATOR_GEMM_NN_TYPE ;
2015-11-19 12:37:18 -05:00
shape_t shape { A . shape ( ) [ 0 ] , B . shape ( ) [ 1 ] } ;
2015-01-12 13:20:53 -05:00
2015-09-30 15:31:41 -04:00
math_expression : : node & B_root = const_cast < math_expression : : node & > ( B . tree ( ) [ B . root ( ) ] ) ;
2015-01-12 13:20:53 -05:00
bool B_trans = B_root . op . type = = OPERATOR_TRANS_TYPE ;
if ( B_trans ) {
2015-07-11 09:36:01 -04:00
type = OPERATOR_GEMM_NT_TYPE ;
2015-01-12 13:20:53 -05:00
}
2015-04-29 15:50:57 -04:00
2015-09-30 15:31:41 -04:00
math_expression res ( A , B , op_element ( OPERATOR_GEMM_TYPE_FAMILY , type ) , A . context ( ) , A . dtype ( ) , shape ) ;
math_expression : : node & res_root = const_cast < math_expression : : node & > ( res . tree ( ) [ res . root ( ) ] ) ;
2015-01-12 13:20:53 -05:00
if ( B_trans ) res_root . rhs = B_root . lhs ;
return res ;
}
2015-09-30 15:31:41 -04:00
math_expression matmatprod ( math_expression const & A , math_expression const & B )
2015-01-12 13:20:53 -05:00
{
2015-07-11 09:36:01 -04:00
operation_node_type type = OPERATOR_GEMM_NN_TYPE ;
2015-09-30 15:31:41 -04:00
math_expression : : node & A_root = const_cast < math_expression : : node & > ( A . tree ( ) [ A . root ( ) ] ) ;
math_expression : : node & B_root = const_cast < math_expression : : node & > ( B . tree ( ) [ B . root ( ) ] ) ;
2015-11-19 12:37:18 -05:00
shape_t shape { A . shape ( ) [ 0 ] , B . shape ( ) [ 1 ] } ;
2015-01-12 13:20:53 -05:00
bool A_trans = A_root . op . type = = OPERATOR_TRANS_TYPE ;
bool B_trans = B_root . op . type = = OPERATOR_TRANS_TYPE ;
2015-06-27 13:53:31 -04:00
2015-07-11 09:36:01 -04:00
if ( A_trans & & B_trans ) type = OPERATOR_GEMM_TT_TYPE ;
else if ( A_trans & & ! B_trans ) type = OPERATOR_GEMM_TN_TYPE ;
else if ( ! A_trans & & B_trans ) type = OPERATOR_GEMM_NT_TYPE ;
else type = OPERATOR_GEMM_NN_TYPE ;
2015-01-12 13:20:53 -05:00
2015-09-30 15:31:41 -04:00
math_expression res ( A , B , op_element ( OPERATOR_GEMM_TYPE_FAMILY , type ) , A . context ( ) , A . dtype ( ) , shape ) ;
math_expression : : node & res_root = const_cast < math_expression : : node & > ( res . tree ( ) [ res . root ( ) ] ) ;
2015-01-12 13:20:53 -05:00
if ( A_trans ) res_root . lhs = A_root . lhs ;
if ( B_trans ) res_root . rhs = B_root . lhs ;
return res ;
}
template < class T >
2015-11-19 12:37:18 -05:00
math_expression matvecprod ( array_base const & A , T const & x )
2015-01-12 13:20:53 -05:00
{
2015-04-29 15:50:57 -04:00
int_t M = A . shape ( ) [ 0 ] ;
int_t N = A . shape ( ) [ 1 ] ;
2015-11-19 12:37:18 -05:00
return sum ( A * repmat ( reshape ( x , { 1 , N } ) , M , 1 ) , 1 ) ;
2015-01-12 13:20:53 -05:00
}
template < class T >
2015-09-30 15:31:41 -04:00
math_expression matvecprod ( math_expression const & A , T const & x )
2015-01-12 13:20:53 -05:00
{
2015-04-29 15:50:57 -04:00
int_t M = A . shape ( ) [ 0 ] ;
int_t N = A . shape ( ) [ 1 ] ;
2015-09-30 15:31:41 -04:00
math_expression : : node & A_root = const_cast < math_expression : : node & > ( A . tree ( ) [ A . root ( ) ] ) ;
2015-01-12 13:20:53 -05:00
bool A_trans = A_root . op . type = = OPERATOR_TRANS_TYPE ;
2015-06-30 17:55:57 -04:00
while ( A_root . lhs . type_family = = COMPOSITE_OPERATOR_FAMILY ) {
A_root = A . tree ( ) [ A_root . lhs . node_index ] ;
A_trans ^ = A_root . op . type = = OPERATOR_TRANS_TYPE ;
}
2015-01-17 15:47:52 -05:00
if ( A_trans )
{
2015-11-19 12:37:18 -05:00
math_expression tmp ( A , repmat ( x , 1 , M ) , op_element ( OPERATOR_BINARY_TYPE_FAMILY , OPERATOR_ELEMENT_PROD_TYPE ) , A . context ( ) , A . dtype ( ) , { N , M } ) ;
2015-01-12 13:20:53 -05:00
//Remove trans
2015-01-16 07:31:39 -05:00
tmp . tree ( ) [ tmp . root ( ) ] . lhs = A . tree ( ) [ A . root ( ) ] . lhs ;
2015-05-01 15:57:03 -04:00
return sum ( tmp , 0 ) ;
2015-01-12 13:20:53 -05:00
}
else
2015-11-19 12:37:18 -05:00
return sum ( A * repmat ( reshape ( x , { 1 , N } ) , M , 1 ) , 1 ) ;
2015-01-12 13:20:53 -05:00
}
}
2015-11-19 12:37:18 -05:00
//Swap
ISAACAPI void swap ( view x , view y )
{
//Seems like some compilers will generate incorrect code without the 1*...
execute ( fuse ( assign ( y , 1 * x ) , assign ( x , 1 * y ) ) ) ;
}
//Reshape
math_expression reshape ( array_base const & x , shape_t const & shape )
{ return math_expression ( x , invalid_node ( ) , op_element ( OPERATOR_UNARY_TYPE_FAMILY , OPERATOR_RESHAPE_TYPE ) , x . context ( ) , x . dtype ( ) , shape ) ; }
2015-04-29 15:50:57 -04:00
2015-11-19 12:37:18 -05:00
math_expression reshape ( math_expression const & x , shape_t const & shape )
{ return math_expression ( x , invalid_node ( ) , op_element ( OPERATOR_UNARY_TYPE_FAMILY , OPERATOR_RESHAPE_TYPE ) , x . context ( ) , x . dtype ( ) , shape ) ; }
2015-01-18 14:52:45 -05:00
2015-11-19 12:37:18 -05:00
math_expression ravel ( array_base const & x )
{ return reshape ( x , { x . shape ( ) . prod ( ) } ) ; }
2015-01-18 14:52:45 -05:00
2015-01-12 13:20:53 -05:00
# define DEFINE_DOT(LTYPE, RTYPE) \
2015-09-30 15:31:41 -04:00
math_expression dot ( LTYPE const & x , RTYPE const & y ) \
2015-01-12 13:20:53 -05:00
{ \
2015-11-19 12:37:18 -05:00
numeric_type dtype = x . dtype ( ) ; \
driver : : Context const & context = x . context ( ) ; \
2015-11-25 18:42:25 -05:00
if ( x . shape ( ) . max ( ) = = 1 | | y . shape ( ) . max ( ) = = 1 ) \
return x * y ; \
2015-11-19 12:37:18 -05:00
if ( x . dim ( ) = = 2 & & x . shape ( ) [ 1 ] = = 0 ) \
return zeros ( x . shape ( ) [ 0 ] , y . shape ( ) [ 1 ] , dtype , context ) ; \
if ( x . shape ( ) [ 0 ] = = 0 | | ( y . dim ( ) = = 2 & & y . shape ( ) [ 1 ] = = 0 ) ) \
return math_expression ( invalid_node ( ) , invalid_node ( ) , op_element ( OPERATOR_UNARY_TYPE_FAMILY , OPERATOR_INVALID_TYPE ) , context , dtype , { 0 } ) ; \
if ( x . dim ( ) = = 1 & & y . dim ( ) = = 1 ) \
return sum ( x * y ) ; \
if ( x . dim ( ) = = 2 & & x . shape ( ) [ 0 ] = = 1 & & y . dim ( ) = = 1 ) { \
if ( y . shape ( ) [ 0 ] = = 1 ) \
return reshape ( x * y , { x . shape ( ) . max ( ) } ) ; \
2015-06-30 17:55:57 -04:00
else \
return sum ( x * y ) ; \
2015-01-12 13:20:53 -05:00
} \
2015-11-19 12:37:18 -05:00
if ( x . dim ( ) = = 2 & & y . dim ( ) = = 1 ) { \
if ( y . shape ( ) [ 0 ] = = 1 ) \
return reshape ( x * y , { x . shape ( ) . max ( ) } ) ; \
else \
return detail : : matvecprod ( x , y ) ; \
} \
if ( x . dim ( ) = = 1 & & y . dim ( ) = = 2 ) { \
if ( x . shape ( ) [ 0 ] = = 1 ) \
return reshape ( x * y , { y . shape ( ) . max ( ) } ) ; \
else \
return trans ( detail : : matvecprod ( trans ( y ) , trans ( x ) ) ) ; \
} \
2015-11-25 18:42:25 -05:00
if ( x . shape ( ) [ 0 ] = = 1 & & y . shape ( ) [ 1 ] = = 1 ) \
2015-11-19 12:37:18 -05:00
return sum ( x * trans ( y ) ) ; \
2015-11-25 18:42:25 -05:00
if ( x . shape ( ) [ 0 ] = = 1 & & y . shape ( ) [ 1 ] = = 2 ) \
return trans ( detail : : matvecprod ( trans ( y ) , trans ( x ) ) ) ; \
if ( x . shape ( ) [ 1 ] = = 1 & & y . shape ( ) [ 0 ] = = 1 ) \
return x * y ; \
2015-11-19 12:37:18 -05:00
else /*if(x.dim()==2 && y.dim()==2)*/ \
2015-01-12 13:20:53 -05:00
return detail : : matmatprod ( x , y ) ; \
}
2015-11-19 12:37:18 -05:00
DEFINE_DOT ( array_base , array_base )
DEFINE_DOT ( math_expression , array_base )
DEFINE_DOT ( array_base , math_expression )
2015-09-30 15:31:41 -04:00
DEFINE_DOT ( math_expression , math_expression )
2015-01-12 13:20:53 -05:00
# undef DEFINE_DOT
2015-01-17 10:48:02 -05:00
2015-01-16 07:31:39 -05:00
# define DEFINE_NORM(TYPE)\
2015-09-30 15:31:41 -04:00
math_expression norm ( TYPE const & x , unsigned int order ) \
2015-01-16 07:31:39 -05:00
{ \
assert ( order > 0 & & order < 3 ) ; \
switch ( order ) \
{ \
case 1 : return sum ( abs ( x ) ) ; \
default : return sqrt ( sum ( pow ( x , 2 ) ) ) ; \
} \
}
2015-11-19 12:37:18 -05:00
DEFINE_NORM ( array_base )
2015-09-30 15:31:41 -04:00
DEFINE_NORM ( math_expression )
2015-01-12 13:20:53 -05:00
2015-01-16 07:31:39 -05:00
# undef DEFINE_NORM
2015-01-12 13:20:53 -05:00
2015-09-30 15:31:41 -04:00
/*--- Fusion ----*/
math_expression fuse ( math_expression const & x , math_expression const & y )
{
assert ( x . context ( ) = = y . context ( ) ) ;
return math_expression ( x , y , op_element ( OPERATOR_BINARY_TYPE_FAMILY , OPERATOR_FUSE ) , x . context ( ) , x . dtype ( ) , x . shape ( ) ) ;
}
/*--- For loops ---*/
ISAACAPI math_expression sfor ( math_expression const & start , math_expression const & end , math_expression const & inc , math_expression const & x )
{
return math_expression ( x , make_tuple ( x . context ( ) , start , end , inc ) , op_element ( OPERATOR_UNARY_TYPE_FAMILY , OPERATOR_SFOR_TYPE ) , x . context ( ) , x . dtype ( ) , x . shape ( ) ) ;
}
2015-01-12 13:20:53 -05:00
/*--- Copy ----*/
//---------------------------------------
//void*
2015-11-19 12:37:18 -05:00
void copy ( void const * data , array_base & x , driver : : CommandQueue & queue , bool blocking )
2015-01-12 13:20:53 -05:00
{
unsigned int dtypesize = size_of ( x . dtype ( ) ) ;
2015-11-19 12:37:18 -05:00
if ( x . start ( ) = = 0 & & x . shape ( ) [ 0 ] * x . stride ( ) . prod ( ) = = x . shape ( ) . prod ( ) )
2015-01-12 13:20:53 -05:00
{
2015-11-19 12:37:18 -05:00
queue . write ( x . data ( ) , blocking , 0 , x . shape ( ) . prod ( ) * dtypesize , data ) ;
2015-01-12 13:20:53 -05:00
}
else
{
2015-11-19 12:37:18 -05:00
array tmp ( x . dtype ( ) , x . shape ( ) , x . context ( ) ) ;
queue . write ( tmp . data ( ) , blocking , 0 , tmp . shape ( ) . prod ( ) * dtypesize , data ) ;
2015-01-12 13:20:53 -05:00
x = tmp ;
}
}
2015-11-19 12:37:18 -05:00
void copy ( array_base const & x , void * data , driver : : CommandQueue & queue , bool blocking )
2015-01-12 13:20:53 -05:00
{
unsigned int dtypesize = size_of ( x . dtype ( ) ) ;
2015-11-25 18:42:25 -05:00
if ( x . start ( ) = = 0 & & x . stride ( ) . prod ( ) = = x . shape ( ) . prod ( ) )
2015-01-12 13:20:53 -05:00
{
2015-11-19 12:37:18 -05:00
queue . read ( x . data ( ) , blocking , 0 , x . shape ( ) . prod ( ) * dtypesize , data ) ;
2015-01-12 13:20:53 -05:00
}
else
{
2015-11-19 12:37:18 -05:00
array tmp ( x . dtype ( ) , x . shape ( ) , x . context ( ) ) ;
2015-01-12 13:20:53 -05:00
tmp = x ;
2015-11-19 12:37:18 -05:00
queue . read ( tmp . data ( ) , blocking , 0 , tmp . shape ( ) . prod ( ) * dtypesize , data ) ;
2015-01-12 13:20:53 -05:00
}
}
2015-11-19 12:37:18 -05:00
void copy ( void const * data , array_base & x , bool blocking )
{
copy ( data , x , driver : : backend : : queues : : get ( x . context ( ) , 0 ) , blocking ) ;
}
2015-01-12 13:20:53 -05:00
2015-11-19 12:37:18 -05:00
void copy ( array_base const & x , void * data , bool blocking )
{
copy ( x , data , driver : : backend : : queues : : get ( x . context ( ) , 0 ) , blocking ) ;
}
2015-01-12 13:20:53 -05:00
//std::vector<>
template < class T >
2015-11-19 12:37:18 -05:00
void copy ( std : : vector < T > const & cx , array_base & x , driver : : CommandQueue & queue , bool blocking )
2015-01-12 13:20:53 -05:00
{
2015-11-19 12:37:18 -05:00
assert ( ( int_t ) cx . size ( ) = = x . shape ( ) . prod ( ) ) ;
2015-01-12 13:20:53 -05:00
copy ( ( void const * ) cx . data ( ) , x , queue , blocking ) ;
}
template < class T >
2015-11-19 12:37:18 -05:00
void copy ( array_base const & x , std : : vector < T > & cx , driver : : CommandQueue & queue , bool blocking )
2015-01-12 13:20:53 -05:00
{
2015-11-19 12:37:18 -05:00
assert ( ( int_t ) cx . size ( ) = = x . shape ( ) . prod ( ) ) ;
2015-01-12 13:20:53 -05:00
copy ( x , ( void * ) cx . data ( ) , queue , blocking ) ;
}
template < class T >
2015-11-19 12:37:18 -05:00
void copy ( std : : vector < T > const & cx , array_base & x , bool blocking )
{
copy ( cx , x , driver : : backend : : queues : : get ( x . context ( ) , 0 ) , blocking ) ;
}
2015-01-12 13:20:53 -05:00
template < class T >
2015-11-19 12:37:18 -05:00
void copy ( array_base const & x , std : : vector < T > & cx , bool blocking )
{
copy ( x , cx , driver : : backend : : queues : : get ( x . context ( ) , 0 ) , blocking ) ;
}
2015-01-12 13:20:53 -05:00
# define INSTANTIATE(T) \
2015-11-19 12:37:18 -05:00
template void ISAACAPI copy < T > ( std : : vector < T > const & , array_base & , driver : : CommandQueue & , bool ) ; \
template void ISAACAPI copy < T > ( array_base const & , std : : vector < T > & , driver : : CommandQueue & , bool ) ; \
template void ISAACAPI copy < T > ( std : : vector < T > const & , array_base & , bool ) ; \
template void ISAACAPI copy < T > ( array_base const & , std : : vector < T > & , bool )
2015-01-12 13:20:53 -05:00
2015-05-04 21:23:05 -04:00
INSTANTIATE ( char ) ;
INSTANTIATE ( unsigned char ) ;
INSTANTIATE ( short ) ;
INSTANTIATE ( unsigned short ) ;
INSTANTIATE ( int ) ;
INSTANTIATE ( unsigned int ) ;
INSTANTIATE ( long ) ;
INSTANTIATE ( unsigned long ) ;
2015-05-04 23:54:43 -04:00
INSTANTIATE ( long long ) ;
INSTANTIATE ( unsigned long long ) ;
2015-05-04 21:23:05 -04:00
INSTANTIATE ( float ) ;
INSTANTIATE ( double ) ;
2015-01-12 13:20:53 -05:00
# undef INSTANTIATE
/*--- Stream operators----*/
//---------------------------------------
namespace detail
{
template < typename ItType >
static std : : ostream & prettyprint ( std : : ostream & os , ItType begin , ItType const & end , size_t stride = 1 , bool col = false , size_t WINDOW = 10 )
{
if ( ! col )
os < < " [ " ;
size_t N = ( end - begin ) / stride ;
size_t upper = std : : min ( WINDOW , N ) ;
for ( size_t j = 0 ; j < upper ; j + + )
{
2015-01-13 14:44:19 -05:00
os < < * begin ;
if ( j < upper - 1 )
os < < " , " ;
2015-01-12 13:20:53 -05:00
begin + = stride ;
}
if ( upper < N )
{
if ( N - upper > WINDOW )
os < < " , ... " ;
for ( size_t j = std : : max ( N - WINDOW , upper ) ; j < N ; j + + )
{
2015-01-13 14:44:19 -05:00
os < < " , " < < * begin ;
2015-01-12 13:20:53 -05:00
begin + = stride ;
}
}
if ( ! col )
os < < " ] " ;
return os ;
}
}
2015-11-19 12:37:18 -05:00
std : : ostream & operator < < ( std : : ostream & os , array_base const & a )
2015-01-12 13:20:53 -05:00
{
size_t WINDOW = 10 ;
2015-01-20 21:02:24 -05:00
numeric_type dtype = a . dtype ( ) ;
2015-04-29 15:50:57 -04:00
size_t M = a . shape ( ) [ 0 ] ;
2015-11-19 12:37:18 -05:00
size_t N = ( a . dim ( ) = = 1 ) ? 1 : a . shape ( ) [ 1 ] ;
2015-01-13 14:44:19 -05:00
2015-11-19 12:37:18 -05:00
void * tmp = new char [ a . shape ( ) . prod ( ) * size_of ( dtype ) ] ;
2015-01-20 21:02:24 -05:00
copy ( a , ( void * ) tmp ) ;
2015-01-12 13:20:53 -05:00
os < < " [ " ;
size_t upper = std : : min ( WINDOW , M ) ;
2015-01-20 21:02:24 -05:00
# define HANDLE(ADTYPE, CTYPE) case ADTYPE: detail::prettyprint(os, reinterpret_cast<CTYPE*>(tmp) + i, reinterpret_cast<CTYPE*>(tmp) + M*N + i, M, true, WINDOW); break;
2015-01-12 13:20:53 -05:00
for ( unsigned int i = 0 ; i < upper ; + + i )
{
if ( i > 0 )
os < < " " ;
2015-01-20 21:02:24 -05:00
switch ( dtype )
{
2015-01-29 15:19:40 -05:00
// HANDLE(BOOL_TYPE, cl_bool)
2015-05-04 21:23:05 -04:00
HANDLE ( CHAR_TYPE , char )
HANDLE ( UCHAR_TYPE , unsigned char )
HANDLE ( SHORT_TYPE , short )
HANDLE ( USHORT_TYPE , unsigned short )
HANDLE ( INT_TYPE , int )
HANDLE ( UINT_TYPE , unsigned int )
HANDLE ( LONG_TYPE , long )
HANDLE ( ULONG_TYPE , unsigned long )
2015-01-29 15:19:40 -05:00
// HANDLE(HALF_TYPE, cl_half)
2015-05-04 21:23:05 -04:00
HANDLE ( FLOAT_TYPE , float )
HANDLE ( DOUBLE_TYPE , double )
2015-01-20 21:02:24 -05:00
default : throw unknown_datatype ( dtype ) ;
}
2015-01-12 13:20:53 -05:00
if ( i < upper - 1 )
os < < std : : endl ;
}
if ( upper < M )
{
if ( N - upper > WINDOW )
os < < std : : endl < < " ... " ;
for ( size_t i = std : : max ( N - WINDOW , upper ) ; i < N ; i + + )
{
os < < std : : endl < < " " ;
2015-01-20 21:02:24 -05:00
switch ( dtype )
{
2015-01-29 15:19:40 -05:00
// HANDLE(BOOL_TYPE, cl_bool)
2015-05-04 21:23:05 -04:00
HANDLE ( CHAR_TYPE , char )
HANDLE ( UCHAR_TYPE , unsigned char )
HANDLE ( SHORT_TYPE , short )
HANDLE ( USHORT_TYPE , unsigned short )
HANDLE ( INT_TYPE , int )
HANDLE ( UINT_TYPE , unsigned int )
HANDLE ( LONG_TYPE , long )
HANDLE ( ULONG_TYPE , unsigned long )
2015-01-29 15:19:40 -05:00
// HANDLE(HALF_TYPE, cl_half)
2015-05-04 21:23:05 -04:00
HANDLE ( FLOAT_TYPE , float )
HANDLE ( DOUBLE_TYPE , double )
2015-01-20 21:02:24 -05:00
default : throw unknown_datatype ( dtype ) ;
}
2015-01-12 13:20:53 -05:00
}
}
os < < " ] " ;
return os ;
}
2015-11-19 12:37:18 -05:00
ISAACAPI std : : ostream & operator < < ( std : : ostream & oss , math_expression const & expression )
{
return oss < < array ( expression ) ;
}
2015-01-12 13:20:53 -05:00
}