2015-01-12 13:20:53 -05:00
# include <cassert>
2015-07-20 18:02:56 -07:00
# include <algorithm>
# include <stdexcept>
2015-01-12 13:20:53 -05:00
2015-04-29 15:50:57 -04:00
# include "isaac/array.h"
# include "isaac/exception/unknown_datatype.h"
# include "isaac/model/model.h"
# include "isaac/symbolic/execute.h"
2015-01-12 13:20:53 -05:00
2015-01-28 17:08:39 -05:00
2015-04-29 15:50:57 -04:00
namespace isaac
2015-01-12 13:20:53 -05:00
{
2015-07-21 17:18:50 -04:00
namespace detail
{
inline int_t max ( size4 const & s ) { return std : : max ( s [ 0 ] , s [ 1 ] ) ; }
}
2015-02-04 22:06:15 -05:00
2015-01-12 13:20:53 -05:00
/*--- Constructors ---*/
//1D Constructors
2015-04-29 15:50:57 -04:00
array : : array ( int_t shape0 , numeric_type dtype , driver : : Context context ) :
dtype_ ( dtype ) , shape_ ( shape0 , 1 , 1 , 1 ) , start_ ( 0 , 0 , 0 , 0 ) , stride_ ( 1 , 1 , 1 , 1 ) , ld_ ( shape_ [ 0 ] ) ,
context_ ( context ) , data_ ( context_ , size_of ( dtype ) * dsize ( ) )
2015-02-04 22:06:15 -05:00
{ }
2015-01-12 13:20:53 -05:00
2015-06-25 08:12:16 -07:00
array : : array ( int_t shape0 , numeric_type dtype , driver : : Buffer data , int_t start , int_t inc ) :
dtype_ ( dtype ) , shape_ ( shape0 ) , start_ ( start , 0 , 0 , 0 ) , stride_ ( inc ) , ld_ ( shape_ [ 0 ] ) , context_ ( data . context ( ) ) , data_ ( data )
{ }
2015-01-21 20:08:52 -05:00
template < class DT >
2015-04-29 15:50:57 -04:00
array : : array ( std : : vector < DT > const & x , driver : : Context context ) :
dtype_ ( to_numeric_type < DT > : : value ) , shape_ ( x . size ( ) , 1 ) , start_ ( 0 , 0 , 0 , 0 ) , stride_ ( 1 , 1 , 1 , 1 ) , ld_ ( shape_ [ 0 ] ) ,
context_ ( context ) , data_ ( context , size_of ( dtype_ ) * dsize ( ) )
2015-01-12 13:20:53 -05:00
{ * this = x ; }
2015-04-29 15:50:57 -04:00
array : : array ( array & v , slice const & s0 ) : dtype_ ( v . dtype_ ) , shape_ ( s0 . size , 1 , 1 , 1 ) , start_ ( v . start_ [ 0 ] + v . stride_ [ 0 ] * s0 . start , 0 , 0 , 0 ) , stride_ ( v . stride_ [ 0 ] * s0 . stride , 1 , 1 , 1 ) ,
ld_ ( v . ld_ ) , context_ ( v . data_ . context ( ) ) , data_ ( v . data_ )
2015-01-17 15:47:52 -05:00
{ }
2015-01-12 13:20:53 -05:00
2015-07-21 23:48:50 -07:00
# define INSTANTIATE(T) template ISAACAPI array::array(std::vector<T> const &, driver::Context)
2015-05-04 21:23:05 -04:00
INSTANTIATE ( char ) ;
INSTANTIATE ( unsigned char ) ;
INSTANTIATE ( short ) ;
INSTANTIATE ( unsigned short ) ;
INSTANTIATE ( int ) ;
INSTANTIATE ( unsigned int ) ;
INSTANTIATE ( long ) ;
INSTANTIATE ( unsigned long ) ;
2015-05-04 23:54:43 -04:00
INSTANTIATE ( long long ) ;
INSTANTIATE ( unsigned long long ) ;
2015-05-04 21:23:05 -04:00
INSTANTIATE ( float ) ;
INSTANTIATE ( double ) ;
2015-01-12 13:20:53 -05:00
# undef INSTANTIATE
// 2D
2015-06-25 08:12:16 -07:00
array : : array ( int_t shape0 , int_t shape1 , numeric_type dtype , driver : : Context context ) : dtype_ ( dtype ) , shape_ ( shape0 , shape1 ) , start_ ( 0 , 0 , 0 , 0 ) , stride_ ( 1 , 1 , 1 , 1 ) , ld_ ( shape0 ) ,
2015-04-29 15:50:57 -04:00
context_ ( context ) , data_ ( context_ , size_of ( dtype_ ) * dsize ( ) )
2015-01-12 13:20:53 -05:00
{ }
2015-06-25 08:12:16 -07:00
array : : array ( int_t shape0 , int_t shape1 , numeric_type dtype , driver : : Buffer data , int_t start , int_t ld ) :
2015-06-27 11:44:50 -04:00
dtype_ ( dtype ) , shape_ ( shape0 , shape1 ) , start_ ( start , 0 , 0 , 0 ) , stride_ ( 1 , 1 , 1 , 1 ) , ld_ ( ld ) , context_ ( data . context ( ) ) , data_ ( data )
2015-06-25 08:12:16 -07:00
{ }
2015-04-29 15:50:57 -04:00
array : : array ( array & M , slice const & s0 , slice const & s1 ) : dtype_ ( M . dtype_ ) , shape_ ( s0 . size , s1 . size , 1 , 1 ) ,
start_ ( M . start_ [ 0 ] + M . stride_ [ 0 ] * s0 . start , M . start_ [ 1 ] + M . stride_ [ 1 ] * s1 . start , 0 , 0 ) ,
stride_ ( M . stride_ [ 0 ] * s0 . stride , M . stride_ [ 1 ] * s1 . stride , 1 , 1 ) , ld_ ( M . ld_ ) ,
context_ ( M . data_ . context ( ) ) , data_ ( M . data_ )
2015-01-12 13:20:53 -05:00
{ }
2015-06-28 17:53:16 -07:00
2015-01-21 20:08:52 -05:00
template < typename DT >
2015-04-29 15:50:57 -04:00
array : : array ( int_t shape0 , int_t shape1 , std : : vector < DT > const & data , driver : : Context context )
2015-02-04 22:06:15 -05:00
: dtype_ ( to_numeric_type < DT > : : value ) ,
2015-04-29 15:50:57 -04:00
shape_ ( shape0 , shape1 ) , start_ ( 0 , 0 ) , stride_ ( 1 , 1 ) , ld_ ( shape0 ) ,
context_ ( context ) , data_ ( context_ , size_of ( dtype_ ) * dsize ( ) )
2015-01-12 13:20:53 -05:00
{
2015-04-29 15:50:57 -04:00
isaac : : copy ( data , * this ) ;
2015-01-12 13:20:53 -05:00
}
2015-06-25 08:12:16 -07:00
// 3D
array : : array ( int_t shape0 , int_t shape1 , int_t shape2 , numeric_type dtype , driver : : Context context ) : dtype_ ( dtype ) , shape_ ( shape0 , shape1 , shape2 , 1 ) , start_ ( 0 , 0 , 0 , 0 ) , stride_ ( 1 , 1 , 1 , 1 ) , ld_ ( shape0 ) ,
context_ ( context ) , data_ ( context_ , size_of ( dtype_ ) * dsize ( ) )
{ }
//Slices
array : : array ( numeric_type dtype , driver : : Buffer data , slice const & s0 , slice const & s1 , int_t ld ) :
dtype_ ( dtype ) , shape_ ( s0 . size , s1 . size ) , start_ ( s0 . start , s1 . start ) , stride_ ( s0 . stride , s1 . stride ) ,
ld_ ( ld ) , context_ ( data . context ( ) ) , data_ ( data )
{ }
2015-07-21 23:48:50 -07:00
# define INSTANTIATE(T) template ISAACAPI array::array(int_t, int_t, std::vector<T> const &, driver::Context)
2015-05-04 21:23:05 -04:00
INSTANTIATE ( char ) ;
INSTANTIATE ( unsigned char ) ;
INSTANTIATE ( short ) ;
INSTANTIATE ( unsigned short ) ;
INSTANTIATE ( int ) ;
INSTANTIATE ( unsigned int ) ;
INSTANTIATE ( long ) ;
INSTANTIATE ( unsigned long ) ;
2015-05-04 23:54:43 -04:00
INSTANTIATE ( long long ) ;
INSTANTIATE ( unsigned long long ) ;
2015-05-04 21:23:05 -04:00
INSTANTIATE ( float ) ;
INSTANTIATE ( double ) ;
2015-01-12 13:20:53 -05:00
# undef INSTANTIATE
2015-02-06 22:11:03 -05:00
array : : array ( array_expression const & proxy ) : array ( control ( proxy ) ) { }
array : : array ( array const & other ) : array ( control ( other ) ) { }
2015-01-12 13:20:53 -05:00
2015-02-06 22:11:03 -05:00
template < class TYPE >
array : : array ( controller < TYPE > const & other ) :
dtype_ ( other . x ( ) . dtype ( ) ) ,
2015-04-29 15:50:57 -04:00
shape_ ( other . x ( ) . shape ( ) ) , start_ ( 0 , 0 ) , stride_ ( 1 , 1 ) , ld_ ( shape_ [ 0 ] ) ,
context_ ( other . x ( ) . context ( ) ) , data_ ( context_ , size_of ( dtype_ ) * dsize ( ) )
2015-01-18 14:52:45 -05:00
{
2015-02-04 22:06:15 -05:00
* this = other ;
2015-01-18 14:52:45 -05:00
}
2015-07-21 23:48:50 -07:00
template ISAACAPI array : : array ( controller < array > const & ) ;
template ISAACAPI array : : array ( controller < array_expression > const & ) ;
2015-02-06 22:11:03 -05:00
2015-01-12 13:20:53 -05:00
/*--- Getters ---*/
numeric_type array : : dtype ( ) const
2015-02-04 22:06:15 -05:00
{ return dtype_ ; }
2015-01-12 13:20:53 -05:00
2015-04-29 15:50:57 -04:00
size4 const & array : : shape ( ) const
2015-02-04 22:06:15 -05:00
{ return shape_ ; }
2015-01-12 13:20:53 -05:00
int_t array : : nshape ( ) const
2015-04-29 15:50:57 -04:00
{ return int_t ( ( shape_ [ 0 ] > 1 ) + ( shape_ [ 1 ] > 1 ) ) ; }
2015-01-12 13:20:53 -05:00
2015-04-29 15:50:57 -04:00
size4 const & array : : start ( ) const
2015-02-04 22:06:15 -05:00
{ return start_ ; }
2015-01-12 13:20:53 -05:00
2015-04-29 15:50:57 -04:00
size4 const & array : : stride ( ) const
2015-02-04 22:06:15 -05:00
{ return stride_ ; }
2015-01-12 13:20:53 -05:00
2015-04-29 15:50:57 -04:00
int_t const & array : : ld ( ) const
2015-02-04 22:06:15 -05:00
{ return ld_ ; }
2015-01-12 13:20:53 -05:00
2015-04-29 15:50:57 -04:00
driver : : Context const & array : : context ( ) const
2015-01-12 13:20:53 -05:00
{ return context_ ; }
2015-04-29 15:50:57 -04:00
driver : : Buffer const & array : : data ( ) const
2015-01-12 13:20:53 -05:00
{ return data_ ; }
2015-04-29 15:50:57 -04:00
driver : : Buffer & array : : data ( )
{ return data_ ; }
2015-01-12 13:20:53 -05:00
int_t array : : dsize ( ) const
2015-04-29 15:50:57 -04:00
{ return ld_ * shape_ [ 1 ] * shape_ [ 2 ] * shape_ [ 3 ] ; }
2015-01-12 13:20:53 -05:00
/*--- Assignment Operators ----*/
//---------------------------------------
array & array : : operator = ( array const & rhs )
2015-02-05 04:42:57 -05:00
{ return * this = controller < array > ( rhs ) ; }
2015-01-12 13:20:53 -05:00
2015-02-04 22:06:15 -05:00
array & array : : operator = ( array_expression const & rhs )
2015-02-05 04:42:57 -05:00
{ return * this = controller < array_expression > ( rhs ) ; }
template < class TYPE >
array & array : : operator = ( controller < TYPE > const & c )
2015-01-12 13:20:53 -05:00
{
2015-02-05 04:42:57 -05:00
assert ( dtype_ = = c . x ( ) . dtype ( ) ) ;
2015-02-08 00:56:24 -05:00
array_expression expression ( * this , c . x ( ) , op_element ( OPERATOR_BINARY_TYPE_FAMILY , OPERATOR_ASSIGN_TYPE ) , context_ , dtype_ , shape_ ) ;
execute ( controller < array_expression > ( expression , c . execution_options ( ) , c . dispatcher_options ( ) , c . compilation_options ( ) ) ,
2015-06-23 09:38:34 -07:00
isaac : : models ( c . execution_options ( ) . queue ( context_ ) ) ) ;
2015-01-19 21:29:47 -05:00
return * this ;
2015-01-12 13:20:53 -05:00
}
2015-01-21 20:08:52 -05:00
template < class DT >
array & array : : operator = ( std : : vector < DT > const & rhs )
2015-01-12 13:20:53 -05:00
{
2015-06-29 21:52:50 -07:00
assert ( nshape ( ) < = 1 ) ;
2015-04-29 15:50:57 -04:00
isaac : : copy ( rhs , * this ) ;
2015-01-12 13:20:53 -05:00
return * this ;
}
2015-06-27 11:44:50 -04:00
array & array : : operator = ( value_scalar const & rhs )
{ return * this = controller < value_scalar > ( rhs ) ; }
2015-07-21 23:48:50 -07:00
# define INSTANTIATE(T) template ISAACAPI array & array::operator=<T>(std::vector<T> const &)
2015-01-12 13:20:53 -05:00
2015-05-04 21:23:05 -04:00
INSTANTIATE ( char ) ;
INSTANTIATE ( unsigned char ) ;
INSTANTIATE ( short ) ;
INSTANTIATE ( unsigned short ) ;
INSTANTIATE ( int ) ;
INSTANTIATE ( unsigned int ) ;
INSTANTIATE ( long ) ;
INSTANTIATE ( unsigned long ) ;
2015-05-04 23:54:43 -04:00
INSTANTIATE ( long long ) ;
INSTANTIATE ( unsigned long long ) ;
2015-05-04 21:23:05 -04:00
INSTANTIATE ( float ) ;
INSTANTIATE ( double ) ;
2015-01-12 13:20:53 -05:00
# undef INSTANTIATE
2015-01-19 21:29:47 -05:00
array_expression array : : operator - ( )
2015-02-04 22:06:15 -05:00
{ return array_expression ( * this , invalid_node ( ) , op_element ( OPERATOR_UNARY_TYPE_FAMILY , OPERATOR_SUB_TYPE ) , context_ , dtype_ , shape_ ) ; }
2015-01-19 21:29:47 -05:00
2015-01-29 15:19:40 -05:00
array_expression array : : operator ! ( )
2015-02-04 22:06:15 -05:00
{ return array_expression ( * this , invalid_node ( ) , op_element ( OPERATOR_UNARY_TYPE_FAMILY , OPERATOR_NEGATE_TYPE ) , context_ , INT_TYPE , shape_ ) ; }
2015-01-29 15:19:40 -05:00
2015-01-12 13:20:53 -05:00
//
2015-01-19 21:29:47 -05:00
array & array : : operator + = ( value_scalar const & rhs )
2015-02-04 22:06:15 -05:00
{ return * this = array_expression ( * this , rhs , op_element ( OPERATOR_BINARY_TYPE_FAMILY , OPERATOR_ADD_TYPE ) , context_ , dtype_ , shape_ ) ; }
2015-01-19 21:29:47 -05:00
array & array : : operator + = ( array const & rhs )
2015-02-04 22:06:15 -05:00
{ return * this = array_expression ( * this , rhs , op_element ( OPERATOR_BINARY_TYPE_FAMILY , OPERATOR_ADD_TYPE ) , context_ , dtype_ , shape_ ) ; }
2015-01-19 21:29:47 -05:00
array & array : : operator + = ( array_expression const & rhs )
2015-02-08 00:56:24 -05:00
{ return * this = array_expression ( * this , rhs , op_element ( OPERATOR_BINARY_TYPE_FAMILY , OPERATOR_ADD_TYPE ) , rhs . context ( ) , dtype_ , shape_ ) ; }
2015-01-19 21:29:47 -05:00
//----
array & array : : operator - = ( value_scalar const & rhs )
2015-02-04 22:06:15 -05:00
{ return * this = array_expression ( * this , rhs , op_element ( OPERATOR_BINARY_TYPE_FAMILY , OPERATOR_SUB_TYPE ) , context_ , dtype_ , shape_ ) ; }
2015-01-19 21:29:47 -05:00
array & array : : operator - = ( array const & rhs )
2015-02-04 22:06:15 -05:00
{ return * this = array_expression ( * this , rhs , op_element ( OPERATOR_BINARY_TYPE_FAMILY , OPERATOR_SUB_TYPE ) , context_ , dtype_ , shape_ ) ; }
2015-01-19 21:29:47 -05:00
array & array : : operator - = ( array_expression const & rhs )
2015-02-08 00:56:24 -05:00
{ return * this = array_expression ( * this , rhs , op_element ( OPERATOR_BINARY_TYPE_FAMILY , OPERATOR_SUB_TYPE ) , rhs . context ( ) , dtype_ , shape_ ) ; }
2015-01-19 21:29:47 -05:00
//----
2015-01-12 13:20:53 -05:00
array & array : : operator * = ( value_scalar const & rhs )
2015-02-04 22:06:15 -05:00
{ return * this = array_expression ( * this , rhs , op_element ( OPERATOR_BINARY_TYPE_FAMILY , OPERATOR_MULT_TYPE ) , context_ , dtype_ , shape_ ) ; }
2015-01-12 13:20:53 -05:00
array & array : : operator * = ( array const & rhs )
2015-02-04 22:06:15 -05:00
{ return * this = array_expression ( * this , rhs , op_element ( OPERATOR_BINARY_TYPE_FAMILY , OPERATOR_MULT_TYPE ) , context_ , dtype_ , shape_ ) ; }
2015-01-12 13:20:53 -05:00
array & array : : operator * = ( array_expression const & rhs )
2015-02-08 00:56:24 -05:00
{ return * this = array_expression ( * this , rhs , op_element ( OPERATOR_BINARY_TYPE_FAMILY , OPERATOR_MULT_TYPE ) , rhs . context ( ) , dtype_ , shape_ ) ; }
2015-01-19 21:29:47 -05:00
//----
2015-01-12 13:20:53 -05:00
array & array : : operator / = ( value_scalar const & rhs )
2015-02-04 22:06:15 -05:00
{ return * this = array_expression ( * this , rhs , op_element ( OPERATOR_BINARY_TYPE_FAMILY , OPERATOR_DIV_TYPE ) , context_ , dtype_ , shape_ ) ; }
2015-01-12 13:20:53 -05:00
array & array : : operator / = ( array const & rhs )
2015-02-04 22:06:15 -05:00
{ return * this = array_expression ( * this , rhs , op_element ( OPERATOR_BINARY_TYPE_FAMILY , OPERATOR_DIV_TYPE ) , context_ , dtype_ , shape_ ) ; }
2015-01-12 13:20:53 -05:00
array & array : : operator / = ( array_expression const & rhs )
2015-02-08 00:56:24 -05:00
{ return * this = array_expression ( * this , rhs , op_element ( OPERATOR_BINARY_TYPE_FAMILY , OPERATOR_DIV_TYPE ) , rhs . context ( ) , dtype_ , shape_ ) ; }
2015-01-12 13:20:53 -05:00
2015-01-21 20:08:52 -05:00
array_expression array : : T ( ) const
2015-04-29 15:50:57 -04:00
{ return isaac : : trans ( * this ) ; }
2015-01-21 20:08:52 -05:00
2015-01-12 13:20:53 -05:00
/*--- Indexing operators -----*/
//---------------------------------------
2015-01-13 01:17:27 -05:00
scalar array : : operator [ ] ( int_t idx )
{
2015-06-29 21:52:50 -07:00
assert ( nshape ( ) < = 1 ) ;
2015-06-24 07:51:27 -07:00
return scalar ( dtype_ , data_ , idx ) ;
2015-01-13 01:17:27 -05:00
}
2015-01-19 21:29:47 -05:00
const scalar array : : operator [ ] ( int_t idx ) const
{
2015-06-29 21:52:50 -07:00
assert ( nshape ( ) < = 1 ) ;
2015-06-24 07:51:27 -07:00
return scalar ( dtype_ , data_ , idx ) ;
2015-01-19 21:29:47 -05:00
}
2015-01-12 13:20:53 -05:00
array array : : operator [ ] ( slice const & e1 )
{
2015-06-29 21:52:50 -07:00
assert ( nshape ( ) < = 1 ) ;
2015-01-12 13:20:53 -05:00
return array ( * this , e1 ) ;
}
array array : : operator ( ) ( slice const & e1 , slice const & e2 )
{ return array ( * this , e1 , e2 ) ; }
2015-05-01 21:39:29 -04:00
//---------------------------------------
2015-01-12 13:20:53 -05:00
/*--- Scalar ---*/
namespace detail
{
template < class T >
2015-04-29 15:50:57 -04:00
void copy ( driver : : Context & ctx , driver : : Buffer const & data , T value )
2015-01-12 13:20:53 -05:00
{
2015-04-29 15:50:57 -04:00
driver : : queues [ ctx ] [ 0 ] . write ( data , CL_TRUE , 0 , sizeof ( T ) , ( void * ) & value ) ;
2015-01-12 13:20:53 -05:00
}
}
2015-06-24 07:51:27 -07:00
scalar : : scalar ( numeric_type dtype , const driver : : Buffer & data , int_t offset ) : array ( dtype , data , _ ( offset , offset + 1 ) , _ ( 1 , 2 ) , 1 )
2015-01-13 01:17:27 -05:00
{ }
2015-04-29 15:50:57 -04:00
scalar : : scalar ( value_scalar value , driver : : Context context ) : array ( 1 , value . dtype ( ) , context )
2015-01-12 13:20:53 -05:00
{
2015-02-04 22:06:15 -05:00
switch ( dtype_ )
2015-01-12 13:20:53 -05:00
{
2015-05-04 21:23:05 -04:00
case CHAR_TYPE : detail : : copy ( context_ , data_ , ( char ) value ) ; break ;
case UCHAR_TYPE : detail : : copy ( context_ , data_ , ( unsigned char ) value ) ; break ;
case SHORT_TYPE : detail : : copy ( context_ , data_ , ( short ) value ) ; break ;
case USHORT_TYPE : detail : : copy ( context_ , data_ , ( unsigned short ) value ) ; break ;
case INT_TYPE : detail : : copy ( context_ , data_ , ( int ) value ) ; break ;
case UINT_TYPE : detail : : copy ( context_ , data_ , ( unsigned int ) value ) ; break ;
case LONG_TYPE : detail : : copy ( context_ , data_ , ( long ) value ) ; break ;
case ULONG_TYPE : detail : : copy ( context_ , data_ , ( unsigned long ) value ) ; break ;
case FLOAT_TYPE : detail : : copy ( context_ , data_ , ( float ) value ) ; break ;
case DOUBLE_TYPE : detail : : copy ( context_ , data_ , ( double ) value ) ; break ;
2015-02-04 22:06:15 -05:00
default : throw unknown_datatype ( dtype_ ) ;
2015-01-12 13:20:53 -05:00
}
}
2015-04-29 15:50:57 -04:00
scalar : : scalar ( numeric_type dtype , driver : : Context context ) : array ( 1 , dtype , context )
2015-01-19 21:29:47 -05:00
{ }
2015-01-12 13:20:53 -05:00
2015-02-04 22:06:15 -05:00
scalar : : scalar ( array_expression const & proxy ) : array ( proxy ) { }
2015-01-12 13:20:53 -05:00
2015-05-04 21:23:05 -04:00
void scalar : : inject ( values_holder & v ) const
{
int_t dtsize = size_of ( dtype_ ) ;
# define HANDLE_CASE(DTYPE, VAL) \
case DTYPE : \
2015-05-13 02:20:44 -04:00
driver : : queues [ context_ ] [ 0 ] . read ( data_ , CL_TRUE , start_ [ 0 ] * dtsize , dtsize , ( void * ) & v . VAL ) ; break ; \
2015-05-04 21:23:05 -04:00
switch ( dtype_ )
{
HANDLE_CASE ( CHAR_TYPE , int8 ) ;
HANDLE_CASE ( UCHAR_TYPE , uint8 ) ;
HANDLE_CASE ( SHORT_TYPE , int16 ) ;
HANDLE_CASE ( USHORT_TYPE , uint16 ) ;
HANDLE_CASE ( INT_TYPE , int32 ) ;
HANDLE_CASE ( UINT_TYPE , uint32 ) ;
HANDLE_CASE ( LONG_TYPE , int64 ) ;
HANDLE_CASE ( ULONG_TYPE , uint64 ) ;
HANDLE_CASE ( FLOAT_TYPE , float32 ) ;
HANDLE_CASE ( DOUBLE_TYPE , float64 ) ;
default : throw unknown_datatype ( dtype_ ) ;
}
# undef HANDLE_CASE
}
2015-07-21 17:18:50 -04:00
template < class TYPE >
TYPE scalar : : cast ( ) const
2015-01-12 13:20:53 -05:00
{
2015-01-13 01:17:27 -05:00
values_holder v ;
2015-05-04 21:23:05 -04:00
inject ( v ) ;
2015-07-21 17:18:50 -04:00
# define HANDLE_CASE(DTYPE, VAL) case DTYPE: return static_cast<TYPE>(v.VAL)
2015-01-13 01:17:27 -05:00
2015-02-04 22:06:15 -05:00
switch ( dtype_ )
2015-01-13 01:17:27 -05:00
{
HANDLE_CASE ( CHAR_TYPE , int8 ) ;
HANDLE_CASE ( UCHAR_TYPE , uint8 ) ;
HANDLE_CASE ( SHORT_TYPE , int16 ) ;
HANDLE_CASE ( USHORT_TYPE , uint16 ) ;
HANDLE_CASE ( INT_TYPE , int32 ) ;
HANDLE_CASE ( UINT_TYPE , uint32 ) ;
HANDLE_CASE ( LONG_TYPE , int64 ) ;
HANDLE_CASE ( ULONG_TYPE , uint64 ) ;
HANDLE_CASE ( FLOAT_TYPE , float32 ) ;
HANDLE_CASE ( DOUBLE_TYPE , float64 ) ;
2015-02-04 22:06:15 -05:00
default : throw unknown_datatype ( dtype_ ) ;
2015-01-13 01:17:27 -05:00
}
# undef HANDLE_CASE
}
scalar & scalar : : operator = ( value_scalar const & s )
{
2015-04-29 15:50:57 -04:00
driver : : CommandQueue & queue = driver : : queues [ context_ ] [ 0 ] ;
2015-02-04 22:06:15 -05:00
int_t dtsize = size_of ( dtype_ ) ;
2015-01-13 01:17:27 -05:00
# define HANDLE_CASE(TYPE, CLTYPE) case TYPE:\
{ \
2015-01-16 07:31:39 -05:00
CLTYPE v = s ; \
2015-04-29 15:50:57 -04:00
queue . write ( data_ , CL_TRUE , start_ [ 0 ] * dtsize , dtsize , ( void * ) & v ) ; \
2015-01-13 14:44:19 -05:00
return * this ; \
2015-01-13 01:17:27 -05:00
}
2015-02-04 22:06:15 -05:00
switch ( dtype_ )
2015-01-13 01:17:27 -05:00
{
2015-05-04 21:23:05 -04:00
HANDLE_CASE ( CHAR_TYPE , char )
HANDLE_CASE ( UCHAR_TYPE , unsigned char )
HANDLE_CASE ( SHORT_TYPE , short )
HANDLE_CASE ( USHORT_TYPE , unsigned short )
HANDLE_CASE ( INT_TYPE , int )
HANDLE_CASE ( UINT_TYPE , unsigned int )
HANDLE_CASE ( LONG_TYPE , long )
HANDLE_CASE ( ULONG_TYPE , unsigned long )
HANDLE_CASE ( FLOAT_TYPE , float )
HANDLE_CASE ( DOUBLE_TYPE , double )
2015-02-04 22:06:15 -05:00
default : throw unknown_datatype ( dtype_ ) ;
2015-01-13 01:17:27 -05:00
}
2015-01-12 13:20:53 -05:00
}
2015-01-13 01:17:27 -05:00
# define INSTANTIATE(type) scalar::operator type() const { return cast<type>(); }
2015-05-04 21:23:05 -04:00
INSTANTIATE ( char )
INSTANTIATE ( unsigned char )
INSTANTIATE ( short )
INSTANTIATE ( unsigned short )
INSTANTIATE ( int )
INSTANTIATE ( unsigned int )
INSTANTIATE ( long )
INSTANTIATE ( unsigned long )
2015-05-04 23:54:43 -04:00
INSTANTIATE ( long long )
INSTANTIATE ( unsigned long long )
2015-05-04 21:23:05 -04:00
INSTANTIATE ( float )
INSTANTIATE ( double )
2015-01-13 01:17:27 -05:00
# undef INSTANTIATE
2015-01-12 13:20:53 -05:00
2015-01-16 07:31:39 -05:00
std : : ostream & operator < < ( std : : ostream & os , scalar const & s )
2015-01-12 13:20:53 -05:00
{
switch ( s . dtype ( ) )
{
2015-07-21 17:18:50 -04:00
// case BOOL_TYPE: return os << static_cast<bool>(s);
2015-05-04 21:23:05 -04:00
case CHAR_TYPE : return os < < static_cast < char > ( s ) ;
case UCHAR_TYPE : return os < < static_cast < unsigned char > ( s ) ;
case SHORT_TYPE : return os < < static_cast < short > ( s ) ;
case USHORT_TYPE : return os < < static_cast < unsigned short > ( s ) ;
case INT_TYPE : return os < < static_cast < int > ( s ) ;
case UINT_TYPE : return os < < static_cast < unsigned int > ( s ) ;
case LONG_TYPE : return os < < static_cast < long > ( s ) ;
case ULONG_TYPE : return os < < static_cast < unsigned long > ( s ) ;
2015-07-21 17:18:50 -04:00
// case HALF_TYPE: return os << static_cast<half>(s);
2015-05-04 21:23:05 -04:00
case FLOAT_TYPE : return os < < static_cast < float > ( s ) ;
case DOUBLE_TYPE : return os < < static_cast < double > ( s ) ;
2015-01-19 21:29:47 -05:00
default : throw unknown_datatype ( s . dtype ( ) ) ;
2015-01-12 13:20:53 -05:00
}
}
/*--- Binary Operators ----*/
//-----------------------------------
2015-01-16 15:24:24 -05:00
template < class U , class V >
size4 elementwise_size ( U const & u , V const & v )
{
2015-07-21 17:18:50 -04:00
if ( detail : : max ( u . shape ( ) ) = = 1 )
2015-01-16 15:24:24 -05:00
return v . shape ( ) ;
return u . shape ( ) ;
}
2015-01-12 13:20:53 -05:00
template < class U , class V >
bool check_elementwise ( U const & u , V const & v )
{
2015-07-21 17:18:50 -04:00
return detail : : max ( u . shape ( ) ) = = 1 | | detail : : max ( v . shape ( ) ) = = 1 | | u . shape ( ) = = v . shape ( ) ;
2015-01-12 13:20:53 -05:00
}
2015-01-29 01:00:50 -05:00
# define DEFINE_ELEMENT_BINARY_OPERATOR(OP, OPNAME, DTYPE) \
2015-01-12 13:20:53 -05:00
array_expression OPNAME ( array_expression const & x , array_expression const & y ) \
{ assert ( check_elementwise ( x , y ) ) ; \
2015-02-08 00:56:24 -05:00
return array_expression ( x , y , op_element ( OPERATOR_BINARY_TYPE_FAMILY , OP ) , x . context ( ) , DTYPE , elementwise_size ( x , y ) ) ; } \
2015-01-12 13:20:53 -05:00
\
array_expression OPNAME ( array const & x , array_expression const & y ) \
{ assert ( check_elementwise ( x , y ) ) ; \
2015-02-08 00:56:24 -05:00
return array_expression ( x , y , op_element ( OPERATOR_BINARY_TYPE_FAMILY , OP ) , x . context ( ) , DTYPE , elementwise_size ( x , y ) ) ; } \
2015-01-12 13:20:53 -05:00
\
array_expression OPNAME ( array_expression const & x , array const & y ) \
{ assert ( check_elementwise ( x , y ) ) ; \
2015-02-08 00:56:24 -05:00
return array_expression ( x , y , op_element ( OPERATOR_BINARY_TYPE_FAMILY , OP ) , x . context ( ) , DTYPE , elementwise_size ( x , y ) ) ; } \
2015-01-12 13:20:53 -05:00
\
array_expression OPNAME ( array const & x , array const & y ) \
{ assert ( check_elementwise ( x , y ) ) ; \
2015-01-29 01:00:50 -05:00
return array_expression ( x , y , op_element ( OPERATOR_BINARY_TYPE_FAMILY , OP ) , x . context ( ) , DTYPE , elementwise_size ( x , y ) ) ; } \
2015-01-12 13:20:53 -05:00
\
array_expression OPNAME ( array_expression const & x , value_scalar const & y ) \
2015-02-08 00:56:24 -05:00
{ return array_expression ( x , y , op_element ( OPERATOR_BINARY_TYPE_FAMILY , OP ) , x . context ( ) , DTYPE , x . shape ( ) ) ; } \
2015-01-12 13:20:53 -05:00
\
array_expression OPNAME ( array const & x , value_scalar const & y ) \
2015-01-29 01:00:50 -05:00
{ return array_expression ( x , y , op_element ( OPERATOR_BINARY_TYPE_FAMILY , OP ) , x . context ( ) , DTYPE , x . shape ( ) ) ; } \
2015-01-12 13:20:53 -05:00
\
array_expression OPNAME ( value_scalar const & y , array_expression const & x ) \
2015-02-08 00:56:24 -05:00
{ return array_expression ( y , x , op_element ( OPERATOR_BINARY_TYPE_FAMILY , OP ) , x . context ( ) , DTYPE , x . shape ( ) ) ; } \
2015-01-12 13:20:53 -05:00
\
array_expression OPNAME ( value_scalar const & y , array const & x ) \
2015-01-29 01:00:50 -05:00
{ return array_expression ( y , x , op_element ( OPERATOR_BINARY_TYPE_FAMILY , OP ) , x . context ( ) , DTYPE , x . shape ( ) ) ; }
DEFINE_ELEMENT_BINARY_OPERATOR ( OPERATOR_ADD_TYPE , operator + , x . dtype ( ) )
DEFINE_ELEMENT_BINARY_OPERATOR ( OPERATOR_SUB_TYPE , operator - , x . dtype ( ) )
DEFINE_ELEMENT_BINARY_OPERATOR ( OPERATOR_MULT_TYPE , operator * , x . dtype ( ) )
DEFINE_ELEMENT_BINARY_OPERATOR ( OPERATOR_DIV_TYPE , operator / , x . dtype ( ) )
DEFINE_ELEMENT_BINARY_OPERATOR ( OPERATOR_ELEMENT_MAX_TYPE , maximum , x . dtype ( ) )
DEFINE_ELEMENT_BINARY_OPERATOR ( OPERATOR_ELEMENT_MIN_TYPE , minimum , x . dtype ( ) )
DEFINE_ELEMENT_BINARY_OPERATOR ( OPERATOR_ELEMENT_POW_TYPE , pow , x . dtype ( ) )
2015-06-28 17:53:16 -07:00
DEFINE_ELEMENT_BINARY_OPERATOR ( OPERATOR_ASSIGN_TYPE , assign , x . dtype ( ) )
2015-01-12 13:20:53 -05:00
2015-01-29 01:00:50 -05:00
DEFINE_ELEMENT_BINARY_OPERATOR ( OPERATOR_ELEMENT_GREATER_TYPE , operator > , INT_TYPE )
DEFINE_ELEMENT_BINARY_OPERATOR ( OPERATOR_ELEMENT_GEQ_TYPE , operator > = , INT_TYPE )
DEFINE_ELEMENT_BINARY_OPERATOR ( OPERATOR_ELEMENT_LESS_TYPE , operator < , INT_TYPE )
DEFINE_ELEMENT_BINARY_OPERATOR ( OPERATOR_ELEMENT_LEQ_TYPE , operator < = , INT_TYPE )
DEFINE_ELEMENT_BINARY_OPERATOR ( OPERATOR_ELEMENT_EQ_TYPE , operator = = , INT_TYPE )
DEFINE_ELEMENT_BINARY_OPERATOR ( OPERATOR_ELEMENT_NEQ_TYPE , operator ! = , INT_TYPE )
2015-01-12 13:20:53 -05:00
2015-06-30 17:55:57 -04:00
# define DEFINE_OUTER(LTYPE, RTYPE) \
array_expression outer ( LTYPE const & x , RTYPE const & y ) \
{ \
assert ( x . nshape ( ) = = 1 & & y . nshape ( ) = = 1 ) ; \
2015-07-21 17:18:50 -04:00
return array_expression ( x , y , op_element ( OPERATOR_BINARY_TYPE_FAMILY , OPERATOR_OUTER_PROD_TYPE ) , x . context ( ) , x . dtype ( ) , size4 ( detail : : max ( x . shape ( ) ) , detail : : max ( y . shape ( ) ) ) ) ; \
2015-06-30 17:55:57 -04:00
} \
2015-01-12 13:20:53 -05:00
2015-06-30 17:55:57 -04:00
DEFINE_OUTER ( array , array )
DEFINE_OUTER ( array_expression , array )
DEFINE_OUTER ( array , array_expression )
DEFINE_OUTER ( array_expression , array_expression )
2015-01-12 13:20:53 -05:00
# undef DEFINE_ELEMENT_BINARY_OPERATOR
//---------------------------------------
/*--- Math Operators----*/
//---------------------------------------
# define DEFINE_ELEMENT_UNARY_OPERATOR(OP, OPNAME) \
array_expression OPNAME ( array const & x ) \
2015-01-31 22:01:48 -05:00
{ return array_expression ( x , invalid_node ( ) , op_element ( OPERATOR_UNARY_TYPE_FAMILY , OP ) , x . context ( ) , x . dtype ( ) , x . shape ( ) ) ; } \
2015-01-12 13:20:53 -05:00
\
array_expression OPNAME ( array_expression const & x ) \
2015-02-08 00:56:24 -05:00
{ return array_expression ( x , invalid_node ( ) , op_element ( OPERATOR_UNARY_TYPE_FAMILY , OP ) , x . context ( ) , x . dtype ( ) , x . shape ( ) ) ; }
2015-01-12 13:20:53 -05:00
DEFINE_ELEMENT_UNARY_OPERATOR ( ( x . dtype ( ) = = FLOAT_TYPE | | x . dtype ( ) = = DOUBLE_TYPE ) ? OPERATOR_FABS_TYPE : OPERATOR_ABS_TYPE , abs )
DEFINE_ELEMENT_UNARY_OPERATOR ( OPERATOR_ACOS_TYPE , acos )
DEFINE_ELEMENT_UNARY_OPERATOR ( OPERATOR_ASIN_TYPE , asin )
DEFINE_ELEMENT_UNARY_OPERATOR ( OPERATOR_ATAN_TYPE , atan )
DEFINE_ELEMENT_UNARY_OPERATOR ( OPERATOR_CEIL_TYPE , ceil )
DEFINE_ELEMENT_UNARY_OPERATOR ( OPERATOR_COS_TYPE , cos )
DEFINE_ELEMENT_UNARY_OPERATOR ( OPERATOR_COSH_TYPE , cosh )
DEFINE_ELEMENT_UNARY_OPERATOR ( OPERATOR_EXP_TYPE , exp )
DEFINE_ELEMENT_UNARY_OPERATOR ( OPERATOR_FLOOR_TYPE , floor )
DEFINE_ELEMENT_UNARY_OPERATOR ( OPERATOR_LOG_TYPE , log )
DEFINE_ELEMENT_UNARY_OPERATOR ( OPERATOR_LOG10_TYPE , log10 )
DEFINE_ELEMENT_UNARY_OPERATOR ( OPERATOR_SIN_TYPE , sin )
DEFINE_ELEMENT_UNARY_OPERATOR ( OPERATOR_SINH_TYPE , sinh )
DEFINE_ELEMENT_UNARY_OPERATOR ( OPERATOR_SQRT_TYPE , sqrt )
DEFINE_ELEMENT_UNARY_OPERATOR ( OPERATOR_TAN_TYPE , tan )
DEFINE_ELEMENT_UNARY_OPERATOR ( OPERATOR_TANH_TYPE , tanh )
# undef DEFINE_ELEMENT_UNARY_OPERATOR
//---------------------------------------
2015-01-16 19:39:26 -05:00
2015-01-12 13:20:53 -05:00
///*--- Misc----*/
////---------------------------------------
2015-01-29 01:00:50 -05:00
inline operation_node_type casted ( numeric_type dtype )
2015-01-16 19:39:26 -05:00
{
2015-01-29 01:00:50 -05:00
switch ( dtype )
{
2015-01-29 15:19:40 -05:00
// case BOOL_TYPE: return OPERATOR_CAST_BOOL_TYPE;
2015-01-29 01:00:50 -05:00
case CHAR_TYPE : return OPERATOR_CAST_CHAR_TYPE ;
case UCHAR_TYPE : return OPERATOR_CAST_UCHAR_TYPE ;
case SHORT_TYPE : return OPERATOR_CAST_SHORT_TYPE ;
case USHORT_TYPE : return OPERATOR_CAST_USHORT_TYPE ;
case INT_TYPE : return OPERATOR_CAST_INT_TYPE ;
case UINT_TYPE : return OPERATOR_CAST_UINT_TYPE ;
case LONG_TYPE : return OPERATOR_CAST_LONG_TYPE ;
case ULONG_TYPE : return OPERATOR_CAST_ULONG_TYPE ;
2015-01-29 15:19:40 -05:00
// case FLOAT_TYPE: return OPERATOR_CAST_HALF_TYPE;
2015-01-29 01:00:50 -05:00
case FLOAT_TYPE : return OPERATOR_CAST_FLOAT_TYPE ;
case DOUBLE_TYPE : return OPERATOR_CAST_DOUBLE_TYPE ;
default : throw unknown_datatype ( dtype ) ;
}
2015-01-16 19:39:26 -05:00
}
2015-01-29 01:00:50 -05:00
array_expression cast ( array const & x , numeric_type dtype )
2015-01-31 22:01:48 -05:00
{ return array_expression ( x , invalid_node ( ) , op_element ( OPERATOR_UNARY_TYPE_FAMILY , casted ( dtype ) ) , x . context ( ) , dtype , x . shape ( ) ) ; }
2015-01-29 01:00:50 -05:00
array_expression cast ( array_expression const & x , numeric_type dtype )
2015-02-08 00:56:24 -05:00
{ return array_expression ( x , invalid_node ( ) , op_element ( OPERATOR_UNARY_TYPE_FAMILY , casted ( dtype ) ) , x . context ( ) , dtype , x . shape ( ) ) ; }
2015-01-29 01:00:50 -05:00
2015-04-29 15:50:57 -04:00
isaac : : array_expression eye ( std : : size_t M , std : : size_t N , isaac : : numeric_type dtype , driver : : Context ctx )
2015-01-29 01:00:50 -05:00
{ return array_expression ( value_scalar ( 1 ) , value_scalar ( 0 ) , op_element ( OPERATOR_UNARY_TYPE_FAMILY , OPERATOR_VDIAG_TYPE ) , ctx , dtype , size4 ( M , N ) ) ; }
2015-04-29 15:50:57 -04:00
isaac : : array_expression zeros ( std : : size_t M , std : : size_t N , isaac : : numeric_type dtype , driver : : Context ctx )
{ return array_expression ( value_scalar ( 0 , dtype ) , invalid_node ( ) , op_element ( OPERATOR_UNARY_TYPE_FAMILY , OPERATOR_ADD_TYPE ) , ctx , dtype , size4 ( M , N ) ) ; }
2015-01-20 11:17:42 -05:00
2015-01-21 20:08:52 -05:00
inline size4 flip ( size4 const & shape )
2015-04-29 15:50:57 -04:00
{ return size4 ( shape [ 1 ] , shape [ 0 ] ) ; }
2015-01-12 13:20:53 -05:00
inline size4 prod ( size4 const & shape1 , size4 const & shape2 )
2015-04-29 15:50:57 -04:00
{ return size4 ( shape1 [ 0 ] * shape2 [ 0 ] , shape1 [ 1 ] * shape2 [ 1 ] ) ; }
2015-01-12 13:20:53 -05:00
array_expression trans ( array const & x ) \
2015-01-31 22:01:48 -05:00
{ return array_expression ( x , invalid_node ( ) , op_element ( OPERATOR_UNARY_TYPE_FAMILY , OPERATOR_TRANS_TYPE ) , x . context ( ) , x . dtype ( ) , flip ( x . shape ( ) ) ) ; } \
2015-01-12 13:20:53 -05:00
\
array_expression trans ( array_expression const & x ) \
2015-02-08 00:56:24 -05:00
{ return array_expression ( x , invalid_node ( ) , op_element ( OPERATOR_UNARY_TYPE_FAMILY , OPERATOR_TRANS_TYPE ) , x . context ( ) , x . dtype ( ) , flip ( x . shape ( ) ) ) ; }
2015-01-12 13:20:53 -05:00
array_expression repmat ( array const & A , int_t const & rep1 , int_t const & rep2 )
{
2015-01-18 14:52:45 -05:00
repeat_infos infos ;
infos . rep1 = rep1 ;
infos . rep2 = rep2 ;
2015-04-29 15:50:57 -04:00
infos . sub1 = A . shape ( ) [ 0 ] ;
infos . sub2 = A . shape ( ) [ 1 ] ;
2015-01-18 14:52:45 -05:00
return array_expression ( A , infos , op_element ( OPERATOR_BINARY_TYPE_FAMILY , OPERATOR_REPEAT_TYPE ) , A . context ( ) , A . dtype ( ) , size4 ( infos . rep1 * infos . sub1 , infos . rep2 * infos . sub2 ) ) ;
2015-01-12 13:20:53 -05:00
}
array_expression repmat ( array_expression const & A , int_t const & rep1 , int_t const & rep2 )
{
2015-01-18 14:52:45 -05:00
repeat_infos infos ;
infos . rep1 = rep1 ;
infos . rep2 = rep2 ;
2015-04-29 15:50:57 -04:00
infos . sub1 = A . shape ( ) [ 0 ] ;
infos . sub2 = A . shape ( ) [ 1 ] ;
2015-02-08 00:56:24 -05:00
return array_expression ( A , infos , op_element ( OPERATOR_BINARY_TYPE_FAMILY , OPERATOR_REPEAT_TYPE ) , A . context ( ) , A . dtype ( ) , size4 ( infos . rep1 * infos . sub1 , infos . rep2 * infos . sub2 ) ) ;
2015-01-12 13:20:53 -05:00
}
////---------------------------------------
///*--- Reductions ---*/
////---------------------------------------
2015-07-11 09:36:01 -04:00
# define DEFINE_DOT(OP, OPNAME)\
2015-01-12 13:20:53 -05:00
array_expression OPNAME ( array const & x , int_t axis ) \
{ \
2015-01-28 17:08:39 -05:00
if ( axis < - 1 | | axis > x . nshape ( ) ) \
throw std : : out_of_range ( " The axis entry is out of bounds " ) ; \
else if ( axis = = - 1 ) \
2015-07-11 09:36:01 -04:00
return array_expression ( x , invalid_node ( ) , op_element ( OPERATOR_VECTOR_DOT_TYPE_FAMILY , OP ) , x . context ( ) , x . dtype ( ) , size4 ( 1 ) ) ; \
2015-01-12 13:20:53 -05:00
else if ( axis = = 0 ) \
2015-07-11 09:36:01 -04:00
return array_expression ( x , invalid_node ( ) , op_element ( OPERATOR_COLUMNS_DOT_TYPE_FAMILY , OP ) , x . context ( ) , x . dtype ( ) , size4 ( x . shape ( ) [ 1 ] ) ) ; \
2015-05-01 15:57:03 -04:00
else \
2015-07-11 09:36:01 -04:00
return array_expression ( x , invalid_node ( ) , op_element ( OPERATOR_ROWS_DOT_TYPE_FAMILY , OP ) , x . context ( ) , x . dtype ( ) , size4 ( x . shape ( ) [ 0 ] ) ) ; \
2015-01-12 13:20:53 -05:00
} \
\
array_expression OPNAME ( array_expression const & x , int_t axis ) \
{ \
2015-01-28 17:08:39 -05:00
if ( axis < - 1 | | axis > x . nshape ( ) ) \
throw std : : out_of_range ( " The axis entry is out of bounds " ) ; \
2015-01-12 13:20:53 -05:00
if ( axis = = - 1 ) \
2015-07-11 09:36:01 -04:00
return array_expression ( x , invalid_node ( ) , op_element ( OPERATOR_VECTOR_DOT_TYPE_FAMILY , OP ) , x . context ( ) , x . dtype ( ) , size4 ( 1 ) ) ; \
2015-01-12 13:20:53 -05:00
else if ( axis = = 0 ) \
2015-07-11 09:36:01 -04:00
return array_expression ( x , invalid_node ( ) , op_element ( OPERATOR_COLUMNS_DOT_TYPE_FAMILY , OP ) , x . context ( ) , x . dtype ( ) , size4 ( x . shape ( ) [ 1 ] ) ) ; \
2015-05-01 15:57:03 -04:00
else \
2015-07-11 09:36:01 -04:00
return array_expression ( x , invalid_node ( ) , op_element ( OPERATOR_ROWS_DOT_TYPE_FAMILY , OP ) , x . context ( ) , x . dtype ( ) , size4 ( x . shape ( ) [ 0 ] ) ) ; \
2015-01-12 13:20:53 -05:00
}
2015-07-11 09:36:01 -04:00
DEFINE_DOT ( OPERATOR_ADD_TYPE , sum )
DEFINE_DOT ( OPERATOR_ELEMENT_ARGMAX_TYPE , argmax )
DEFINE_DOT ( OPERATOR_ELEMENT_MAX_TYPE , max )
DEFINE_DOT ( OPERATOR_ELEMENT_MIN_TYPE , min )
DEFINE_DOT ( OPERATOR_ELEMENT_ARGMIN_TYPE , argmin )
2015-01-12 13:20:53 -05:00
2015-07-11 09:36:01 -04:00
# undef DEFINE_DOT
2015-01-12 13:20:53 -05:00
namespace detail
{
array_expression matmatprod ( array const & A , array const & B )
{
2015-04-29 15:50:57 -04:00
size4 shape ( A . shape ( ) [ 0 ] , B . shape ( ) [ 1 ] ) ;
2015-07-11 09:36:01 -04:00
return array_expression ( A , B , op_element ( OPERATOR_GEMM_TYPE_FAMILY , OPERATOR_GEMM_NN_TYPE ) , A . context ( ) , A . dtype ( ) , shape ) ;
2015-01-12 13:20:53 -05:00
}
array_expression matmatprod ( array_expression const & A , array const & B )
{
2015-07-11 09:36:01 -04:00
operation_node_type type = OPERATOR_GEMM_NN_TYPE ;
2015-04-29 15:50:57 -04:00
size4 shape ( A . shape ( ) [ 0 ] , B . shape ( ) [ 1 ] ) ;
2015-01-12 13:20:53 -05:00
2015-01-31 22:01:48 -05:00
array_expression : : node & A_root = const_cast < array_expression : : node & > ( A . tree ( ) [ A . root ( ) ] ) ;
2015-01-12 13:20:53 -05:00
bool A_trans = A_root . op . type = = OPERATOR_TRANS_TYPE ;
if ( A_trans ) {
2015-07-11 09:36:01 -04:00
type = OPERATOR_GEMM_TN_TYPE ;
2015-01-12 13:20:53 -05:00
}
2015-07-11 09:36:01 -04:00
array_expression res ( A , B , op_element ( OPERATOR_GEMM_TYPE_FAMILY , type ) , A . context ( ) , A . dtype ( ) , shape ) ;
2015-01-31 22:01:48 -05:00
array_expression : : node & res_root = const_cast < array_expression : : node & > ( res . tree ( ) [ res . root ( ) ] ) ;
2015-01-12 13:20:53 -05:00
if ( A_trans ) res_root . lhs = A_root . lhs ;
return res ;
}
array_expression matmatprod ( array const & A , array_expression const & B )
{
2015-07-11 09:36:01 -04:00
operation_node_type type = OPERATOR_GEMM_NN_TYPE ;
2015-04-29 15:50:57 -04:00
size4 shape ( A . shape ( ) [ 0 ] , B . shape ( ) [ 1 ] ) ;
2015-01-12 13:20:53 -05:00
2015-01-31 22:01:48 -05:00
array_expression : : node & B_root = const_cast < array_expression : : node & > ( B . tree ( ) [ B . root ( ) ] ) ;
2015-01-12 13:20:53 -05:00
bool B_trans = B_root . op . type = = OPERATOR_TRANS_TYPE ;
if ( B_trans ) {
2015-07-11 09:36:01 -04:00
type = OPERATOR_GEMM_NT_TYPE ;
2015-01-12 13:20:53 -05:00
}
2015-04-29 15:50:57 -04:00
2015-07-11 09:36:01 -04:00
array_expression res ( A , B , op_element ( OPERATOR_GEMM_TYPE_FAMILY , type ) , A . context ( ) , A . dtype ( ) , shape ) ;
2015-01-31 22:01:48 -05:00
array_expression : : node & res_root = const_cast < array_expression : : node & > ( res . tree ( ) [ res . root ( ) ] ) ;
2015-01-12 13:20:53 -05:00
if ( B_trans ) res_root . rhs = B_root . lhs ;
return res ;
}
array_expression matmatprod ( array_expression const & A , array_expression const & B )
{
2015-07-11 09:36:01 -04:00
operation_node_type type = OPERATOR_GEMM_NN_TYPE ;
2015-01-31 22:01:48 -05:00
array_expression : : node & A_root = const_cast < array_expression : : node & > ( A . tree ( ) [ A . root ( ) ] ) ;
array_expression : : node & B_root = const_cast < array_expression : : node & > ( B . tree ( ) [ B . root ( ) ] ) ;
2015-04-29 15:50:57 -04:00
size4 shape ( A . shape ( ) [ 0 ] , B . shape ( ) [ 1 ] ) ;
2015-01-12 13:20:53 -05:00
bool A_trans = A_root . op . type = = OPERATOR_TRANS_TYPE ;
bool B_trans = B_root . op . type = = OPERATOR_TRANS_TYPE ;
2015-06-27 13:53:31 -04:00
2015-07-11 09:36:01 -04:00
if ( A_trans & & B_trans ) type = OPERATOR_GEMM_TT_TYPE ;
else if ( A_trans & & ! B_trans ) type = OPERATOR_GEMM_TN_TYPE ;
else if ( ! A_trans & & B_trans ) type = OPERATOR_GEMM_NT_TYPE ;
else type = OPERATOR_GEMM_NN_TYPE ;
2015-01-12 13:20:53 -05:00
2015-07-11 09:36:01 -04:00
array_expression res ( A , B , op_element ( OPERATOR_GEMM_TYPE_FAMILY , type ) , A . context ( ) , A . dtype ( ) , shape ) ;
2015-01-31 22:01:48 -05:00
array_expression : : node & res_root = const_cast < array_expression : : node & > ( res . tree ( ) [ res . root ( ) ] ) ;
2015-01-12 13:20:53 -05:00
if ( A_trans ) res_root . lhs = A_root . lhs ;
if ( B_trans ) res_root . rhs = B_root . lhs ;
return res ;
}
template < class T >
array_expression matvecprod ( array const & A , T const & x )
{
2015-04-29 15:50:57 -04:00
int_t M = A . shape ( ) [ 0 ] ;
int_t N = A . shape ( ) [ 1 ] ;
2015-05-01 15:57:03 -04:00
return sum ( A * repmat ( reshape ( x , 1 , N ) , M , 1 ) , 1 ) ;
2015-01-12 13:20:53 -05:00
}
template < class T >
array_expression matvecprod ( array_expression const & A , T const & x )
{
2015-04-29 15:50:57 -04:00
int_t M = A . shape ( ) [ 0 ] ;
int_t N = A . shape ( ) [ 1 ] ;
2015-01-31 22:01:48 -05:00
array_expression : : node & A_root = const_cast < array_expression : : node & > ( A . tree ( ) [ A . root ( ) ] ) ;
2015-01-12 13:20:53 -05:00
bool A_trans = A_root . op . type = = OPERATOR_TRANS_TYPE ;
2015-06-30 17:55:57 -04:00
while ( A_root . lhs . type_family = = COMPOSITE_OPERATOR_FAMILY ) {
A_root = A . tree ( ) [ A_root . lhs . node_index ] ;
A_trans ^ = A_root . op . type = = OPERATOR_TRANS_TYPE ;
}
2015-01-17 15:47:52 -05:00
if ( A_trans )
{
2015-02-08 00:56:24 -05:00
array_expression tmp ( A , repmat ( x , 1 , M ) , op_element ( OPERATOR_BINARY_TYPE_FAMILY , OPERATOR_ELEMENT_PROD_TYPE ) , A . context ( ) , A . dtype ( ) , size4 ( N , M ) ) ;
2015-01-12 13:20:53 -05:00
//Remove trans
2015-01-16 07:31:39 -05:00
tmp . tree ( ) [ tmp . root ( ) ] . lhs = A . tree ( ) [ A . root ( ) ] . lhs ;
2015-05-01 15:57:03 -04:00
return sum ( tmp , 0 ) ;
2015-01-12 13:20:53 -05:00
}
else
2015-05-01 15:57:03 -04:00
return sum ( A * repmat ( reshape ( x , 1 , N ) , M , 1 ) , 1 ) ;
2015-01-12 13:20:53 -05:00
}
}
2015-04-29 15:50:57 -04:00
array_expression reshape ( array const & x , int_t shape0 , int_t shape1 )
{ return array_expression ( x , invalid_node ( ) , op_element ( OPERATOR_UNARY_TYPE_FAMILY , OPERATOR_RESHAPE_TYPE ) , x . context ( ) , x . dtype ( ) , size4 ( shape0 , shape1 ) ) ; }
array_expression reshape ( array_expression const & x , int_t shape0 , int_t shape1 )
{ return array_expression ( x , invalid_node ( ) , op_element ( OPERATOR_UNARY_TYPE_FAMILY , OPERATOR_RESHAPE_TYPE ) , x . context ( ) , x . dtype ( ) , size4 ( shape0 , shape1 ) ) ; }
2015-01-18 14:52:45 -05:00
2015-01-12 13:20:53 -05:00
# define DEFINE_DOT(LTYPE, RTYPE) \
array_expression dot ( LTYPE const & x , RTYPE const & y ) \
{ \
2015-06-30 17:55:57 -04:00
if ( x . nshape ( ) < 1 | | y . nshape ( ) < 1 ) { \
return x * y ; \
} \
if ( x . nshape ( ) = = 1 & & y . nshape ( ) = = 1 ) { \
if ( x . shape ( ) [ 1 ] = = 1 & & y . shape ( ) [ 0 ] = = 1 ) \
return outer ( x , y ) ; \
else if ( x . shape ( ) [ 0 ] = = 1 & & y . shape ( ) [ 1 ] = = 1 ) \
return sum ( x * trans ( y ) ) ; \
else \
return sum ( x * y ) ; \
2015-01-12 13:20:53 -05:00
} \
else if ( x . nshape ( ) = = 2 & & y . nshape ( ) = = 1 ) \
return detail : : matvecprod ( x , y ) ; \
else if ( x . nshape ( ) = = 1 & & y . nshape ( ) = = 2 ) \
2015-06-30 17:55:57 -04:00
return trans ( detail : : matvecprod ( trans ( y ) , trans ( x ) ) ) ; \
2015-01-12 13:20:53 -05:00
else /*if(x.nshape()==2 && y.nshape()==2)*/ \
return detail : : matmatprod ( x , y ) ; \
}
DEFINE_DOT ( array , array )
DEFINE_DOT ( array_expression , array )
DEFINE_DOT ( array , array_expression )
DEFINE_DOT ( array_expression , array_expression )
# undef DEFINE_DOT
2015-01-17 10:48:02 -05:00
2015-01-16 07:31:39 -05:00
# define DEFINE_NORM(TYPE)\
array_expression norm ( TYPE const & x , unsigned int order ) \
{ \
assert ( order > 0 & & order < 3 ) ; \
switch ( order ) \
{ \
case 1 : return sum ( abs ( x ) ) ; \
default : return sqrt ( sum ( pow ( x , 2 ) ) ) ; \
} \
}
DEFINE_NORM ( array )
DEFINE_NORM ( array_expression )
2015-01-12 13:20:53 -05:00
2015-01-16 07:31:39 -05:00
# undef DEFINE_NORM
2015-01-12 13:20:53 -05:00
/*--- Copy ----*/
//---------------------------------------
//void*
2015-04-29 15:50:57 -04:00
void copy ( void const * data , array & x , driver : : CommandQueue & queue , bool blocking )
2015-01-12 13:20:53 -05:00
{
unsigned int dtypesize = size_of ( x . dtype ( ) ) ;
2015-04-29 15:50:57 -04:00
if ( x . ld ( ) = = x . shape ( ) [ 0 ] )
2015-01-12 13:20:53 -05:00
{
2015-04-29 15:50:57 -04:00
queue . write ( x . data ( ) , CL_FALSE , 0 , x . dsize ( ) * dtypesize , data ) ;
2015-01-12 13:20:53 -05:00
}
else
{
2015-04-29 15:50:57 -04:00
array tmp ( x . shape ( ) [ 0 ] , x . shape ( ) [ 1 ] , x . dtype ( ) , x . context ( ) ) ;
queue . write ( x . data ( ) , CL_FALSE , 0 , tmp . dsize ( ) * dtypesize , data ) ;
2015-01-12 13:20:53 -05:00
x = tmp ;
}
if ( blocking )
2015-04-29 15:50:57 -04:00
driver : : synchronize ( x . context ( ) ) ;
2015-01-12 13:20:53 -05:00
}
2015-04-29 15:50:57 -04:00
void copy ( array const & x , void * data , driver : : CommandQueue & queue , bool blocking )
2015-01-12 13:20:53 -05:00
{
unsigned int dtypesize = size_of ( x . dtype ( ) ) ;
2015-04-29 15:50:57 -04:00
if ( x . ld ( ) = = x . shape ( ) [ 0 ] )
2015-01-12 13:20:53 -05:00
{
2015-04-29 15:50:57 -04:00
queue . read ( x . data ( ) , CL_FALSE , 0 , x . dsize ( ) * dtypesize , data ) ;
2015-01-12 13:20:53 -05:00
}
else
{
2015-04-29 15:50:57 -04:00
array tmp ( x . shape ( ) [ 0 ] , x . shape ( ) [ 1 ] , x . dtype ( ) , x . context ( ) ) ;
2015-01-12 13:20:53 -05:00
tmp = x ;
2015-04-29 15:50:57 -04:00
queue . read ( tmp . data ( ) , CL_FALSE , 0 , tmp . dsize ( ) * dtypesize , data ) ;
2015-01-12 13:20:53 -05:00
}
if ( blocking )
2015-04-29 15:50:57 -04:00
driver : : synchronize ( x . context ( ) ) ;
2015-01-12 13:20:53 -05:00
}
void copy ( void const * data , array & x , bool blocking )
2015-04-29 15:50:57 -04:00
{ copy ( data , x , driver : : queues [ x . context ( ) ] [ 0 ] , blocking ) ; }
2015-01-12 13:20:53 -05:00
void copy ( array const & x , void * data , bool blocking )
2015-04-29 15:50:57 -04:00
{ copy ( x , data , driver : : queues [ x . context ( ) ] [ 0 ] , blocking ) ; }
2015-01-12 13:20:53 -05:00
//std::vector<>
template < class T >
2015-04-29 15:50:57 -04:00
void copy ( std : : vector < T > const & cx , array & x , driver : : CommandQueue & queue , bool blocking )
2015-01-12 13:20:53 -05:00
{
2015-04-29 15:50:57 -04:00
if ( x . ld ( ) = = x . shape ( ) [ 0 ] )
2015-07-20 23:07:53 -07:00
assert ( ( int_t ) cx . size ( ) = = x . dsize ( ) ) ;
2015-01-12 13:20:53 -05:00
else
2015-07-20 23:07:53 -07:00
assert ( ( int_t ) cx . size ( ) = = prod ( x . shape ( ) ) ) ;
2015-01-12 13:20:53 -05:00
copy ( ( void const * ) cx . data ( ) , x , queue , blocking ) ;
}
template < class T >
2015-04-29 15:50:57 -04:00
void copy ( array const & x , std : : vector < T > & cx , driver : : CommandQueue & queue , bool blocking )
2015-01-12 13:20:53 -05:00
{
2015-04-29 15:50:57 -04:00
if ( x . ld ( ) = = x . shape ( ) [ 0 ] )
2015-07-20 23:07:53 -07:00
assert ( ( int_t ) cx . size ( ) = = x . dsize ( ) ) ;
2015-01-12 13:20:53 -05:00
else
2015-07-20 23:07:53 -07:00
assert ( ( int_t ) cx . size ( ) = = prod ( x . shape ( ) ) ) ;
2015-01-12 13:20:53 -05:00
copy ( x , ( void * ) cx . data ( ) , queue , blocking ) ;
}
template < class T >
void copy ( std : : vector < T > const & cx , array & x , bool blocking )
2015-04-29 15:50:57 -04:00
{ copy ( cx , x , driver : : queues [ x . context ( ) ] [ 0 ] , blocking ) ; }
2015-01-12 13:20:53 -05:00
template < class T >
void copy ( array const & x , std : : vector < T > & cx , bool blocking )
2015-04-29 15:50:57 -04:00
{ copy ( x , cx , driver : : queues [ x . context ( ) ] [ 0 ] , blocking ) ; }
2015-01-12 13:20:53 -05:00
# define INSTANTIATE(T) \
2015-07-21 23:48:50 -07:00
template void ISAACAPI copy < T > ( std : : vector < T > const & , array & , driver : : CommandQueue & , bool ) ; \
template void ISAACAPI copy < T > ( array const & , std : : vector < T > & , driver : : CommandQueue & , bool ) ; \
template void ISAACAPI copy < T > ( std : : vector < T > const & , array & , bool ) ; \
template void ISAACAPI copy < T > ( array const & , std : : vector < T > & , bool )
2015-01-12 13:20:53 -05:00
2015-05-04 21:23:05 -04:00
INSTANTIATE ( char ) ;
INSTANTIATE ( unsigned char ) ;
INSTANTIATE ( short ) ;
INSTANTIATE ( unsigned short ) ;
INSTANTIATE ( int ) ;
INSTANTIATE ( unsigned int ) ;
INSTANTIATE ( long ) ;
INSTANTIATE ( unsigned long ) ;
2015-05-04 23:54:43 -04:00
INSTANTIATE ( long long ) ;
INSTANTIATE ( unsigned long long ) ;
2015-05-04 21:23:05 -04:00
INSTANTIATE ( float ) ;
INSTANTIATE ( double ) ;
2015-01-12 13:20:53 -05:00
# undef INSTANTIATE
/*--- Stream operators----*/
//---------------------------------------
namespace detail
{
template < typename ItType >
static std : : ostream & prettyprint ( std : : ostream & os , ItType begin , ItType const & end , size_t stride = 1 , bool col = false , size_t WINDOW = 10 )
{
if ( ! col )
os < < " [ " ;
size_t N = ( end - begin ) / stride ;
size_t upper = std : : min ( WINDOW , N ) ;
for ( size_t j = 0 ; j < upper ; j + + )
{
2015-01-13 14:44:19 -05:00
os < < * begin ;
if ( j < upper - 1 )
os < < " , " ;
2015-01-12 13:20:53 -05:00
begin + = stride ;
}
if ( upper < N )
{
if ( N - upper > WINDOW )
os < < " , ... " ;
for ( size_t j = std : : max ( N - WINDOW , upper ) ; j < N ; j + + )
{
2015-01-13 14:44:19 -05:00
os < < " , " < < * begin ;
2015-01-12 13:20:53 -05:00
begin + = stride ;
}
}
if ( ! col )
os < < " ] " ;
return os ;
}
}
std : : ostream & operator < < ( std : : ostream & os , array const & a )
{
size_t WINDOW = 10 ;
2015-01-20 21:02:24 -05:00
numeric_type dtype = a . dtype ( ) ;
2015-04-29 15:50:57 -04:00
size_t M = a . shape ( ) [ 0 ] ;
size_t N = a . shape ( ) [ 1 ] ;
2015-01-12 13:20:53 -05:00
2015-01-13 14:44:19 -05:00
if ( M > 1 & & N = = 1 )
std : : swap ( M , N ) ;
2015-01-20 21:02:24 -05:00
void * tmp = new char [ M * N * size_of ( dtype ) ] ;
copy ( a , ( void * ) tmp ) ;
2015-01-12 13:20:53 -05:00
os < < " [ " ;
size_t upper = std : : min ( WINDOW , M ) ;
2015-01-20 21:02:24 -05:00
# define HANDLE(ADTYPE, CTYPE) case ADTYPE: detail::prettyprint(os, reinterpret_cast<CTYPE*>(tmp) + i, reinterpret_cast<CTYPE*>(tmp) + M*N + i, M, true, WINDOW); break;
2015-01-12 13:20:53 -05:00
for ( unsigned int i = 0 ; i < upper ; + + i )
{
if ( i > 0 )
os < < " " ;
2015-01-20 21:02:24 -05:00
switch ( dtype )
{
2015-01-29 15:19:40 -05:00
// HANDLE(BOOL_TYPE, cl_bool)
2015-05-04 21:23:05 -04:00
HANDLE ( CHAR_TYPE , char )
HANDLE ( UCHAR_TYPE , unsigned char )
HANDLE ( SHORT_TYPE , short )
HANDLE ( USHORT_TYPE , unsigned short )
HANDLE ( INT_TYPE , int )
HANDLE ( UINT_TYPE , unsigned int )
HANDLE ( LONG_TYPE , long )
HANDLE ( ULONG_TYPE , unsigned long )
2015-01-29 15:19:40 -05:00
// HANDLE(HALF_TYPE, cl_half)
2015-05-04 21:23:05 -04:00
HANDLE ( FLOAT_TYPE , float )
HANDLE ( DOUBLE_TYPE , double )
2015-01-20 21:02:24 -05:00
default : throw unknown_datatype ( dtype ) ;
}
2015-01-12 13:20:53 -05:00
if ( i < upper - 1 )
os < < std : : endl ;
}
if ( upper < M )
{
if ( N - upper > WINDOW )
os < < std : : endl < < " ... " ;
for ( size_t i = std : : max ( N - WINDOW , upper ) ; i < N ; i + + )
{
os < < std : : endl < < " " ;
2015-01-20 21:02:24 -05:00
switch ( dtype )
{
2015-01-29 15:19:40 -05:00
// HANDLE(BOOL_TYPE, cl_bool)
2015-05-04 21:23:05 -04:00
HANDLE ( CHAR_TYPE , char )
HANDLE ( UCHAR_TYPE , unsigned char )
HANDLE ( SHORT_TYPE , short )
HANDLE ( USHORT_TYPE , unsigned short )
HANDLE ( INT_TYPE , int )
HANDLE ( UINT_TYPE , unsigned int )
HANDLE ( LONG_TYPE , long )
HANDLE ( ULONG_TYPE , unsigned long )
2015-01-29 15:19:40 -05:00
// HANDLE(HALF_TYPE, cl_half)
2015-05-04 21:23:05 -04:00
HANDLE ( FLOAT_TYPE , float )
HANDLE ( DOUBLE_TYPE , double )
2015-01-20 21:02:24 -05:00
default : throw unknown_datatype ( dtype ) ;
}
2015-01-12 13:20:53 -05:00
}
}
os < < " ] " ;
return os ;
}
}