low level representation of array
This commit is contained in:
@@ -42,19 +42,19 @@ tools::shared_ptr<mapped_object> base::map_functor::create(numeric_type dtype, v
|
||||
}
|
||||
|
||||
/** @brief Vector mapping */
|
||||
tools::shared_ptr<mapped_object> base::map_functor::create(array const & a) const
|
||||
tools::shared_ptr<mapped_object> base::map_functor::create(array_infos const & a) const
|
||||
{
|
||||
std::string dtype = numeric_type_to_string(a.dtype());
|
||||
unsigned int id = binder_.get(&a.data());
|
||||
if(max(a.shape())==1)
|
||||
std::string dtype = numeric_type_to_string(a.dtype);
|
||||
unsigned int id = binder_.get(a.data);
|
||||
if(a.shape1==1 && a.shape2==1)
|
||||
return tools::shared_ptr<mapped_object>(new mapped_scalar(dtype, id));
|
||||
else
|
||||
{
|
||||
//Column vector
|
||||
if(a.shape()._1>1 && a.shape()._2==1)
|
||||
if(a.shape1>1 && a.shape2==1)
|
||||
return tools::shared_ptr<mapped_object>(new mapped_array(dtype, id, 'c'));
|
||||
//Row vector
|
||||
else if(a.shape()._1==1 && a.shape()._2>1)
|
||||
else if(a.shape1==1 && a.shape2>1)
|
||||
return tools::shared_ptr<mapped_object>(new mapped_array(dtype, id, 'r'));
|
||||
//Matrix
|
||||
else
|
||||
@@ -72,9 +72,9 @@ tools::shared_ptr<mapped_object> base::map_functor::create(lhs_rhs_element const
|
||||
{
|
||||
switch(lhs_rhs.type_family)
|
||||
{
|
||||
case INFOS_TYPE_FAMILY: return create(*lhs_rhs.tuple);
|
||||
case INFOS_TYPE_FAMILY: return create(lhs_rhs.tuple);
|
||||
case VALUE_TYPE_FAMILY: return create(lhs_rhs.dtype, lhs_rhs.vscalar);
|
||||
case ARRAY_TYPE_FAMILY: return create(*lhs_rhs.array);
|
||||
case ARRAY_TYPE_FAMILY: return create(lhs_rhs.array);
|
||||
default: throw "";
|
||||
}
|
||||
}
|
||||
@@ -136,32 +136,32 @@ void base::set_arguments_functor::set_arguments(numeric_type dtype, values_holde
|
||||
}
|
||||
|
||||
/** @brief Vector mapping */
|
||||
void base::set_arguments_functor::set_arguments(array const & x) const
|
||||
void base::set_arguments_functor::set_arguments(array_infos const & x) const
|
||||
{
|
||||
bool is_bound = binder_.bind(&x.data());
|
||||
bool is_bound = binder_.bind(x.data);
|
||||
if (is_bound)
|
||||
{
|
||||
//scalar
|
||||
if(x.nshape()==0)
|
||||
if(x.shape1==1 && x.shape2==1)
|
||||
{
|
||||
kernel_.setArg(current_arg_++, x.data());
|
||||
kernel_.setArg(current_arg_++, x.data);
|
||||
}
|
||||
//array
|
||||
else
|
||||
{
|
||||
kernel_.setArg(current_arg_++, x.data());
|
||||
if(x.nshape()==1)
|
||||
kernel_.setArg(current_arg_++, x.data);
|
||||
if(x.shape1==1 || x.shape2==1)
|
||||
{
|
||||
kernel_.setArg(current_arg_++, cl_uint(max(x.start())));
|
||||
kernel_.setArg(current_arg_++, cl_uint(max(x.stride())));
|
||||
kernel_.setArg(current_arg_++, cl_uint(std::max(x.start1, x.start2)));
|
||||
kernel_.setArg(current_arg_++, cl_uint(std::max(x.stride1, x.stride2)));
|
||||
}
|
||||
else
|
||||
{
|
||||
kernel_.setArg(current_arg_++, cl_uint(x.ld()));
|
||||
kernel_.setArg(current_arg_++, cl_uint(x.start()._1));
|
||||
kernel_.setArg(current_arg_++, cl_uint(x.start()._2));
|
||||
kernel_.setArg(current_arg_++, cl_uint(x.stride()._1));
|
||||
kernel_.setArg(current_arg_++, cl_uint(x.stride()._2));
|
||||
kernel_.setArg(current_arg_++, cl_uint(x.ld));
|
||||
kernel_.setArg(current_arg_++, cl_uint(x.start1));
|
||||
kernel_.setArg(current_arg_++, cl_uint(x.start2));
|
||||
kernel_.setArg(current_arg_++, cl_uint(x.stride1));
|
||||
kernel_.setArg(current_arg_++, cl_uint(x.stride2));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -169,10 +169,10 @@ void base::set_arguments_functor::set_arguments(array const & x) const
|
||||
|
||||
void base::set_arguments_functor::set_arguments(repeat_infos const & i) const
|
||||
{
|
||||
kernel_.setArg(current_arg_++, cl_uint(i.sub._1));
|
||||
kernel_.setArg(current_arg_++, cl_uint(i.sub._2));
|
||||
kernel_.setArg(current_arg_++, cl_uint(i.rep._1));
|
||||
kernel_.setArg(current_arg_++, cl_uint(i.rep._2));
|
||||
kernel_.setArg(current_arg_++, cl_uint(i.sub1));
|
||||
kernel_.setArg(current_arg_++, cl_uint(i.sub2));
|
||||
kernel_.setArg(current_arg_++, cl_uint(i.rep1));
|
||||
kernel_.setArg(current_arg_++, cl_uint(i.rep2));
|
||||
}
|
||||
|
||||
void base::set_arguments_functor::set_arguments(lhs_rhs_element const & lhs_rhs) const
|
||||
@@ -180,8 +180,8 @@ void base::set_arguments_functor::set_arguments(lhs_rhs_element const & lhs_rhs)
|
||||
switch(lhs_rhs.type_family)
|
||||
{
|
||||
case VALUE_TYPE_FAMILY: return set_arguments(lhs_rhs.dtype, lhs_rhs.vscalar);
|
||||
case ARRAY_TYPE_FAMILY: return set_arguments(*lhs_rhs.array);
|
||||
case INFOS_TYPE_FAMILY: return set_arguments(*lhs_rhs.tuple);
|
||||
case ARRAY_TYPE_FAMILY: return set_arguments(lhs_rhs.array);
|
||||
case INFOS_TYPE_FAMILY: return set_arguments(lhs_rhs.tuple);
|
||||
default: throw "oh noez";
|
||||
}
|
||||
}
|
||||
@@ -376,7 +376,7 @@ bool base::has_strided_access(symbolic_expressions_container const & symbolic_ex
|
||||
{
|
||||
std::vector<lhs_rhs_element> arrays = filter_elements(DENSE_ARRAY_TYPE, **it);
|
||||
for (std::vector<lhs_rhs_element>::iterator itt = arrays.begin(); itt != arrays.end(); ++itt)
|
||||
if(max(itt->array->stride())>1)
|
||||
if(std::max(itt->array.stride1, itt->array.stride2)>1)
|
||||
return true;
|
||||
if(filter_nodes(&is_strided, **it, true).empty()==false)
|
||||
return true;
|
||||
@@ -388,13 +388,13 @@ int_t base::vector_size(symbolic_expression_node const & node)
|
||||
{
|
||||
using namespace tools;
|
||||
if (node.op.type==OPERATOR_MATRIX_DIAG_TYPE)
|
||||
return std::min<int_t>(node.lhs.array->shape()._1, node.lhs.array->shape()._2);
|
||||
return std::min<int_t>(node.lhs.array.shape1, node.lhs.array.shape2);
|
||||
else if (node.op.type==OPERATOR_MATRIX_ROW_TYPE)
|
||||
return node.lhs.array->shape()._2;
|
||||
return node.lhs.array.shape2;
|
||||
else if (node.op.type==OPERATOR_MATRIX_COLUMN_TYPE)
|
||||
return node.lhs.array->shape()._1;
|
||||
return node.lhs.array.shape1;
|
||||
else
|
||||
return max(node.lhs.array->shape());
|
||||
return std::max(node.lhs.array.shape1, node.lhs.array.shape2);
|
||||
|
||||
}
|
||||
|
||||
@@ -402,13 +402,13 @@ std::pair<int_t, int_t> base::matrix_size(symbolic_expression_node const & node)
|
||||
{
|
||||
if (node.op.type==OPERATOR_VDIAG_TYPE)
|
||||
{
|
||||
int_t size = node.lhs.array->shape()._1;
|
||||
int_t size = node.lhs.array.shape1;
|
||||
return std::make_pair(size,size);
|
||||
}
|
||||
else if(node.op.type==OPERATOR_REPEAT_TYPE)
|
||||
return std::make_pair(node.lhs.array->shape()._1*node.rhs.tuple->rep._1, node.lhs.array->shape()._2*node.rhs.tuple->rep._2);
|
||||
return std::make_pair(node.lhs.array.shape1*node.rhs.tuple.rep1, node.lhs.array.shape2*node.rhs.tuple.rep2);
|
||||
else
|
||||
return std::make_pair(node.lhs.array->shape()._1,node.lhs.array->shape()._2);
|
||||
return std::make_pair(node.lhs.array.shape1,node.lhs.array.shape2);
|
||||
}
|
||||
|
||||
void base::element_wise_loop_1D(kernel_generation_stream & stream, loop_body_base const & loop_body,
|
||||
@@ -482,7 +482,7 @@ unsigned int base::align(unsigned int to_round, unsigned int base)
|
||||
return (to_round + base - 1)/base * base;
|
||||
}
|
||||
|
||||
inline tools::shared_ptr<symbolic_binder> base::make_binder()
|
||||
tools::shared_ptr<symbolic_binder> base::make_binder()
|
||||
{
|
||||
if (binding_policy_==BIND_TO_HANDLE)
|
||||
return tools::shared_ptr<symbolic_binder>(new bind_to_handle());
|
||||
@@ -531,7 +531,7 @@ bool base_impl<TType, PType>::has_misaligned_offset(symbolic_expressions_contain
|
||||
{
|
||||
std::vector<lhs_rhs_element> arrays = filter_elements(DENSE_ARRAY_TYPE, **it);
|
||||
for (std::vector<lhs_rhs_element>::iterator itt = arrays.begin(); itt != arrays.end(); ++itt)
|
||||
if (max(itt->array->start())>0)
|
||||
if (itt->array.start1>0 || itt->array.start2>0)
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
|
Reference in New Issue
Block a user