API enhancement

This commit is contained in:
Philippe Tillet
2015-01-20 11:17:42 -05:00
parent 4f73fb384f
commit e74563070a
6 changed files with 25 additions and 14 deletions

View File

@@ -114,7 +114,7 @@ public:
atidlas::array_expression eye(std::size_t, std::size_t, atidlas::numeric_type, cl::Context ctx = cl::default_context()); atidlas::array_expression eye(std::size_t, std::size_t, atidlas::numeric_type, cl::Context ctx = cl::default_context());
array_expression zeros(std::size_t N, numeric_type dtype); array_expression zeros(std::size_t M, std::size_t N, numeric_type dtype, cl::Context ctx = cl::default_context());
array reshape(array const &, int_t, int_t); array reshape(array const &, int_t, int_t);
//copy //copy
@@ -153,8 +153,8 @@ ATIDLAS_DECLARE_ELEMENT_BINARY_OPERATOR(operator <=)
ATIDLAS_DECLARE_ELEMENT_BINARY_OPERATOR(operator ==) ATIDLAS_DECLARE_ELEMENT_BINARY_OPERATOR(operator ==)
ATIDLAS_DECLARE_ELEMENT_BINARY_OPERATOR(operator !=) ATIDLAS_DECLARE_ELEMENT_BINARY_OPERATOR(operator !=)
ATIDLAS_DECLARE_ELEMENT_BINARY_OPERATOR(max) ATIDLAS_DECLARE_ELEMENT_BINARY_OPERATOR(maximum)
ATIDLAS_DECLARE_ELEMENT_BINARY_OPERATOR(min) ATIDLAS_DECLARE_ELEMENT_BINARY_OPERATOR(minimum)
ATIDLAS_DECLARE_ELEMENT_BINARY_OPERATOR(pow) ATIDLAS_DECLARE_ELEMENT_BINARY_OPERATOR(pow)
ATIDLAS_DECLARE_ELEMENT_BINARY_OPERATOR(dot) ATIDLAS_DECLARE_ELEMENT_BINARY_OPERATOR(dot)

View File

@@ -422,8 +422,8 @@ DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_LEQ_TYPE, operator <=)
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_EQ_TYPE, operator ==) DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_EQ_TYPE, operator ==)
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_NEQ_TYPE, operator !=) DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_NEQ_TYPE, operator !=)
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_MAX_TYPE, max) DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_MAX_TYPE, maximum)
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_MIN_TYPE, min) DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_MIN_TYPE, minimum)
DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_POW_TYPE, pow) DEFINE_ELEMENT_BINARY_OPERATOR(OPERATOR_ELEMENT_POW_TYPE, pow)
array_expression outer(array const & x, array const & y) array_expression outer(array const & x, array const & y)
@@ -476,6 +476,11 @@ atidlas::array_expression eye(std::size_t M, std::size_t N, atidlas::numeric_typ
return array_expression(value_scalar(1), value_scalar(0), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_VDIAG_TYPE), ctx, dtype, size4(M, N)); return array_expression(value_scalar(1), value_scalar(0), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_VDIAG_TYPE), ctx, dtype, size4(M, N));
} }
atidlas::array_expression zeros(std::size_t M, std::size_t N, atidlas::numeric_type dtype, cl::Context ctx)
{
return array_expression(value_scalar(0), lhs_rhs_element(), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_ADD_TYPE), ctx, dtype, size4(M, N));
}
inline size4 trans(size4 const & shape) inline size4 trans(size4 const & shape)
{ return size4(shape._2, shape._1);} { return size4(shape._2, shape._1);}

View File

@@ -277,14 +277,17 @@ void evaluate_expression_traversal::operator()(atidlas::symbolic_expression cons
{ {
symbolic_expression_node const & root_node = symbolic_expression.tree()[root_idx]; symbolic_expression_node const & root_node = symbolic_expression.tree()[root_idx];
mapping_type::key_type key = std::make_pair(root_idx, leaf); mapping_type::key_type key = std::make_pair(root_idx, leaf);
if (leaf==PARENT_NODE_TYPE && root_node.op.type_family!=OPERATOR_UNARY_TYPE_FAMILY) if (leaf==PARENT_NODE_TYPE)
{ {
if (detail::is_node_leaf(root_node.op)) if (detail::is_node_leaf(root_node.op))
str_ += mapping_.at(key)->evaluate(accessors_); str_ += mapping_.at(key)->evaluate(accessors_);
else if (detail::is_elementwise_operator(root_node.op)) else if(root_node.op.type_family!=OPERATOR_UNARY_TYPE_FAMILY)
str_ += evaluate(root_node.op.type); {
else if (detail::is_elementwise_function(root_node.op)) if (detail::is_elementwise_operator(root_node.op))
str_ += ","; str_ += evaluate(root_node.op.type);
else if (detail::is_elementwise_function(root_node.op))
str_ += ",";
}
} }
else else
{ {

View File

@@ -1,7 +1,7 @@
#include "atidlas/backend/templates/maxpy.h" #include "atidlas/backend/templates/maxpy.h"
#include "atidlas/tools/make_map.hpp" #include "atidlas/tools/make_map.hpp"
#include "atidlas/tools/make_vector.hpp" #include "atidlas/tools/make_vector.hpp"
#include "atidlas/symbolic/io.h"
#include <iostream> #include <iostream>
namespace atidlas namespace atidlas
@@ -74,7 +74,7 @@ std::string maxpy::generate_impl(unsigned int label, symbolic_expressions_contai
stream << "}" << std::endl; stream << "}" << std::endl;
// std::cout << stream.str() << std::endl; // std::cout << stream.str() << std::endl;
// std::cout << symbolic_expressions.data().front() << std::endl; // std::cout << to_string(*symbolic_expressions.data().front()) << std::endl;
return stream.str(); return stream.str();
} }

View File

@@ -26,7 +26,7 @@ inline std::string to_string(lhs_rhs_element const & e)
{ {
return"COMPOSITE [" + tools::to_string(e.node_index) + "]"; return"COMPOSITE [" + tools::to_string(e.node_index) + "]";
} }
return tools::to_string(e.dtype); return tools::to_string(e.subtype);
} }
inline std::ostream & operator<<(std::ostream & os, symbolic_expression_node const & s_node) inline std::ostream & operator<<(std::ostream & os, symbolic_expression_node const & s_node)

View File

@@ -13,7 +13,8 @@ void test_element_wise_vector(T epsilon, simple_vector_base<T> & cx, simple_vect
using namespace std; using namespace std;
int failure_count = 0; int failure_count = 0;
ad::cl::Context const & ctx = z.context(); ad::numeric_type dtype = x.dtype();
ad::cl::Context const & ctx = x.context();
int_t N = cz.size(); int_t N = cz.size();
@@ -38,6 +39,8 @@ void test_element_wise_vector(T epsilon, simple_vector_base<T> & cx, simple_vect
std::cout << std::endl;\ std::cout << std::endl;\
} }
RUN_TEST_VECTOR_AXPY("z = 0", cz[i] = 0, z = zeros(N, 1, dtype, ctx))
RUN_TEST_VECTOR_AXPY("z = x", cz[i] = cx[i], z = x) RUN_TEST_VECTOR_AXPY("z = x", cz[i] = cx[i], z = x)
RUN_TEST_VECTOR_AXPY("z = -x", cz[i] = -cx[i], z = -x) RUN_TEST_VECTOR_AXPY("z = -x", cz[i] = -cx[i], z = -x)