Some cleaning + outer product

This commit is contained in:
Philippe Tillet
2015-01-17 10:48:02 -05:00
parent 1d70396711
commit 0068560bc6
28 changed files with 317 additions and 255 deletions

View File

@@ -13,13 +13,14 @@ namespace detail
return op.type==OPERATOR_TRANS_TYPE
|| op.type_family==OPERATOR_MATRIX_PRODUCT_TYPE_FAMILY
|| op.type==OPERATOR_MATRIX_DIAG_TYPE
|| op.type==OPERATOR_VECTOR_DIAG_TYPE
|| op.type==OPERATOR_MATRIX_REPEAT_TYPE
|| op.type==OPERATOR_VDIAG_TYPE
|| op.type==OPERATOR_REPEAT_TYPE
|| op.type==OPERATOR_MATRIX_ROW_TYPE
|| op.type==OPERATOR_MATRIX_COLUMN_TYPE
|| op.type_family==OPERATOR_VECTOR_REDUCTION_TYPE_FAMILY
|| op.type_family==OPERATOR_ROWS_REDUCTION_TYPE_FAMILY
|| op.type_family==OPERATOR_COLUMNS_REDUCTION_TYPE_FAMILY;
|| op.type_family==OPERATOR_COLUMNS_REDUCTION_TYPE_FAMILY
|| op.type==OPERATOR_OUTER_PROD_TYPE;
}
bool is_scalar_reduction(symbolic_expression_node const & node)
@@ -211,7 +212,7 @@ const char * evaluate(operation_node_type type)
case OPERATOR_MATRIX_PRODUCT_TN_TYPE : return "prodTN";
case OPERATOR_MATRIX_PRODUCT_NT_TYPE : return "prodNT";
case OPERATOR_MATRIX_PRODUCT_TT_TYPE : return "prodTT";
case OPERATOR_VECTOR_DIAG_TYPE : return "vdiag";
case OPERATOR_VDIAG_TYPE : return "vdiag";
case OPERATOR_MATRIX_DIAG_TYPE : return "mdiag";
case OPERATOR_MATRIX_ROW_TYPE : return "row";
case OPERATOR_MATRIX_COLUMN_TYPE : return "col";
@@ -339,7 +340,7 @@ void evaluate(kernel_generation_stream & stream, leaf_t leaf, std::map<std::stri
stream << evaluate(leaf, accessors, **sit, (*sit)->root(), *mit) << ";" << std::endl;
}
process_traversal::process_traversal(std::multimap<std::string, std::string> const & accessors, kernel_generation_stream & stream,
process_traversal::process_traversal(std::map<std::string, std::string> const & accessors, kernel_generation_stream & stream,
mapping_type const & mapping, std::set<std::string> & already_processed) :
accessors_(accessors), stream_(stream), mapping_(mapping), already_processed_(already_processed)
{ }
@@ -351,19 +352,24 @@ void process_traversal::operator()(symbolic_expression const & /*symbolic_expres
{
mapped_object * obj = it->second.get();
std::string name = obj->name();
if(accessors_.find(name)!=accessors_.end() && already_processed_.insert(obj->process("#name")).second)
for(std::multimap<std::string, std::string>::const_iterator it = accessors_.lower_bound(name) ; it != accessors_.upper_bound(name) ; ++it)
stream_ << obj->process(it->second) << std::endl;
if(accessors_.find(name)!=accessors_.end() && already_processed_.insert(name).second)
for(std::map<std::string, std::string>::const_iterator itt = accessors_.lower_bound(name) ; itt != accessors_.upper_bound(name) ; ++itt)
{
stream_ << obj->process(itt->second) << std::endl;
}
std::string key = obj->type_key();
if(accessors_.find(key)!=accessors_.end() && already_processed_.insert(obj->process("#name")).second)
for(std::multimap<std::string, std::string>::const_iterator it = accessors_.lower_bound(key) ; it != accessors_.upper_bound(key) ; ++it)
stream_ << obj->process(it->second) << std::endl;
if(accessors_.find(key)!=accessors_.end() && already_processed_.insert(name).second)
for(std::map<std::string, std::string>::const_iterator itt = accessors_.lower_bound(key) ; itt != accessors_.upper_bound(key) ; ++itt)
{
stream_ << obj->process(itt->second) << std::endl;
}
}
}
void process(kernel_generation_stream & stream, leaf_t leaf, std::multimap<std::string, std::string> const & accessors,
void process(kernel_generation_stream & stream, leaf_t leaf, std::map<std::string, std::string> const & accessors,
atidlas::symbolic_expression const & symbolic_expression, size_t root_idx, mapping_type const & mapping, std::set<std::string> & already_processed)
{
process_traversal traversal_functor(accessors, stream, mapping, already_processed);
@@ -389,7 +395,7 @@ void process(kernel_generation_stream & stream, leaf_t leaf, std::multimap<std::
}
}
void process(kernel_generation_stream & stream, leaf_t leaf, std::multimap<std::string, std::string> const & accessors,
void process(kernel_generation_stream & stream, leaf_t leaf, std::map<std::string, std::string> const & accessors,
symbolic_expressions_container const & symbolic_expressions, std::vector<mapping_type> const & mappings)
{
symbolic_expressions_container::data_type::const_iterator sit;