Cleaning: Largely renamed templates to BLAS-like names

This commit is contained in:
Philippe Tillet
2015-07-11 09:36:01 -04:00
parent 281fa9c7a6
commit cfa6ea812d
40 changed files with 606 additions and 572 deletions

View File

@@ -102,23 +102,23 @@ std::string binary_leaf::evaluate_recursive(leaf_t leaf, std::map<std::string, s
}
mapped_mproduct::mapped_mproduct(std::string const & scalartype, unsigned int id, node_info info) : mapped_object(scalartype, id, "mproduct"), binary_leaf(info) { }
mapped_gemm::mapped_gemm(std::string const & scalartype, unsigned int id, node_info info) : mapped_object(scalartype, id, "gemm"), binary_leaf(info) { }
//
mapped_reduction::mapped_reduction(std::string const & scalartype, unsigned int id, node_info info, std::string const & type_key) :
mapped_dot::mapped_dot(std::string const & scalartype, unsigned int id, node_info info, std::string const & type_key) :
mapped_object(scalartype, id, type_key), binary_leaf(info)
{ }
int_t mapped_reduction::root_idx() const
int_t mapped_dot::root_idx() const
{ return info_.root_idx; }
isaac::array_expression const & mapped_reduction::array_expression() const
isaac::array_expression const & mapped_dot::array_expression() const
{ return *info_.array_expression; }
array_expression::node mapped_reduction::root_node() const
array_expression::node mapped_dot::root_node() const
{ return array_expression().tree()[root_idx()]; }
bool mapped_reduction::is_index_reduction() const
bool mapped_dot::is_index_dot() const
{
op_element const & op = root_op();
return op.type==OPERATOR_ELEMENT_ARGFMAX_TYPE
@@ -127,17 +127,17 @@ bool mapped_reduction::is_index_reduction() const
|| op.type==OPERATOR_ELEMENT_ARGMIN_TYPE;
}
op_element mapped_reduction::root_op() const
op_element mapped_dot::root_op() const
{
return info_.array_expression->tree()[info_.root_idx].op;
}
//
mapped_scalar_reduction::mapped_scalar_reduction(std::string const & scalartype, unsigned int id, node_info info) : mapped_reduction(scalartype, id, info, "scalar_reduction"){ }
mapped_scalar_dot::mapped_scalar_dot(std::string const & scalartype, unsigned int id, node_info info) : mapped_dot(scalartype, id, info, "scalar_dot"){ }
//
mapped_mreduction::mapped_mreduction(std::string const & scalartype, unsigned int id, node_info info) : mapped_reduction(scalartype, id, info, "mreduction") { }
mapped_gemv::mapped_gemv(std::string const & scalartype, unsigned int id, node_info info) : mapped_dot(scalartype, id, info, "gemv") { }
//
void mapped_host_scalar::preprocess(std::string & str) const

View File

@@ -10,15 +10,15 @@ namespace detail
bool is_scalar_reduction(array_expression::node const & node)
bool is_scalar_dot(array_expression::node const & node)
{
return node.op.type_family==OPERATOR_VECTOR_REDUCTION_TYPE_FAMILY;
return node.op.type_family==OPERATOR_VECTOR_DOT_TYPE_FAMILY;
}
bool is_vector_reduction(array_expression::node const & node)
bool is_vector_dot(array_expression::node const & node)
{
return node.op.type_family==OPERATOR_ROWS_REDUCTION_TYPE_FAMILY
|| node.op.type_family==OPERATOR_COLUMNS_REDUCTION_TYPE_FAMILY;
return node.op.type_family==OPERATOR_ROWS_DOT_TYPE_FAMILY
|| node.op.type_family==OPERATOR_COLUMNS_DOT_TYPE_FAMILY;
}
bool is_assignment(op_element const & op)
@@ -75,10 +75,10 @@ namespace detail
|| op.type==OPERATOR_MATRIX_ROW_TYPE
|| op.type==OPERATOR_MATRIX_COLUMN_TYPE
|| op.type==OPERATOR_OUTER_PROD_TYPE
|| op.type_family==OPERATOR_VECTOR_REDUCTION_TYPE_FAMILY
|| op.type_family==OPERATOR_ROWS_REDUCTION_TYPE_FAMILY
|| op.type_family==OPERATOR_COLUMNS_REDUCTION_TYPE_FAMILY
|| op.type_family==OPERATOR_MATRIX_PRODUCT_TYPE_FAMILY
|| op.type_family==OPERATOR_VECTOR_DOT_TYPE_FAMILY
|| op.type_family==OPERATOR_ROWS_DOT_TYPE_FAMILY
|| op.type_family==OPERATOR_COLUMNS_DOT_TYPE_FAMILY
|| op.type_family==OPERATOR_GEMM_TYPE_FAMILY
;
}
@@ -214,10 +214,10 @@ const char * evaluate(operation_node_type type)
case OPERATOR_ELEMENT_MIN_TYPE : return "min";
//Binary
case OPERATOR_MATRIX_PRODUCT_NN_TYPE : return "prodNN";
case OPERATOR_MATRIX_PRODUCT_TN_TYPE : return "prodTN";
case OPERATOR_MATRIX_PRODUCT_NT_TYPE : return "prodNT";
case OPERATOR_MATRIX_PRODUCT_TT_TYPE : return "prodTT";
case OPERATOR_GEMM_NN_TYPE : return "prodNN";
case OPERATOR_GEMM_TN_TYPE : return "prodTN";
case OPERATOR_GEMM_NT_TYPE : return "prodNT";
case OPERATOR_GEMM_TT_TYPE : return "prodTT";
case OPERATOR_VDIAG_TYPE : return "vdiag";
case OPERATOR_MATRIX_DIAG_TYPE : return "mdiag";
case OPERATOR_MATRIX_ROW_TYPE : return "row";

View File

@@ -1,4 +1,4 @@
#include "isaac/backend/templates/vaxpy.h"
#include "isaac/backend/templates/axpy.h"
#include "isaac/backend/keywords.h"
#include "isaac/driver/backend.h"
#include "isaac/tools/make_map.hpp"
@@ -8,23 +8,24 @@
namespace isaac
{
namespace templates
{
vaxpy_parameters::vaxpy_parameters(unsigned int _simd_width,
axpy_parameters::axpy_parameters(unsigned int _simd_width,
unsigned int _group_size, unsigned int _num_groups,
fetching_policy_type _fetching_policy) :
base::parameters_type(_simd_width, _group_size, 1, 1), num_groups(_num_groups), fetching_policy(_fetching_policy)
{ }
int vaxpy::is_invalid_impl(driver::Device const &, expressions_tuple const &) const
int axpy::is_invalid_impl(driver::Device const &, expressions_tuple const &) const
{
if (p_.fetching_policy==FETCH_FROM_LOCAL)
return TEMPLATE_INVALID_FETCHING_POLICY_TYPE;
return TEMPLATE_VALID;
}
std::string vaxpy::generate_impl(const char * suffix, expressions_tuple const & expressions, driver::Device const & device, std::vector<mapping_type> const & mappings) const
std::string axpy::generate_impl(const char * suffix, expressions_tuple const & expressions, driver::Device const & device, std::vector<mapping_type> const & mappings) const
{
driver::backend_type backend = device.backend();
std::string _size_t = size_type(device);
@@ -90,25 +91,24 @@ std::string vaxpy::generate_impl(const char * suffix, expressions_tuple const &
return stream.str();
}
vaxpy::vaxpy(vaxpy_parameters const & parameters,
axpy::axpy(axpy_parameters const & parameters,
binding_policy_t binding_policy) :
base_impl<vaxpy, vaxpy_parameters>(parameters, binding_policy)
base_impl<axpy, axpy_parameters>(parameters, binding_policy)
{}
vaxpy::vaxpy(unsigned int simd, unsigned int ls, unsigned int ng,
axpy::axpy(unsigned int simd, unsigned int ls, unsigned int ng,
fetching_policy_type fetch, binding_policy_t bind):
base_impl<vaxpy, vaxpy_parameters>(vaxpy_parameters(simd,ls,ng,fetch), bind)
base_impl<axpy, axpy_parameters>(axpy_parameters(simd,ls,ng,fetch), bind)
{}
std::vector<int_t> vaxpy::input_sizes(expressions_tuple const & expressions) const
std::vector<int_t> axpy::input_sizes(expressions_tuple const & expressions) const
{
size4 shape = static_cast<array_expression const *>(expressions.data().front().get())->shape();
int_t size = static_cast<array_expression const *>(expressions.data().front().get())->shape()[0];
return tools::make_vector<int_t>() << std::max(shape[0], shape[1]);
}
void vaxpy::enqueue(driver::CommandQueue & queue, driver::Program & program, const char * suffix, base & fallback, controller<expressions_tuple> const & controller)
void axpy::enqueue(driver::CommandQueue & queue, driver::Program & program, const char * suffix, base & fallback, controller<expressions_tuple> const & controller)
{
expressions_tuple const & expressions = controller.x();
//Size
@@ -135,3 +135,4 @@ void vaxpy::enqueue(driver::CommandQueue & queue, driver::Program & program, con
}
}

View File

@@ -2,11 +2,11 @@
#include "isaac/array.h"
#include "isaac/backend/keywords.h"
#include "isaac/backend/templates/vaxpy.h"
#include "isaac/backend/templates/reduction.h"
#include "isaac/backend/templates/maxpy.h"
#include "isaac/backend/templates/mreduction.h"
#include "isaac/backend/templates/mproduct.h"
#include "isaac/backend/templates/axpy.h"
#include "isaac/backend/templates/dot.h"
#include "isaac/backend/templates/ger.h"
#include "isaac/backend/templates/gemv.h"
#include "isaac/backend/templates/gemm.h"
#include "isaac/backend/templates/base.h"
#include "isaac/backend/parse.h"
#include "isaac/exception/operation_not_supported.h"
@@ -17,6 +17,8 @@
namespace isaac
{
namespace templates
{
base::parameters_type::parameters_type(unsigned int _simd_width, int_t _local_size_1, int_t _local_size_2, int_t _num_kernels) : simd_width(_simd_width), local_size_0(_local_size_1), local_size_1(_local_size_2), num_kernels(_num_kernels)
{ }
@@ -102,12 +104,12 @@ void base::map_functor::operator()(isaac::array_expression const & array_express
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_matrix_row>(&array_expression, root_idx, &mapping_)));
else if (root_node.op.type==OPERATOR_MATRIX_COLUMN_TYPE)
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_matrix_column>(&array_expression, root_idx, &mapping_)));
else if (detail::is_scalar_reduction(root_node))
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_scalar_reduction>(&array_expression, root_idx, &mapping_)));
else if (detail::is_vector_reduction(root_node))
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_mreduction>(&array_expression, root_idx, &mapping_)));
else if (root_node.op.type_family == OPERATOR_MATRIX_PRODUCT_TYPE_FAMILY)
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_mproduct>(&array_expression, root_idx, &mapping_)));
else if (detail::is_scalar_dot(root_node))
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_scalar_dot>(&array_expression, root_idx, &mapping_)));
else if (detail::is_vector_dot(root_node))
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_gemv>(&array_expression, root_idx, &mapping_)));
else if (root_node.op.type_family == OPERATOR_GEMM_TYPE_FAMILY)
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_gemm>(&array_expression, root_idx, &mapping_)));
else if (root_node.op.type == OPERATOR_REPEAT_TYPE)
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_repeat>(&array_expression, root_idx, &mapping_)));
else if (root_node.op.type == OPERATOR_OUTER_PROD_TYPE)
@@ -198,7 +200,7 @@ void base::set_arguments_functor::operator()(isaac::array_expression const & arr
set_arguments(root_node.rhs);
}
void base::compute_reduction(kernel_generation_stream & os, std::string acc, std::string cur, op_element const & op)
void base::compute_dot(kernel_generation_stream & os, std::string acc, std::string cur, op_element const & op)
{
if (detail::is_elementwise_function(op))
os << acc << "=" << evaluate(op.type) << "(" << acc << "," << cur << ");" << std::endl;
@@ -206,7 +208,7 @@ void base::compute_reduction(kernel_generation_stream & os, std::string acc, std
os << acc << "= (" << acc << ")" << evaluate(op.type) << "(" << cur << ");" << std::endl;
}
void base::compute_index_reduction(kernel_generation_stream & os, std::string acc, std::string cur, std::string const & acc_value, std::string const & cur_value, op_element const & op)
void base::compute_index_dot(kernel_generation_stream & os, std::string acc, std::string cur, std::string const & acc_value, std::string const & cur_value, op_element const & op)
{
// os << acc << " = " << cur_value << ">" << acc_value << "?" << cur << ":" << acc << ";" << std::endl;
os << acc << "= select(" << acc << "," << cur << "," << cur_value << ">" << acc_value << ");" << std::endl;
@@ -259,7 +261,7 @@ std::string base::neutral_element(op_element const & op, driver::backend_type ba
case OPERATOR_ELEMENT_MIN_TYPE : return INF;
case OPERATOR_ELEMENT_ARGMIN_TYPE : return INF;
default: throw operation_not_supported_exception("Unsupported reduction operator : no neutral element known");
default: throw operation_not_supported_exception("Unsupported dot operator : no neutral element known");
}
}
@@ -399,14 +401,14 @@ std::pair<int_t, int_t> base::matrix_size(array_expression::node const & node)
return std::make_pair(node.lhs.array->shape()[0],node.lhs.array->shape()[1]);
}
bool base::is_reduction(array_expression::node const & node)
bool base::is_dot(array_expression::node const & node)
{
return node.op.type_family==OPERATOR_VECTOR_REDUCTION_TYPE_FAMILY
|| node.op.type_family==OPERATOR_COLUMNS_REDUCTION_TYPE_FAMILY
|| node.op.type_family==OPERATOR_ROWS_REDUCTION_TYPE_FAMILY;
return node.op.type_family==OPERATOR_VECTOR_DOT_TYPE_FAMILY
|| node.op.type_family==OPERATOR_COLUMNS_DOT_TYPE_FAMILY
|| node.op.type_family==OPERATOR_ROWS_DOT_TYPE_FAMILY;
}
bool base::is_index_reduction(op_element const & op)
bool base::is_index_dot(op_element const & op)
{
return op.type==OPERATOR_ELEMENT_ARGFMAX_TYPE
|| op.type==OPERATOR_ELEMENT_ARGMAX_TYPE
@@ -566,10 +568,11 @@ int base_impl<TType, PType>::is_invalid(expressions_tuple const & expressions, d
return is_invalid_impl(device, expressions);
}
template class base_impl<vaxpy, vaxpy_parameters>;
template class base_impl<reduction, reduction_parameters>;
template class base_impl<maxpy, maxpy_parameters>;
template class base_impl<mreduction, mreduction_parameters>;
template class base_impl<mproduct, mproduct_parameters>;
template class base_impl<axpy, axpy_parameters>;
template class base_impl<dot, dot_parameters>;
template class base_impl<ger, ger_parameters>;
template class base_impl<gemv, gemv_parameters>;
template class base_impl<gemm, gemm_parameters>;
}
}

View File

@@ -1,5 +1,5 @@
#include <iostream>
#include "isaac/backend/templates/reduction.h"
#include "isaac/backend/templates/dot.h"
#include <CL/cl.hpp>
#include "isaac/tools/to_string.hpp"
#include "isaac/tools/make_map.hpp"
@@ -7,13 +7,14 @@
#include "isaac/backend/keywords.h"
namespace isaac
{
reduction_parameters::reduction_parameters(unsigned int _simd_width,
namespace templates
{
dot_parameters::dot_parameters(unsigned int _simd_width,
unsigned int _group_size, unsigned int _num_groups,
fetching_policy_type _fetching_policy) : base::parameters_type(_simd_width, _group_size, 1, 2), num_groups(_num_groups), fetching_policy(_fetching_policy)
{ }
unsigned int reduction::lmem_usage(expressions_tuple const & expressions) const
unsigned int dot::lmem_usage(expressions_tuple const & expressions) const
{
unsigned int res = 0;
for(const auto & elem : expressions.data())
@@ -24,14 +25,14 @@ unsigned int reduction::lmem_usage(expressions_tuple const & expressions) const
return res;
}
int reduction::is_invalid_impl(driver::Device const &, expressions_tuple const &) const
int dot::is_invalid_impl(driver::Device const &, expressions_tuple const &) const
{
if (p_.fetching_policy==FETCH_FROM_LOCAL)
return TEMPLATE_INVALID_FETCHING_POLICY_TYPE;
return TEMPLATE_VALID;
}
inline void reduction::reduce_1d_local_memory(kernel_generation_stream & stream, unsigned int size, std::vector<mapped_scalar_reduction*> exprs,
inline void dot::reduce_1d_local_memory(kernel_generation_stream & stream, unsigned int size, std::vector<mapped_scalar_dot*> exprs,
std::string const & buf_str, std::string const & buf_value_str, driver::backend_type backend) const
{
stream << "#pragma unroll" << std::endl;
@@ -44,26 +45,26 @@ inline void reduction::reduce_1d_local_memory(kernel_generation_stream & stream,
stream.inc_tab();
for (auto & expr : exprs)
if (expr->is_index_reduction())
compute_index_reduction(stream, expr->process(buf_str+"[lid]"), expr->process(buf_str+"[lid+stride]")
if (expr->is_index_dot())
compute_index_dot(stream, expr->process(buf_str+"[lid]"), expr->process(buf_str+"[lid+stride]")
, expr->process(buf_value_str+"[lid]"), expr->process(buf_value_str+"[lid+stride]"),
expr->root_op());
else
compute_reduction(stream, expr->process(buf_str+"[lid]"), expr->process(buf_str+"[lid+stride]"), expr->root_op());
compute_dot(stream, expr->process(buf_str+"[lid]"), expr->process(buf_str+"[lid+stride]"), expr->root_op());
stream.dec_tab();
stream << "}" << std::endl;
stream.dec_tab();
stream << "}" << std::endl;
}
std::string reduction::generate_impl(const char * suffix, expressions_tuple const & expressions, driver::Device const & device, std::vector<mapping_type> const & mappings) const
std::string dot::generate_impl(const char * suffix, expressions_tuple const & expressions, driver::Device const & device, std::vector<mapping_type> const & mappings) const
{
kernel_generation_stream stream;
std::vector<mapped_scalar_reduction*> exprs;
std::vector<mapped_scalar_dot*> exprs;
for (const auto & mapping : mappings)
for (mapping_type::const_iterator iit = mapping.begin(); iit != mapping.end(); ++iit)
if (mapped_scalar_reduction * p = dynamic_cast<mapped_scalar_reduction*>(iit->second.get()))
if (mapped_scalar_dot * p = dynamic_cast<mapped_scalar_dot*>(iit->second.get()))
exprs.push_back(p);
std::size_t N = exprs.size();
driver::backend_type backend = device.backend();
@@ -73,7 +74,7 @@ std::string reduction::generate_impl(const char * suffix, expressions_tuple cons
for (unsigned int k = 0; k < N; ++k)
{
std::string numeric_type = numeric_type_to_string(lhs_most(exprs[k]->array_expression().tree(), exprs[k]->array_expression().root()).lhs.dtype);
if (exprs[k]->is_index_reduction())
if (exprs[k]->is_index_dot())
{
arguments += exprs[k]->process(Global(backend).get() + " unsigned int* #name_temp, ");
arguments += exprs[k]->process(Global(backend).get() + " " + tools::to_string(numeric_type) + "* #name_temp_value, ");
@@ -112,7 +113,7 @@ std::string reduction::generate_impl(const char * suffix, expressions_tuple cons
for (unsigned int k = 0; k < N; ++k)
{
if (exprs[k]->is_index_reduction())
if (exprs[k]->is_index_dot())
{
stream << exprs[k]->process(Local(backend).get() + " #scalartype #name_buf_value[" + tools::to_string(p_.local_size_0) + "];") << std::endl;
stream << exprs[k]->process("#scalartype #name_acc_value = " + neutral_element(exprs[k]->root_op(), backend, "#scalartype") + ";") << std::endl;
@@ -156,11 +157,11 @@ std::string reduction::generate_impl(const char * suffix, expressions_tuple cons
accessors["matrix_diag"] = str[a];
accessors["array0"] = "#namereg";
std::string value = elem->evaluate_recursive(LHS_NODE_TYPE, accessors);
if (elem->is_index_reduction())
compute_index_reduction(stream, elem->process("#name_acc"), "i*" + tools::to_string(simd_width) + "+"
if (elem->is_index_dot())
compute_index_dot(stream, elem->process("#name_acc"), "i*" + tools::to_string(simd_width) + "+"
+ tools::to_string(a), elem->process("#name_acc_value"), value,elem->root_op());
else
compute_reduction(stream, elem->process("#name_acc"), value,elem->root_op());
compute_dot(stream, elem->process("#name_acc"), value,elem->root_op());
}
}
});
@@ -168,7 +169,7 @@ std::string reduction::generate_impl(const char * suffix, expressions_tuple cons
//Fills local memory
for (unsigned int k = 0; k < N; ++k)
{
if (exprs[k]->is_index_reduction())
if (exprs[k]->is_index_dot())
stream << exprs[k]->process("#name_buf_value[lid] = #name_acc_value;") << std::endl;
stream << exprs[k]->process("#name_buf[lid] = #name_acc;") << std::endl;
}
@@ -182,7 +183,7 @@ std::string reduction::generate_impl(const char * suffix, expressions_tuple cons
stream.inc_tab();
for (unsigned int k = 0; k < N; ++k)
{
if (exprs[k]->is_index_reduction())
if (exprs[k]->is_index_dot())
stream << exprs[k]->process("#name_temp_value[gpid] = #name_buf_value[0];") << std::endl;
stream << exprs[k]->process("#name_temp[gpid] = #name_buf[0];") << std::endl;
}
@@ -205,9 +206,9 @@ std::string reduction::generate_impl(const char * suffix, expressions_tuple cons
stream << "unsigned int lid = " <<LocalIdx0(backend) << ";" << std::endl;
stream << "unsigned int lsize = " <<LocalSize0(backend) << ";" << std::endl;
for (mapped_scalar_reduction* e: exprs)
for (mapped_scalar_dot* e: exprs)
{
if (e->is_index_reduction())
if (e->is_index_dot())
{
stream << e->process(Local(backend).get() + " unsigned int #name_buf[" + tools::to_string(p_.local_size_0) + "];");
stream << e->process("unsigned int #name_acc = 0;") << std::endl;
@@ -224,18 +225,18 @@ std::string reduction::generate_impl(const char * suffix, expressions_tuple cons
stream << "for(unsigned int i = lid; i < " << p_.num_groups << "; i += lsize)" << std::endl;
stream << "{" << std::endl;
stream.inc_tab();
for (mapped_scalar_reduction* e: exprs)
if (e->is_index_reduction())
compute_index_reduction(stream, e->process("#name_acc"), e->process("#name_temp[i]"), e->process("#name_acc_value"),e->process("#name_temp_value[i]"),e->root_op());
for (mapped_scalar_dot* e: exprs)
if (e->is_index_dot())
compute_index_dot(stream, e->process("#name_acc"), e->process("#name_temp[i]"), e->process("#name_acc_value"),e->process("#name_temp_value[i]"),e->root_op());
else
compute_reduction(stream, e->process("#name_acc"), e->process("#name_temp[i]"), e->root_op());
compute_dot(stream, e->process("#name_acc"), e->process("#name_temp[i]"), e->root_op());
stream.dec_tab();
stream << "}" << std::endl;
for (unsigned int k = 0; k < N; ++k)
{
if (exprs[k]->is_index_reduction())
if (exprs[k]->is_index_dot())
stream << exprs[k]->process("#name_buf_value[lid] = #name_acc_value;") << std::endl;
stream << exprs[k]->process("#name_buf[lid] = #name_acc;") << std::endl;
}
@@ -248,7 +249,7 @@ std::string reduction::generate_impl(const char * suffix, expressions_tuple cons
stream << "{" << std::endl;
stream.inc_tab();
std::map<std::string, std::string> accessors;
accessors["scalar_reduction"] = "#name_buf[0]";
accessors["scalar_dot"] = "#name_buf[0]";
accessors["array0"] = "#pointer[#start]";
evaluate(stream, PARENT_NODE_TYPE, accessors, expressions, mappings);
stream.dec_tab();
@@ -260,23 +261,23 @@ std::string reduction::generate_impl(const char * suffix, expressions_tuple cons
return stream.str();
}
reduction::reduction(reduction::parameters_type const & parameters,
binding_policy_t binding) : base_impl<reduction, reduction_parameters>(parameters, binding)
dot::dot(dot::parameters_type const & parameters,
binding_policy_t binding) : base_impl<dot, dot_parameters>(parameters, binding)
{ }
reduction::reduction(unsigned int simd, unsigned int ls, unsigned int ng,
dot::dot(unsigned int simd, unsigned int ls, unsigned int ng,
fetching_policy_type fetch, binding_policy_t bind):
base_impl<reduction, reduction_parameters>(reduction_parameters(simd,ls,ng,fetch), bind)
base_impl<dot, dot_parameters>(dot_parameters(simd,ls,ng,fetch), bind)
{}
std::vector<int_t> reduction::input_sizes(expressions_tuple const & expressions) const
std::vector<int_t> dot::input_sizes(expressions_tuple const & expressions) const
{
std::vector<size_t> reductions_idx = filter_nodes(&is_reduction, *(expressions.data().front()), false);
int_t N = vector_size(lhs_most(expressions.data().front()->tree(), reductions_idx[0]));
std::vector<size_t> dots_idx = filter_nodes(&is_dot, *(expressions.data().front()), false);
int_t N = vector_size(lhs_most(expressions.data().front()->tree(), dots_idx[0]));
return tools::make_vector<int_t>() << N;
}
void reduction::enqueue(driver::CommandQueue & queue, driver::Program & program, const char * suffix, base & fallback, controller<expressions_tuple> const & controller)
void dot::enqueue(driver::CommandQueue & queue, driver::Program & program, const char * suffix, base & fallback, controller<expressions_tuple> const & controller)
{
expressions_tuple const & expressions = controller.x();
@@ -290,12 +291,12 @@ void reduction::enqueue(driver::CommandQueue & queue, driver::Program & program,
return;
}
std::vector<array_expression::node const *> reductions;
std::vector<array_expression::node const *> dots;
for (const auto & elem : expressions.data())
{
std::vector<size_t> reductions_idx = filter_nodes(&is_reduction, *elem, false);
for (auto & reductions_idx_itt : reductions_idx)
reductions.push_back(&(elem)->tree()[reductions_idx_itt]);
std::vector<size_t> dots_idx = filter_nodes(&is_dot, *elem, false);
for (auto & dots_idx_itt : dots_idx)
dots.push_back(&(elem)->tree()[dots_idx_itt]);
}
//Kernel
@@ -321,9 +322,9 @@ void reduction::enqueue(driver::CommandQueue & queue, driver::Program & program,
//Temporary buffers
unsigned int i = 0;
unsigned int j = 0;
for (std::vector<array_expression::node const *>::const_iterator it = reductions.begin(); it != reductions.end(); ++it)
for (std::vector<array_expression::node const *>::const_iterator it = dots.begin(); it != dots.end(); ++it)
{
if (is_index_reduction((*it)->op))
if (is_index_dot((*it)->op))
{
if (tmpidx_.size() <= j)
tmpidx_.push_back(driver::Buffer(context, p_.num_groups*4));
@@ -343,3 +344,4 @@ void reduction::enqueue(driver::CommandQueue & queue, driver::Program & program,
}
}
}

View File

@@ -1,5 +1,5 @@
#include "isaac/array.h"
#include "isaac/backend/templates/mproduct.h"
#include "isaac/backend/templates/gemm.h"
#include "isaac/backend/keywords.h"
#include "isaac/model/model.h"
#include "isaac/symbolic/preset.h"
@@ -10,8 +10,10 @@
namespace isaac
{
namespace templates
{
mproduct_parameters::mproduct_parameters(unsigned int simd_width
gemm_parameters::gemm_parameters(unsigned int simd_width
, int_t local_size_0, int_t KL, int_t local_size_1, int_t D
, int_t ms, int_t ks, int_t ns
, fetching_policy_type A_fetching_policy, fetching_policy_type B_fetching_policy
@@ -21,7 +23,7 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
mL(ms*local_size_0), nL(ns*local_size_1){}
unsigned int mproduct::lmem_usage(expressions_tuple const & expressions) const
unsigned int gemm::lmem_usage(expressions_tuple const & expressions) const
{
isaac::array_expression const & array_expression = (*expressions.data().front());
numeric_type numeric_t = lhs_most(array_expression.tree(), array_expression.root()).lhs.dtype;
@@ -32,7 +34,7 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
return N*size_of(numeric_t);
}
unsigned int mproduct::registers_usage(expressions_tuple const & expressions) const
unsigned int gemm::registers_usage(expressions_tuple const & expressions) const
{
isaac::array_expression const & array_expression = (*expressions.data().front());
numeric_type numeric_t = lhs_most(array_expression.tree(), array_expression.root()).lhs.dtype;
@@ -41,7 +43,7 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
return N*size_of(numeric_t);
}
int mproduct::is_invalid_impl(driver::Device const &, expressions_tuple const & expressions) const
int gemm::is_invalid_impl(driver::Device const &, expressions_tuple const & expressions) const
{
std::vector<int_t> MNK = input_sizes(expressions);
int_t M = MNK[0]; int_t N = MNK[1];
@@ -95,7 +97,7 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
return TEMPLATE_VALID;
}
std::string mproduct::generate_impl(const char * suffix, expressions_tuple const & expressions, driver::Device const & device, std::vector<mapping_type> const &) const
std::string gemm::generate_impl(const char * suffix, expressions_tuple const & expressions, driver::Device const & device, std::vector<mapping_type> const &) const
{
using std::string;
using tools::to_string;
@@ -437,7 +439,7 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
#undef VST0RE
}
void mproduct::enqueue_block(driver::CommandQueue & queue, int_t M, int_t N, int_t K,
void gemm::enqueue_block(driver::CommandQueue & /*queue*/, int_t M, int_t N, int_t K,
array const & A, array const & B, array const & C,
value_scalar const & alpha, value_scalar const & beta,
driver::Program & program, const char * suffix, execution_options_type const & options)
@@ -516,7 +518,7 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
}
}
array mproduct::create_slice(array & M, int_t s0_0, int_t s0_1, int_t s1_0, int_t s1_1, bool swap)
array gemm::create_slice(array & M, int_t s0_0, int_t s0_1, int_t s1_0, int_t s1_1, bool swap)
{
slice s0(s0_0, s0_1);
slice s1(s1_0, s1_1);
@@ -525,7 +527,7 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
return array(M, s0, s1);
}
std::vector<int_t> mproduct::infos(expressions_tuple const & expressions, symbolic::preset::gemm::args& arguments) const
std::vector<int_t> gemm::infos(expressions_tuple const & expressions, symbolic::preset::gemm::args& arguments) const
{
isaac::array_expression & array_expression = (*expressions.data().front());
array_expression::container_type & array = array_expression.tree();
@@ -537,26 +539,26 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
return {M, N, K};
}
mproduct::mproduct(mproduct_parameters const & parameters, bool check_bounds, char A_trans, char B_trans) : base_impl<mproduct, mproduct_parameters>(parameters, BIND_ALL_UNIQUE), A_trans_(A_trans), B_trans_(B_trans), check_bounds_(check_bounds)
gemm::gemm(gemm_parameters const & parameters, bool check_bounds, char A_trans, char B_trans) : base_impl<gemm, gemm_parameters>(parameters, BIND_ALL_UNIQUE), A_trans_(A_trans), B_trans_(B_trans), check_bounds_(check_bounds)
{
if(A_trans_=='N' && B_trans_=='N') type_ = MATRIX_PRODUCT_NN_TYPE;
else if(A_trans_=='T' && B_trans_=='N') type_ = MATRIX_PRODUCT_TN_TYPE;
else if(A_trans_=='N' && B_trans_=='T') type_ = MATRIX_PRODUCT_NT_TYPE;
else if(A_trans_=='T' && B_trans_=='T') type_ = MATRIX_PRODUCT_TT_TYPE;
if(A_trans_=='N' && B_trans_=='N') type_ = GEMM_NN_TYPE;
else if(A_trans_=='T' && B_trans_=='N') type_ = GEMM_TN_TYPE;
else if(A_trans_=='N' && B_trans_=='T') type_ = GEMM_NT_TYPE;
else if(A_trans_=='T' && B_trans_=='T') type_ = GEMM_TT_TYPE;
else throw;
}
std::vector<int_t> mproduct::input_sizes(expressions_tuple const & expressions) const
std::vector<int_t> gemm::input_sizes(expressions_tuple const & expressions) const
{
symbolic::preset::gemm::args dummy;
return infos(expressions, dummy);
}
void mproduct::enqueue(driver::CommandQueue & queue, driver::Program & program, const char * suffix, base & fallback_base, controller<expressions_tuple> const & ctr)
void gemm::enqueue(driver::CommandQueue & queue, driver::Program & program, const char * suffix, base & fallback_base, controller<expressions_tuple> const & ctr)
{
using namespace tools;
mproduct & fallback = (mproduct&)fallback_base;
gemm & fallback = (gemm&)fallback_base;
expressions_tuple const & expressions = ctr.x();
@@ -579,8 +581,6 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
int_t ldstrideA = pA->stride()[0];
int_t ldstrideB = pB->stride()[0];
int_t ldstrideC = pC->stride()[0];
int_t ldstartA = pA->start()[0];
int_t ldstartB = pB->start()[0];
numeric_type dtype = args.C->dtype;
@@ -613,40 +613,41 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
}
//
mproduct_nn::mproduct_nn(unsigned int simd
gemm_nn::gemm_nn(unsigned int simd
, int_t ls0, int_t KL, int_t ls1, int_t D
, int_t ms, int_t ks, int_t ns
, fetching_policy_type Afetch , fetching_policy_type Bfetch
, int_t lfetch0, int_t lfetch1, bool check_bound) :
mproduct(mproduct_parameters(simd, ls0, KL, ls1, D, ms, ks, ns, Afetch, Bfetch, lfetch0, lfetch1), check_bound, 'N', 'N')
gemm(gemm_parameters(simd, ls0, KL, ls1, D, ms, ks, ns, Afetch, Bfetch, lfetch0, lfetch1), check_bound, 'N', 'N')
{ }
//
mproduct_tn::mproduct_tn(unsigned int simd
gemm_tn::gemm_tn(unsigned int simd
, int_t ls0, int_t KL, int_t ls1, int_t D
, int_t ms, int_t ks, int_t ns
, fetching_policy_type Afetch , fetching_policy_type Bfetch
, int_t lfetch0, int_t lfetch1, bool check_bound) :
mproduct(mproduct_parameters(simd, ls0, KL, ls1, D, ms, ks, ns, Afetch, Bfetch, lfetch0, lfetch1), check_bound, 'T', 'N')
gemm(gemm_parameters(simd, ls0, KL, ls1, D, ms, ks, ns, Afetch, Bfetch, lfetch0, lfetch1), check_bound, 'T', 'N')
{ }
//
mproduct_nt::mproduct_nt(unsigned int simd
gemm_nt::gemm_nt(unsigned int simd
, int_t ls0, int_t KL, int_t ls1, int_t D
, int_t ms, int_t ks, int_t ns
, fetching_policy_type Afetch , fetching_policy_type Bfetch
, int_t lfetch0, int_t lfetch1, bool check_bound) :
mproduct(mproduct_parameters(simd, ls0, KL, ls1, D, ms, ks, ns, Afetch, Bfetch, lfetch0, lfetch1), check_bound, 'N', 'T')
gemm(gemm_parameters(simd, ls0, KL, ls1, D, ms, ks, ns, Afetch, Bfetch, lfetch0, lfetch1), check_bound, 'N', 'T')
{ }
//
mproduct_tt::mproduct_tt(unsigned int simd
gemm_tt::gemm_tt(unsigned int simd
, int_t ls0, int_t KL, int_t ls1, int_t D
, int_t ms, int_t ks, int_t ns
, fetching_policy_type Afetch , fetching_policy_type Bfetch
, int_t lfetch0, int_t lfetch1, bool check_bound) :
mproduct(mproduct_parameters(simd, ls0, KL, ls1, D, ms, ks, ns, Afetch, Bfetch, lfetch0, lfetch1), check_bound, 'T', 'T')
gemm(gemm_parameters(simd, ls0, KL, ls1, D, ms, ks, ns, Afetch, Bfetch, lfetch0, lfetch1), check_bound, 'T', 'T')
{ }
}
}

View File

@@ -1,48 +1,50 @@
#include <iostream>
#include "isaac/backend/stream.h"
#include "isaac/backend/keywords.h"
#include "isaac/backend/templates/mreduction.h"
#include "isaac/backend/templates/gemv.h"
#include "isaac/tools/to_string.hpp"
#include "isaac/tools/make_map.hpp"
#include "isaac/tools/make_vector.hpp"
namespace isaac
{
namespace templates
{
mreduction_parameters::mreduction_parameters(unsigned int _simd_width,
gemv_parameters::gemv_parameters(unsigned int _simd_width,
unsigned int _local_size_0, unsigned int _local_size_1,
unsigned int _num_groups_0, unsigned int _num_groups_1, fetching_policy_type _fetch_policy): base::parameters_type(_simd_width, _local_size_0, _local_size_1, 1),
num_groups_0(_num_groups_0), num_groups_1(_num_groups_1), fetch_policy(_fetch_policy) { }
int mreduction::is_invalid_impl(driver::Device const &, expressions_tuple const &) const
int gemv::is_invalid_impl(driver::Device const &, expressions_tuple const &) const
{
if(reduction_type_==REDUCE_ROWS && p_.simd_width>1)
if(dot_type_==REDUCE_ROWS && p_.simd_width>1)
return TEMPLATE_INVALID_SIMD_WIDTH;
if (p_.fetch_policy==FETCH_FROM_LOCAL)
return TEMPLATE_INVALID_FETCHING_POLICY_TYPE;
return TEMPLATE_VALID;
}
unsigned int mreduction::lmem_usage() const
unsigned int gemv::lmem_usage() const
{
return (p_.local_size_0+1)*p_.local_size_1;
}
std::string mreduction::generate_impl(const char * suffix, expressions_tuple const & expressions, driver::Device const & device, std::vector<mapping_type> const & mappings) const
std::string gemv::generate_impl(const char * suffix, expressions_tuple const & expressions, driver::Device const & device, std::vector<mapping_type> const & mappings) const
{
using tools::to_string;
std::vector<mapped_mreduction*> reductions;
std::vector<mapped_gemv*> dots;
expressions_tuple::data_type::const_iterator sit;
std::vector<mapping_type>::const_iterator mit;
for (mit = mappings.begin(), sit = expressions.data().begin(); mit != mappings.end(); ++mit, ++sit)
{
array_expression const & first_expression = *expressions.data().front();
std::vector<size_t> idx = filter_nodes(&is_reduction, first_expression, false);
std::vector<size_t> idx = filter_nodes(&is_dot, first_expression, false);
for (auto & elem : idx)
reductions.push_back((mapped_mreduction*)(mit->at(mapping_key(elem, PARENT_NODE_TYPE)).get()));
dots.push_back((mapped_gemv*)(mit->at(mapping_key(elem, PARENT_NODE_TYPE)).get()));
}
kernel_generation_stream stream;
@@ -54,10 +56,10 @@ std::string mreduction::generate_impl(const char * suffix, expressions_tuple con
strcat(name[1], suffix);
std::string arguments = _size_t + " M, " + _size_t + " N, " ;
for (const auto & e : reductions)
for (const auto & e : dots)
{
std::string numeric_type = numeric_type_to_string(lhs_most(e->array_expression().tree(), e->array_expression().root()).lhs.dtype);
if (e->is_index_reduction())
if (e->is_index_dot())
{
arguments += e->process(Global(backend).get() + " unsigned int* #name_temp, ");
arguments += e->process(Global(backend).get() + " " + to_string(numeric_type) + "* #name_temp_value,");
@@ -87,7 +89,7 @@ std::string mreduction::generate_impl(const char * suffix, expressions_tuple con
unsigned int local_size_0_ld = p_.local_size_0;
std::string local_size_0_ld_str = to_string(local_size_0_ld);
for (const auto & e : reductions)
for (const auto & e : dots)
stream << e->process(Local(backend).get() + " #scalartype #name_buf[" + to_string(p_.local_size_1*local_size_0_ld) + "];") << std::endl;
stream << "" << _size_t << " lid0 = " << LocalIdx0(backend) << ";" << std::endl;
@@ -104,7 +106,7 @@ std::string mreduction::generate_impl(const char * suffix, expressions_tuple con
stream << "for(" << _size_t << " r = gid1; r < upper_bound_1; r += gsize1){" << std::endl;
stream.inc_tab();
for (const auto & e : reductions)
for (const auto & e : dots)
stream << e->process("#scalartype #name_acc = " + neutral_element((e)->root_op(), backend, "#scalartype") + ";") << std::endl;
stream << "if (r < M)" << std::endl;
@@ -116,10 +118,10 @@ std::string mreduction::generate_impl(const char * suffix, expressions_tuple con
std::string data_type = append_width("#scalartype",simd_width);
for (const auto & e : reductions)
for (const auto & e : dots)
{
std::map<std::string, std::string> accessors;
if(reduction_type_==REDUCE_COLUMNS)
if(dot_type_==REDUCE_COLUMNS)
{
accessors["array2"] = data_type + " #namereg = " + vload(simd_width, "#scalartype", "c*#stride1", "#pointer + r*#ld", backend)+";";
accessors["repeat"] = data_type + " #namereg = " + vload(simd_width, "#scalartype", "(c%#tuplearg0)*#stride", "#pointer + (r%#tuplearg1)*#stride ", backend)+";";
@@ -141,20 +143,20 @@ std::string mreduction::generate_impl(const char * suffix, expressions_tuple con
str[a] = access_vector_type("#namereg",a);
for (auto & elem : reductions)
for (auto & elem : dots)
for (unsigned int a = 0; a < simd_width; ++a)
{
std::string value = elem->evaluate_recursive(LHS_NODE_TYPE, {{"array2", str[a]}, {"repeat", str[a]}, {"array0", "#namereg"}});
if (elem->is_index_reduction())
compute_index_reduction(stream, elem->process("#name_acc"), "c*"+to_string(simd_width) + to_string(a), elem->process("#name_acc_value"), value, elem->root_op());
if (elem->is_index_dot())
compute_index_dot(stream, elem->process("#name_acc"), "c*"+to_string(simd_width) + to_string(a), elem->process("#name_acc_value"), value, elem->root_op());
else
compute_reduction(stream, elem->process("#name_acc"), value,elem->root_op());
compute_dot(stream, elem->process("#name_acc"), value,elem->root_op());
}
});
stream.dec_tab();
stream << "}" << std::endl;
for (auto & expr : reductions)
for (auto & expr : dots)
stream << expr->process("#name_buf[lid1*" + local_size_0_ld_str + "+ lid0] = #name_acc;") << std::endl;
stream << "#pragma unroll" << std::endl;
@@ -167,13 +169,13 @@ std::string mreduction::generate_impl(const char * suffix, expressions_tuple con
stream << "{" << std::endl;
stream.inc_tab();
for (auto & e : reductions)
if (e->is_index_reduction())
compute_index_reduction(stream, e->process("#name_buf[lid1*" + local_size_0_ld_str + " + lid0]"), e->process("#name_buf[lid1*" + local_size_0_ld_str + " + lid0 + stride]")
for (auto & e : dots)
if (e->is_index_dot())
compute_index_dot(stream, e->process("#name_buf[lid1*" + local_size_0_ld_str + " + lid0]"), e->process("#name_buf[lid1*" + local_size_0_ld_str + " + lid0 + stride]")
, e->process("#name_buf_value[lid1*" + local_size_0_ld_str + " + lid0]"), e->process("#name_buf_value[lid1*" + local_size_0_ld_str + " + lid0 + stride]")
, e->root_op());
else
compute_reduction(stream,e->process("#name_buf[lid1*" + local_size_0_ld_str + " + lid0]"), e->process("#name_buf[lid1*" + local_size_0_ld_str + " + lid0 + stride]"), e->root_op());
compute_dot(stream,e->process("#name_buf[lid1*" + local_size_0_ld_str + " + lid0]"), e->process("#name_buf[lid1*" + local_size_0_ld_str + " + lid0 + stride]"), e->root_op());
stream.dec_tab();
stream << "}" << std::endl;
@@ -188,15 +190,15 @@ std::string mreduction::generate_impl(const char * suffix, expressions_tuple con
if(p_.num_groups_0==1)
{
std::map<std::string, std::string> accessors;
accessors["mreduction"] = "#name_buf[lid1*" + local_size_0_ld_str + "]";
accessors["gemv"] = "#name_buf[lid1*" + local_size_0_ld_str + "]";
accessors["array1"] = "#pointer[r*#stride]";
evaluate(stream, PARENT_NODE_TYPE, accessors, expressions, mappings);
}
else
{
for (mapped_reduction const * e : reductions)
for (mapped_dot const * e : dots)
{
if (e->is_index_reduction())
if (e->is_index_dot())
stream << e->process("#name_temp_value[r + M*gpid0] = #name_buf_value[lid1*" + local_size_0_ld_str + "];") << std::endl;
stream << e->process("#name_temp[r + M*gpid0] = #name_buf[lid1*" + local_size_0_ld_str + "];") << std::endl;
}
@@ -230,7 +232,7 @@ std::string mreduction::generate_impl(const char * suffix, expressions_tuple con
{"array2", "#pointer += #start1 + #start2*#ld; "
"#ld *= #nldstride; "}}, expressions, mappings);
for (const auto & e : reductions)
for (const auto & e : dots)
stream << e->process(Local(backend).get() + " #scalartype #name_buf[" + to_string(p_.local_size_1*local_size_0_ld) + "];") << std::endl;
stream << _size_t << " lid0 = " << LocalIdx0(backend) << ";" << std::endl;
@@ -246,7 +248,7 @@ std::string mreduction::generate_impl(const char * suffix, expressions_tuple con
stream << "for(" << _size_t << " r = gid1; r < upper_bound_1; r += gsize1){" << std::endl;
stream.inc_tab();
for (const auto & e : reductions)
for (const auto & e : dots)
stream << e->process("#scalartype #name_acc = " + neutral_element((e)->root_op(), backend, "#scalartype") + ";") << std::endl;
stream << "if (r < M)" << std::endl;
@@ -256,8 +258,8 @@ std::string mreduction::generate_impl(const char * suffix, expressions_tuple con
stream << "for(" << _size_t << " c = lid0; c < " << p_.num_groups_0 << "; c += lsize0){" << std::endl;
stream.inc_tab();
for (mapped_reduction* e: reductions)
compute_reduction(stream, e->process("#name_acc"), e->process("#name_temp[r + M*c]"), e->root_op());
for (mapped_dot* e: dots)
compute_dot(stream, e->process("#name_acc"), e->process("#name_temp[r + M*c]"), e->root_op());
stream.dec_tab();
stream << "}" << std::endl;
@@ -266,7 +268,7 @@ std::string mreduction::generate_impl(const char * suffix, expressions_tuple con
stream.dec_tab();
stream << "}" << std::endl;
for (auto & expr : reductions)
for (auto & expr : dots)
stream << expr->process("#name_buf[lid1*" + local_size_0_ld_str + "+ lid0] = #name_acc;") << std::endl;
stream << "#pragma unroll" << std::endl;
@@ -279,13 +281,13 @@ std::string mreduction::generate_impl(const char * suffix, expressions_tuple con
stream << "{" << std::endl;
stream.inc_tab();
for (auto & e : reductions)
if (e->is_index_reduction())
compute_index_reduction(stream, e->process("#name_buf[lid1*" + local_size_0_ld_str + " + lid0]"), e->process("#name_buf[lid1*" + local_size_0_ld_str + " + lid0 + stride]")
for (auto & e : dots)
if (e->is_index_dot())
compute_index_dot(stream, e->process("#name_buf[lid1*" + local_size_0_ld_str + " + lid0]"), e->process("#name_buf[lid1*" + local_size_0_ld_str + " + lid0 + stride]")
, e->process("#name_buf_value[lid1*" + local_size_0_ld_str + " + lid0]"), e->process("#name_buf_value[lid1*" + local_size_0_ld_str + " + lid0 + stride]")
, e->root_op());
else
compute_reduction(stream,e->process("#name_buf[lid1*" + local_size_0_ld_str + " + lid0]"), e->process("#name_buf[lid1*" + local_size_0_ld_str + " + lid0 + stride]"), e->root_op());
compute_dot(stream,e->process("#name_buf[lid1*" + local_size_0_ld_str + " + lid0]"), e->process("#name_buf[lid1*" + local_size_0_ld_str + " + lid0 + stride]"), e->root_op());
stream.dec_tab();
stream << "}" << std::endl;
@@ -299,7 +301,7 @@ std::string mreduction::generate_impl(const char * suffix, expressions_tuple con
stream.inc_tab();
std::map<std::string, std::string> accessors;
accessors["mreduction"] = "#name_buf[lid1*" + local_size_0_ld_str + "]";
accessors["gemv"] = "#name_buf[lid1*" + local_size_0_ld_str + "]";
accessors["array1"] = "#pointer[r*#stride]";
evaluate(stream, PARENT_NODE_TYPE, accessors, expressions, mappings);
@@ -317,38 +319,38 @@ std::string mreduction::generate_impl(const char * suffix, expressions_tuple con
return stream.str();
}
mreduction::mreduction(mreduction::parameters_type const & parameters,
mreduction::reduction_type rtype,
gemv::gemv(gemv::parameters_type const & parameters,
gemv::dot_type rtype,
binding_policy_t binding_policy) :
base_impl<mreduction, mreduction_parameters>(parameters, binding_policy),
reduction_type_(rtype){ }
base_impl<gemv, gemv_parameters>(parameters, binding_policy),
dot_type_(rtype){ }
std::vector<int_t> mreduction::input_sizes(expressions_tuple const & expressions) const
std::vector<int_t> gemv::input_sizes(expressions_tuple const & expressions) const
{
array_expression const & first_expression = *expressions.data().front();
std::vector<std::size_t> idx = filter_nodes(&is_reduction, first_expression, false);
std::vector<std::size_t> idx = filter_nodes(&is_dot, first_expression, false);
std::pair<int_t, int_t> MN = matrix_size(lhs_most(first_expression.tree(), idx[0]));
if(reduction_type_==REDUCE_COLUMNS)
if(dot_type_==REDUCE_COLUMNS)
std::swap(MN.first,MN.second);
return tools::make_vector<int_t>() << MN.first << MN.second;
}
void mreduction::enqueue(driver::CommandQueue & queue, driver::Program & program, const char * suffix, base & fallback, controller<expressions_tuple> const & controller)
void gemv::enqueue(driver::CommandQueue & queue, driver::Program & program, const char * suffix, base & fallback, controller<expressions_tuple> const & controller)
{
expressions_tuple const & expressions = controller.x();
driver::Context const & context = expressions.context();
std::vector<int_t> MN = input_sizes(expressions);
std::vector<array_expression::node const *> reductions;
std::vector<array_expression::node const *> dots;
for (const auto & e : expressions.data())
{
std::vector<size_t> reductions_idx = filter_nodes(&is_reduction, *e, false);
for (auto & r : reductions_idx)
reductions.push_back(&(e)->tree()[r]);
std::vector<size_t> dots_idx = filter_nodes(&is_dot, *e, false);
for (auto & r : dots_idx)
dots.push_back(&(e)->tree()[r]);
}
//Fallback
if(reduction_type_==REDUCE_COLUMNS && p_.simd_width>1 && requires_fallback(expressions))
if(dot_type_==REDUCE_COLUMNS && p_.simd_width>1 && requires_fallback(expressions))
{
fallback.enqueue(queue, program, "fallback", fallback, controller);
return;
@@ -381,9 +383,9 @@ void mreduction::enqueue(driver::CommandQueue & queue, driver::Program & program
//Temporary buffers
unsigned int i = 0;
unsigned int j = 0;
for (auto const & r : reductions)
for (auto const & r : dots)
{
if (is_index_reduction(r->op))
if (is_index_dot(r->op))
{
if (tmpidx.size() <= j)
tmpidx.push_back(driver::Buffer(context, p_.num_groups_0*M*4));
@@ -405,24 +407,25 @@ void mreduction::enqueue(driver::CommandQueue & queue, driver::Program & program
controller.execution_options().enqueue(program.context(), kernels[i], global[i], local[i]);
}
mreduction_rows::mreduction_rows(mreduction_parameters const & parameters,
gemv_n::gemv_n(gemv_parameters const & parameters,
binding_policy_t binding_policy):
mreduction(parameters, REDUCE_ROWS, binding_policy){}
gemv(parameters, REDUCE_ROWS, binding_policy){}
mreduction_rows::mreduction_rows(unsigned int simd, unsigned int ls1, unsigned int ls2,
gemv_n::gemv_n(unsigned int simd, unsigned int ls1, unsigned int ls2,
unsigned int ng1, unsigned int ng2, fetching_policy_type fetch, binding_policy_t bind):
mreduction(mreduction_parameters(simd, ls1, ls2, ng1, ng2, fetch), REDUCE_ROWS, bind)
gemv(gemv_parameters(simd, ls1, ls2, ng1, ng2, fetch), REDUCE_ROWS, bind)
{}
mreduction_cols::mreduction_cols(mreduction::parameters_type const & parameters,
gemv_t::gemv_t(gemv::parameters_type const & parameters,
binding_policy_t binding_policy):
mreduction(parameters, REDUCE_COLUMNS, binding_policy){}
gemv(parameters, REDUCE_COLUMNS, binding_policy){}
mreduction_cols::mreduction_cols(unsigned int simd, unsigned int ls1, unsigned int ls2,
gemv_t::gemv_t(unsigned int simd, unsigned int ls1, unsigned int ls2,
unsigned int ng1, unsigned int ng2, fetching_policy_type fetch, binding_policy_t bind):
mreduction(mreduction_parameters(simd, ls1, ls2, ng1, ng2, fetch), REDUCE_COLUMNS, bind)
gemv(gemv_parameters(simd, ls1, ls2, ng1, ng2, fetch), REDUCE_COLUMNS, bind)
{}
}
}

View File

@@ -1,4 +1,4 @@
#include "isaac/backend/templates/maxpy.h"
#include "isaac/backend/templates/ger.h"
#include "isaac/tools/make_map.hpp"
#include "isaac/tools/make_vector.hpp"
#include "isaac/symbolic/io.h"
@@ -7,15 +7,17 @@
namespace isaac
{
namespace templates
{
maxpy_parameters::maxpy_parameters(unsigned int _simd_width,
ger_parameters::ger_parameters(unsigned int _simd_width,
unsigned int _local_size_0, unsigned int _local_size_1,
unsigned int _num_groups_0, unsigned int _num_groups_1,
fetching_policy_type _fetching_policy) : base::parameters_type(_simd_width, _local_size_0, _local_size_1, 1), num_groups_0(_num_groups_0), num_groups_1(_num_groups_1), fetching_policy(_fetching_policy){ }
int maxpy::is_invalid_impl(driver::Device const &, expressions_tuple const &) const
int ger::is_invalid_impl(driver::Device const &, expressions_tuple const &) const
{
if (p_.simd_width>1)
return TEMPLATE_INVALID_SIMD_WIDTH;
@@ -24,7 +26,7 @@ int maxpy::is_invalid_impl(driver::Device const &, expressions_tuple const &) co
return TEMPLATE_VALID;
}
std::string maxpy::generate_impl(const char * suffix, expressions_tuple const & expressions, driver::Device const & device, std::vector<mapping_type> const & mappings) const
std::string ger::generate_impl(const char * suffix, expressions_tuple const & expressions, driver::Device const & device, std::vector<mapping_type> const & mappings) const
{
kernel_generation_stream stream;
std::string _size_t = size_type(device);
@@ -95,23 +97,23 @@ std::string maxpy::generate_impl(const char * suffix, expressions_tuple const &
return stream.str();
}
maxpy::maxpy(parameters_type const & parameters, binding_policy_t binding_policy) :
base_impl<maxpy, maxpy_parameters>(parameters, binding_policy){ }
ger::ger(parameters_type const & parameters, binding_policy_t binding_policy) :
base_impl<ger, ger_parameters>(parameters, binding_policy){ }
maxpy::maxpy(unsigned int simd, unsigned int ls1, unsigned int ls2,
ger::ger(unsigned int simd, unsigned int ls1, unsigned int ls2,
unsigned int ng1, unsigned int ng2, fetching_policy_type fetch,
binding_policy_t bind):
base_impl<maxpy, maxpy_parameters>(maxpy_parameters(simd, ls1, ls2, ng1, ng2, fetch), bind)
base_impl<ger, ger_parameters>(ger_parameters(simd, ls1, ls2, ng1, ng2, fetch), bind)
{}
std::vector<int_t> maxpy::input_sizes(expressions_tuple const & expressions) const
std::vector<int_t> ger::input_sizes(expressions_tuple const & expressions) const
{
isaac::array_expression const & array_expression = *(expressions.data().front());
std::pair<int_t, int_t> size = matrix_size(lhs_most(array_expression.tree(), array_expression.root()));
return tools::make_vector<int_t>() << size.first << size.second;
}
void maxpy::enqueue(driver::CommandQueue & queue, driver::Program & program, const char * suffix, base &, controller<expressions_tuple> const & controller)
void ger::enqueue(driver::CommandQueue & /*queue*/, driver::Program & program, const char * suffix, base &, controller<expressions_tuple> const & controller)
{
expressions_tuple const & expressions = controller.x();
char name[32] = {"axpy"};
@@ -129,3 +131,4 @@ void maxpy::enqueue(driver::CommandQueue & queue, driver::Program & program, con
}
}
}