API: more consistent zeros() initializer

This commit is contained in:
Philippe Tillet
2015-12-21 03:23:38 -05:00
parent da43f89ea4
commit 0d09b0518f
5 changed files with 33 additions and 36 deletions

View File

@@ -323,7 +323,7 @@ static const for_idx_t _i9{9};
//Initializers
ISAACAPI expression_tree eye(int_t, int_t, isaac::numeric_type, driver::Context const & context = driver::backend::contexts::get_default());
ISAACAPI expression_tree zeros(int_t M, int_t N, numeric_type dtype, driver::Context const & context = driver::backend::contexts::get_default());
ISAACAPI expression_tree zeros(shape_t const & shape, numeric_type dtype, driver::Context const & context = driver::backend::contexts::get_default());
//Swap
ISAACAPI void swap(view x, view y);

View File

@@ -708,8 +708,8 @@ array diag(array_base & x, int offset)
}
isaac::expression_tree zeros(int_t M, int_t N, isaac::numeric_type dtype, driver::Context const & ctx)
{ return expression_tree(value_scalar(0, dtype), invalid_node(), op_element(UNARY_TYPE_FAMILY, ADD_TYPE), ctx, dtype, {M, N}); }
isaac::expression_tree zeros(shape_t const & shape, isaac::numeric_type dtype, driver::Context const & ctx)
{ return expression_tree(value_scalar(0, dtype), invalid_node(), op_element(UNARY_TYPE_FAMILY, ADD_TYPE), ctx, dtype, shape); }
inline shape_t flip(shape_t const & shape)
{
@@ -926,7 +926,7 @@ expression_tree dot(LTYPE const & x, RTYPE const & y)\
if(x.shape().max()==1 || y.shape().max()==1)\
return x*y;\
if(x.dim()==2 && x.shape()[1]==0)\
return zeros(x.shape()[0], y.shape()[1], dtype, context);\
return zeros({x.shape()[0], y.shape()[1]}, dtype, context);\
if(x.shape()[0]==0 || (y.dim()==2 && y.shape()[1]==0))\
return expression_tree(invalid_node(), invalid_node(), op_element(UNARY_TYPE_FAMILY, INVALID_TYPE), context, dtype, {0});\
if(x.dim()==1 && y.dim()==1)\

View File

@@ -46,12 +46,12 @@ namespace tools
}
template<class T>
std::vector<T> to_vector(bp::list const & list)
std::vector<T> to_vector(bp::object const & iterable)
{
std::size_t len = bp::len(list);
std::size_t len = bp::len(iterable);
std::vector<T> res; res.reserve(len);
for(std::size_t i = 0 ; i < len ; ++i)
res.push_back(boost::python::extract<T>(list[i]));
res.push_back(boost::python::extract<T>(iterable[i]));
return res;
}

View File

@@ -137,45 +137,42 @@ namespace detail
throw;
}
inline void check_sizes(std::vector<int> s)
{
if(s.size() < 1 || s.size() > 2)
{
PyErr_SetString(PyExc_TypeError, "Only 1-D and 2-D arrays are supported!");
bp::throw_error_already_set();
}
}
std::shared_ptr<sc::array> create_array(bp::object const & obj, bp::object odtype, bp::object pycontext)
{
return ndarray_to_scarray(np::from_object(obj, to_np_dtype(tools::extract_dtype(odtype))), extract_context(pycontext));
}
std::shared_ptr<sc::array> create_zeros_array(sc::int_t M, sc::int_t N, bp::object odtype, bp::object pycontext)
std::shared_ptr<sc::array> create_zeros_array(bp::object pysizes, bp::object pydtype, bp::object pycontext)
{
return std::shared_ptr<sc::array>(new sc::array(sc::zeros(M, N, tools::extract_dtype(odtype), extract_context(pycontext))));
}
std::shared_ptr<sc::array> create_empty_array(bp::object sizes, bp::object odtype, bp::object pycontext)
{
typedef std::shared_ptr<sc::array> result_type;
std::size_t len;
int size1;
int size2;
try{
len = bp::len(sizes);
size1 = bp::extract<int>(sizes[0])();
size2 = bp::extract<int>(sizes[1])();
}catch(bp::error_already_set const &){
PyErr_Clear();
len = 1;
size1 = bp::extract<int>(sizes)();
}
sc::numeric_type dtype = tools::extract_dtype(odtype);
if(len < 1 || len > 2)
{
PyErr_SetString(PyExc_TypeError, "Only 1-D and 2-D arrays are supported!");
bp::throw_error_already_set();
}
std::vector<int> sizes = tools::to_vector<int>(pysizes);
sc::numeric_type dtype = tools::extract_dtype(pydtype);
sc::driver::Context const & context = extract_context(pycontext);
if(len==1)
return result_type(new sc::array(size1, dtype, context));
return result_type(new sc::array(size1, size2, dtype, context));
check_sizes(sizes);
if(sizes.size()==1)
return std::shared_ptr<sc::array>(new sc::array(sc::zeros({sizes[0]}, dtype, context)));
return std::shared_ptr<sc::array> (new sc::array(sc::zeros({sizes[0], sizes[1]}, dtype, context)));
}
std::shared_ptr<sc::array> create_empty_array(bp::object pysizes, bp::object pydtype, bp::object pycontext)
{
std::vector<int> sizes = tools::to_vector<int>(pysizes);
sc::numeric_type dtype = tools::extract_dtype(pydtype);
sc::driver::Context const & context = extract_context(pycontext);
check_sizes(sizes);
if(sizes.size()==1)
return std::shared_ptr<sc::array>(new sc::array(sizes[0], dtype, context));
return std::shared_ptr<sc::array> (new sc::array(sizes[0], sizes[1], dtype, context));
}
std::string type_name(bp::object const & obj)

View File

@@ -62,7 +62,7 @@ void test_impl(T epsilon, simple_vector_base<T> & cx, simple_vector_base<T>& cy,
}
if(interf == CPP)
{
RUN_TEST("z = 0", cz[i] = 0, z = zeros(N, 1, dtype, context))
RUN_TEST("z = 0", cz[i] = 0, z = sc::zeros({N}, dtype, context))
RUN_TEST("z = x", cz[i] = cx[i], z = x)
RUN_TEST("z = -x", cz[i] = -cx[i], z = -x)