Runtime: More progress towards cuBLAS integration

This commit is contained in:
Philippe Tillet
2016-10-04 01:02:43 -04:00
parent fb9669a34d
commit ffb9548b6a
18 changed files with 170 additions and 210 deletions

View File

@@ -86,31 +86,5 @@ namespace tools
throw;
}
}
inline sc::expression_type extract_template_type(bp::object const & odtype)
{
std::string name = bp::extract<std::string>(odtype.attr("__class__").attr("__name__"))();
if(name=="class")
name = bp::extract<std::string>(odtype.attr("__name__"))();
else
name = bp::extract<std::string>(odtype.attr("__class__").attr("__name__"))();
if(name=="elementwise_1d") return sc::ELEMENTWISE_1D;
else if(name=="elementwise_2d") return sc::ELEMENTWISE_2D;
else if(name=="reduce_1d") return sc::REDUCE_1D;
else if(name=="reduce_2d_rows") return sc::REDUCE_2D_ROWS;
else if(name=="reduce_2d_cols") return sc::REDUCE_2D_COLS;
else if(name=="gemm_nn") return sc::GEMM_NN;
else if(name=="gemm_tn") return sc::GEMM_TN;
else if(name=="gemm_nt") return sc::GEMM_NT;
else if(name=="gemm_tt") return sc::GEMM_TT;
else
{
PyErr_SetString(PyExc_TypeError, "Template type not understood");
bp::throw_error_already_set();
throw;
}
}
}
#endif

View File

@@ -106,7 +106,7 @@ namespace detail
std::shared_ptr<rt::profiles::value_type> construct_model(bp::object const & tp, bp::object dtype, sc::driver::CommandQueue & queue)
{
tpt::base* raw = bp::extract<tpt::base*>(tp);
return std::make_shared<rt::profiles::value_type>(tools::extract_template_type(tp), tools::extract_dtype(dtype), raw->getptr(), queue);
return std::make_shared<rt::profiles::value_type>(tools::extract_dtype(dtype), raw->getptr(), queue);
}
std::shared_ptr<sc::array>
@@ -219,9 +219,9 @@ namespace detail
{
static rt::profiles::value_type& get_item(rt::profiles::map_type& container, bp::tuple i_)
{
sc::expression_type expression = tools::extract_template_type(i_[0]);
tpt::base* tpt = bp::extract<tpt::base*>(i_[0]);
sc::numeric_type dtype = tools::extract_dtype(i_[1]);
rt::profiles::map_type::iterator i = container.find(std::make_pair(expression, dtype));
rt::profiles::map_type::iterator i = container.find(std::make_pair(tpt->type(), dtype));
if (i == container.end())
{
PyErr_SetString(PyExc_KeyError, "Invalid key");
@@ -232,9 +232,9 @@ namespace detail
static void set_item(rt::profiles::map_type& container, bp::tuple i_, rt::profiles::value_type const & v)
{
sc::expression_type expression = tools::extract_template_type(i_[0]);
tpt::base* tpt = bp::extract<tpt::base*>(i_[0]);
sc::numeric_type dtype = tools::extract_dtype(i_[1]);
container[std::make_pair(expression, dtype)].reset(new rt::profiles::value_type(v));
container[std::make_pair(tpt->type(), dtype)].reset(new rt::profiles::value_type(v));
}
};
}