API enhancement
This commit is contained in:
@@ -114,7 +114,7 @@ public:
|
||||
|
||||
|
||||
atidlas::array_expression eye(std::size_t, std::size_t, atidlas::numeric_type, cl::Context ctx = cl::default_context());
|
||||
array_expression zeros(std::size_t N, numeric_type dtype);
|
||||
array_expression zeros(std::size_t M, std::size_t N, numeric_type dtype, cl::Context ctx = cl::default_context());
|
||||
array reshape(array const &, int_t, int_t);
|
||||
|
||||
//copy
|
||||
@@ -153,8 +153,8 @@ ATIDLAS_DECLARE_ELEMENT_BINARY_OPERATOR(operator <=)
|
||||
ATIDLAS_DECLARE_ELEMENT_BINARY_OPERATOR(operator ==)
|
||||
ATIDLAS_DECLARE_ELEMENT_BINARY_OPERATOR(operator !=)
|
||||
|
||||
ATIDLAS_DECLARE_ELEMENT_BINARY_OPERATOR(max)
|
||||
ATIDLAS_DECLARE_ELEMENT_BINARY_OPERATOR(min)
|
||||
ATIDLAS_DECLARE_ELEMENT_BINARY_OPERATOR(maximum)
|
||||
ATIDLAS_DECLARE_ELEMENT_BINARY_OPERATOR(minimum)
|
||||
ATIDLAS_DECLARE_ELEMENT_BINARY_OPERATOR(pow)
|
||||
|
||||
ATIDLAS_DECLARE_ELEMENT_BINARY_OPERATOR(dot)
|
||||
|
@@ -422,8 +422,8 @@ DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_LEQ_TYPE, operator <=)
|
||||
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_EQ_TYPE, operator ==)
|
||||
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_NEQ_TYPE, operator !=)
|
||||
|
||||
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_MAX_TYPE, max)
|
||||
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_MIN_TYPE, min)
|
||||
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_MAX_TYPE, maximum)
|
||||
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_MIN_TYPE, minimum)
|
||||
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_POW_TYPE, pow)
|
||||
|
||||
array_expression outer(array const & x, array const & y)
|
||||
@@ -476,6 +476,11 @@ atidlas::array_expression eye(std::size_t M, std::size_t N, atidlas::numeric_typ
|
||||
return array_expression(value_scalar(1), value_scalar(0), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_VDIAG_TYPE), ctx, dtype, size4(M, N));
|
||||
}
|
||||
|
||||
atidlas::array_expression zeros(std::size_t M, std::size_t N, atidlas::numeric_type dtype, cl::Context ctx)
|
||||
{
|
||||
return array_expression(value_scalar(0), lhs_rhs_element(), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_ADD_TYPE), ctx, dtype, size4(M, N));
|
||||
}
|
||||
|
||||
inline size4 trans(size4 const & shape)
|
||||
{ return size4(shape._2, shape._1);}
|
||||
|
||||
|
@@ -277,15 +277,18 @@ void evaluate_expression_traversal::operator()(atidlas::symbolic_expression cons
|
||||
{
|
||||
symbolic_expression_node const & root_node = symbolic_expression.tree()[root_idx];
|
||||
mapping_type::key_type key = std::make_pair(root_idx, leaf);
|
||||
if (leaf==PARENT_NODE_TYPE && root_node.op.type_family!=OPERATOR_UNARY_TYPE_FAMILY)
|
||||
if (leaf==PARENT_NODE_TYPE)
|
||||
{
|
||||
if (detail::is_node_leaf(root_node.op))
|
||||
str_ += mapping_.at(key)->evaluate(accessors_);
|
||||
else if (detail::is_elementwise_operator(root_node.op))
|
||||
else if(root_node.op.type_family!=OPERATOR_UNARY_TYPE_FAMILY)
|
||||
{
|
||||
if (detail::is_elementwise_operator(root_node.op))
|
||||
str_ += evaluate(root_node.op.type);
|
||||
else if (detail::is_elementwise_function(root_node.op))
|
||||
str_ += ",";
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if (leaf==LHS_NODE_TYPE)
|
||||
|
@@ -1,7 +1,7 @@
|
||||
#include "atidlas/backend/templates/maxpy.h"
|
||||
#include "atidlas/tools/make_map.hpp"
|
||||
#include "atidlas/tools/make_vector.hpp"
|
||||
|
||||
#include "atidlas/symbolic/io.h"
|
||||
#include <iostream>
|
||||
|
||||
namespace atidlas
|
||||
@@ -74,7 +74,7 @@ std::string maxpy::generate_impl(unsigned int label, symbolic_expressions_contai
|
||||
stream << "}" << std::endl;
|
||||
|
||||
// std::cout << stream.str() << std::endl;
|
||||
// std::cout << symbolic_expressions.data().front() << std::endl;
|
||||
// std::cout << to_string(*symbolic_expressions.data().front()) << std::endl;
|
||||
return stream.str();
|
||||
}
|
||||
|
||||
|
@@ -26,7 +26,7 @@ inline std::string to_string(lhs_rhs_element const & e)
|
||||
{
|
||||
return"COMPOSITE [" + tools::to_string(e.node_index) + "]";
|
||||
}
|
||||
return tools::to_string(e.dtype);
|
||||
return tools::to_string(e.subtype);
|
||||
}
|
||||
|
||||
inline std::ostream & operator<<(std::ostream & os, symbolic_expression_node const & s_node)
|
||||
|
@@ -13,7 +13,8 @@ void test_element_wise_vector(T epsilon, simple_vector_base<T> & cx, simple_vect
|
||||
using namespace std;
|
||||
|
||||
int failure_count = 0;
|
||||
ad::cl::Context const & ctx = z.context();
|
||||
ad::numeric_type dtype = x.dtype();
|
||||
ad::cl::Context const & ctx = x.context();
|
||||
|
||||
int_t N = cz.size();
|
||||
|
||||
@@ -38,6 +39,8 @@ void test_element_wise_vector(T epsilon, simple_vector_base<T> & cx, simple_vect
|
||||
std::cout << std::endl;\
|
||||
}
|
||||
|
||||
RUN_TEST_VECTOR_AXPY("z = 0", cz[i] = 0, z = zeros(N, 1, dtype, ctx))
|
||||
|
||||
RUN_TEST_VECTOR_AXPY("z = x", cz[i] = cx[i], z = x)
|
||||
RUN_TEST_VECTOR_AXPY("z = -x", cz[i] = -cx[i], z = -x)
|
||||
|
||||
|
Reference in New Issue
Block a user