Code quality: removed tools::shared_ptr<>

This commit is contained in:
Philippe Tillet
2015-07-28 15:26:10 -07:00
parent 0434ac551c
commit 9c15debf8b
10 changed files with 70 additions and 231 deletions

View File

@@ -20,7 +20,7 @@ enum leaf_t
class mapped_object;
typedef std::pair<int_t, leaf_t> mapping_key;
typedef std::map<mapping_key, tools::shared_ptr<mapped_object> > mapping_type;
typedef std::map<mapping_key, std::shared_ptr<mapped_object> > mapping_type;
/** @brief Mapped Object
*

View File

@@ -76,15 +76,15 @@ protected:
/** @brief Accessor for the numeric type */
numeric_type get_numeric_type(isaac::array_expression const * array_expression, int_t root_idx) const;
/** @brief Creates a binary leaf */
template<class T> tools::shared_ptr<mapped_object> binary_leaf(isaac::array_expression const * array_expression, int_t root_idx, mapping_type const * mapping) const;
template<class T> std::shared_ptr<mapped_object> binary_leaf(isaac::array_expression const * array_expression, int_t root_idx, mapping_type const * mapping) const;
/** @brief Creates a value scalar mapping */
tools::shared_ptr<mapped_object> create(numeric_type dtype, values_holder) const;
std::shared_ptr<mapped_object> create(numeric_type dtype, values_holder) const;
/** @brief Creates a vector mapping */
tools::shared_ptr<mapped_object> create(array const *) const;
std::shared_ptr<mapped_object> create(array const *) const;
/** @brief Creates a tuple mapping */
tools::shared_ptr<mapped_object> create(repeat_infos const &) const;
std::shared_ptr<mapped_object> create(repeat_infos const &) const;
/** @brief Creates a mapping */
tools::shared_ptr<mapped_object> create(lhs_rhs_element const &) const;
std::shared_ptr<mapped_object> create(lhs_rhs_element const &) const;
public:
map_functor(symbolic_binder & binder, mapping_type & mapping, const driver::Device &device);
/** @brief Functor for traversing the tree */
@@ -165,7 +165,7 @@ protected:
static bool is_index_dot(op_element const & op);
static std::string access_vector_type(std::string const & v, int i);
tools::shared_ptr<symbolic_binder> make_binder();
std::shared_ptr<symbolic_binder> make_binder();
static std::string vstore(unsigned int simd_width, std::string const & dtype, std::string const & value, std::string const & offset, std::string const & ptr, driver::backend_type backend);
static std::string vload(unsigned int simd_width, std::string const & dtype, std::string const & offset, std::string const & ptr, driver::backend_type backend);
static std::string append_width(std::string const & str, unsigned int width);
@@ -182,7 +182,7 @@ public:
std::string generate(const char * suffix, expressions_tuple const & expressions, driver::Device const & device);
virtual int is_invalid(expressions_tuple const & expressions, driver::Device const & device) const = 0;
virtual void enqueue(driver::CommandQueue & queue, driver::Program & program, const char * suffix, base & fallback, controller<expressions_tuple> const & expressions) = 0;
virtual tools::shared_ptr<base> clone() const = 0;
virtual std::shared_ptr<base> clone() const = 0;
private:
binding_policy_t binding_policy_;
};
@@ -198,7 +198,7 @@ public:
base_impl(parameters_type const & parameters, binding_policy_t binding_policy);
int_t local_size_0() const;
int_t local_size_1() const;
tools::shared_ptr<base> clone() const;
std::shared_ptr<base> clone() const;
/** @brief returns whether or not the profile has undefined behavior on particular device */
int is_invalid(expressions_tuple const & expressions, driver::Device const & device) const;
protected:

View File

@@ -14,7 +14,7 @@ namespace isaac
class model
{
typedef tools::shared_ptr<templates::base> template_pointer;
typedef std::shared_ptr<templates::base> template_pointer;
typedef std::vector< template_pointer > templates_container;
private:
@@ -23,7 +23,7 @@ namespace isaac
driver::Program& init(controller<expressions_tuple> const &);
public:
model(expression_type, numeric_type, predictors::random_forest const &, std::vector< tools::shared_ptr<templates::base> > const &, driver::CommandQueue const &);
model(expression_type, numeric_type, predictors::random_forest const &, std::vector< std::shared_ptr<templates::base> > const &, driver::CommandQueue const &);
model(expression_type, numeric_type, templates::base const &, driver::CommandQueue const &);
void execute(controller<expressions_tuple> const &);
@@ -35,18 +35,18 @@ namespace isaac
private:
templates_container templates_;
template_pointer fallback_;
tools::shared_ptr<predictors::random_forest> predictor_;
std::shared_ptr<predictors::random_forest> predictor_;
std::map<std::vector<int_t>, int> hardcoded_;
std::map<driver::Context, std::map<std::string, std::shared_ptr<driver::Program> > > programs_;
driver::CommandQueue queue_;
};
typedef std::map<std::pair<expression_type, numeric_type>, tools::shared_ptr<model> > model_map_t;
typedef std::map<std::pair<expression_type, numeric_type>, std::shared_ptr<model> > model_map_t;
model_map_t init_models(driver::CommandQueue const & queue);
model_map_t& models(driver::CommandQueue & queue);
extern std::map<std::pair<expression_type, numeric_type>, tools::shared_ptr<templates::base> > fallbacks;
extern std::map<std::pair<expression_type, numeric_type>, std::shared_ptr<templates::base> > fallbacks;
extern std::map<driver::CommandQueue, model_map_t> models_;
}

View File

@@ -14,7 +14,7 @@
#include "isaac/types.h"
#include "isaac/value_scalar.h"
#include "isaac/tools/shared_ptr.hpp"
#include <memory>
#include <iostream>
namespace isaac
@@ -290,9 +290,9 @@ controller<TYPE> control(TYPE const & x, execution_options_type const& execution
class expressions_tuple
{
private:
tools::shared_ptr<array_expression> create(array_expression const & s);
std::shared_ptr<array_expression> create(array_expression const & s);
public:
typedef std::list<tools::shared_ptr<array_expression> > data_type;
typedef std::list<std::shared_ptr<array_expression> > data_type;
enum order_type { SEQUENTIAL, INDEPENDENT };
expressions_tuple(array_expression const & s0);

View File

@@ -1,162 +0,0 @@
#ifndef ISAAC_TOOLS_SHARED_PTR_HPP
#define ISAAC_TOOLS_SHARED_PTR_HPP
/* =========================================================================
Copyright (c) 2010-2012, Institute for Microelectronics,
Institute for Analysis and Scientific Computing,
TU Wien.
Portions of this software are copyright by UChicago Argonne, LLC.
-----------------
ViennaCL - The Vienna Computing Library
-----------------
Project Head: Karl Rupp rupp@iue.tuwien.ac.at
(A list of authors and contributors can be found in the PDF manual)
License: MIT (X11), see file LICENSE in the base directory
============================================================================= */
/** @file tools/shared_ptr.hpp
@brief Implementation of a shared pointer class (cf. tools::shared_ptr, boost::shared_ptr). Will be used until C++11 is widely available.
Contributed by Philippe Tillet.
*/
#include <cstdlib>
#include <algorithm>
namespace isaac
{
namespace tools
{
namespace detail
{
/** @brief Reference counting class for the shared_ptr implementation */
class count
{
public:
count(unsigned int val) : val_(val){ }
void dec(){ --val_; }
void inc(){ ++val_; }
bool is_null(){ return val_ == 0; }
unsigned int val(){ return val_; }
private:
unsigned int val_;
};
/** @brief Interface for the reference counter inside the shared_ptr */
struct aux
{
detail::count count;
aux() :count(1) {}
virtual void destroy()=0;
virtual ~aux() {}
};
/** @brief Implementation helper for the reference counting mechanism inside shared_ptr. */
template<class U, class Deleter>
struct auximpl: public detail::aux
{
U* p;
Deleter d;
auximpl(U* pu, Deleter x) :p(pu), d(x) {}
virtual void destroy() { d(p); }
};
/** @brief Default deleter class for a pointer. The default is to just call 'delete' on the pointer. Provide your own implementations for 'delete[]' and 'free'. */
template<class U>
struct default_deleter
{
void operator()(U* p) const { delete p; }
};
}
class shared_ptr_base
{
protected:
detail::aux* pa;
public:
unsigned int count() { return pa->count.val(); }
};
/** @brief A shared pointer class similar to boost::shared_ptr. Reimplemented in order to avoid a Boost-dependency. Will be replaced by tools::shared_ptr as soon as C++11 is widely available. */
template<class T>
class shared_ptr : public shared_ptr_base
{
template<class U>
friend class shared_ptr;
detail::aux* pa;
T* pt;
public:
shared_ptr() :pa(NULL), pt(NULL) {}
template<class U, class Deleter>
shared_ptr(U* pu, Deleter d) : pa(new detail::auximpl<U, Deleter>(pu, d)), pt(pu) {}
template<class U>
explicit shared_ptr(U* pu) : pa(new detail::auximpl<U, detail::default_deleter<U> >(pu, detail::default_deleter<U>())), pt(pu) {}
template<class U>
shared_ptr(const shared_ptr<U>& s) :pa(s.pa), pt(s.pt) { inc(); }
shared_ptr(const shared_ptr& s) :pa(s.pa), pt(s.pt) { inc(); }
~shared_ptr() { dec(); }
T* get() const { return pt; }
T* operator->() const { return pt; }
T& operator*() const { return *pt; }
void reset() { shared_ptr<T>().swap(*this); }
void reset(T * ptr) { shared_ptr<T>(ptr).swap(*this); }
void swap(shared_ptr<T> & other)
{
std::swap(pt,other.pt);
std::swap(pa, other.pa);
}
shared_ptr& operator=(const shared_ptr& s)
{
if (this!=&s)
{
dec();
pa = s.pa;
pt = s.pt;
inc();
}
return *this;
}
void inc()
{
if (pa) pa->count.inc();
}
void dec()
{
if (pa)
{
pa->count.dec();
if (pa->count.is_null())
{
pa->destroy();
delete pa;
pa = NULL;
}
}
}
};
}
}
#endif

View File

@@ -33,44 +33,44 @@ numeric_type base::map_functor::get_numeric_type(isaac::array_expression const *
/** @brief Binary leaf */
template<class T>
tools::shared_ptr<mapped_object> base::map_functor::binary_leaf(isaac::array_expression const * array_expression, int_t root_idx, mapping_type const * mapping) const
std::shared_ptr<mapped_object> base::map_functor::binary_leaf(isaac::array_expression const * array_expression, int_t root_idx, mapping_type const * mapping) const
{
return tools::shared_ptr<mapped_object>(new T(numeric_type_to_string(array_expression->dtype()), binder_.get(), mapped_object::node_info(mapping, array_expression, root_idx)));
return std::shared_ptr<mapped_object>(new T(numeric_type_to_string(array_expression->dtype()), binder_.get(), mapped_object::node_info(mapping, array_expression, root_idx)));
}
/** @brief Scalar mapping */
tools::shared_ptr<mapped_object> base::map_functor::create(numeric_type dtype, values_holder) const
std::shared_ptr<mapped_object> base::map_functor::create(numeric_type dtype, values_holder) const
{
std::string strdtype = numeric_type_to_string(dtype);
return tools::shared_ptr<mapped_object>(new mapped_host_scalar(strdtype, binder_.get()));
return std::shared_ptr<mapped_object>(new mapped_host_scalar(strdtype, binder_.get()));
}
/** @brief Vector mapping */
tools::shared_ptr<mapped_object> base::map_functor::create(array const * a) const
std::shared_ptr<mapped_object> base::map_functor::create(array const * a) const
{
std::string dtype = numeric_type_to_string(a->dtype());
unsigned int id = binder_.get(a->data());
//Scalar
if(a->shape()[0]==1 && a->shape()[1]==1)
return tools::shared_ptr<mapped_object>(new mapped_array(dtype, id, 's'));
return std::shared_ptr<mapped_object>(new mapped_array(dtype, id, 's'));
//Column vector
else if(a->shape()[0]>1 && a->shape()[1]==1)
return tools::shared_ptr<mapped_object>(new mapped_array(dtype, id, 'c'));
return std::shared_ptr<mapped_object>(new mapped_array(dtype, id, 'c'));
//Row vector
else if(a->shape()[0]==1 && a->shape()[1]>1)
return tools::shared_ptr<mapped_object>(new mapped_array(dtype, id, 'r'));
return std::shared_ptr<mapped_object>(new mapped_array(dtype, id, 'r'));
//Matrix
else
return tools::shared_ptr<mapped_object>(new mapped_array(dtype, id, 'm'));
return std::shared_ptr<mapped_object>(new mapped_array(dtype, id, 'm'));
}
tools::shared_ptr<mapped_object> base::map_functor::create(repeat_infos const &) const
std::shared_ptr<mapped_object> base::map_functor::create(repeat_infos const &) const
{
//TODO: Make it less specific!
return tools::shared_ptr<mapped_object>(new mapped_tuple(size_type(device_),binder_.get(),4));
return std::shared_ptr<mapped_object>(new mapped_tuple(size_type(device_),binder_.get(),4));
}
tools::shared_ptr<mapped_object> base::map_functor::create(lhs_rhs_element const & lhs_rhs) const
std::shared_ptr<mapped_object> base::map_functor::create(lhs_rhs_element const & lhs_rhs) const
{
switch(lhs_rhs.type_family)
{
@@ -115,7 +115,7 @@ void base::map_functor::operator()(isaac::array_expression const & array_express
else if (root_node.op.type == OPERATOR_OUTER_PROD_TYPE)
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_outer>(&array_expression, root_idx, &mapping_)));
else if (detail::is_cast(root_node.op))
mapping_.insert(mapping_type::value_type(key, tools::shared_ptr<mapped_object>(new mapped_cast(root_node.op.type, binder_.get()))));
mapping_.insert(mapping_type::value_type(key, std::shared_ptr<mapped_object>(new mapped_cast(root_node.op.type, binder_.get()))));
}
}
@@ -289,7 +289,7 @@ std::string base::generate_arguments(std::string const & data_type, driver::Devi
void base::set_arguments(expressions_tuple const & expressions, driver::Kernel & kernel, unsigned int & current_arg)
{
tools::shared_ptr<symbolic_binder> binder = make_binder();
std::shared_ptr<symbolic_binder> binder = make_binder();
for (const auto & elem : expressions.data())
traverse(*elem, (elem)->root(), set_arguments_functor(*binder, current_arg, kernel), true);
}
@@ -483,12 +483,12 @@ std::string base::append_width(std::string const & str, unsigned int width)
return str + tools::to_string(width);
}
tools::shared_ptr<symbolic_binder> base::make_binder()
std::shared_ptr<symbolic_binder> base::make_binder()
{
if (binding_policy_==BIND_TO_HANDLE)
return tools::shared_ptr<symbolic_binder>(new bind_to_handle());
return std::shared_ptr<symbolic_binder>(new bind_to_handle());
else
return tools::shared_ptr<symbolic_binder>(new bind_all_unique());
return std::shared_ptr<symbolic_binder>(new bind_all_unique());
}
@@ -514,7 +514,7 @@ std::string base::generate(const char * suffix, expressions_tuple const & expres
//Create mapping
std::vector<mapping_type> mappings(expressions.data().size());
tools::shared_ptr<symbolic_binder> binder = make_binder();
std::shared_ptr<symbolic_binder> binder = make_binder();
for (mit = mappings.begin(), sit = expressions.data().begin(); sit != expressions.data().end(); ++sit, ++mit)
traverse(**sit, (*sit)->root(), map_functor(*binder,*mit,device), true);
@@ -538,8 +538,8 @@ int_t base_impl<TType, PType>::local_size_1() const
{ return p_.local_size_1; }
template<class TType, class PType>
tools::shared_ptr<base> base_impl<TType, PType>::clone() const
{ return tools::shared_ptr<base>(new TType(*dynamic_cast<TType const *>(this))); }
std::shared_ptr<base> base_impl<TType, PType>::clone() const
{ return std::shared_ptr<base>(new TType(*dynamic_cast<TType const *>(this))); }
template<class TType, class PType>
int base_impl<TType, PType>::is_invalid(expressions_tuple const & expressions, driver::Device const & device) const

View File

@@ -85,7 +85,7 @@ driver::Program& model::init(controller<expressions_tuple> const & expressions)
return *program;
}
model::model(expression_type etype, numeric_type dtype, predictors::random_forest const & predictor, std::vector< tools::shared_ptr<templates::base> > const & templates, driver::CommandQueue const & queue) :
model::model(expression_type etype, numeric_type dtype, predictors::random_forest const & predictor, std::vector< std::shared_ptr<templates::base> > const & templates, driver::CommandQueue const & queue) :
templates_(templates), fallback_(fallbacks[std::make_pair(etype, dtype)]), predictor_(new predictors::random_forest(predictor)), queue_(queue)
{}
@@ -166,27 +166,27 @@ namespace detail
throw std::invalid_argument("Invalid datatype: " + name);
}
static tools::shared_ptr<templates::base> create(std::string const & template_name, std::vector<int> const & a)
static std::shared_ptr<templates::base> create(std::string const & template_name, std::vector<int> const & a)
{
templates::fetching_policy_type fetch[] = {templates::FETCH_FROM_LOCAL, templates::FETCH_FROM_GLOBAL_STRIDED, templates::FETCH_FROM_GLOBAL_CONTIGUOUS};
if(template_name=="axpy")
return tools::shared_ptr<templates::base>(new templates::axpy(a[0], a[1], a[2], fetch[a[3]]));
return std::shared_ptr<templates::base>(new templates::axpy(a[0], a[1], a[2], fetch[a[3]]));
else if(template_name=="dot")
return tools::shared_ptr<templates::base>(new templates::dot(a[0], a[1], a[2], fetch[a[3]]));
return std::shared_ptr<templates::base>(new templates::dot(a[0], a[1], a[2], fetch[a[3]]));
else if(template_name=="ger")
return tools::shared_ptr<templates::base>(new templates::ger(a[0], a[1], a[2], a[3], a[4], fetch[a[5]]));
return std::shared_ptr<templates::base>(new templates::ger(a[0], a[1], a[2], a[3], a[4], fetch[a[5]]));
else if(template_name.find("gemv_n")!=std::string::npos)
return tools::shared_ptr<templates::base>(new templates::gemv_n(a[0], a[1], a[2], a[3], a[4], fetch[a[5]]));
return std::shared_ptr<templates::base>(new templates::gemv_n(a[0], a[1], a[2], a[3], a[4], fetch[a[5]]));
else if(template_name.find("gemv_t")!=std::string::npos)
return tools::shared_ptr<templates::base>(new templates::gemv_t(a[0], a[1], a[2], a[3], a[4], fetch[a[5]]));
return std::shared_ptr<templates::base>(new templates::gemv_t(a[0], a[1], a[2], a[3], a[4], fetch[a[5]]));
else if(template_name.find("gemm_nn")!=std::string::npos)
return tools::shared_ptr<templates::base>(new templates::gemm_nn(a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], fetch[a[8]], fetch[a[9]], a[10], a[11]));
return std::shared_ptr<templates::base>(new templates::gemm_nn(a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], fetch[a[8]], fetch[a[9]], a[10], a[11]));
else if(template_name.find("gemm_tn")!=std::string::npos)
return tools::shared_ptr<templates::base>(new templates::gemm_tn(a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], fetch[a[8]], fetch[a[9]], a[10], a[11]));
return std::shared_ptr<templates::base>(new templates::gemm_tn(a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], fetch[a[8]], fetch[a[9]], a[10], a[11]));
else if(template_name.find("gemm_nt")!=std::string::npos)
return tools::shared_ptr<templates::base>(new templates::gemm_nt(a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], fetch[a[8]], fetch[a[9]], a[10], a[11]));
return std::shared_ptr<templates::base>(new templates::gemm_nt(a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], fetch[a[8]], fetch[a[9]], a[10], a[11]));
else if(template_name.find("gemm_tt")!=std::string::npos)
return tools::shared_ptr<templates::base>(new templates::gemm_tt(a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], fetch[a[8]], fetch[a[9]], a[10], a[11]));
return std::shared_ptr<templates::base>(new templates::gemm_tt(a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], fetch[a[8]], fetch[a[9]], a[10], a[11]));
else
throw std::invalid_argument("Invalid expression: " + template_name);
}
@@ -222,7 +222,7 @@ void import(std::string const & fname, driver::CommandQueue & queue, model_map_t
numeric_type dtype = detail::get_dtype(elem);
// Get profiles
std::vector<tools::shared_ptr<templates::base> > templates;
std::vector<std::shared_ptr<templates::base> > templates;
js::Value const & profiles = document[opcstr][dtcstr]["profiles"];
for (js::SizeType id = 0 ; id < profiles.Size() ; ++id)
templates.push_back(detail::create(operation, tools::to_int_array<int>(profiles[id])));
@@ -231,10 +231,10 @@ void import(std::string const & fname, driver::CommandQueue & queue, model_map_t
{
// Get predictor
predictors::random_forest predictor(document[opcstr][dtcstr]["predictor"]);
result[std::make_pair(etype, dtype)] = tools::shared_ptr<model>(new model(etype, dtype, predictor, templates, queue));
result[std::make_pair(etype, dtype)] = std::shared_ptr<model>(new model(etype, dtype, predictor, templates, queue));
}
else
result[std::make_pair(etype, dtype)] = tools::shared_ptr<model>(new model(etype, dtype, *templates[0], queue));
result[std::make_pair(etype, dtype)] = std::shared_ptr<model>(new model(etype, dtype, *templates[0], queue));
}
}
}
@@ -242,9 +242,9 @@ void import(std::string const & fname, driver::CommandQueue & queue, model_map_t
}
std::map<std::pair<expression_type, numeric_type>, tools::shared_ptr<templates::base> > init_fallback()
std::map<std::pair<expression_type, numeric_type>, std::shared_ptr<templates::base> > init_fallback()
{
typedef tools::shared_ptr<templates::base> ptr_t;
typedef std::shared_ptr<templates::base> ptr_t;
std::map<std::pair<expression_type, numeric_type>, ptr_t > res;
numeric_type types[] = {CHAR_TYPE, UCHAR_TYPE, SHORT_TYPE, USHORT_TYPE, INT_TYPE, UINT_TYPE, LONG_TYPE, ULONG_TYPE, FLOAT_TYPE, DOUBLE_TYPE};
for(auto DTYPE : types)
@@ -271,7 +271,7 @@ model_map_t init_models(driver::CommandQueue & queue)
for(numeric_type dtype: dtypes)
for(expression_type etype: etypes)
res[std::make_pair(etype, dtype)] = tools::shared_ptr<model>(new model(etype, dtype, *fallbacks[std::make_pair(etype, dtype)], queue));
res[std::make_pair(etype, dtype)] = std::shared_ptr<model>(new model(etype, dtype, *fallbacks[std::make_pair(etype, dtype)], queue));
if(const char * homepath = std::getenv("HOME"))
import(std::string(homepath) + "/.isaac/devices/device0.json", queue, res);
@@ -286,7 +286,7 @@ model_map_t& models(driver::CommandQueue & queue)
return it->second;
}
std::map<std::pair<expression_type, numeric_type>, tools::shared_ptr<templates::base> > fallbacks = init_fallback();
std::map<std::pair<expression_type, numeric_type>, std::shared_ptr<templates::base> > fallbacks = init_fallback();
std::map<driver::CommandQueue, model_map_t> models_;
}

View File

@@ -177,29 +177,29 @@ namespace isaac
/*----Parse required temporaries-----*/
detail::parse(tree, rootidx, breakpoints, final_type);
std::vector<tools::shared_ptr<array> > temporaries_;
std::vector<std::shared_ptr<array> > temporaries_;
/*----Compute required temporaries----*/
for(detail::breakpoints_t::iterator it = breakpoints.begin() ; it != breakpoints.end() ; ++it)
{
tools::shared_ptr<model> const & pmodel = models[std::make_pair(it->first, dtype)];
std::shared_ptr<model> const & pmodel = models[std::make_pair(it->first, dtype)];
array_expression::node const & node = tree[it->second->node_index];
array_expression::node const & lmost = lhs_most(tree, node);
//Creates temporary
tools::shared_ptr<array> tmp;
std::shared_ptr<array> tmp;
switch(it->first){
case DOT_TYPE: tmp = tools::shared_ptr<array>(new array(1, dtype, context)); break;
case DOT_TYPE: tmp = std::shared_ptr<array>(new array(1, dtype, context)); break;
case AXPY_TYPE: tmp = tools::shared_ptr<array>(new array(lmost.lhs.array->shape()[0], dtype, context)); break;
case GEMV_N_TYPE: tmp = tools::shared_ptr<array>(new array(lmost.lhs.array->shape()[0], dtype, context)); break;
case GEMV_T_TYPE: tmp = tools::shared_ptr<array>(new array(lmost.lhs.array->shape()[1], dtype, context)); break;
case AXPY_TYPE: tmp = std::shared_ptr<array>(new array(lmost.lhs.array->shape()[0], dtype, context)); break;
case GEMV_N_TYPE: tmp = std::shared_ptr<array>(new array(lmost.lhs.array->shape()[0], dtype, context)); break;
case GEMV_T_TYPE: tmp = std::shared_ptr<array>(new array(lmost.lhs.array->shape()[1], dtype, context)); break;
case GER_TYPE: tmp = tools::shared_ptr<array>(new array(lmost.lhs.array->shape()[0], lmost.lhs.array->shape()[1], dtype, context)); break;
case GEMM_NN_TYPE: tmp = tools::shared_ptr<array>(new array(node.lhs.array->shape()[0], node.rhs.array->shape()[1], dtype, context)); break;
case GEMM_NT_TYPE: tmp = tools::shared_ptr<array>(new array(node.lhs.array->shape()[0], node.rhs.array->shape()[0], dtype, context)); break;
case GEMM_TN_TYPE: tmp = tools::shared_ptr<array>(new array(node.lhs.array->shape()[1], node.rhs.array->shape()[1], dtype, context)); break;
case GEMM_TT_TYPE: tmp = tools::shared_ptr<array>(new array(node.lhs.array->shape()[1], node.rhs.array->shape()[0], dtype, context)); break;
case GER_TYPE: tmp = std::shared_ptr<array>(new array(lmost.lhs.array->shape()[0], lmost.lhs.array->shape()[1], dtype, context)); break;
case GEMM_NN_TYPE: tmp = std::shared_ptr<array>(new array(node.lhs.array->shape()[0], node.rhs.array->shape()[1], dtype, context)); break;
case GEMM_NT_TYPE: tmp = std::shared_ptr<array>(new array(node.lhs.array->shape()[0], node.rhs.array->shape()[0], dtype, context)); break;
case GEMM_TN_TYPE: tmp = std::shared_ptr<array>(new array(node.lhs.array->shape()[1], node.rhs.array->shape()[1], dtype, context)); break;
case GEMM_TT_TYPE: tmp = std::shared_ptr<array>(new array(node.lhs.array->shape()[1], node.rhs.array->shape()[0], dtype, context)); break;
default: throw std::invalid_argument("Unrecognized operation");
}

View File

@@ -166,9 +166,9 @@ array_expression array_expression::operator!()
{ return array_expression(*this, invalid_node(), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_NEGATE_TYPE), context_, INT_TYPE, shape_); }
//
tools::shared_ptr<array_expression> expressions_tuple::create(array_expression const & s)
std::shared_ptr<array_expression> expressions_tuple::create(array_expression const & s)
{
return tools::shared_ptr<array_expression>(new array_expression(static_cast<array_expression const &>(s)));
return std::shared_ptr<array_expression>(new array_expression(static_cast<array_expression const &>(s)));
}
expressions_tuple::expressions_tuple(data_type const & data, order_type order) : data_(data), order_(order)

View File

@@ -1,3 +1,4 @@
#include <cmath>
#include "common.hpp"
#include "isaac/array.h"