Code quality: removed tools::shared_ptr<>
This commit is contained in:
@@ -20,7 +20,7 @@ enum leaf_t
|
||||
class mapped_object;
|
||||
|
||||
typedef std::pair<int_t, leaf_t> mapping_key;
|
||||
typedef std::map<mapping_key, tools::shared_ptr<mapped_object> > mapping_type;
|
||||
typedef std::map<mapping_key, std::shared_ptr<mapped_object> > mapping_type;
|
||||
|
||||
/** @brief Mapped Object
|
||||
*
|
||||
|
@@ -76,15 +76,15 @@ protected:
|
||||
/** @brief Accessor for the numeric type */
|
||||
numeric_type get_numeric_type(isaac::array_expression const * array_expression, int_t root_idx) const;
|
||||
/** @brief Creates a binary leaf */
|
||||
template<class T> tools::shared_ptr<mapped_object> binary_leaf(isaac::array_expression const * array_expression, int_t root_idx, mapping_type const * mapping) const;
|
||||
template<class T> std::shared_ptr<mapped_object> binary_leaf(isaac::array_expression const * array_expression, int_t root_idx, mapping_type const * mapping) const;
|
||||
/** @brief Creates a value scalar mapping */
|
||||
tools::shared_ptr<mapped_object> create(numeric_type dtype, values_holder) const;
|
||||
std::shared_ptr<mapped_object> create(numeric_type dtype, values_holder) const;
|
||||
/** @brief Creates a vector mapping */
|
||||
tools::shared_ptr<mapped_object> create(array const *) const;
|
||||
std::shared_ptr<mapped_object> create(array const *) const;
|
||||
/** @brief Creates a tuple mapping */
|
||||
tools::shared_ptr<mapped_object> create(repeat_infos const &) const;
|
||||
std::shared_ptr<mapped_object> create(repeat_infos const &) const;
|
||||
/** @brief Creates a mapping */
|
||||
tools::shared_ptr<mapped_object> create(lhs_rhs_element const &) const;
|
||||
std::shared_ptr<mapped_object> create(lhs_rhs_element const &) const;
|
||||
public:
|
||||
map_functor(symbolic_binder & binder, mapping_type & mapping, const driver::Device &device);
|
||||
/** @brief Functor for traversing the tree */
|
||||
@@ -165,7 +165,7 @@ protected:
|
||||
static bool is_index_dot(op_element const & op);
|
||||
static std::string access_vector_type(std::string const & v, int i);
|
||||
|
||||
tools::shared_ptr<symbolic_binder> make_binder();
|
||||
std::shared_ptr<symbolic_binder> make_binder();
|
||||
static std::string vstore(unsigned int simd_width, std::string const & dtype, std::string const & value, std::string const & offset, std::string const & ptr, driver::backend_type backend);
|
||||
static std::string vload(unsigned int simd_width, std::string const & dtype, std::string const & offset, std::string const & ptr, driver::backend_type backend);
|
||||
static std::string append_width(std::string const & str, unsigned int width);
|
||||
@@ -182,7 +182,7 @@ public:
|
||||
std::string generate(const char * suffix, expressions_tuple const & expressions, driver::Device const & device);
|
||||
virtual int is_invalid(expressions_tuple const & expressions, driver::Device const & device) const = 0;
|
||||
virtual void enqueue(driver::CommandQueue & queue, driver::Program & program, const char * suffix, base & fallback, controller<expressions_tuple> const & expressions) = 0;
|
||||
virtual tools::shared_ptr<base> clone() const = 0;
|
||||
virtual std::shared_ptr<base> clone() const = 0;
|
||||
private:
|
||||
binding_policy_t binding_policy_;
|
||||
};
|
||||
@@ -198,7 +198,7 @@ public:
|
||||
base_impl(parameters_type const & parameters, binding_policy_t binding_policy);
|
||||
int_t local_size_0() const;
|
||||
int_t local_size_1() const;
|
||||
tools::shared_ptr<base> clone() const;
|
||||
std::shared_ptr<base> clone() const;
|
||||
/** @brief returns whether or not the profile has undefined behavior on particular device */
|
||||
int is_invalid(expressions_tuple const & expressions, driver::Device const & device) const;
|
||||
protected:
|
||||
|
@@ -14,7 +14,7 @@ namespace isaac
|
||||
|
||||
class model
|
||||
{
|
||||
typedef tools::shared_ptr<templates::base> template_pointer;
|
||||
typedef std::shared_ptr<templates::base> template_pointer;
|
||||
typedef std::vector< template_pointer > templates_container;
|
||||
|
||||
private:
|
||||
@@ -23,7 +23,7 @@ namespace isaac
|
||||
driver::Program& init(controller<expressions_tuple> const &);
|
||||
|
||||
public:
|
||||
model(expression_type, numeric_type, predictors::random_forest const &, std::vector< tools::shared_ptr<templates::base> > const &, driver::CommandQueue const &);
|
||||
model(expression_type, numeric_type, predictors::random_forest const &, std::vector< std::shared_ptr<templates::base> > const &, driver::CommandQueue const &);
|
||||
model(expression_type, numeric_type, templates::base const &, driver::CommandQueue const &);
|
||||
|
||||
void execute(controller<expressions_tuple> const &);
|
||||
@@ -35,18 +35,18 @@ namespace isaac
|
||||
private:
|
||||
templates_container templates_;
|
||||
template_pointer fallback_;
|
||||
tools::shared_ptr<predictors::random_forest> predictor_;
|
||||
std::shared_ptr<predictors::random_forest> predictor_;
|
||||
std::map<std::vector<int_t>, int> hardcoded_;
|
||||
std::map<driver::Context, std::map<std::string, std::shared_ptr<driver::Program> > > programs_;
|
||||
driver::CommandQueue queue_;
|
||||
};
|
||||
|
||||
typedef std::map<std::pair<expression_type, numeric_type>, tools::shared_ptr<model> > model_map_t;
|
||||
typedef std::map<std::pair<expression_type, numeric_type>, std::shared_ptr<model> > model_map_t;
|
||||
|
||||
model_map_t init_models(driver::CommandQueue const & queue);
|
||||
model_map_t& models(driver::CommandQueue & queue);
|
||||
|
||||
extern std::map<std::pair<expression_type, numeric_type>, tools::shared_ptr<templates::base> > fallbacks;
|
||||
extern std::map<std::pair<expression_type, numeric_type>, std::shared_ptr<templates::base> > fallbacks;
|
||||
extern std::map<driver::CommandQueue, model_map_t> models_;
|
||||
|
||||
}
|
||||
|
@@ -14,7 +14,7 @@
|
||||
|
||||
#include "isaac/types.h"
|
||||
#include "isaac/value_scalar.h"
|
||||
#include "isaac/tools/shared_ptr.hpp"
|
||||
#include <memory>
|
||||
#include <iostream>
|
||||
|
||||
namespace isaac
|
||||
@@ -290,9 +290,9 @@ controller<TYPE> control(TYPE const & x, execution_options_type const& execution
|
||||
class expressions_tuple
|
||||
{
|
||||
private:
|
||||
tools::shared_ptr<array_expression> create(array_expression const & s);
|
||||
std::shared_ptr<array_expression> create(array_expression const & s);
|
||||
public:
|
||||
typedef std::list<tools::shared_ptr<array_expression> > data_type;
|
||||
typedef std::list<std::shared_ptr<array_expression> > data_type;
|
||||
enum order_type { SEQUENTIAL, INDEPENDENT };
|
||||
|
||||
expressions_tuple(array_expression const & s0);
|
||||
|
@@ -1,162 +0,0 @@
|
||||
#ifndef ISAAC_TOOLS_SHARED_PTR_HPP
|
||||
#define ISAAC_TOOLS_SHARED_PTR_HPP
|
||||
|
||||
/* =========================================================================
|
||||
Copyright (c) 2010-2012, Institute for Microelectronics,
|
||||
Institute for Analysis and Scientific Computing,
|
||||
TU Wien.
|
||||
Portions of this software are copyright by UChicago Argonne, LLC.
|
||||
|
||||
-----------------
|
||||
ViennaCL - The Vienna Computing Library
|
||||
-----------------
|
||||
|
||||
Project Head: Karl Rupp rupp@iue.tuwien.ac.at
|
||||
|
||||
(A list of authors and contributors can be found in the PDF manual)
|
||||
|
||||
License: MIT (X11), see file LICENSE in the base directory
|
||||
============================================================================= */
|
||||
|
||||
/** @file tools/shared_ptr.hpp
|
||||
@brief Implementation of a shared pointer class (cf. tools::shared_ptr, boost::shared_ptr). Will be used until C++11 is widely available.
|
||||
|
||||
Contributed by Philippe Tillet.
|
||||
*/
|
||||
|
||||
#include <cstdlib>
|
||||
#include <algorithm>
|
||||
|
||||
namespace isaac
|
||||
{
|
||||
namespace tools
|
||||
{
|
||||
namespace detail
|
||||
{
|
||||
|
||||
/** @brief Reference counting class for the shared_ptr implementation */
|
||||
class count
|
||||
{
|
||||
public:
|
||||
count(unsigned int val) : val_(val){ }
|
||||
void dec(){ --val_; }
|
||||
void inc(){ ++val_; }
|
||||
bool is_null(){ return val_ == 0; }
|
||||
unsigned int val(){ return val_; }
|
||||
private:
|
||||
unsigned int val_;
|
||||
};
|
||||
|
||||
/** @brief Interface for the reference counter inside the shared_ptr */
|
||||
struct aux
|
||||
{
|
||||
detail::count count;
|
||||
|
||||
aux() :count(1) {}
|
||||
virtual void destroy()=0;
|
||||
virtual ~aux() {}
|
||||
};
|
||||
|
||||
/** @brief Implementation helper for the reference counting mechanism inside shared_ptr. */
|
||||
template<class U, class Deleter>
|
||||
struct auximpl: public detail::aux
|
||||
{
|
||||
U* p;
|
||||
Deleter d;
|
||||
|
||||
auximpl(U* pu, Deleter x) :p(pu), d(x) {}
|
||||
virtual void destroy() { d(p); }
|
||||
};
|
||||
|
||||
/** @brief Default deleter class for a pointer. The default is to just call 'delete' on the pointer. Provide your own implementations for 'delete[]' and 'free'. */
|
||||
template<class U>
|
||||
struct default_deleter
|
||||
{
|
||||
void operator()(U* p) const { delete p; }
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
class shared_ptr_base
|
||||
{
|
||||
protected:
|
||||
detail::aux* pa;
|
||||
public:
|
||||
unsigned int count() { return pa->count.val(); }
|
||||
};
|
||||
|
||||
/** @brief A shared pointer class similar to boost::shared_ptr. Reimplemented in order to avoid a Boost-dependency. Will be replaced by tools::shared_ptr as soon as C++11 is widely available. */
|
||||
template<class T>
|
||||
class shared_ptr : public shared_ptr_base
|
||||
{
|
||||
template<class U>
|
||||
friend class shared_ptr;
|
||||
|
||||
detail::aux* pa;
|
||||
T* pt;
|
||||
|
||||
public:
|
||||
|
||||
shared_ptr() :pa(NULL), pt(NULL) {}
|
||||
|
||||
template<class U, class Deleter>
|
||||
shared_ptr(U* pu, Deleter d) : pa(new detail::auximpl<U, Deleter>(pu, d)), pt(pu) {}
|
||||
|
||||
template<class U>
|
||||
explicit shared_ptr(U* pu) : pa(new detail::auximpl<U, detail::default_deleter<U> >(pu, detail::default_deleter<U>())), pt(pu) {}
|
||||
|
||||
template<class U>
|
||||
shared_ptr(const shared_ptr<U>& s) :pa(s.pa), pt(s.pt) { inc(); }
|
||||
|
||||
shared_ptr(const shared_ptr& s) :pa(s.pa), pt(s.pt) { inc(); }
|
||||
~shared_ptr() { dec(); }
|
||||
|
||||
T* get() const { return pt; }
|
||||
T* operator->() const { return pt; }
|
||||
T& operator*() const { return *pt; }
|
||||
|
||||
void reset() { shared_ptr<T>().swap(*this); }
|
||||
void reset(T * ptr) { shared_ptr<T>(ptr).swap(*this); }
|
||||
|
||||
void swap(shared_ptr<T> & other)
|
||||
{
|
||||
std::swap(pt,other.pt);
|
||||
std::swap(pa, other.pa);
|
||||
}
|
||||
|
||||
shared_ptr& operator=(const shared_ptr& s)
|
||||
{
|
||||
if (this!=&s)
|
||||
{
|
||||
dec();
|
||||
pa = s.pa;
|
||||
pt = s.pt;
|
||||
inc();
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
void inc()
|
||||
{
|
||||
if (pa) pa->count.inc();
|
||||
}
|
||||
|
||||
void dec()
|
||||
{
|
||||
if (pa)
|
||||
{
|
||||
pa->count.dec();
|
||||
if (pa->count.is_null())
|
||||
{
|
||||
pa->destroy();
|
||||
delete pa;
|
||||
pa = NULL;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -33,44 +33,44 @@ numeric_type base::map_functor::get_numeric_type(isaac::array_expression const *
|
||||
|
||||
/** @brief Binary leaf */
|
||||
template<class T>
|
||||
tools::shared_ptr<mapped_object> base::map_functor::binary_leaf(isaac::array_expression const * array_expression, int_t root_idx, mapping_type const * mapping) const
|
||||
std::shared_ptr<mapped_object> base::map_functor::binary_leaf(isaac::array_expression const * array_expression, int_t root_idx, mapping_type const * mapping) const
|
||||
{
|
||||
return tools::shared_ptr<mapped_object>(new T(numeric_type_to_string(array_expression->dtype()), binder_.get(), mapped_object::node_info(mapping, array_expression, root_idx)));
|
||||
return std::shared_ptr<mapped_object>(new T(numeric_type_to_string(array_expression->dtype()), binder_.get(), mapped_object::node_info(mapping, array_expression, root_idx)));
|
||||
}
|
||||
|
||||
/** @brief Scalar mapping */
|
||||
tools::shared_ptr<mapped_object> base::map_functor::create(numeric_type dtype, values_holder) const
|
||||
std::shared_ptr<mapped_object> base::map_functor::create(numeric_type dtype, values_holder) const
|
||||
{
|
||||
std::string strdtype = numeric_type_to_string(dtype);
|
||||
return tools::shared_ptr<mapped_object>(new mapped_host_scalar(strdtype, binder_.get()));
|
||||
return std::shared_ptr<mapped_object>(new mapped_host_scalar(strdtype, binder_.get()));
|
||||
}
|
||||
|
||||
/** @brief Vector mapping */
|
||||
tools::shared_ptr<mapped_object> base::map_functor::create(array const * a) const
|
||||
std::shared_ptr<mapped_object> base::map_functor::create(array const * a) const
|
||||
{
|
||||
std::string dtype = numeric_type_to_string(a->dtype());
|
||||
unsigned int id = binder_.get(a->data());
|
||||
//Scalar
|
||||
if(a->shape()[0]==1 && a->shape()[1]==1)
|
||||
return tools::shared_ptr<mapped_object>(new mapped_array(dtype, id, 's'));
|
||||
return std::shared_ptr<mapped_object>(new mapped_array(dtype, id, 's'));
|
||||
//Column vector
|
||||
else if(a->shape()[0]>1 && a->shape()[1]==1)
|
||||
return tools::shared_ptr<mapped_object>(new mapped_array(dtype, id, 'c'));
|
||||
return std::shared_ptr<mapped_object>(new mapped_array(dtype, id, 'c'));
|
||||
//Row vector
|
||||
else if(a->shape()[0]==1 && a->shape()[1]>1)
|
||||
return tools::shared_ptr<mapped_object>(new mapped_array(dtype, id, 'r'));
|
||||
return std::shared_ptr<mapped_object>(new mapped_array(dtype, id, 'r'));
|
||||
//Matrix
|
||||
else
|
||||
return tools::shared_ptr<mapped_object>(new mapped_array(dtype, id, 'm'));
|
||||
return std::shared_ptr<mapped_object>(new mapped_array(dtype, id, 'm'));
|
||||
}
|
||||
|
||||
tools::shared_ptr<mapped_object> base::map_functor::create(repeat_infos const &) const
|
||||
std::shared_ptr<mapped_object> base::map_functor::create(repeat_infos const &) const
|
||||
{
|
||||
//TODO: Make it less specific!
|
||||
return tools::shared_ptr<mapped_object>(new mapped_tuple(size_type(device_),binder_.get(),4));
|
||||
return std::shared_ptr<mapped_object>(new mapped_tuple(size_type(device_),binder_.get(),4));
|
||||
}
|
||||
|
||||
tools::shared_ptr<mapped_object> base::map_functor::create(lhs_rhs_element const & lhs_rhs) const
|
||||
std::shared_ptr<mapped_object> base::map_functor::create(lhs_rhs_element const & lhs_rhs) const
|
||||
{
|
||||
switch(lhs_rhs.type_family)
|
||||
{
|
||||
@@ -115,7 +115,7 @@ void base::map_functor::operator()(isaac::array_expression const & array_express
|
||||
else if (root_node.op.type == OPERATOR_OUTER_PROD_TYPE)
|
||||
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_outer>(&array_expression, root_idx, &mapping_)));
|
||||
else if (detail::is_cast(root_node.op))
|
||||
mapping_.insert(mapping_type::value_type(key, tools::shared_ptr<mapped_object>(new mapped_cast(root_node.op.type, binder_.get()))));
|
||||
mapping_.insert(mapping_type::value_type(key, std::shared_ptr<mapped_object>(new mapped_cast(root_node.op.type, binder_.get()))));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -289,7 +289,7 @@ std::string base::generate_arguments(std::string const & data_type, driver::Devi
|
||||
|
||||
void base::set_arguments(expressions_tuple const & expressions, driver::Kernel & kernel, unsigned int & current_arg)
|
||||
{
|
||||
tools::shared_ptr<symbolic_binder> binder = make_binder();
|
||||
std::shared_ptr<symbolic_binder> binder = make_binder();
|
||||
for (const auto & elem : expressions.data())
|
||||
traverse(*elem, (elem)->root(), set_arguments_functor(*binder, current_arg, kernel), true);
|
||||
}
|
||||
@@ -483,12 +483,12 @@ std::string base::append_width(std::string const & str, unsigned int width)
|
||||
return str + tools::to_string(width);
|
||||
}
|
||||
|
||||
tools::shared_ptr<symbolic_binder> base::make_binder()
|
||||
std::shared_ptr<symbolic_binder> base::make_binder()
|
||||
{
|
||||
if (binding_policy_==BIND_TO_HANDLE)
|
||||
return tools::shared_ptr<symbolic_binder>(new bind_to_handle());
|
||||
return std::shared_ptr<symbolic_binder>(new bind_to_handle());
|
||||
else
|
||||
return tools::shared_ptr<symbolic_binder>(new bind_all_unique());
|
||||
return std::shared_ptr<symbolic_binder>(new bind_all_unique());
|
||||
}
|
||||
|
||||
|
||||
@@ -514,7 +514,7 @@ std::string base::generate(const char * suffix, expressions_tuple const & expres
|
||||
|
||||
//Create mapping
|
||||
std::vector<mapping_type> mappings(expressions.data().size());
|
||||
tools::shared_ptr<symbolic_binder> binder = make_binder();
|
||||
std::shared_ptr<symbolic_binder> binder = make_binder();
|
||||
for (mit = mappings.begin(), sit = expressions.data().begin(); sit != expressions.data().end(); ++sit, ++mit)
|
||||
traverse(**sit, (*sit)->root(), map_functor(*binder,*mit,device), true);
|
||||
|
||||
@@ -538,8 +538,8 @@ int_t base_impl<TType, PType>::local_size_1() const
|
||||
{ return p_.local_size_1; }
|
||||
|
||||
template<class TType, class PType>
|
||||
tools::shared_ptr<base> base_impl<TType, PType>::clone() const
|
||||
{ return tools::shared_ptr<base>(new TType(*dynamic_cast<TType const *>(this))); }
|
||||
std::shared_ptr<base> base_impl<TType, PType>::clone() const
|
||||
{ return std::shared_ptr<base>(new TType(*dynamic_cast<TType const *>(this))); }
|
||||
|
||||
template<class TType, class PType>
|
||||
int base_impl<TType, PType>::is_invalid(expressions_tuple const & expressions, driver::Device const & device) const
|
||||
|
@@ -85,7 +85,7 @@ driver::Program& model::init(controller<expressions_tuple> const & expressions)
|
||||
return *program;
|
||||
}
|
||||
|
||||
model::model(expression_type etype, numeric_type dtype, predictors::random_forest const & predictor, std::vector< tools::shared_ptr<templates::base> > const & templates, driver::CommandQueue const & queue) :
|
||||
model::model(expression_type etype, numeric_type dtype, predictors::random_forest const & predictor, std::vector< std::shared_ptr<templates::base> > const & templates, driver::CommandQueue const & queue) :
|
||||
templates_(templates), fallback_(fallbacks[std::make_pair(etype, dtype)]), predictor_(new predictors::random_forest(predictor)), queue_(queue)
|
||||
{}
|
||||
|
||||
@@ -166,27 +166,27 @@ namespace detail
|
||||
throw std::invalid_argument("Invalid datatype: " + name);
|
||||
}
|
||||
|
||||
static tools::shared_ptr<templates::base> create(std::string const & template_name, std::vector<int> const & a)
|
||||
static std::shared_ptr<templates::base> create(std::string const & template_name, std::vector<int> const & a)
|
||||
{
|
||||
templates::fetching_policy_type fetch[] = {templates::FETCH_FROM_LOCAL, templates::FETCH_FROM_GLOBAL_STRIDED, templates::FETCH_FROM_GLOBAL_CONTIGUOUS};
|
||||
if(template_name=="axpy")
|
||||
return tools::shared_ptr<templates::base>(new templates::axpy(a[0], a[1], a[2], fetch[a[3]]));
|
||||
return std::shared_ptr<templates::base>(new templates::axpy(a[0], a[1], a[2], fetch[a[3]]));
|
||||
else if(template_name=="dot")
|
||||
return tools::shared_ptr<templates::base>(new templates::dot(a[0], a[1], a[2], fetch[a[3]]));
|
||||
return std::shared_ptr<templates::base>(new templates::dot(a[0], a[1], a[2], fetch[a[3]]));
|
||||
else if(template_name=="ger")
|
||||
return tools::shared_ptr<templates::base>(new templates::ger(a[0], a[1], a[2], a[3], a[4], fetch[a[5]]));
|
||||
return std::shared_ptr<templates::base>(new templates::ger(a[0], a[1], a[2], a[3], a[4], fetch[a[5]]));
|
||||
else if(template_name.find("gemv_n")!=std::string::npos)
|
||||
return tools::shared_ptr<templates::base>(new templates::gemv_n(a[0], a[1], a[2], a[3], a[4], fetch[a[5]]));
|
||||
return std::shared_ptr<templates::base>(new templates::gemv_n(a[0], a[1], a[2], a[3], a[4], fetch[a[5]]));
|
||||
else if(template_name.find("gemv_t")!=std::string::npos)
|
||||
return tools::shared_ptr<templates::base>(new templates::gemv_t(a[0], a[1], a[2], a[3], a[4], fetch[a[5]]));
|
||||
return std::shared_ptr<templates::base>(new templates::gemv_t(a[0], a[1], a[2], a[3], a[4], fetch[a[5]]));
|
||||
else if(template_name.find("gemm_nn")!=std::string::npos)
|
||||
return tools::shared_ptr<templates::base>(new templates::gemm_nn(a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], fetch[a[8]], fetch[a[9]], a[10], a[11]));
|
||||
return std::shared_ptr<templates::base>(new templates::gemm_nn(a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], fetch[a[8]], fetch[a[9]], a[10], a[11]));
|
||||
else if(template_name.find("gemm_tn")!=std::string::npos)
|
||||
return tools::shared_ptr<templates::base>(new templates::gemm_tn(a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], fetch[a[8]], fetch[a[9]], a[10], a[11]));
|
||||
return std::shared_ptr<templates::base>(new templates::gemm_tn(a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], fetch[a[8]], fetch[a[9]], a[10], a[11]));
|
||||
else if(template_name.find("gemm_nt")!=std::string::npos)
|
||||
return tools::shared_ptr<templates::base>(new templates::gemm_nt(a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], fetch[a[8]], fetch[a[9]], a[10], a[11]));
|
||||
return std::shared_ptr<templates::base>(new templates::gemm_nt(a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], fetch[a[8]], fetch[a[9]], a[10], a[11]));
|
||||
else if(template_name.find("gemm_tt")!=std::string::npos)
|
||||
return tools::shared_ptr<templates::base>(new templates::gemm_tt(a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], fetch[a[8]], fetch[a[9]], a[10], a[11]));
|
||||
return std::shared_ptr<templates::base>(new templates::gemm_tt(a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], fetch[a[8]], fetch[a[9]], a[10], a[11]));
|
||||
else
|
||||
throw std::invalid_argument("Invalid expression: " + template_name);
|
||||
}
|
||||
@@ -222,7 +222,7 @@ void import(std::string const & fname, driver::CommandQueue & queue, model_map_t
|
||||
numeric_type dtype = detail::get_dtype(elem);
|
||||
|
||||
// Get profiles
|
||||
std::vector<tools::shared_ptr<templates::base> > templates;
|
||||
std::vector<std::shared_ptr<templates::base> > templates;
|
||||
js::Value const & profiles = document[opcstr][dtcstr]["profiles"];
|
||||
for (js::SizeType id = 0 ; id < profiles.Size() ; ++id)
|
||||
templates.push_back(detail::create(operation, tools::to_int_array<int>(profiles[id])));
|
||||
@@ -231,10 +231,10 @@ void import(std::string const & fname, driver::CommandQueue & queue, model_map_t
|
||||
{
|
||||
// Get predictor
|
||||
predictors::random_forest predictor(document[opcstr][dtcstr]["predictor"]);
|
||||
result[std::make_pair(etype, dtype)] = tools::shared_ptr<model>(new model(etype, dtype, predictor, templates, queue));
|
||||
result[std::make_pair(etype, dtype)] = std::shared_ptr<model>(new model(etype, dtype, predictor, templates, queue));
|
||||
}
|
||||
else
|
||||
result[std::make_pair(etype, dtype)] = tools::shared_ptr<model>(new model(etype, dtype, *templates[0], queue));
|
||||
result[std::make_pair(etype, dtype)] = std::shared_ptr<model>(new model(etype, dtype, *templates[0], queue));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -242,9 +242,9 @@ void import(std::string const & fname, driver::CommandQueue & queue, model_map_t
|
||||
}
|
||||
|
||||
|
||||
std::map<std::pair<expression_type, numeric_type>, tools::shared_ptr<templates::base> > init_fallback()
|
||||
std::map<std::pair<expression_type, numeric_type>, std::shared_ptr<templates::base> > init_fallback()
|
||||
{
|
||||
typedef tools::shared_ptr<templates::base> ptr_t;
|
||||
typedef std::shared_ptr<templates::base> ptr_t;
|
||||
std::map<std::pair<expression_type, numeric_type>, ptr_t > res;
|
||||
numeric_type types[] = {CHAR_TYPE, UCHAR_TYPE, SHORT_TYPE, USHORT_TYPE, INT_TYPE, UINT_TYPE, LONG_TYPE, ULONG_TYPE, FLOAT_TYPE, DOUBLE_TYPE};
|
||||
for(auto DTYPE : types)
|
||||
@@ -271,7 +271,7 @@ model_map_t init_models(driver::CommandQueue & queue)
|
||||
|
||||
for(numeric_type dtype: dtypes)
|
||||
for(expression_type etype: etypes)
|
||||
res[std::make_pair(etype, dtype)] = tools::shared_ptr<model>(new model(etype, dtype, *fallbacks[std::make_pair(etype, dtype)], queue));
|
||||
res[std::make_pair(etype, dtype)] = std::shared_ptr<model>(new model(etype, dtype, *fallbacks[std::make_pair(etype, dtype)], queue));
|
||||
|
||||
if(const char * homepath = std::getenv("HOME"))
|
||||
import(std::string(homepath) + "/.isaac/devices/device0.json", queue, res);
|
||||
@@ -286,7 +286,7 @@ model_map_t& models(driver::CommandQueue & queue)
|
||||
return it->second;
|
||||
}
|
||||
|
||||
std::map<std::pair<expression_type, numeric_type>, tools::shared_ptr<templates::base> > fallbacks = init_fallback();
|
||||
std::map<std::pair<expression_type, numeric_type>, std::shared_ptr<templates::base> > fallbacks = init_fallback();
|
||||
std::map<driver::CommandQueue, model_map_t> models_;
|
||||
|
||||
}
|
||||
|
@@ -177,29 +177,29 @@ namespace isaac
|
||||
|
||||
/*----Parse required temporaries-----*/
|
||||
detail::parse(tree, rootidx, breakpoints, final_type);
|
||||
std::vector<tools::shared_ptr<array> > temporaries_;
|
||||
std::vector<std::shared_ptr<array> > temporaries_;
|
||||
|
||||
/*----Compute required temporaries----*/
|
||||
for(detail::breakpoints_t::iterator it = breakpoints.begin() ; it != breakpoints.end() ; ++it)
|
||||
{
|
||||
tools::shared_ptr<model> const & pmodel = models[std::make_pair(it->first, dtype)];
|
||||
std::shared_ptr<model> const & pmodel = models[std::make_pair(it->first, dtype)];
|
||||
array_expression::node const & node = tree[it->second->node_index];
|
||||
array_expression::node const & lmost = lhs_most(tree, node);
|
||||
|
||||
//Creates temporary
|
||||
tools::shared_ptr<array> tmp;
|
||||
std::shared_ptr<array> tmp;
|
||||
switch(it->first){
|
||||
case DOT_TYPE: tmp = tools::shared_ptr<array>(new array(1, dtype, context)); break;
|
||||
case DOT_TYPE: tmp = std::shared_ptr<array>(new array(1, dtype, context)); break;
|
||||
|
||||
case AXPY_TYPE: tmp = tools::shared_ptr<array>(new array(lmost.lhs.array->shape()[0], dtype, context)); break;
|
||||
case GEMV_N_TYPE: tmp = tools::shared_ptr<array>(new array(lmost.lhs.array->shape()[0], dtype, context)); break;
|
||||
case GEMV_T_TYPE: tmp = tools::shared_ptr<array>(new array(lmost.lhs.array->shape()[1], dtype, context)); break;
|
||||
case AXPY_TYPE: tmp = std::shared_ptr<array>(new array(lmost.lhs.array->shape()[0], dtype, context)); break;
|
||||
case GEMV_N_TYPE: tmp = std::shared_ptr<array>(new array(lmost.lhs.array->shape()[0], dtype, context)); break;
|
||||
case GEMV_T_TYPE: tmp = std::shared_ptr<array>(new array(lmost.lhs.array->shape()[1], dtype, context)); break;
|
||||
|
||||
case GER_TYPE: tmp = tools::shared_ptr<array>(new array(lmost.lhs.array->shape()[0], lmost.lhs.array->shape()[1], dtype, context)); break;
|
||||
case GEMM_NN_TYPE: tmp = tools::shared_ptr<array>(new array(node.lhs.array->shape()[0], node.rhs.array->shape()[1], dtype, context)); break;
|
||||
case GEMM_NT_TYPE: tmp = tools::shared_ptr<array>(new array(node.lhs.array->shape()[0], node.rhs.array->shape()[0], dtype, context)); break;
|
||||
case GEMM_TN_TYPE: tmp = tools::shared_ptr<array>(new array(node.lhs.array->shape()[1], node.rhs.array->shape()[1], dtype, context)); break;
|
||||
case GEMM_TT_TYPE: tmp = tools::shared_ptr<array>(new array(node.lhs.array->shape()[1], node.rhs.array->shape()[0], dtype, context)); break;
|
||||
case GER_TYPE: tmp = std::shared_ptr<array>(new array(lmost.lhs.array->shape()[0], lmost.lhs.array->shape()[1], dtype, context)); break;
|
||||
case GEMM_NN_TYPE: tmp = std::shared_ptr<array>(new array(node.lhs.array->shape()[0], node.rhs.array->shape()[1], dtype, context)); break;
|
||||
case GEMM_NT_TYPE: tmp = std::shared_ptr<array>(new array(node.lhs.array->shape()[0], node.rhs.array->shape()[0], dtype, context)); break;
|
||||
case GEMM_TN_TYPE: tmp = std::shared_ptr<array>(new array(node.lhs.array->shape()[1], node.rhs.array->shape()[1], dtype, context)); break;
|
||||
case GEMM_TT_TYPE: tmp = std::shared_ptr<array>(new array(node.lhs.array->shape()[1], node.rhs.array->shape()[0], dtype, context)); break;
|
||||
|
||||
default: throw std::invalid_argument("Unrecognized operation");
|
||||
}
|
||||
|
@@ -166,9 +166,9 @@ array_expression array_expression::operator!()
|
||||
{ return array_expression(*this, invalid_node(), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_NEGATE_TYPE), context_, INT_TYPE, shape_); }
|
||||
|
||||
//
|
||||
tools::shared_ptr<array_expression> expressions_tuple::create(array_expression const & s)
|
||||
std::shared_ptr<array_expression> expressions_tuple::create(array_expression const & s)
|
||||
{
|
||||
return tools::shared_ptr<array_expression>(new array_expression(static_cast<array_expression const &>(s)));
|
||||
return std::shared_ptr<array_expression>(new array_expression(static_cast<array_expression const &>(s)));
|
||||
}
|
||||
|
||||
expressions_tuple::expressions_tuple(data_type const & data, order_type order) : data_(data), order_(order)
|
||||
|
@@ -1,3 +1,4 @@
|
||||
#include <cmath>
|
||||
#include "common.hpp"
|
||||
#include "isaac/array.h"
|
||||
|
||||
|
Reference in New Issue
Block a user