2015-01-12 13:20:53 -05:00
|
|
|
#ifndef ATIDLAS_TYPES_H
|
|
|
|
#define ATIDLAS_TYPES_H
|
|
|
|
|
2015-01-27 16:14:02 -05:00
|
|
|
#include <CL/cl.hpp>
|
2015-02-01 23:56:05 -05:00
|
|
|
#include <list>
|
2015-01-19 21:29:47 -05:00
|
|
|
#include "atidlas/exception/unknown_datatype.h"
|
2015-01-12 13:20:53 -05:00
|
|
|
|
|
|
|
namespace atidlas
|
|
|
|
{
|
|
|
|
|
2015-01-18 14:52:45 -05:00
|
|
|
typedef int int_t;
|
|
|
|
|
|
|
|
struct size4
|
|
|
|
{
|
|
|
|
size4(int_t s1, int_t s2 = 1) : _1(s1), _2(s2){ }
|
|
|
|
int_t prod() const { return _1*_2; }
|
|
|
|
bool operator==(size4 const & other) const { return _1==other._1 && _2==other._2; }
|
|
|
|
int_t _1;
|
|
|
|
int_t _2;
|
|
|
|
};
|
|
|
|
inline int_t prod(size4 const & s) { return s._1*s._2; }
|
|
|
|
inline int_t max(size4 const & s) { return std::max(s._1, s._2); }
|
|
|
|
inline int_t min(size4 const & s) { return std::min(s._1, s._2); }
|
|
|
|
|
|
|
|
struct repeat_infos
|
|
|
|
{
|
|
|
|
int_t sub1;
|
|
|
|
int_t sub2;
|
|
|
|
int_t rep1;
|
|
|
|
int_t rep2;
|
|
|
|
};
|
2015-01-12 13:20:53 -05:00
|
|
|
|
|
|
|
|
|
|
|
enum numeric_type
|
|
|
|
{
|
|
|
|
INVALID_NUMERIC_TYPE = 0,
|
2015-01-29 15:19:40 -05:00
|
|
|
// BOOL_TYPE,
|
2015-01-12 13:20:53 -05:00
|
|
|
CHAR_TYPE,
|
|
|
|
UCHAR_TYPE,
|
|
|
|
SHORT_TYPE,
|
|
|
|
USHORT_TYPE,
|
|
|
|
INT_TYPE,
|
|
|
|
UINT_TYPE,
|
|
|
|
LONG_TYPE,
|
|
|
|
ULONG_TYPE,
|
2015-01-29 15:19:40 -05:00
|
|
|
// HALF_TYPE,
|
2015-01-12 13:20:53 -05:00
|
|
|
FLOAT_TYPE,
|
|
|
|
DOUBLE_TYPE
|
|
|
|
};
|
|
|
|
|
2015-02-01 17:15:41 -05:00
|
|
|
struct array_infos
|
|
|
|
{
|
|
|
|
numeric_type dtype;
|
|
|
|
cl_mem data;
|
|
|
|
int_t shape1;
|
|
|
|
int_t shape2;
|
|
|
|
int_t start1;
|
|
|
|
int_t start2;
|
|
|
|
int_t stride1;
|
|
|
|
int_t stride2;
|
|
|
|
int_t ld;
|
|
|
|
};
|
|
|
|
|
2015-02-01 23:56:05 -05:00
|
|
|
class operation_cache
|
|
|
|
{
|
|
|
|
struct infos
|
|
|
|
{
|
|
|
|
infos(cl::CommandQueue & q, cl::Kernel const & k, cl::NDRange const & off, cl::NDRange const & g, cl::NDRange const & l)
|
|
|
|
: queue(q), kernel(k), offset(off), grange(g), lrange(l) {}
|
|
|
|
|
|
|
|
cl::CommandQueue & queue;
|
|
|
|
cl::Kernel kernel;
|
|
|
|
cl::NDRange offset;
|
|
|
|
cl::NDRange grange;
|
|
|
|
cl::NDRange lrange;
|
|
|
|
};
|
|
|
|
|
|
|
|
public:
|
|
|
|
void push_back(cl::CommandQueue & queue, cl::Kernel const & kernel, cl::NDRange const & offset, cl::NDRange const & grange, cl::NDRange const & lrange)
|
|
|
|
{
|
|
|
|
l_.push_back(infos(queue, kernel, offset, grange, lrange));
|
|
|
|
}
|
|
|
|
|
|
|
|
void enqueue()
|
|
|
|
{
|
|
|
|
for(std::list<infos>::iterator it = l_.begin() ; it != l_.end() ; ++it)
|
|
|
|
it->queue.enqueueNDRangeKernel(it->kernel, it->offset, it->grange, it->lrange);
|
|
|
|
}
|
|
|
|
|
|
|
|
private:
|
|
|
|
std::list<infos> l_;
|
|
|
|
};
|
|
|
|
|
2015-01-12 13:20:53 -05:00
|
|
|
inline std::string numeric_type_to_string(numeric_type const & type)
|
|
|
|
{
|
|
|
|
switch (type)
|
|
|
|
{
|
2015-01-29 15:19:40 -05:00
|
|
|
// case BOOL_TYPE: return "bool";
|
2015-01-12 13:20:53 -05:00
|
|
|
case CHAR_TYPE: return "char";
|
|
|
|
case UCHAR_TYPE: return "uchar";
|
|
|
|
case SHORT_TYPE: return "short";
|
|
|
|
case USHORT_TYPE: return "ushort";
|
|
|
|
case INT_TYPE: return "int";
|
|
|
|
case UINT_TYPE: return "uint";
|
|
|
|
case LONG_TYPE: return "long";
|
|
|
|
case ULONG_TYPE: return "ulong";
|
2015-01-29 15:19:40 -05:00
|
|
|
// case HALF_TYPE : return "half";
|
2015-01-12 13:20:53 -05:00
|
|
|
case FLOAT_TYPE : return "float";
|
|
|
|
case DOUBLE_TYPE : return "double";
|
2015-01-19 21:29:47 -05:00
|
|
|
default : throw unknown_datatype(type);
|
2015-01-12 13:20:53 -05:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
template<class T> struct to_numeric_type;
|
2015-01-29 15:19:40 -05:00
|
|
|
//template<> struct to_numeric_type<cl_bool> { static const numeric_type value = BOOL_TYPE; };
|
2015-01-12 13:20:53 -05:00
|
|
|
template<> struct to_numeric_type<cl_char> { static const numeric_type value = CHAR_TYPE; };
|
|
|
|
template<> struct to_numeric_type<cl_uchar> { static const numeric_type value = UCHAR_TYPE; };
|
|
|
|
template<> struct to_numeric_type<cl_short> { static const numeric_type value = SHORT_TYPE; };
|
|
|
|
template<> struct to_numeric_type<cl_ushort> { static const numeric_type value = USHORT_TYPE; };
|
|
|
|
template<> struct to_numeric_type<cl_int> { static const numeric_type value = INT_TYPE; };
|
|
|
|
template<> struct to_numeric_type<cl_uint> { static const numeric_type value = UINT_TYPE; };
|
|
|
|
template<> struct to_numeric_type<cl_long> { static const numeric_type value = LONG_TYPE; };
|
|
|
|
template<> struct to_numeric_type<cl_ulong> { static const numeric_type value = ULONG_TYPE; };
|
2015-01-29 15:19:40 -05:00
|
|
|
//template<> struct to_numeric_type<cl_float> { static const numeric_type value = HALF_TYPE; };
|
2015-01-12 13:20:53 -05:00
|
|
|
template<> struct to_numeric_type<cl_float> { static const numeric_type value = FLOAT_TYPE; };
|
|
|
|
template<> struct to_numeric_type<cl_double> { static const numeric_type value = DOUBLE_TYPE; };
|
|
|
|
|
|
|
|
inline unsigned int size_of(numeric_type type)
|
|
|
|
{
|
|
|
|
switch (type)
|
|
|
|
{
|
2015-01-29 15:19:40 -05:00
|
|
|
// case BOOL_TYPE:
|
2015-01-12 13:20:53 -05:00
|
|
|
case UCHAR_TYPE:
|
|
|
|
case CHAR_TYPE: return 1;
|
|
|
|
|
2015-01-29 15:19:40 -05:00
|
|
|
// case HALF_TYPE:
|
2015-01-12 13:20:53 -05:00
|
|
|
case USHORT_TYPE:
|
2015-01-29 15:19:40 -05:00
|
|
|
case SHORT_TYPE: return 2;
|
2015-01-12 13:20:53 -05:00
|
|
|
|
|
|
|
case UINT_TYPE:
|
|
|
|
case INT_TYPE:
|
|
|
|
case FLOAT_TYPE: return 4;
|
|
|
|
|
|
|
|
case ULONG_TYPE:
|
|
|
|
case LONG_TYPE:
|
|
|
|
case DOUBLE_TYPE: return 8;
|
|
|
|
|
2015-01-19 21:29:47 -05:00
|
|
|
default: throw unknown_datatype(type);
|
2015-01-12 13:20:53 -05:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
enum expression_type
|
|
|
|
{
|
|
|
|
SCALAR_AXPY_TYPE,
|
|
|
|
VECTOR_AXPY_TYPE,
|
|
|
|
MATRIX_AXPY_TYPE,
|
|
|
|
REDUCTION_TYPE,
|
|
|
|
ROW_WISE_REDUCTION_TYPE,
|
|
|
|
COL_WISE_REDUCTION_TYPE,
|
|
|
|
MATRIX_PRODUCT_NN_TYPE,
|
|
|
|
MATRIX_PRODUCT_TN_TYPE,
|
|
|
|
MATRIX_PRODUCT_NT_TYPE,
|
|
|
|
MATRIX_PRODUCT_TT_TYPE,
|
|
|
|
INVALID_EXPRESSION_TYPE
|
|
|
|
};
|
|
|
|
|
|
|
|
struct slice
|
|
|
|
{
|
|
|
|
slice(int_t _start, int_t _end, int_t _stride = 1) : start(_start), size((_end - _start)/_stride), stride(_stride) { }
|
|
|
|
int_t start;
|
|
|
|
int_t size;
|
|
|
|
int_t stride;
|
|
|
|
};
|
|
|
|
typedef slice _;
|
|
|
|
|
|
|
|
class obj_base{};
|
|
|
|
|
|
|
|
}
|
|
|
|
#endif
|