API enhancement
This commit is contained in:
@@ -114,7 +114,7 @@ public:
|
|||||||
|
|
||||||
|
|
||||||
atidlas::array_expression eye(std::size_t, std::size_t, atidlas::numeric_type, cl::Context ctx = cl::default_context());
|
atidlas::array_expression eye(std::size_t, std::size_t, atidlas::numeric_type, cl::Context ctx = cl::default_context());
|
||||||
array_expression zeros(std::size_t N, numeric_type dtype);
|
array_expression zeros(std::size_t M, std::size_t N, numeric_type dtype, cl::Context ctx = cl::default_context());
|
||||||
array reshape(array const &, int_t, int_t);
|
array reshape(array const &, int_t, int_t);
|
||||||
|
|
||||||
//copy
|
//copy
|
||||||
@@ -153,8 +153,8 @@ ATIDLAS_DECLARE_ELEMENT_BINARY_OPERATOR(operator <=)
|
|||||||
ATIDLAS_DECLARE_ELEMENT_BINARY_OPERATOR(operator ==)
|
ATIDLAS_DECLARE_ELEMENT_BINARY_OPERATOR(operator ==)
|
||||||
ATIDLAS_DECLARE_ELEMENT_BINARY_OPERATOR(operator !=)
|
ATIDLAS_DECLARE_ELEMENT_BINARY_OPERATOR(operator !=)
|
||||||
|
|
||||||
ATIDLAS_DECLARE_ELEMENT_BINARY_OPERATOR(max)
|
ATIDLAS_DECLARE_ELEMENT_BINARY_OPERATOR(maximum)
|
||||||
ATIDLAS_DECLARE_ELEMENT_BINARY_OPERATOR(min)
|
ATIDLAS_DECLARE_ELEMENT_BINARY_OPERATOR(minimum)
|
||||||
ATIDLAS_DECLARE_ELEMENT_BINARY_OPERATOR(pow)
|
ATIDLAS_DECLARE_ELEMENT_BINARY_OPERATOR(pow)
|
||||||
|
|
||||||
ATIDLAS_DECLARE_ELEMENT_BINARY_OPERATOR(dot)
|
ATIDLAS_DECLARE_ELEMENT_BINARY_OPERATOR(dot)
|
||||||
|
@@ -422,8 +422,8 @@ DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_LEQ_TYPE, operator <=)
|
|||||||
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_EQ_TYPE, operator ==)
|
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_EQ_TYPE, operator ==)
|
||||||
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_NEQ_TYPE, operator !=)
|
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_NEQ_TYPE, operator !=)
|
||||||
|
|
||||||
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_MAX_TYPE, max)
|
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_MAX_TYPE, maximum)
|
||||||
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_MIN_TYPE, min)
|
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_MIN_TYPE, minimum)
|
||||||
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_POW_TYPE, pow)
|
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_POW_TYPE, pow)
|
||||||
|
|
||||||
array_expression outer(array const & x, array const & y)
|
array_expression outer(array const & x, array const & y)
|
||||||
@@ -476,6 +476,11 @@ atidlas::array_expression eye(std::size_t M, std::size_t N, atidlas::numeric_typ
|
|||||||
return array_expression(value_scalar(1), value_scalar(0), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_VDIAG_TYPE), ctx, dtype, size4(M, N));
|
return array_expression(value_scalar(1), value_scalar(0), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_VDIAG_TYPE), ctx, dtype, size4(M, N));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
atidlas::array_expression zeros(std::size_t M, std::size_t N, atidlas::numeric_type dtype, cl::Context ctx)
|
||||||
|
{
|
||||||
|
return array_expression(value_scalar(0), lhs_rhs_element(), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_ADD_TYPE), ctx, dtype, size4(M, N));
|
||||||
|
}
|
||||||
|
|
||||||
inline size4 trans(size4 const & shape)
|
inline size4 trans(size4 const & shape)
|
||||||
{ return size4(shape._2, shape._1);}
|
{ return size4(shape._2, shape._1);}
|
||||||
|
|
||||||
|
@@ -277,14 +277,17 @@ void evaluate_expression_traversal::operator()(atidlas::symbolic_expression cons
|
|||||||
{
|
{
|
||||||
symbolic_expression_node const & root_node = symbolic_expression.tree()[root_idx];
|
symbolic_expression_node const & root_node = symbolic_expression.tree()[root_idx];
|
||||||
mapping_type::key_type key = std::make_pair(root_idx, leaf);
|
mapping_type::key_type key = std::make_pair(root_idx, leaf);
|
||||||
if (leaf==PARENT_NODE_TYPE && root_node.op.type_family!=OPERATOR_UNARY_TYPE_FAMILY)
|
if (leaf==PARENT_NODE_TYPE)
|
||||||
{
|
{
|
||||||
if (detail::is_node_leaf(root_node.op))
|
if (detail::is_node_leaf(root_node.op))
|
||||||
str_ += mapping_.at(key)->evaluate(accessors_);
|
str_ += mapping_.at(key)->evaluate(accessors_);
|
||||||
else if (detail::is_elementwise_operator(root_node.op))
|
else if(root_node.op.type_family!=OPERATOR_UNARY_TYPE_FAMILY)
|
||||||
str_ += evaluate(root_node.op.type);
|
{
|
||||||
else if (detail::is_elementwise_function(root_node.op))
|
if (detail::is_elementwise_operator(root_node.op))
|
||||||
str_ += ",";
|
str_ += evaluate(root_node.op.type);
|
||||||
|
else if (detail::is_elementwise_function(root_node.op))
|
||||||
|
str_ += ",";
|
||||||
|
}
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
|
@@ -1,7 +1,7 @@
|
|||||||
#include "atidlas/backend/templates/maxpy.h"
|
#include "atidlas/backend/templates/maxpy.h"
|
||||||
#include "atidlas/tools/make_map.hpp"
|
#include "atidlas/tools/make_map.hpp"
|
||||||
#include "atidlas/tools/make_vector.hpp"
|
#include "atidlas/tools/make_vector.hpp"
|
||||||
|
#include "atidlas/symbolic/io.h"
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
|
|
||||||
namespace atidlas
|
namespace atidlas
|
||||||
@@ -74,7 +74,7 @@ std::string maxpy::generate_impl(unsigned int label, symbolic_expressions_contai
|
|||||||
stream << "}" << std::endl;
|
stream << "}" << std::endl;
|
||||||
|
|
||||||
// std::cout << stream.str() << std::endl;
|
// std::cout << stream.str() << std::endl;
|
||||||
// std::cout << symbolic_expressions.data().front() << std::endl;
|
// std::cout << to_string(*symbolic_expressions.data().front()) << std::endl;
|
||||||
return stream.str();
|
return stream.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -26,7 +26,7 @@ inline std::string to_string(lhs_rhs_element const & e)
|
|||||||
{
|
{
|
||||||
return"COMPOSITE [" + tools::to_string(e.node_index) + "]";
|
return"COMPOSITE [" + tools::to_string(e.node_index) + "]";
|
||||||
}
|
}
|
||||||
return tools::to_string(e.dtype);
|
return tools::to_string(e.subtype);
|
||||||
}
|
}
|
||||||
|
|
||||||
inline std::ostream & operator<<(std::ostream & os, symbolic_expression_node const & s_node)
|
inline std::ostream & operator<<(std::ostream & os, symbolic_expression_node const & s_node)
|
||||||
|
@@ -13,7 +13,8 @@ void test_element_wise_vector(T epsilon, simple_vector_base<T> & cx, simple_vect
|
|||||||
using namespace std;
|
using namespace std;
|
||||||
|
|
||||||
int failure_count = 0;
|
int failure_count = 0;
|
||||||
ad::cl::Context const & ctx = z.context();
|
ad::numeric_type dtype = x.dtype();
|
||||||
|
ad::cl::Context const & ctx = x.context();
|
||||||
|
|
||||||
int_t N = cz.size();
|
int_t N = cz.size();
|
||||||
|
|
||||||
@@ -38,6 +39,8 @@ void test_element_wise_vector(T epsilon, simple_vector_base<T> & cx, simple_vect
|
|||||||
std::cout << std::endl;\
|
std::cout << std::endl;\
|
||||||
}
|
}
|
||||||
|
|
||||||
|
RUN_TEST_VECTOR_AXPY("z = 0", cz[i] = 0, z = zeros(N, 1, dtype, ctx))
|
||||||
|
|
||||||
RUN_TEST_VECTOR_AXPY("z = x", cz[i] = cx[i], z = x)
|
RUN_TEST_VECTOR_AXPY("z = x", cz[i] = cx[i], z = x)
|
||||||
RUN_TEST_VECTOR_AXPY("z = -x", cz[i] = -cx[i], z = -x)
|
RUN_TEST_VECTOR_AXPY("z = -x", cz[i] = -cx[i], z = -x)
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user