Runtime: More progress towards cuBLAS integration
This commit is contained in:
@@ -86,31 +86,5 @@ namespace tools
|
||||
throw;
|
||||
}
|
||||
}
|
||||
|
||||
inline sc::expression_type extract_template_type(bp::object const & odtype)
|
||||
{
|
||||
std::string name = bp::extract<std::string>(odtype.attr("__class__").attr("__name__"))();
|
||||
if(name=="class")
|
||||
name = bp::extract<std::string>(odtype.attr("__name__"))();
|
||||
else
|
||||
name = bp::extract<std::string>(odtype.attr("__class__").attr("__name__"))();
|
||||
|
||||
if(name=="elementwise_1d") return sc::ELEMENTWISE_1D;
|
||||
else if(name=="elementwise_2d") return sc::ELEMENTWISE_2D;
|
||||
else if(name=="reduce_1d") return sc::REDUCE_1D;
|
||||
else if(name=="reduce_2d_rows") return sc::REDUCE_2D_ROWS;
|
||||
else if(name=="reduce_2d_cols") return sc::REDUCE_2D_COLS;
|
||||
else if(name=="gemm_nn") return sc::GEMM_NN;
|
||||
else if(name=="gemm_tn") return sc::GEMM_TN;
|
||||
else if(name=="gemm_nt") return sc::GEMM_NT;
|
||||
else if(name=="gemm_tt") return sc::GEMM_TT;
|
||||
else
|
||||
{
|
||||
PyErr_SetString(PyExc_TypeError, "Template type not understood");
|
||||
bp::throw_error_already_set();
|
||||
throw;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
#endif
|
||||
|
@@ -106,7 +106,7 @@ namespace detail
|
||||
std::shared_ptr<rt::profiles::value_type> construct_model(bp::object const & tp, bp::object dtype, sc::driver::CommandQueue & queue)
|
||||
{
|
||||
tpt::base* raw = bp::extract<tpt::base*>(tp);
|
||||
return std::make_shared<rt::profiles::value_type>(tools::extract_template_type(tp), tools::extract_dtype(dtype), raw->getptr(), queue);
|
||||
return std::make_shared<rt::profiles::value_type>(tools::extract_dtype(dtype), raw->getptr(), queue);
|
||||
}
|
||||
|
||||
std::shared_ptr<sc::array>
|
||||
@@ -219,9 +219,9 @@ namespace detail
|
||||
{
|
||||
static rt::profiles::value_type& get_item(rt::profiles::map_type& container, bp::tuple i_)
|
||||
{
|
||||
sc::expression_type expression = tools::extract_template_type(i_[0]);
|
||||
tpt::base* tpt = bp::extract<tpt::base*>(i_[0]);
|
||||
sc::numeric_type dtype = tools::extract_dtype(i_[1]);
|
||||
rt::profiles::map_type::iterator i = container.find(std::make_pair(expression, dtype));
|
||||
rt::profiles::map_type::iterator i = container.find(std::make_pair(tpt->type(), dtype));
|
||||
if (i == container.end())
|
||||
{
|
||||
PyErr_SetString(PyExc_KeyError, "Invalid key");
|
||||
@@ -232,9 +232,9 @@ namespace detail
|
||||
|
||||
static void set_item(rt::profiles::map_type& container, bp::tuple i_, rt::profiles::value_type const & v)
|
||||
{
|
||||
sc::expression_type expression = tools::extract_template_type(i_[0]);
|
||||
tpt::base* tpt = bp::extract<tpt::base*>(i_[0]);
|
||||
sc::numeric_type dtype = tools::extract_dtype(i_[1]);
|
||||
container[std::make_pair(expression, dtype)].reset(new rt::profiles::value_type(v));
|
||||
container[std::make_pair(tpt->type(), dtype)].reset(new rt::profiles::value_type(v));
|
||||
}
|
||||
};
|
||||
}
|
||||
|
Reference in New Issue
Block a user