Cleaning: Largely renamed templates to BLAS-like names
This commit is contained in:
@@ -2,11 +2,11 @@
|
||||
|
||||
#include "isaac/array.h"
|
||||
#include "isaac/backend/keywords.h"
|
||||
#include "isaac/backend/templates/vaxpy.h"
|
||||
#include "isaac/backend/templates/reduction.h"
|
||||
#include "isaac/backend/templates/maxpy.h"
|
||||
#include "isaac/backend/templates/mreduction.h"
|
||||
#include "isaac/backend/templates/mproduct.h"
|
||||
#include "isaac/backend/templates/axpy.h"
|
||||
#include "isaac/backend/templates/dot.h"
|
||||
#include "isaac/backend/templates/ger.h"
|
||||
#include "isaac/backend/templates/gemv.h"
|
||||
#include "isaac/backend/templates/gemm.h"
|
||||
#include "isaac/backend/templates/base.h"
|
||||
#include "isaac/backend/parse.h"
|
||||
#include "isaac/exception/operation_not_supported.h"
|
||||
@@ -17,6 +17,8 @@
|
||||
|
||||
namespace isaac
|
||||
{
|
||||
namespace templates
|
||||
{
|
||||
|
||||
base::parameters_type::parameters_type(unsigned int _simd_width, int_t _local_size_1, int_t _local_size_2, int_t _num_kernels) : simd_width(_simd_width), local_size_0(_local_size_1), local_size_1(_local_size_2), num_kernels(_num_kernels)
|
||||
{ }
|
||||
@@ -102,12 +104,12 @@ void base::map_functor::operator()(isaac::array_expression const & array_express
|
||||
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_matrix_row>(&array_expression, root_idx, &mapping_)));
|
||||
else if (root_node.op.type==OPERATOR_MATRIX_COLUMN_TYPE)
|
||||
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_matrix_column>(&array_expression, root_idx, &mapping_)));
|
||||
else if (detail::is_scalar_reduction(root_node))
|
||||
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_scalar_reduction>(&array_expression, root_idx, &mapping_)));
|
||||
else if (detail::is_vector_reduction(root_node))
|
||||
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_mreduction>(&array_expression, root_idx, &mapping_)));
|
||||
else if (root_node.op.type_family == OPERATOR_MATRIX_PRODUCT_TYPE_FAMILY)
|
||||
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_mproduct>(&array_expression, root_idx, &mapping_)));
|
||||
else if (detail::is_scalar_dot(root_node))
|
||||
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_scalar_dot>(&array_expression, root_idx, &mapping_)));
|
||||
else if (detail::is_vector_dot(root_node))
|
||||
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_gemv>(&array_expression, root_idx, &mapping_)));
|
||||
else if (root_node.op.type_family == OPERATOR_GEMM_TYPE_FAMILY)
|
||||
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_gemm>(&array_expression, root_idx, &mapping_)));
|
||||
else if (root_node.op.type == OPERATOR_REPEAT_TYPE)
|
||||
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_repeat>(&array_expression, root_idx, &mapping_)));
|
||||
else if (root_node.op.type == OPERATOR_OUTER_PROD_TYPE)
|
||||
@@ -198,7 +200,7 @@ void base::set_arguments_functor::operator()(isaac::array_expression const & arr
|
||||
set_arguments(root_node.rhs);
|
||||
}
|
||||
|
||||
void base::compute_reduction(kernel_generation_stream & os, std::string acc, std::string cur, op_element const & op)
|
||||
void base::compute_dot(kernel_generation_stream & os, std::string acc, std::string cur, op_element const & op)
|
||||
{
|
||||
if (detail::is_elementwise_function(op))
|
||||
os << acc << "=" << evaluate(op.type) << "(" << acc << "," << cur << ");" << std::endl;
|
||||
@@ -206,7 +208,7 @@ void base::compute_reduction(kernel_generation_stream & os, std::string acc, std
|
||||
os << acc << "= (" << acc << ")" << evaluate(op.type) << "(" << cur << ");" << std::endl;
|
||||
}
|
||||
|
||||
void base::compute_index_reduction(kernel_generation_stream & os, std::string acc, std::string cur, std::string const & acc_value, std::string const & cur_value, op_element const & op)
|
||||
void base::compute_index_dot(kernel_generation_stream & os, std::string acc, std::string cur, std::string const & acc_value, std::string const & cur_value, op_element const & op)
|
||||
{
|
||||
// os << acc << " = " << cur_value << ">" << acc_value << "?" << cur << ":" << acc << ";" << std::endl;
|
||||
os << acc << "= select(" << acc << "," << cur << "," << cur_value << ">" << acc_value << ");" << std::endl;
|
||||
@@ -259,7 +261,7 @@ std::string base::neutral_element(op_element const & op, driver::backend_type ba
|
||||
case OPERATOR_ELEMENT_MIN_TYPE : return INF;
|
||||
case OPERATOR_ELEMENT_ARGMIN_TYPE : return INF;
|
||||
|
||||
default: throw operation_not_supported_exception("Unsupported reduction operator : no neutral element known");
|
||||
default: throw operation_not_supported_exception("Unsupported dot operator : no neutral element known");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -399,14 +401,14 @@ std::pair<int_t, int_t> base::matrix_size(array_expression::node const & node)
|
||||
return std::make_pair(node.lhs.array->shape()[0],node.lhs.array->shape()[1]);
|
||||
}
|
||||
|
||||
bool base::is_reduction(array_expression::node const & node)
|
||||
bool base::is_dot(array_expression::node const & node)
|
||||
{
|
||||
return node.op.type_family==OPERATOR_VECTOR_REDUCTION_TYPE_FAMILY
|
||||
|| node.op.type_family==OPERATOR_COLUMNS_REDUCTION_TYPE_FAMILY
|
||||
|| node.op.type_family==OPERATOR_ROWS_REDUCTION_TYPE_FAMILY;
|
||||
return node.op.type_family==OPERATOR_VECTOR_DOT_TYPE_FAMILY
|
||||
|| node.op.type_family==OPERATOR_COLUMNS_DOT_TYPE_FAMILY
|
||||
|| node.op.type_family==OPERATOR_ROWS_DOT_TYPE_FAMILY;
|
||||
}
|
||||
|
||||
bool base::is_index_reduction(op_element const & op)
|
||||
bool base::is_index_dot(op_element const & op)
|
||||
{
|
||||
return op.type==OPERATOR_ELEMENT_ARGFMAX_TYPE
|
||||
|| op.type==OPERATOR_ELEMENT_ARGMAX_TYPE
|
||||
@@ -566,10 +568,11 @@ int base_impl<TType, PType>::is_invalid(expressions_tuple const & expressions, d
|
||||
return is_invalid_impl(device, expressions);
|
||||
}
|
||||
|
||||
template class base_impl<vaxpy, vaxpy_parameters>;
|
||||
template class base_impl<reduction, reduction_parameters>;
|
||||
template class base_impl<maxpy, maxpy_parameters>;
|
||||
template class base_impl<mreduction, mreduction_parameters>;
|
||||
template class base_impl<mproduct, mproduct_parameters>;
|
||||
template class base_impl<axpy, axpy_parameters>;
|
||||
template class base_impl<dot, dot_parameters>;
|
||||
template class base_impl<ger, ger_parameters>;
|
||||
template class base_impl<gemv, gemv_parameters>;
|
||||
template class base_impl<gemm, gemm_parameters>;
|
||||
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user