Cleaning: Largely renamed templates to BLAS-like names
This commit is contained in:
@@ -102,23 +102,23 @@ std::string binary_leaf::evaluate_recursive(leaf_t leaf, std::map<std::string, s
|
||||
}
|
||||
|
||||
|
||||
mapped_mproduct::mapped_mproduct(std::string const & scalartype, unsigned int id, node_info info) : mapped_object(scalartype, id, "mproduct"), binary_leaf(info) { }
|
||||
mapped_gemm::mapped_gemm(std::string const & scalartype, unsigned int id, node_info info) : mapped_object(scalartype, id, "gemm"), binary_leaf(info) { }
|
||||
|
||||
//
|
||||
mapped_reduction::mapped_reduction(std::string const & scalartype, unsigned int id, node_info info, std::string const & type_key) :
|
||||
mapped_dot::mapped_dot(std::string const & scalartype, unsigned int id, node_info info, std::string const & type_key) :
|
||||
mapped_object(scalartype, id, type_key), binary_leaf(info)
|
||||
{ }
|
||||
|
||||
int_t mapped_reduction::root_idx() const
|
||||
int_t mapped_dot::root_idx() const
|
||||
{ return info_.root_idx; }
|
||||
|
||||
isaac::array_expression const & mapped_reduction::array_expression() const
|
||||
isaac::array_expression const & mapped_dot::array_expression() const
|
||||
{ return *info_.array_expression; }
|
||||
|
||||
array_expression::node mapped_reduction::root_node() const
|
||||
array_expression::node mapped_dot::root_node() const
|
||||
{ return array_expression().tree()[root_idx()]; }
|
||||
|
||||
bool mapped_reduction::is_index_reduction() const
|
||||
bool mapped_dot::is_index_dot() const
|
||||
{
|
||||
op_element const & op = root_op();
|
||||
return op.type==OPERATOR_ELEMENT_ARGFMAX_TYPE
|
||||
@@ -127,17 +127,17 @@ bool mapped_reduction::is_index_reduction() const
|
||||
|| op.type==OPERATOR_ELEMENT_ARGMIN_TYPE;
|
||||
}
|
||||
|
||||
op_element mapped_reduction::root_op() const
|
||||
op_element mapped_dot::root_op() const
|
||||
{
|
||||
return info_.array_expression->tree()[info_.root_idx].op;
|
||||
}
|
||||
|
||||
|
||||
//
|
||||
mapped_scalar_reduction::mapped_scalar_reduction(std::string const & scalartype, unsigned int id, node_info info) : mapped_reduction(scalartype, id, info, "scalar_reduction"){ }
|
||||
mapped_scalar_dot::mapped_scalar_dot(std::string const & scalartype, unsigned int id, node_info info) : mapped_dot(scalartype, id, info, "scalar_dot"){ }
|
||||
|
||||
//
|
||||
mapped_mreduction::mapped_mreduction(std::string const & scalartype, unsigned int id, node_info info) : mapped_reduction(scalartype, id, info, "mreduction") { }
|
||||
mapped_gemv::mapped_gemv(std::string const & scalartype, unsigned int id, node_info info) : mapped_dot(scalartype, id, info, "gemv") { }
|
||||
|
||||
//
|
||||
void mapped_host_scalar::preprocess(std::string & str) const
|
||||
|
@@ -10,15 +10,15 @@ namespace detail
|
||||
|
||||
|
||||
|
||||
bool is_scalar_reduction(array_expression::node const & node)
|
||||
bool is_scalar_dot(array_expression::node const & node)
|
||||
{
|
||||
return node.op.type_family==OPERATOR_VECTOR_REDUCTION_TYPE_FAMILY;
|
||||
return node.op.type_family==OPERATOR_VECTOR_DOT_TYPE_FAMILY;
|
||||
}
|
||||
|
||||
bool is_vector_reduction(array_expression::node const & node)
|
||||
bool is_vector_dot(array_expression::node const & node)
|
||||
{
|
||||
return node.op.type_family==OPERATOR_ROWS_REDUCTION_TYPE_FAMILY
|
||||
|| node.op.type_family==OPERATOR_COLUMNS_REDUCTION_TYPE_FAMILY;
|
||||
return node.op.type_family==OPERATOR_ROWS_DOT_TYPE_FAMILY
|
||||
|| node.op.type_family==OPERATOR_COLUMNS_DOT_TYPE_FAMILY;
|
||||
}
|
||||
|
||||
bool is_assignment(op_element const & op)
|
||||
@@ -75,10 +75,10 @@ namespace detail
|
||||
|| op.type==OPERATOR_MATRIX_ROW_TYPE
|
||||
|| op.type==OPERATOR_MATRIX_COLUMN_TYPE
|
||||
|| op.type==OPERATOR_OUTER_PROD_TYPE
|
||||
|| op.type_family==OPERATOR_VECTOR_REDUCTION_TYPE_FAMILY
|
||||
|| op.type_family==OPERATOR_ROWS_REDUCTION_TYPE_FAMILY
|
||||
|| op.type_family==OPERATOR_COLUMNS_REDUCTION_TYPE_FAMILY
|
||||
|| op.type_family==OPERATOR_MATRIX_PRODUCT_TYPE_FAMILY
|
||||
|| op.type_family==OPERATOR_VECTOR_DOT_TYPE_FAMILY
|
||||
|| op.type_family==OPERATOR_ROWS_DOT_TYPE_FAMILY
|
||||
|| op.type_family==OPERATOR_COLUMNS_DOT_TYPE_FAMILY
|
||||
|| op.type_family==OPERATOR_GEMM_TYPE_FAMILY
|
||||
;
|
||||
}
|
||||
|
||||
@@ -214,10 +214,10 @@ const char * evaluate(operation_node_type type)
|
||||
case OPERATOR_ELEMENT_MIN_TYPE : return "min";
|
||||
|
||||
//Binary
|
||||
case OPERATOR_MATRIX_PRODUCT_NN_TYPE : return "prodNN";
|
||||
case OPERATOR_MATRIX_PRODUCT_TN_TYPE : return "prodTN";
|
||||
case OPERATOR_MATRIX_PRODUCT_NT_TYPE : return "prodNT";
|
||||
case OPERATOR_MATRIX_PRODUCT_TT_TYPE : return "prodTT";
|
||||
case OPERATOR_GEMM_NN_TYPE : return "prodNN";
|
||||
case OPERATOR_GEMM_TN_TYPE : return "prodTN";
|
||||
case OPERATOR_GEMM_NT_TYPE : return "prodNT";
|
||||
case OPERATOR_GEMM_TT_TYPE : return "prodTT";
|
||||
case OPERATOR_VDIAG_TYPE : return "vdiag";
|
||||
case OPERATOR_MATRIX_DIAG_TYPE : return "mdiag";
|
||||
case OPERATOR_MATRIX_ROW_TYPE : return "row";
|
||||
|
@@ -1,4 +1,4 @@
|
||||
#include "isaac/backend/templates/vaxpy.h"
|
||||
#include "isaac/backend/templates/axpy.h"
|
||||
#include "isaac/backend/keywords.h"
|
||||
#include "isaac/driver/backend.h"
|
||||
#include "isaac/tools/make_map.hpp"
|
||||
@@ -8,23 +8,24 @@
|
||||
|
||||
namespace isaac
|
||||
{
|
||||
namespace templates
|
||||
{
|
||||
|
||||
|
||||
vaxpy_parameters::vaxpy_parameters(unsigned int _simd_width,
|
||||
axpy_parameters::axpy_parameters(unsigned int _simd_width,
|
||||
unsigned int _group_size, unsigned int _num_groups,
|
||||
fetching_policy_type _fetching_policy) :
|
||||
base::parameters_type(_simd_width, _group_size, 1, 1), num_groups(_num_groups), fetching_policy(_fetching_policy)
|
||||
{ }
|
||||
|
||||
|
||||
int vaxpy::is_invalid_impl(driver::Device const &, expressions_tuple const &) const
|
||||
int axpy::is_invalid_impl(driver::Device const &, expressions_tuple const &) const
|
||||
{
|
||||
if (p_.fetching_policy==FETCH_FROM_LOCAL)
|
||||
return TEMPLATE_INVALID_FETCHING_POLICY_TYPE;
|
||||
return TEMPLATE_VALID;
|
||||
}
|
||||
|
||||
std::string vaxpy::generate_impl(const char * suffix, expressions_tuple const & expressions, driver::Device const & device, std::vector<mapping_type> const & mappings) const
|
||||
std::string axpy::generate_impl(const char * suffix, expressions_tuple const & expressions, driver::Device const & device, std::vector<mapping_type> const & mappings) const
|
||||
{
|
||||
driver::backend_type backend = device.backend();
|
||||
std::string _size_t = size_type(device);
|
||||
@@ -90,25 +91,24 @@ std::string vaxpy::generate_impl(const char * suffix, expressions_tuple const &
|
||||
return stream.str();
|
||||
}
|
||||
|
||||
vaxpy::vaxpy(vaxpy_parameters const & parameters,
|
||||
axpy::axpy(axpy_parameters const & parameters,
|
||||
binding_policy_t binding_policy) :
|
||||
base_impl<vaxpy, vaxpy_parameters>(parameters, binding_policy)
|
||||
base_impl<axpy, axpy_parameters>(parameters, binding_policy)
|
||||
{}
|
||||
|
||||
vaxpy::vaxpy(unsigned int simd, unsigned int ls, unsigned int ng,
|
||||
axpy::axpy(unsigned int simd, unsigned int ls, unsigned int ng,
|
||||
fetching_policy_type fetch, binding_policy_t bind):
|
||||
base_impl<vaxpy, vaxpy_parameters>(vaxpy_parameters(simd,ls,ng,fetch), bind)
|
||||
base_impl<axpy, axpy_parameters>(axpy_parameters(simd,ls,ng,fetch), bind)
|
||||
{}
|
||||
|
||||
|
||||
std::vector<int_t> vaxpy::input_sizes(expressions_tuple const & expressions) const
|
||||
std::vector<int_t> axpy::input_sizes(expressions_tuple const & expressions) const
|
||||
{
|
||||
size4 shape = static_cast<array_expression const *>(expressions.data().front().get())->shape();
|
||||
int_t size = static_cast<array_expression const *>(expressions.data().front().get())->shape()[0];
|
||||
return tools::make_vector<int_t>() << std::max(shape[0], shape[1]);
|
||||
}
|
||||
|
||||
void vaxpy::enqueue(driver::CommandQueue & queue, driver::Program & program, const char * suffix, base & fallback, controller<expressions_tuple> const & controller)
|
||||
void axpy::enqueue(driver::CommandQueue & queue, driver::Program & program, const char * suffix, base & fallback, controller<expressions_tuple> const & controller)
|
||||
{
|
||||
expressions_tuple const & expressions = controller.x();
|
||||
//Size
|
||||
@@ -135,3 +135,4 @@ void vaxpy::enqueue(driver::CommandQueue & queue, driver::Program & program, con
|
||||
|
||||
|
||||
}
|
||||
}
|
@@ -2,11 +2,11 @@
|
||||
|
||||
#include "isaac/array.h"
|
||||
#include "isaac/backend/keywords.h"
|
||||
#include "isaac/backend/templates/vaxpy.h"
|
||||
#include "isaac/backend/templates/reduction.h"
|
||||
#include "isaac/backend/templates/maxpy.h"
|
||||
#include "isaac/backend/templates/mreduction.h"
|
||||
#include "isaac/backend/templates/mproduct.h"
|
||||
#include "isaac/backend/templates/axpy.h"
|
||||
#include "isaac/backend/templates/dot.h"
|
||||
#include "isaac/backend/templates/ger.h"
|
||||
#include "isaac/backend/templates/gemv.h"
|
||||
#include "isaac/backend/templates/gemm.h"
|
||||
#include "isaac/backend/templates/base.h"
|
||||
#include "isaac/backend/parse.h"
|
||||
#include "isaac/exception/operation_not_supported.h"
|
||||
@@ -17,6 +17,8 @@
|
||||
|
||||
namespace isaac
|
||||
{
|
||||
namespace templates
|
||||
{
|
||||
|
||||
base::parameters_type::parameters_type(unsigned int _simd_width, int_t _local_size_1, int_t _local_size_2, int_t _num_kernels) : simd_width(_simd_width), local_size_0(_local_size_1), local_size_1(_local_size_2), num_kernels(_num_kernels)
|
||||
{ }
|
||||
@@ -102,12 +104,12 @@ void base::map_functor::operator()(isaac::array_expression const & array_express
|
||||
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_matrix_row>(&array_expression, root_idx, &mapping_)));
|
||||
else if (root_node.op.type==OPERATOR_MATRIX_COLUMN_TYPE)
|
||||
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_matrix_column>(&array_expression, root_idx, &mapping_)));
|
||||
else if (detail::is_scalar_reduction(root_node))
|
||||
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_scalar_reduction>(&array_expression, root_idx, &mapping_)));
|
||||
else if (detail::is_vector_reduction(root_node))
|
||||
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_mreduction>(&array_expression, root_idx, &mapping_)));
|
||||
else if (root_node.op.type_family == OPERATOR_MATRIX_PRODUCT_TYPE_FAMILY)
|
||||
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_mproduct>(&array_expression, root_idx, &mapping_)));
|
||||
else if (detail::is_scalar_dot(root_node))
|
||||
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_scalar_dot>(&array_expression, root_idx, &mapping_)));
|
||||
else if (detail::is_vector_dot(root_node))
|
||||
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_gemv>(&array_expression, root_idx, &mapping_)));
|
||||
else if (root_node.op.type_family == OPERATOR_GEMM_TYPE_FAMILY)
|
||||
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_gemm>(&array_expression, root_idx, &mapping_)));
|
||||
else if (root_node.op.type == OPERATOR_REPEAT_TYPE)
|
||||
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_repeat>(&array_expression, root_idx, &mapping_)));
|
||||
else if (root_node.op.type == OPERATOR_OUTER_PROD_TYPE)
|
||||
@@ -198,7 +200,7 @@ void base::set_arguments_functor::operator()(isaac::array_expression const & arr
|
||||
set_arguments(root_node.rhs);
|
||||
}
|
||||
|
||||
void base::compute_reduction(kernel_generation_stream & os, std::string acc, std::string cur, op_element const & op)
|
||||
void base::compute_dot(kernel_generation_stream & os, std::string acc, std::string cur, op_element const & op)
|
||||
{
|
||||
if (detail::is_elementwise_function(op))
|
||||
os << acc << "=" << evaluate(op.type) << "(" << acc << "," << cur << ");" << std::endl;
|
||||
@@ -206,7 +208,7 @@ void base::compute_reduction(kernel_generation_stream & os, std::string acc, std
|
||||
os << acc << "= (" << acc << ")" << evaluate(op.type) << "(" << cur << ");" << std::endl;
|
||||
}
|
||||
|
||||
void base::compute_index_reduction(kernel_generation_stream & os, std::string acc, std::string cur, std::string const & acc_value, std::string const & cur_value, op_element const & op)
|
||||
void base::compute_index_dot(kernel_generation_stream & os, std::string acc, std::string cur, std::string const & acc_value, std::string const & cur_value, op_element const & op)
|
||||
{
|
||||
// os << acc << " = " << cur_value << ">" << acc_value << "?" << cur << ":" << acc << ";" << std::endl;
|
||||
os << acc << "= select(" << acc << "," << cur << "," << cur_value << ">" << acc_value << ");" << std::endl;
|
||||
@@ -259,7 +261,7 @@ std::string base::neutral_element(op_element const & op, driver::backend_type ba
|
||||
case OPERATOR_ELEMENT_MIN_TYPE : return INF;
|
||||
case OPERATOR_ELEMENT_ARGMIN_TYPE : return INF;
|
||||
|
||||
default: throw operation_not_supported_exception("Unsupported reduction operator : no neutral element known");
|
||||
default: throw operation_not_supported_exception("Unsupported dot operator : no neutral element known");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -399,14 +401,14 @@ std::pair<int_t, int_t> base::matrix_size(array_expression::node const & node)
|
||||
return std::make_pair(node.lhs.array->shape()[0],node.lhs.array->shape()[1]);
|
||||
}
|
||||
|
||||
bool base::is_reduction(array_expression::node const & node)
|
||||
bool base::is_dot(array_expression::node const & node)
|
||||
{
|
||||
return node.op.type_family==OPERATOR_VECTOR_REDUCTION_TYPE_FAMILY
|
||||
|| node.op.type_family==OPERATOR_COLUMNS_REDUCTION_TYPE_FAMILY
|
||||
|| node.op.type_family==OPERATOR_ROWS_REDUCTION_TYPE_FAMILY;
|
||||
return node.op.type_family==OPERATOR_VECTOR_DOT_TYPE_FAMILY
|
||||
|| node.op.type_family==OPERATOR_COLUMNS_DOT_TYPE_FAMILY
|
||||
|| node.op.type_family==OPERATOR_ROWS_DOT_TYPE_FAMILY;
|
||||
}
|
||||
|
||||
bool base::is_index_reduction(op_element const & op)
|
||||
bool base::is_index_dot(op_element const & op)
|
||||
{
|
||||
return op.type==OPERATOR_ELEMENT_ARGFMAX_TYPE
|
||||
|| op.type==OPERATOR_ELEMENT_ARGMAX_TYPE
|
||||
@@ -566,10 +568,11 @@ int base_impl<TType, PType>::is_invalid(expressions_tuple const & expressions, d
|
||||
return is_invalid_impl(device, expressions);
|
||||
}
|
||||
|
||||
template class base_impl<vaxpy, vaxpy_parameters>;
|
||||
template class base_impl<reduction, reduction_parameters>;
|
||||
template class base_impl<maxpy, maxpy_parameters>;
|
||||
template class base_impl<mreduction, mreduction_parameters>;
|
||||
template class base_impl<mproduct, mproduct_parameters>;
|
||||
template class base_impl<axpy, axpy_parameters>;
|
||||
template class base_impl<dot, dot_parameters>;
|
||||
template class base_impl<ger, ger_parameters>;
|
||||
template class base_impl<gemv, gemv_parameters>;
|
||||
template class base_impl<gemm, gemm_parameters>;
|
||||
|
||||
}
|
||||
}
|
||||
|
@@ -1,5 +1,5 @@
|
||||
#include <iostream>
|
||||
#include "isaac/backend/templates/reduction.h"
|
||||
#include "isaac/backend/templates/dot.h"
|
||||
#include <CL/cl.hpp>
|
||||
#include "isaac/tools/to_string.hpp"
|
||||
#include "isaac/tools/make_map.hpp"
|
||||
@@ -7,13 +7,14 @@
|
||||
#include "isaac/backend/keywords.h"
|
||||
namespace isaac
|
||||
{
|
||||
|
||||
reduction_parameters::reduction_parameters(unsigned int _simd_width,
|
||||
namespace templates
|
||||
{
|
||||
dot_parameters::dot_parameters(unsigned int _simd_width,
|
||||
unsigned int _group_size, unsigned int _num_groups,
|
||||
fetching_policy_type _fetching_policy) : base::parameters_type(_simd_width, _group_size, 1, 2), num_groups(_num_groups), fetching_policy(_fetching_policy)
|
||||
{ }
|
||||
|
||||
unsigned int reduction::lmem_usage(expressions_tuple const & expressions) const
|
||||
unsigned int dot::lmem_usage(expressions_tuple const & expressions) const
|
||||
{
|
||||
unsigned int res = 0;
|
||||
for(const auto & elem : expressions.data())
|
||||
@@ -24,14 +25,14 @@ unsigned int reduction::lmem_usage(expressions_tuple const & expressions) const
|
||||
return res;
|
||||
}
|
||||
|
||||
int reduction::is_invalid_impl(driver::Device const &, expressions_tuple const &) const
|
||||
int dot::is_invalid_impl(driver::Device const &, expressions_tuple const &) const
|
||||
{
|
||||
if (p_.fetching_policy==FETCH_FROM_LOCAL)
|
||||
return TEMPLATE_INVALID_FETCHING_POLICY_TYPE;
|
||||
return TEMPLATE_VALID;
|
||||
}
|
||||
|
||||
inline void reduction::reduce_1d_local_memory(kernel_generation_stream & stream, unsigned int size, std::vector<mapped_scalar_reduction*> exprs,
|
||||
inline void dot::reduce_1d_local_memory(kernel_generation_stream & stream, unsigned int size, std::vector<mapped_scalar_dot*> exprs,
|
||||
std::string const & buf_str, std::string const & buf_value_str, driver::backend_type backend) const
|
||||
{
|
||||
stream << "#pragma unroll" << std::endl;
|
||||
@@ -44,26 +45,26 @@ inline void reduction::reduce_1d_local_memory(kernel_generation_stream & stream,
|
||||
stream.inc_tab();
|
||||
|
||||
for (auto & expr : exprs)
|
||||
if (expr->is_index_reduction())
|
||||
compute_index_reduction(stream, expr->process(buf_str+"[lid]"), expr->process(buf_str+"[lid+stride]")
|
||||
if (expr->is_index_dot())
|
||||
compute_index_dot(stream, expr->process(buf_str+"[lid]"), expr->process(buf_str+"[lid+stride]")
|
||||
, expr->process(buf_value_str+"[lid]"), expr->process(buf_value_str+"[lid+stride]"),
|
||||
expr->root_op());
|
||||
else
|
||||
compute_reduction(stream, expr->process(buf_str+"[lid]"), expr->process(buf_str+"[lid+stride]"), expr->root_op());
|
||||
compute_dot(stream, expr->process(buf_str+"[lid]"), expr->process(buf_str+"[lid+stride]"), expr->root_op());
|
||||
stream.dec_tab();
|
||||
stream << "}" << std::endl;
|
||||
stream.dec_tab();
|
||||
stream << "}" << std::endl;
|
||||
}
|
||||
|
||||
std::string reduction::generate_impl(const char * suffix, expressions_tuple const & expressions, driver::Device const & device, std::vector<mapping_type> const & mappings) const
|
||||
std::string dot::generate_impl(const char * suffix, expressions_tuple const & expressions, driver::Device const & device, std::vector<mapping_type> const & mappings) const
|
||||
{
|
||||
kernel_generation_stream stream;
|
||||
|
||||
std::vector<mapped_scalar_reduction*> exprs;
|
||||
std::vector<mapped_scalar_dot*> exprs;
|
||||
for (const auto & mapping : mappings)
|
||||
for (mapping_type::const_iterator iit = mapping.begin(); iit != mapping.end(); ++iit)
|
||||
if (mapped_scalar_reduction * p = dynamic_cast<mapped_scalar_reduction*>(iit->second.get()))
|
||||
if (mapped_scalar_dot * p = dynamic_cast<mapped_scalar_dot*>(iit->second.get()))
|
||||
exprs.push_back(p);
|
||||
std::size_t N = exprs.size();
|
||||
driver::backend_type backend = device.backend();
|
||||
@@ -73,7 +74,7 @@ std::string reduction::generate_impl(const char * suffix, expressions_tuple cons
|
||||
for (unsigned int k = 0; k < N; ++k)
|
||||
{
|
||||
std::string numeric_type = numeric_type_to_string(lhs_most(exprs[k]->array_expression().tree(), exprs[k]->array_expression().root()).lhs.dtype);
|
||||
if (exprs[k]->is_index_reduction())
|
||||
if (exprs[k]->is_index_dot())
|
||||
{
|
||||
arguments += exprs[k]->process(Global(backend).get() + " unsigned int* #name_temp, ");
|
||||
arguments += exprs[k]->process(Global(backend).get() + " " + tools::to_string(numeric_type) + "* #name_temp_value, ");
|
||||
@@ -112,7 +113,7 @@ std::string reduction::generate_impl(const char * suffix, expressions_tuple cons
|
||||
|
||||
for (unsigned int k = 0; k < N; ++k)
|
||||
{
|
||||
if (exprs[k]->is_index_reduction())
|
||||
if (exprs[k]->is_index_dot())
|
||||
{
|
||||
stream << exprs[k]->process(Local(backend).get() + " #scalartype #name_buf_value[" + tools::to_string(p_.local_size_0) + "];") << std::endl;
|
||||
stream << exprs[k]->process("#scalartype #name_acc_value = " + neutral_element(exprs[k]->root_op(), backend, "#scalartype") + ";") << std::endl;
|
||||
@@ -156,11 +157,11 @@ std::string reduction::generate_impl(const char * suffix, expressions_tuple cons
|
||||
accessors["matrix_diag"] = str[a];
|
||||
accessors["array0"] = "#namereg";
|
||||
std::string value = elem->evaluate_recursive(LHS_NODE_TYPE, accessors);
|
||||
if (elem->is_index_reduction())
|
||||
compute_index_reduction(stream, elem->process("#name_acc"), "i*" + tools::to_string(simd_width) + "+"
|
||||
if (elem->is_index_dot())
|
||||
compute_index_dot(stream, elem->process("#name_acc"), "i*" + tools::to_string(simd_width) + "+"
|
||||
+ tools::to_string(a), elem->process("#name_acc_value"), value,elem->root_op());
|
||||
else
|
||||
compute_reduction(stream, elem->process("#name_acc"), value,elem->root_op());
|
||||
compute_dot(stream, elem->process("#name_acc"), value,elem->root_op());
|
||||
}
|
||||
}
|
||||
});
|
||||
@@ -168,7 +169,7 @@ std::string reduction::generate_impl(const char * suffix, expressions_tuple cons
|
||||
//Fills local memory
|
||||
for (unsigned int k = 0; k < N; ++k)
|
||||
{
|
||||
if (exprs[k]->is_index_reduction())
|
||||
if (exprs[k]->is_index_dot())
|
||||
stream << exprs[k]->process("#name_buf_value[lid] = #name_acc_value;") << std::endl;
|
||||
stream << exprs[k]->process("#name_buf[lid] = #name_acc;") << std::endl;
|
||||
}
|
||||
@@ -182,7 +183,7 @@ std::string reduction::generate_impl(const char * suffix, expressions_tuple cons
|
||||
stream.inc_tab();
|
||||
for (unsigned int k = 0; k < N; ++k)
|
||||
{
|
||||
if (exprs[k]->is_index_reduction())
|
||||
if (exprs[k]->is_index_dot())
|
||||
stream << exprs[k]->process("#name_temp_value[gpid] = #name_buf_value[0];") << std::endl;
|
||||
stream << exprs[k]->process("#name_temp[gpid] = #name_buf[0];") << std::endl;
|
||||
}
|
||||
@@ -205,9 +206,9 @@ std::string reduction::generate_impl(const char * suffix, expressions_tuple cons
|
||||
stream << "unsigned int lid = " <<LocalIdx0(backend) << ";" << std::endl;
|
||||
stream << "unsigned int lsize = " <<LocalSize0(backend) << ";" << std::endl;
|
||||
|
||||
for (mapped_scalar_reduction* e: exprs)
|
||||
for (mapped_scalar_dot* e: exprs)
|
||||
{
|
||||
if (e->is_index_reduction())
|
||||
if (e->is_index_dot())
|
||||
{
|
||||
stream << e->process(Local(backend).get() + " unsigned int #name_buf[" + tools::to_string(p_.local_size_0) + "];");
|
||||
stream << e->process("unsigned int #name_acc = 0;") << std::endl;
|
||||
@@ -224,18 +225,18 @@ std::string reduction::generate_impl(const char * suffix, expressions_tuple cons
|
||||
stream << "for(unsigned int i = lid; i < " << p_.num_groups << "; i += lsize)" << std::endl;
|
||||
stream << "{" << std::endl;
|
||||
stream.inc_tab();
|
||||
for (mapped_scalar_reduction* e: exprs)
|
||||
if (e->is_index_reduction())
|
||||
compute_index_reduction(stream, e->process("#name_acc"), e->process("#name_temp[i]"), e->process("#name_acc_value"),e->process("#name_temp_value[i]"),e->root_op());
|
||||
for (mapped_scalar_dot* e: exprs)
|
||||
if (e->is_index_dot())
|
||||
compute_index_dot(stream, e->process("#name_acc"), e->process("#name_temp[i]"), e->process("#name_acc_value"),e->process("#name_temp_value[i]"),e->root_op());
|
||||
else
|
||||
compute_reduction(stream, e->process("#name_acc"), e->process("#name_temp[i]"), e->root_op());
|
||||
compute_dot(stream, e->process("#name_acc"), e->process("#name_temp[i]"), e->root_op());
|
||||
|
||||
stream.dec_tab();
|
||||
stream << "}" << std::endl;
|
||||
|
||||
for (unsigned int k = 0; k < N; ++k)
|
||||
{
|
||||
if (exprs[k]->is_index_reduction())
|
||||
if (exprs[k]->is_index_dot())
|
||||
stream << exprs[k]->process("#name_buf_value[lid] = #name_acc_value;") << std::endl;
|
||||
stream << exprs[k]->process("#name_buf[lid] = #name_acc;") << std::endl;
|
||||
}
|
||||
@@ -248,7 +249,7 @@ std::string reduction::generate_impl(const char * suffix, expressions_tuple cons
|
||||
stream << "{" << std::endl;
|
||||
stream.inc_tab();
|
||||
std::map<std::string, std::string> accessors;
|
||||
accessors["scalar_reduction"] = "#name_buf[0]";
|
||||
accessors["scalar_dot"] = "#name_buf[0]";
|
||||
accessors["array0"] = "#pointer[#start]";
|
||||
evaluate(stream, PARENT_NODE_TYPE, accessors, expressions, mappings);
|
||||
stream.dec_tab();
|
||||
@@ -260,23 +261,23 @@ std::string reduction::generate_impl(const char * suffix, expressions_tuple cons
|
||||
return stream.str();
|
||||
}
|
||||
|
||||
reduction::reduction(reduction::parameters_type const & parameters,
|
||||
binding_policy_t binding) : base_impl<reduction, reduction_parameters>(parameters, binding)
|
||||
dot::dot(dot::parameters_type const & parameters,
|
||||
binding_policy_t binding) : base_impl<dot, dot_parameters>(parameters, binding)
|
||||
{ }
|
||||
|
||||
reduction::reduction(unsigned int simd, unsigned int ls, unsigned int ng,
|
||||
dot::dot(unsigned int simd, unsigned int ls, unsigned int ng,
|
||||
fetching_policy_type fetch, binding_policy_t bind):
|
||||
base_impl<reduction, reduction_parameters>(reduction_parameters(simd,ls,ng,fetch), bind)
|
||||
base_impl<dot, dot_parameters>(dot_parameters(simd,ls,ng,fetch), bind)
|
||||
{}
|
||||
|
||||
std::vector<int_t> reduction::input_sizes(expressions_tuple const & expressions) const
|
||||
std::vector<int_t> dot::input_sizes(expressions_tuple const & expressions) const
|
||||
{
|
||||
std::vector<size_t> reductions_idx = filter_nodes(&is_reduction, *(expressions.data().front()), false);
|
||||
int_t N = vector_size(lhs_most(expressions.data().front()->tree(), reductions_idx[0]));
|
||||
std::vector<size_t> dots_idx = filter_nodes(&is_dot, *(expressions.data().front()), false);
|
||||
int_t N = vector_size(lhs_most(expressions.data().front()->tree(), dots_idx[0]));
|
||||
return tools::make_vector<int_t>() << N;
|
||||
}
|
||||
|
||||
void reduction::enqueue(driver::CommandQueue & queue, driver::Program & program, const char * suffix, base & fallback, controller<expressions_tuple> const & controller)
|
||||
void dot::enqueue(driver::CommandQueue & queue, driver::Program & program, const char * suffix, base & fallback, controller<expressions_tuple> const & controller)
|
||||
{
|
||||
expressions_tuple const & expressions = controller.x();
|
||||
|
||||
@@ -290,12 +291,12 @@ void reduction::enqueue(driver::CommandQueue & queue, driver::Program & program,
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<array_expression::node const *> reductions;
|
||||
std::vector<array_expression::node const *> dots;
|
||||
for (const auto & elem : expressions.data())
|
||||
{
|
||||
std::vector<size_t> reductions_idx = filter_nodes(&is_reduction, *elem, false);
|
||||
for (auto & reductions_idx_itt : reductions_idx)
|
||||
reductions.push_back(&(elem)->tree()[reductions_idx_itt]);
|
||||
std::vector<size_t> dots_idx = filter_nodes(&is_dot, *elem, false);
|
||||
for (auto & dots_idx_itt : dots_idx)
|
||||
dots.push_back(&(elem)->tree()[dots_idx_itt]);
|
||||
}
|
||||
|
||||
//Kernel
|
||||
@@ -321,9 +322,9 @@ void reduction::enqueue(driver::CommandQueue & queue, driver::Program & program,
|
||||
//Temporary buffers
|
||||
unsigned int i = 0;
|
||||
unsigned int j = 0;
|
||||
for (std::vector<array_expression::node const *>::const_iterator it = reductions.begin(); it != reductions.end(); ++it)
|
||||
for (std::vector<array_expression::node const *>::const_iterator it = dots.begin(); it != dots.end(); ++it)
|
||||
{
|
||||
if (is_index_reduction((*it)->op))
|
||||
if (is_index_dot((*it)->op))
|
||||
{
|
||||
if (tmpidx_.size() <= j)
|
||||
tmpidx_.push_back(driver::Buffer(context, p_.num_groups*4));
|
||||
@@ -343,3 +344,4 @@ void reduction::enqueue(driver::CommandQueue & queue, driver::Program & program,
|
||||
}
|
||||
|
||||
}
|
||||
}
|
@@ -1,5 +1,5 @@
|
||||
#include "isaac/array.h"
|
||||
#include "isaac/backend/templates/mproduct.h"
|
||||
#include "isaac/backend/templates/gemm.h"
|
||||
#include "isaac/backend/keywords.h"
|
||||
#include "isaac/model/model.h"
|
||||
#include "isaac/symbolic/preset.h"
|
||||
@@ -10,8 +10,10 @@
|
||||
|
||||
namespace isaac
|
||||
{
|
||||
namespace templates
|
||||
{
|
||||
|
||||
mproduct_parameters::mproduct_parameters(unsigned int simd_width
|
||||
gemm_parameters::gemm_parameters(unsigned int simd_width
|
||||
, int_t local_size_0, int_t KL, int_t local_size_1, int_t D
|
||||
, int_t ms, int_t ks, int_t ns
|
||||
, fetching_policy_type A_fetching_policy, fetching_policy_type B_fetching_policy
|
||||
@@ -21,7 +23,7 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
|
||||
mL(ms*local_size_0), nL(ns*local_size_1){}
|
||||
|
||||
|
||||
unsigned int mproduct::lmem_usage(expressions_tuple const & expressions) const
|
||||
unsigned int gemm::lmem_usage(expressions_tuple const & expressions) const
|
||||
{
|
||||
isaac::array_expression const & array_expression = (*expressions.data().front());
|
||||
numeric_type numeric_t = lhs_most(array_expression.tree(), array_expression.root()).lhs.dtype;
|
||||
@@ -32,7 +34,7 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
|
||||
return N*size_of(numeric_t);
|
||||
}
|
||||
|
||||
unsigned int mproduct::registers_usage(expressions_tuple const & expressions) const
|
||||
unsigned int gemm::registers_usage(expressions_tuple const & expressions) const
|
||||
{
|
||||
isaac::array_expression const & array_expression = (*expressions.data().front());
|
||||
numeric_type numeric_t = lhs_most(array_expression.tree(), array_expression.root()).lhs.dtype;
|
||||
@@ -41,7 +43,7 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
|
||||
return N*size_of(numeric_t);
|
||||
}
|
||||
|
||||
int mproduct::is_invalid_impl(driver::Device const &, expressions_tuple const & expressions) const
|
||||
int gemm::is_invalid_impl(driver::Device const &, expressions_tuple const & expressions) const
|
||||
{
|
||||
std::vector<int_t> MNK = input_sizes(expressions);
|
||||
int_t M = MNK[0]; int_t N = MNK[1];
|
||||
@@ -95,7 +97,7 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
|
||||
return TEMPLATE_VALID;
|
||||
}
|
||||
|
||||
std::string mproduct::generate_impl(const char * suffix, expressions_tuple const & expressions, driver::Device const & device, std::vector<mapping_type> const &) const
|
||||
std::string gemm::generate_impl(const char * suffix, expressions_tuple const & expressions, driver::Device const & device, std::vector<mapping_type> const &) const
|
||||
{
|
||||
using std::string;
|
||||
using tools::to_string;
|
||||
@@ -437,7 +439,7 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
|
||||
#undef VST0RE
|
||||
}
|
||||
|
||||
void mproduct::enqueue_block(driver::CommandQueue & queue, int_t M, int_t N, int_t K,
|
||||
void gemm::enqueue_block(driver::CommandQueue & /*queue*/, int_t M, int_t N, int_t K,
|
||||
array const & A, array const & B, array const & C,
|
||||
value_scalar const & alpha, value_scalar const & beta,
|
||||
driver::Program & program, const char * suffix, execution_options_type const & options)
|
||||
@@ -516,7 +518,7 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
|
||||
}
|
||||
}
|
||||
|
||||
array mproduct::create_slice(array & M, int_t s0_0, int_t s0_1, int_t s1_0, int_t s1_1, bool swap)
|
||||
array gemm::create_slice(array & M, int_t s0_0, int_t s0_1, int_t s1_0, int_t s1_1, bool swap)
|
||||
{
|
||||
slice s0(s0_0, s0_1);
|
||||
slice s1(s1_0, s1_1);
|
||||
@@ -525,7 +527,7 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
|
||||
return array(M, s0, s1);
|
||||
}
|
||||
|
||||
std::vector<int_t> mproduct::infos(expressions_tuple const & expressions, symbolic::preset::gemm::args& arguments) const
|
||||
std::vector<int_t> gemm::infos(expressions_tuple const & expressions, symbolic::preset::gemm::args& arguments) const
|
||||
{
|
||||
isaac::array_expression & array_expression = (*expressions.data().front());
|
||||
array_expression::container_type & array = array_expression.tree();
|
||||
@@ -537,26 +539,26 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
|
||||
return {M, N, K};
|
||||
}
|
||||
|
||||
mproduct::mproduct(mproduct_parameters const & parameters, bool check_bounds, char A_trans, char B_trans) : base_impl<mproduct, mproduct_parameters>(parameters, BIND_ALL_UNIQUE), A_trans_(A_trans), B_trans_(B_trans), check_bounds_(check_bounds)
|
||||
gemm::gemm(gemm_parameters const & parameters, bool check_bounds, char A_trans, char B_trans) : base_impl<gemm, gemm_parameters>(parameters, BIND_ALL_UNIQUE), A_trans_(A_trans), B_trans_(B_trans), check_bounds_(check_bounds)
|
||||
{
|
||||
if(A_trans_=='N' && B_trans_=='N') type_ = MATRIX_PRODUCT_NN_TYPE;
|
||||
else if(A_trans_=='T' && B_trans_=='N') type_ = MATRIX_PRODUCT_TN_TYPE;
|
||||
else if(A_trans_=='N' && B_trans_=='T') type_ = MATRIX_PRODUCT_NT_TYPE;
|
||||
else if(A_trans_=='T' && B_trans_=='T') type_ = MATRIX_PRODUCT_TT_TYPE;
|
||||
if(A_trans_=='N' && B_trans_=='N') type_ = GEMM_NN_TYPE;
|
||||
else if(A_trans_=='T' && B_trans_=='N') type_ = GEMM_TN_TYPE;
|
||||
else if(A_trans_=='N' && B_trans_=='T') type_ = GEMM_NT_TYPE;
|
||||
else if(A_trans_=='T' && B_trans_=='T') type_ = GEMM_TT_TYPE;
|
||||
else throw;
|
||||
}
|
||||
|
||||
std::vector<int_t> mproduct::input_sizes(expressions_tuple const & expressions) const
|
||||
std::vector<int_t> gemm::input_sizes(expressions_tuple const & expressions) const
|
||||
{
|
||||
symbolic::preset::gemm::args dummy;
|
||||
return infos(expressions, dummy);
|
||||
}
|
||||
|
||||
void mproduct::enqueue(driver::CommandQueue & queue, driver::Program & program, const char * suffix, base & fallback_base, controller<expressions_tuple> const & ctr)
|
||||
void gemm::enqueue(driver::CommandQueue & queue, driver::Program & program, const char * suffix, base & fallback_base, controller<expressions_tuple> const & ctr)
|
||||
{
|
||||
using namespace tools;
|
||||
|
||||
mproduct & fallback = (mproduct&)fallback_base;
|
||||
gemm & fallback = (gemm&)fallback_base;
|
||||
expressions_tuple const & expressions = ctr.x();
|
||||
|
||||
|
||||
@@ -579,8 +581,6 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
|
||||
int_t ldstrideA = pA->stride()[0];
|
||||
int_t ldstrideB = pB->stride()[0];
|
||||
int_t ldstrideC = pC->stride()[0];
|
||||
int_t ldstartA = pA->start()[0];
|
||||
int_t ldstartB = pB->start()[0];
|
||||
|
||||
numeric_type dtype = args.C->dtype;
|
||||
|
||||
@@ -613,40 +613,41 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
|
||||
}
|
||||
|
||||
//
|
||||
mproduct_nn::mproduct_nn(unsigned int simd
|
||||
gemm_nn::gemm_nn(unsigned int simd
|
||||
, int_t ls0, int_t KL, int_t ls1, int_t D
|
||||
, int_t ms, int_t ks, int_t ns
|
||||
, fetching_policy_type Afetch , fetching_policy_type Bfetch
|
||||
, int_t lfetch0, int_t lfetch1, bool check_bound) :
|
||||
mproduct(mproduct_parameters(simd, ls0, KL, ls1, D, ms, ks, ns, Afetch, Bfetch, lfetch0, lfetch1), check_bound, 'N', 'N')
|
||||
gemm(gemm_parameters(simd, ls0, KL, ls1, D, ms, ks, ns, Afetch, Bfetch, lfetch0, lfetch1), check_bound, 'N', 'N')
|
||||
{ }
|
||||
|
||||
//
|
||||
mproduct_tn::mproduct_tn(unsigned int simd
|
||||
gemm_tn::gemm_tn(unsigned int simd
|
||||
, int_t ls0, int_t KL, int_t ls1, int_t D
|
||||
, int_t ms, int_t ks, int_t ns
|
||||
, fetching_policy_type Afetch , fetching_policy_type Bfetch
|
||||
, int_t lfetch0, int_t lfetch1, bool check_bound) :
|
||||
mproduct(mproduct_parameters(simd, ls0, KL, ls1, D, ms, ks, ns, Afetch, Bfetch, lfetch0, lfetch1), check_bound, 'T', 'N')
|
||||
gemm(gemm_parameters(simd, ls0, KL, ls1, D, ms, ks, ns, Afetch, Bfetch, lfetch0, lfetch1), check_bound, 'T', 'N')
|
||||
{ }
|
||||
|
||||
//
|
||||
mproduct_nt::mproduct_nt(unsigned int simd
|
||||
gemm_nt::gemm_nt(unsigned int simd
|
||||
, int_t ls0, int_t KL, int_t ls1, int_t D
|
||||
, int_t ms, int_t ks, int_t ns
|
||||
, fetching_policy_type Afetch , fetching_policy_type Bfetch
|
||||
, int_t lfetch0, int_t lfetch1, bool check_bound) :
|
||||
mproduct(mproduct_parameters(simd, ls0, KL, ls1, D, ms, ks, ns, Afetch, Bfetch, lfetch0, lfetch1), check_bound, 'N', 'T')
|
||||
gemm(gemm_parameters(simd, ls0, KL, ls1, D, ms, ks, ns, Afetch, Bfetch, lfetch0, lfetch1), check_bound, 'N', 'T')
|
||||
{ }
|
||||
|
||||
//
|
||||
mproduct_tt::mproduct_tt(unsigned int simd
|
||||
gemm_tt::gemm_tt(unsigned int simd
|
||||
, int_t ls0, int_t KL, int_t ls1, int_t D
|
||||
, int_t ms, int_t ks, int_t ns
|
||||
, fetching_policy_type Afetch , fetching_policy_type Bfetch
|
||||
, int_t lfetch0, int_t lfetch1, bool check_bound) :
|
||||
mproduct(mproduct_parameters(simd, ls0, KL, ls1, D, ms, ks, ns, Afetch, Bfetch, lfetch0, lfetch1), check_bound, 'T', 'T')
|
||||
gemm(gemm_parameters(simd, ls0, KL, ls1, D, ms, ks, ns, Afetch, Bfetch, lfetch0, lfetch1), check_bound, 'T', 'T')
|
||||
{ }
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -1,48 +1,50 @@
|
||||
#include <iostream>
|
||||
#include "isaac/backend/stream.h"
|
||||
#include "isaac/backend/keywords.h"
|
||||
#include "isaac/backend/templates/mreduction.h"
|
||||
#include "isaac/backend/templates/gemv.h"
|
||||
#include "isaac/tools/to_string.hpp"
|
||||
#include "isaac/tools/make_map.hpp"
|
||||
#include "isaac/tools/make_vector.hpp"
|
||||
|
||||
namespace isaac
|
||||
{
|
||||
namespace templates
|
||||
{
|
||||
|
||||
mreduction_parameters::mreduction_parameters(unsigned int _simd_width,
|
||||
gemv_parameters::gemv_parameters(unsigned int _simd_width,
|
||||
unsigned int _local_size_0, unsigned int _local_size_1,
|
||||
unsigned int _num_groups_0, unsigned int _num_groups_1, fetching_policy_type _fetch_policy): base::parameters_type(_simd_width, _local_size_0, _local_size_1, 1),
|
||||
num_groups_0(_num_groups_0), num_groups_1(_num_groups_1), fetch_policy(_fetch_policy) { }
|
||||
|
||||
|
||||
int mreduction::is_invalid_impl(driver::Device const &, expressions_tuple const &) const
|
||||
int gemv::is_invalid_impl(driver::Device const &, expressions_tuple const &) const
|
||||
{
|
||||
if(reduction_type_==REDUCE_ROWS && p_.simd_width>1)
|
||||
if(dot_type_==REDUCE_ROWS && p_.simd_width>1)
|
||||
return TEMPLATE_INVALID_SIMD_WIDTH;
|
||||
if (p_.fetch_policy==FETCH_FROM_LOCAL)
|
||||
return TEMPLATE_INVALID_FETCHING_POLICY_TYPE;
|
||||
return TEMPLATE_VALID;
|
||||
}
|
||||
|
||||
unsigned int mreduction::lmem_usage() const
|
||||
unsigned int gemv::lmem_usage() const
|
||||
{
|
||||
return (p_.local_size_0+1)*p_.local_size_1;
|
||||
}
|
||||
|
||||
std::string mreduction::generate_impl(const char * suffix, expressions_tuple const & expressions, driver::Device const & device, std::vector<mapping_type> const & mappings) const
|
||||
std::string gemv::generate_impl(const char * suffix, expressions_tuple const & expressions, driver::Device const & device, std::vector<mapping_type> const & mappings) const
|
||||
{
|
||||
using tools::to_string;
|
||||
|
||||
|
||||
std::vector<mapped_mreduction*> reductions;
|
||||
std::vector<mapped_gemv*> dots;
|
||||
expressions_tuple::data_type::const_iterator sit;
|
||||
std::vector<mapping_type>::const_iterator mit;
|
||||
for (mit = mappings.begin(), sit = expressions.data().begin(); mit != mappings.end(); ++mit, ++sit)
|
||||
{
|
||||
array_expression const & first_expression = *expressions.data().front();
|
||||
std::vector<size_t> idx = filter_nodes(&is_reduction, first_expression, false);
|
||||
std::vector<size_t> idx = filter_nodes(&is_dot, first_expression, false);
|
||||
for (auto & elem : idx)
|
||||
reductions.push_back((mapped_mreduction*)(mit->at(mapping_key(elem, PARENT_NODE_TYPE)).get()));
|
||||
dots.push_back((mapped_gemv*)(mit->at(mapping_key(elem, PARENT_NODE_TYPE)).get()));
|
||||
}
|
||||
|
||||
kernel_generation_stream stream;
|
||||
@@ -54,10 +56,10 @@ std::string mreduction::generate_impl(const char * suffix, expressions_tuple con
|
||||
strcat(name[1], suffix);
|
||||
|
||||
std::string arguments = _size_t + " M, " + _size_t + " N, " ;
|
||||
for (const auto & e : reductions)
|
||||
for (const auto & e : dots)
|
||||
{
|
||||
std::string numeric_type = numeric_type_to_string(lhs_most(e->array_expression().tree(), e->array_expression().root()).lhs.dtype);
|
||||
if (e->is_index_reduction())
|
||||
if (e->is_index_dot())
|
||||
{
|
||||
arguments += e->process(Global(backend).get() + " unsigned int* #name_temp, ");
|
||||
arguments += e->process(Global(backend).get() + " " + to_string(numeric_type) + "* #name_temp_value,");
|
||||
@@ -87,7 +89,7 @@ std::string mreduction::generate_impl(const char * suffix, expressions_tuple con
|
||||
unsigned int local_size_0_ld = p_.local_size_0;
|
||||
std::string local_size_0_ld_str = to_string(local_size_0_ld);
|
||||
|
||||
for (const auto & e : reductions)
|
||||
for (const auto & e : dots)
|
||||
stream << e->process(Local(backend).get() + " #scalartype #name_buf[" + to_string(p_.local_size_1*local_size_0_ld) + "];") << std::endl;
|
||||
|
||||
stream << "" << _size_t << " lid0 = " << LocalIdx0(backend) << ";" << std::endl;
|
||||
@@ -104,7 +106,7 @@ std::string mreduction::generate_impl(const char * suffix, expressions_tuple con
|
||||
stream << "for(" << _size_t << " r = gid1; r < upper_bound_1; r += gsize1){" << std::endl;
|
||||
stream.inc_tab();
|
||||
|
||||
for (const auto & e : reductions)
|
||||
for (const auto & e : dots)
|
||||
stream << e->process("#scalartype #name_acc = " + neutral_element((e)->root_op(), backend, "#scalartype") + ";") << std::endl;
|
||||
|
||||
stream << "if (r < M)" << std::endl;
|
||||
@@ -116,10 +118,10 @@ std::string mreduction::generate_impl(const char * suffix, expressions_tuple con
|
||||
std::string data_type = append_width("#scalartype",simd_width);
|
||||
|
||||
|
||||
for (const auto & e : reductions)
|
||||
for (const auto & e : dots)
|
||||
{
|
||||
std::map<std::string, std::string> accessors;
|
||||
if(reduction_type_==REDUCE_COLUMNS)
|
||||
if(dot_type_==REDUCE_COLUMNS)
|
||||
{
|
||||
accessors["array2"] = data_type + " #namereg = " + vload(simd_width, "#scalartype", "c*#stride1", "#pointer + r*#ld", backend)+";";
|
||||
accessors["repeat"] = data_type + " #namereg = " + vload(simd_width, "#scalartype", "(c%#tuplearg0)*#stride", "#pointer + (r%#tuplearg1)*#stride ", backend)+";";
|
||||
@@ -141,20 +143,20 @@ std::string mreduction::generate_impl(const char * suffix, expressions_tuple con
|
||||
str[a] = access_vector_type("#namereg",a);
|
||||
|
||||
|
||||
for (auto & elem : reductions)
|
||||
for (auto & elem : dots)
|
||||
for (unsigned int a = 0; a < simd_width; ++a)
|
||||
{
|
||||
std::string value = elem->evaluate_recursive(LHS_NODE_TYPE, {{"array2", str[a]}, {"repeat", str[a]}, {"array0", "#namereg"}});
|
||||
if (elem->is_index_reduction())
|
||||
compute_index_reduction(stream, elem->process("#name_acc"), "c*"+to_string(simd_width) + to_string(a), elem->process("#name_acc_value"), value, elem->root_op());
|
||||
if (elem->is_index_dot())
|
||||
compute_index_dot(stream, elem->process("#name_acc"), "c*"+to_string(simd_width) + to_string(a), elem->process("#name_acc_value"), value, elem->root_op());
|
||||
else
|
||||
compute_reduction(stream, elem->process("#name_acc"), value,elem->root_op());
|
||||
compute_dot(stream, elem->process("#name_acc"), value,elem->root_op());
|
||||
}
|
||||
});
|
||||
stream.dec_tab();
|
||||
stream << "}" << std::endl;
|
||||
|
||||
for (auto & expr : reductions)
|
||||
for (auto & expr : dots)
|
||||
stream << expr->process("#name_buf[lid1*" + local_size_0_ld_str + "+ lid0] = #name_acc;") << std::endl;
|
||||
|
||||
stream << "#pragma unroll" << std::endl;
|
||||
@@ -167,13 +169,13 @@ std::string mreduction::generate_impl(const char * suffix, expressions_tuple con
|
||||
stream << "{" << std::endl;
|
||||
stream.inc_tab();
|
||||
|
||||
for (auto & e : reductions)
|
||||
if (e->is_index_reduction())
|
||||
compute_index_reduction(stream, e->process("#name_buf[lid1*" + local_size_0_ld_str + " + lid0]"), e->process("#name_buf[lid1*" + local_size_0_ld_str + " + lid0 + stride]")
|
||||
for (auto & e : dots)
|
||||
if (e->is_index_dot())
|
||||
compute_index_dot(stream, e->process("#name_buf[lid1*" + local_size_0_ld_str + " + lid0]"), e->process("#name_buf[lid1*" + local_size_0_ld_str + " + lid0 + stride]")
|
||||
, e->process("#name_buf_value[lid1*" + local_size_0_ld_str + " + lid0]"), e->process("#name_buf_value[lid1*" + local_size_0_ld_str + " + lid0 + stride]")
|
||||
, e->root_op());
|
||||
else
|
||||
compute_reduction(stream,e->process("#name_buf[lid1*" + local_size_0_ld_str + " + lid0]"), e->process("#name_buf[lid1*" + local_size_0_ld_str + " + lid0 + stride]"), e->root_op());
|
||||
compute_dot(stream,e->process("#name_buf[lid1*" + local_size_0_ld_str + " + lid0]"), e->process("#name_buf[lid1*" + local_size_0_ld_str + " + lid0 + stride]"), e->root_op());
|
||||
|
||||
stream.dec_tab();
|
||||
stream << "}" << std::endl;
|
||||
@@ -188,15 +190,15 @@ std::string mreduction::generate_impl(const char * suffix, expressions_tuple con
|
||||
if(p_.num_groups_0==1)
|
||||
{
|
||||
std::map<std::string, std::string> accessors;
|
||||
accessors["mreduction"] = "#name_buf[lid1*" + local_size_0_ld_str + "]";
|
||||
accessors["gemv"] = "#name_buf[lid1*" + local_size_0_ld_str + "]";
|
||||
accessors["array1"] = "#pointer[r*#stride]";
|
||||
evaluate(stream, PARENT_NODE_TYPE, accessors, expressions, mappings);
|
||||
}
|
||||
else
|
||||
{
|
||||
for (mapped_reduction const * e : reductions)
|
||||
for (mapped_dot const * e : dots)
|
||||
{
|
||||
if (e->is_index_reduction())
|
||||
if (e->is_index_dot())
|
||||
stream << e->process("#name_temp_value[r + M*gpid0] = #name_buf_value[lid1*" + local_size_0_ld_str + "];") << std::endl;
|
||||
stream << e->process("#name_temp[r + M*gpid0] = #name_buf[lid1*" + local_size_0_ld_str + "];") << std::endl;
|
||||
}
|
||||
@@ -230,7 +232,7 @@ std::string mreduction::generate_impl(const char * suffix, expressions_tuple con
|
||||
{"array2", "#pointer += #start1 + #start2*#ld; "
|
||||
"#ld *= #nldstride; "}}, expressions, mappings);
|
||||
|
||||
for (const auto & e : reductions)
|
||||
for (const auto & e : dots)
|
||||
stream << e->process(Local(backend).get() + " #scalartype #name_buf[" + to_string(p_.local_size_1*local_size_0_ld) + "];") << std::endl;
|
||||
|
||||
stream << _size_t << " lid0 = " << LocalIdx0(backend) << ";" << std::endl;
|
||||
@@ -246,7 +248,7 @@ std::string mreduction::generate_impl(const char * suffix, expressions_tuple con
|
||||
stream << "for(" << _size_t << " r = gid1; r < upper_bound_1; r += gsize1){" << std::endl;
|
||||
stream.inc_tab();
|
||||
|
||||
for (const auto & e : reductions)
|
||||
for (const auto & e : dots)
|
||||
stream << e->process("#scalartype #name_acc = " + neutral_element((e)->root_op(), backend, "#scalartype") + ";") << std::endl;
|
||||
|
||||
stream << "if (r < M)" << std::endl;
|
||||
@@ -256,8 +258,8 @@ std::string mreduction::generate_impl(const char * suffix, expressions_tuple con
|
||||
stream << "for(" << _size_t << " c = lid0; c < " << p_.num_groups_0 << "; c += lsize0){" << std::endl;
|
||||
stream.inc_tab();
|
||||
|
||||
for (mapped_reduction* e: reductions)
|
||||
compute_reduction(stream, e->process("#name_acc"), e->process("#name_temp[r + M*c]"), e->root_op());
|
||||
for (mapped_dot* e: dots)
|
||||
compute_dot(stream, e->process("#name_acc"), e->process("#name_temp[r + M*c]"), e->root_op());
|
||||
|
||||
stream.dec_tab();
|
||||
stream << "}" << std::endl;
|
||||
@@ -266,7 +268,7 @@ std::string mreduction::generate_impl(const char * suffix, expressions_tuple con
|
||||
stream.dec_tab();
|
||||
stream << "}" << std::endl;
|
||||
|
||||
for (auto & expr : reductions)
|
||||
for (auto & expr : dots)
|
||||
stream << expr->process("#name_buf[lid1*" + local_size_0_ld_str + "+ lid0] = #name_acc;") << std::endl;
|
||||
|
||||
stream << "#pragma unroll" << std::endl;
|
||||
@@ -279,13 +281,13 @@ std::string mreduction::generate_impl(const char * suffix, expressions_tuple con
|
||||
stream << "{" << std::endl;
|
||||
stream.inc_tab();
|
||||
|
||||
for (auto & e : reductions)
|
||||
if (e->is_index_reduction())
|
||||
compute_index_reduction(stream, e->process("#name_buf[lid1*" + local_size_0_ld_str + " + lid0]"), e->process("#name_buf[lid1*" + local_size_0_ld_str + " + lid0 + stride]")
|
||||
for (auto & e : dots)
|
||||
if (e->is_index_dot())
|
||||
compute_index_dot(stream, e->process("#name_buf[lid1*" + local_size_0_ld_str + " + lid0]"), e->process("#name_buf[lid1*" + local_size_0_ld_str + " + lid0 + stride]")
|
||||
, e->process("#name_buf_value[lid1*" + local_size_0_ld_str + " + lid0]"), e->process("#name_buf_value[lid1*" + local_size_0_ld_str + " + lid0 + stride]")
|
||||
, e->root_op());
|
||||
else
|
||||
compute_reduction(stream,e->process("#name_buf[lid1*" + local_size_0_ld_str + " + lid0]"), e->process("#name_buf[lid1*" + local_size_0_ld_str + " + lid0 + stride]"), e->root_op());
|
||||
compute_dot(stream,e->process("#name_buf[lid1*" + local_size_0_ld_str + " + lid0]"), e->process("#name_buf[lid1*" + local_size_0_ld_str + " + lid0 + stride]"), e->root_op());
|
||||
|
||||
stream.dec_tab();
|
||||
stream << "}" << std::endl;
|
||||
@@ -299,7 +301,7 @@ std::string mreduction::generate_impl(const char * suffix, expressions_tuple con
|
||||
stream.inc_tab();
|
||||
|
||||
std::map<std::string, std::string> accessors;
|
||||
accessors["mreduction"] = "#name_buf[lid1*" + local_size_0_ld_str + "]";
|
||||
accessors["gemv"] = "#name_buf[lid1*" + local_size_0_ld_str + "]";
|
||||
accessors["array1"] = "#pointer[r*#stride]";
|
||||
evaluate(stream, PARENT_NODE_TYPE, accessors, expressions, mappings);
|
||||
|
||||
@@ -317,38 +319,38 @@ std::string mreduction::generate_impl(const char * suffix, expressions_tuple con
|
||||
return stream.str();
|
||||
}
|
||||
|
||||
mreduction::mreduction(mreduction::parameters_type const & parameters,
|
||||
mreduction::reduction_type rtype,
|
||||
gemv::gemv(gemv::parameters_type const & parameters,
|
||||
gemv::dot_type rtype,
|
||||
binding_policy_t binding_policy) :
|
||||
base_impl<mreduction, mreduction_parameters>(parameters, binding_policy),
|
||||
reduction_type_(rtype){ }
|
||||
base_impl<gemv, gemv_parameters>(parameters, binding_policy),
|
||||
dot_type_(rtype){ }
|
||||
|
||||
std::vector<int_t> mreduction::input_sizes(expressions_tuple const & expressions) const
|
||||
std::vector<int_t> gemv::input_sizes(expressions_tuple const & expressions) const
|
||||
{
|
||||
array_expression const & first_expression = *expressions.data().front();
|
||||
std::vector<std::size_t> idx = filter_nodes(&is_reduction, first_expression, false);
|
||||
std::vector<std::size_t> idx = filter_nodes(&is_dot, first_expression, false);
|
||||
std::pair<int_t, int_t> MN = matrix_size(lhs_most(first_expression.tree(), idx[0]));
|
||||
if(reduction_type_==REDUCE_COLUMNS)
|
||||
if(dot_type_==REDUCE_COLUMNS)
|
||||
std::swap(MN.first,MN.second);
|
||||
return tools::make_vector<int_t>() << MN.first << MN.second;
|
||||
}
|
||||
|
||||
void mreduction::enqueue(driver::CommandQueue & queue, driver::Program & program, const char * suffix, base & fallback, controller<expressions_tuple> const & controller)
|
||||
void gemv::enqueue(driver::CommandQueue & queue, driver::Program & program, const char * suffix, base & fallback, controller<expressions_tuple> const & controller)
|
||||
{
|
||||
expressions_tuple const & expressions = controller.x();
|
||||
driver::Context const & context = expressions.context();
|
||||
|
||||
std::vector<int_t> MN = input_sizes(expressions);
|
||||
std::vector<array_expression::node const *> reductions;
|
||||
std::vector<array_expression::node const *> dots;
|
||||
for (const auto & e : expressions.data())
|
||||
{
|
||||
std::vector<size_t> reductions_idx = filter_nodes(&is_reduction, *e, false);
|
||||
for (auto & r : reductions_idx)
|
||||
reductions.push_back(&(e)->tree()[r]);
|
||||
std::vector<size_t> dots_idx = filter_nodes(&is_dot, *e, false);
|
||||
for (auto & r : dots_idx)
|
||||
dots.push_back(&(e)->tree()[r]);
|
||||
}
|
||||
|
||||
//Fallback
|
||||
if(reduction_type_==REDUCE_COLUMNS && p_.simd_width>1 && requires_fallback(expressions))
|
||||
if(dot_type_==REDUCE_COLUMNS && p_.simd_width>1 && requires_fallback(expressions))
|
||||
{
|
||||
fallback.enqueue(queue, program, "fallback", fallback, controller);
|
||||
return;
|
||||
@@ -381,9 +383,9 @@ void mreduction::enqueue(driver::CommandQueue & queue, driver::Program & program
|
||||
//Temporary buffers
|
||||
unsigned int i = 0;
|
||||
unsigned int j = 0;
|
||||
for (auto const & r : reductions)
|
||||
for (auto const & r : dots)
|
||||
{
|
||||
if (is_index_reduction(r->op))
|
||||
if (is_index_dot(r->op))
|
||||
{
|
||||
if (tmpidx.size() <= j)
|
||||
tmpidx.push_back(driver::Buffer(context, p_.num_groups_0*M*4));
|
||||
@@ -405,24 +407,25 @@ void mreduction::enqueue(driver::CommandQueue & queue, driver::Program & program
|
||||
controller.execution_options().enqueue(program.context(), kernels[i], global[i], local[i]);
|
||||
}
|
||||
|
||||
mreduction_rows::mreduction_rows(mreduction_parameters const & parameters,
|
||||
gemv_n::gemv_n(gemv_parameters const & parameters,
|
||||
binding_policy_t binding_policy):
|
||||
mreduction(parameters, REDUCE_ROWS, binding_policy){}
|
||||
gemv(parameters, REDUCE_ROWS, binding_policy){}
|
||||
|
||||
mreduction_rows::mreduction_rows(unsigned int simd, unsigned int ls1, unsigned int ls2,
|
||||
gemv_n::gemv_n(unsigned int simd, unsigned int ls1, unsigned int ls2,
|
||||
unsigned int ng1, unsigned int ng2, fetching_policy_type fetch, binding_policy_t bind):
|
||||
mreduction(mreduction_parameters(simd, ls1, ls2, ng1, ng2, fetch), REDUCE_ROWS, bind)
|
||||
gemv(gemv_parameters(simd, ls1, ls2, ng1, ng2, fetch), REDUCE_ROWS, bind)
|
||||
{}
|
||||
|
||||
|
||||
mreduction_cols::mreduction_cols(mreduction::parameters_type const & parameters,
|
||||
gemv_t::gemv_t(gemv::parameters_type const & parameters,
|
||||
binding_policy_t binding_policy):
|
||||
mreduction(parameters, REDUCE_COLUMNS, binding_policy){}
|
||||
gemv(parameters, REDUCE_COLUMNS, binding_policy){}
|
||||
|
||||
mreduction_cols::mreduction_cols(unsigned int simd, unsigned int ls1, unsigned int ls2,
|
||||
gemv_t::gemv_t(unsigned int simd, unsigned int ls1, unsigned int ls2,
|
||||
unsigned int ng1, unsigned int ng2, fetching_policy_type fetch, binding_policy_t bind):
|
||||
mreduction(mreduction_parameters(simd, ls1, ls2, ng1, ng2, fetch), REDUCE_COLUMNS, bind)
|
||||
gemv(gemv_parameters(simd, ls1, ls2, ng1, ng2, fetch), REDUCE_COLUMNS, bind)
|
||||
{}
|
||||
|
||||
|
||||
}
|
||||
}
|
@@ -1,4 +1,4 @@
|
||||
#include "isaac/backend/templates/maxpy.h"
|
||||
#include "isaac/backend/templates/ger.h"
|
||||
#include "isaac/tools/make_map.hpp"
|
||||
#include "isaac/tools/make_vector.hpp"
|
||||
#include "isaac/symbolic/io.h"
|
||||
@@ -7,15 +7,17 @@
|
||||
|
||||
namespace isaac
|
||||
{
|
||||
namespace templates
|
||||
{
|
||||
|
||||
maxpy_parameters::maxpy_parameters(unsigned int _simd_width,
|
||||
ger_parameters::ger_parameters(unsigned int _simd_width,
|
||||
unsigned int _local_size_0, unsigned int _local_size_1,
|
||||
unsigned int _num_groups_0, unsigned int _num_groups_1,
|
||||
fetching_policy_type _fetching_policy) : base::parameters_type(_simd_width, _local_size_0, _local_size_1, 1), num_groups_0(_num_groups_0), num_groups_1(_num_groups_1), fetching_policy(_fetching_policy){ }
|
||||
|
||||
|
||||
|
||||
int maxpy::is_invalid_impl(driver::Device const &, expressions_tuple const &) const
|
||||
int ger::is_invalid_impl(driver::Device const &, expressions_tuple const &) const
|
||||
{
|
||||
if (p_.simd_width>1)
|
||||
return TEMPLATE_INVALID_SIMD_WIDTH;
|
||||
@@ -24,7 +26,7 @@ int maxpy::is_invalid_impl(driver::Device const &, expressions_tuple const &) co
|
||||
return TEMPLATE_VALID;
|
||||
}
|
||||
|
||||
std::string maxpy::generate_impl(const char * suffix, expressions_tuple const & expressions, driver::Device const & device, std::vector<mapping_type> const & mappings) const
|
||||
std::string ger::generate_impl(const char * suffix, expressions_tuple const & expressions, driver::Device const & device, std::vector<mapping_type> const & mappings) const
|
||||
{
|
||||
kernel_generation_stream stream;
|
||||
std::string _size_t = size_type(device);
|
||||
@@ -95,23 +97,23 @@ std::string maxpy::generate_impl(const char * suffix, expressions_tuple const &
|
||||
return stream.str();
|
||||
}
|
||||
|
||||
maxpy::maxpy(parameters_type const & parameters, binding_policy_t binding_policy) :
|
||||
base_impl<maxpy, maxpy_parameters>(parameters, binding_policy){ }
|
||||
ger::ger(parameters_type const & parameters, binding_policy_t binding_policy) :
|
||||
base_impl<ger, ger_parameters>(parameters, binding_policy){ }
|
||||
|
||||
maxpy::maxpy(unsigned int simd, unsigned int ls1, unsigned int ls2,
|
||||
ger::ger(unsigned int simd, unsigned int ls1, unsigned int ls2,
|
||||
unsigned int ng1, unsigned int ng2, fetching_policy_type fetch,
|
||||
binding_policy_t bind):
|
||||
base_impl<maxpy, maxpy_parameters>(maxpy_parameters(simd, ls1, ls2, ng1, ng2, fetch), bind)
|
||||
base_impl<ger, ger_parameters>(ger_parameters(simd, ls1, ls2, ng1, ng2, fetch), bind)
|
||||
{}
|
||||
|
||||
std::vector<int_t> maxpy::input_sizes(expressions_tuple const & expressions) const
|
||||
std::vector<int_t> ger::input_sizes(expressions_tuple const & expressions) const
|
||||
{
|
||||
isaac::array_expression const & array_expression = *(expressions.data().front());
|
||||
std::pair<int_t, int_t> size = matrix_size(lhs_most(array_expression.tree(), array_expression.root()));
|
||||
return tools::make_vector<int_t>() << size.first << size.second;
|
||||
}
|
||||
|
||||
void maxpy::enqueue(driver::CommandQueue & queue, driver::Program & program, const char * suffix, base &, controller<expressions_tuple> const & controller)
|
||||
void ger::enqueue(driver::CommandQueue & /*queue*/, driver::Program & program, const char * suffix, base &, controller<expressions_tuple> const & controller)
|
||||
{
|
||||
expressions_tuple const & expressions = controller.x();
|
||||
char name[32] = {"axpy"};
|
||||
@@ -129,3 +131,4 @@ void maxpy::enqueue(driver::CommandQueue & queue, driver::Program & program, con
|
||||
}
|
||||
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user