Better value_scalar constructor

This commit is contained in:
Philippe Tillet
2015-01-20 21:02:24 -05:00
parent e74563070a
commit d285bd81e0
3 changed files with 63 additions and 29 deletions

View File

@@ -27,20 +27,21 @@ union values_holder
class value_scalar class value_scalar
{ {
void init(scalar const &); template<class T> void init(T const &);
template<class T> T cast() const; template<class T> T cast() const;
public: public:
value_scalar(cl_char value); #define ATIDLAS_INSTANTIATE(CLTYPE, ADTYPE) value_scalar(CLTYPE value, numeric_type dtype = ADTYPE);
value_scalar(cl_uchar value); ATIDLAS_INSTANTIATE(cl_char, CHAR_TYPE)
value_scalar(cl_short value); ATIDLAS_INSTANTIATE(cl_uchar, UCHAR_TYPE)
value_scalar(cl_ushort value); ATIDLAS_INSTANTIATE(cl_short, SHORT_TYPE)
value_scalar(cl_int value); ATIDLAS_INSTANTIATE(cl_ushort, USHORT_TYPE)
value_scalar(cl_uint value); ATIDLAS_INSTANTIATE(cl_int, INT_TYPE)
value_scalar(cl_long value); ATIDLAS_INSTANTIATE(cl_uint, UINT_TYPE)
value_scalar(cl_ulong value); ATIDLAS_INSTANTIATE(cl_long, LONG_TYPE)
// value_scalar(cl_half value); ATIDLAS_INSTANTIATE(cl_ulong, ULONG_TYPE)
value_scalar(cl_float value); ATIDLAS_INSTANTIATE(cl_float, FLOAT_TYPE)
value_scalar(cl_double value); ATIDLAS_INSTANTIATE(cl_double, DOUBLE_TYPE)
#undef ATIDLAS_INSTANTIATE
explicit value_scalar(scalar const &); explicit value_scalar(scalar const &);
explicit value_scalar(array_expression const &); explicit value_scalar(array_expression const &);
explicit value_scalar(numeric_type dtype); explicit value_scalar(numeric_type dtype);

View File

@@ -836,23 +836,38 @@ namespace detail
std::ostream& operator<<(std::ostream & os, array const & a) std::ostream& operator<<(std::ostream & os, array const & a)
{ {
size_t WINDOW = 10; size_t WINDOW = 10;
numeric_type dtype = a.dtype();
size_t M = a.shape()._1; size_t M = a.shape()._1;
size_t N = a.shape()._2; size_t N = a.shape()._2;
if(M>1 && N==1) if(M>1 && N==1)
std::swap(M, N); std::swap(M, N);
std::vector<float> tmp(M*N); void* tmp = new char[M*N*size_of(dtype)];
copy(a, tmp); copy(a, (void*)tmp);
os << "[ " ; os << "[ " ;
size_t upper = std::min(WINDOW,M); size_t upper = std::min(WINDOW,M);
#define HANDLE(ADTYPE, CTYPE) case ADTYPE: detail::prettyprint(os, reinterpret_cast<CTYPE*>(tmp) + i, reinterpret_cast<CTYPE*>(tmp) + M*N + i, M, true, WINDOW); break;
for(unsigned int i = 0 ; i < upper ; ++i) for(unsigned int i = 0 ; i < upper ; ++i)
{ {
if(i>0) if(i>0)
os << " "; os << " ";
detail::prettyprint(os, tmp.begin() + i, tmp.end() + i, M, true, WINDOW); switch(dtype)
{
HANDLE(CHAR_TYPE, cl_char)
HANDLE(UCHAR_TYPE, cl_uchar)
HANDLE(SHORT_TYPE, cl_short)
HANDLE(USHORT_TYPE, cl_ushort)
HANDLE(INT_TYPE, cl_int)
HANDLE(UINT_TYPE, cl_uint)
HANDLE(LONG_TYPE, cl_long)
HANDLE(ULONG_TYPE, cl_ulong)
HANDLE(FLOAT_TYPE, cl_float)
HANDLE(DOUBLE_TYPE, cl_double)
default: throw unknown_datatype(dtype);
}
if(i < upper-1) if(i < upper-1)
os << std::endl; os << std::endl;
} }
@@ -863,7 +878,20 @@ std::ostream& operator<<(std::ostream & os, array const & a)
for(size_t i = std::max(N - WINDOW, upper) ; i < N ; i++) for(size_t i = std::max(N - WINDOW, upper) ; i < N ; i++)
{ {
os << std::endl << " "; os << std::endl << " ";
detail::prettyprint(os, tmp.begin() + i, tmp.end() + i, M, true, WINDOW); switch(dtype)
{
HANDLE(CHAR_TYPE, cl_char)
HANDLE(UCHAR_TYPE, cl_uchar)
HANDLE(SHORT_TYPE, cl_short)
HANDLE(USHORT_TYPE, cl_ushort)
HANDLE(INT_TYPE, cl_int)
HANDLE(UINT_TYPE, cl_uint)
HANDLE(LONG_TYPE, cl_long)
HANDLE(ULONG_TYPE, cl_ulong)
HANDLE(FLOAT_TYPE, cl_float)
HANDLE(DOUBLE_TYPE, cl_double)
default: throw unknown_datatype(dtype);
}
} }
} }
os << " ]"; os << " ]";

View File

@@ -7,7 +7,8 @@
namespace atidlas namespace atidlas
{ {
void value_scalar::init(scalar const & s) template<class T>
void value_scalar::init(T const & s)
{ {
switch(dtype_) switch(dtype_)
{ {
@@ -25,17 +26,21 @@ void value_scalar::init(scalar const & s)
} }
} }
value_scalar::value_scalar(cl_char value) : dtype_(CHAR_TYPE) { values_.int8 = value; } #define INSTANTIATE(CLTYPE, ADTYPE) value_scalar::value_scalar(CLTYPE value, numeric_type dtype) : dtype_(dtype) { init(value); }
value_scalar::value_scalar(cl_uchar value) : dtype_(UCHAR_TYPE) { values_.uint8 = value; }
value_scalar::value_scalar(cl_short value) : dtype_(SHORT_TYPE) { values_.int16 = value; } INSTANTIATE(cl_char, CHAR_TYPE)
value_scalar::value_scalar(cl_ushort value) : dtype_(USHORT_TYPE) { values_.uint16 = value; } INSTANTIATE(cl_uchar, UCHAR_TYPE)
value_scalar::value_scalar(cl_int value) : dtype_(INT_TYPE) { values_.int32 = value; } INSTANTIATE(cl_short, SHORT_TYPE)
value_scalar::value_scalar(cl_uint value) : dtype_(UINT_TYPE) { values_.uint32 = value; } INSTANTIATE(cl_ushort, USHORT_TYPE)
value_scalar::value_scalar(cl_long value) : dtype_(LONG_TYPE) { values_.int64 = value; } INSTANTIATE(cl_int, INT_TYPE)
value_scalar::value_scalar(cl_ulong value) : dtype_(ULONG_TYPE) { values_.uint64 = value; } INSTANTIATE(cl_uint, UINT_TYPE)
//value_scalar::value_scalar(cl_half value) : dtype_(HALF_TYPE) { values_.float16 = value; } INSTANTIATE(cl_long, LONG_TYPE)
value_scalar::value_scalar(cl_float value) : dtype_(FLOAT_TYPE) { values_.float32 = value; } INSTANTIATE(cl_ulong, ULONG_TYPE)
value_scalar::value_scalar(cl_double value) : dtype_(DOUBLE_TYPE) { values_.float64 = value; } INSTANTIATE(cl_float, FLOAT_TYPE)
INSTANTIATE(cl_double, DOUBLE_TYPE)
#undef INSTANTIATE
value_scalar::value_scalar(numeric_type dtype) : dtype_(dtype) {} value_scalar::value_scalar(numeric_type dtype) : dtype_(dtype) {}
value_scalar::value_scalar(scalar const & s) : dtype_(s.dtype()) { init(s); } value_scalar::value_scalar(scalar const & s) : dtype_(s.dtype()) { init(s); }
value_scalar::value_scalar(array_expression const &expr) : dtype_(expr.dtype()) { init(scalar(expr)); } value_scalar::value_scalar(array_expression const &expr) : dtype_(expr.dtype()) { init(scalar(expr)); }