Python: Fixed wrapper issues induced after cleaning
This commit is contained in:
@@ -23,6 +23,8 @@
|
||||
#include "common.hpp"
|
||||
#include "core.h"
|
||||
|
||||
namespace tpt = sc::templates;
|
||||
|
||||
namespace detail
|
||||
{
|
||||
|
||||
@@ -103,7 +105,8 @@ namespace detail
|
||||
{
|
||||
std::shared_ptr<rt::profiles::value_type> construct_model(bp::object const & tp, bp::object dtype, sc::driver::CommandQueue & queue)
|
||||
{
|
||||
return std::shared_ptr<rt::profiles::value_type>(new rt::profiles::value_type(tools::extract_template_type(tp), tools::extract_dtype(dtype), (sc::templates::base const &)bp::extract<sc::templates::base>(tp), queue));
|
||||
tpt::base* raw = bp::extract<tpt::base*>(tp);
|
||||
return std::make_shared<rt::profiles::value_type>(tools::extract_template_type(tp), tools::extract_dtype(dtype), raw->getptr(), queue);
|
||||
}
|
||||
|
||||
std::shared_ptr<sc::array>
|
||||
|
@@ -58,8 +58,8 @@ void export_templates()
|
||||
|
||||
//Base
|
||||
{
|
||||
#define __PROP(name) .def_readonly(#name, &tpt::base::parameters_type::name)
|
||||
bp::class_<tpt::base, boost::noncopyable>("base", bp::no_init)
|
||||
#define __PROP(name) .def_readonly(#name, &tpt::base::name)
|
||||
bp::class_<tpt::base, std::shared_ptr<tpt::base>, boost::noncopyable>("base", bp::no_init)
|
||||
.def("lmem_usage", &tpt::base::lmem_usage)
|
||||
.def("registers_usage", &tpt::base::registers_usage)
|
||||
.def("is_invalid", &tpt::base::is_invalid)
|
||||
@@ -68,26 +68,23 @@ void export_templates()
|
||||
#undef __PROP
|
||||
}
|
||||
|
||||
#define WRAP_BASE(name) bp::class_<tpt::base_impl<tpt::name, tpt::name::parameters_type>, bp::bases<tpt::base>, boost::noncopyable>(#name, bp::no_init)\
|
||||
.add_property("ls0", &tpt::base_impl<tpt::name, tpt::name::parameters_type>::ls0)\
|
||||
.add_property("ls1", &tpt::base_impl<tpt::name, tpt::name::parameters_type>::ls1);
|
||||
bp::class_<tpt::base_impl, bp::bases<tpt::base>, boost::noncopyable>("base_impl", bp::no_init)
|
||||
.add_property("ls0", &tpt::base_impl::ls0)
|
||||
.add_property("ls1", &tpt::base_impl::ls1);
|
||||
|
||||
#define WRAP_TEMPLATE(name, basename, ...) bp::class_<tpt::name, bp::bases<tpt::base_impl<tpt::basename, tpt::basename::parameters_type> > >(#name, bp::init<__VA_ARGS__>())\
|
||||
#define WRAP_BASE(name) bp::class_<tpt::name, bp::bases<tpt::base_impl>, boost::noncopyable>(#name, bp::no_init);
|
||||
|
||||
#define WRAP_TEMPLATE(name, basename, ...) bp::class_<tpt::name, std::shared_ptr<tpt::name>, bp::bases<basename>>(#name, bp::init<__VA_ARGS__>())\
|
||||
;
|
||||
#define WRAP_SINGLE_TEMPLATE(name, ...) WRAP_BASE(name) WRAP_TEMPLATE(name, name, __VA_ARGS__)
|
||||
|
||||
//Vector AXPY
|
||||
WRAP_SINGLE_TEMPLATE(elementwise_1d, uint, uint, uint, tpt::fetch_type)
|
||||
WRAP_SINGLE_TEMPLATE(elementwise_2d, uint, uint, uint, uint, uint, tpt::fetch_type)
|
||||
WRAP_SINGLE_TEMPLATE(reduce_1d, uint, uint, uint, tpt::fetch_type)
|
||||
WRAP_TEMPLATE(elementwise_1d, tpt::base_impl, uint, uint, uint, tpt::fetch_type)
|
||||
WRAP_TEMPLATE(elementwise_2d, tpt::base_impl, uint, uint, uint, uint, uint, tpt::fetch_type)
|
||||
WRAP_TEMPLATE(reduce_1d, tpt::base_impl, uint, uint, uint, tpt::fetch_type)
|
||||
WRAP_BASE(reduce_2d)
|
||||
WRAP_TEMPLATE(reduce_2d_rows, reduce_2d, uint, uint, uint, uint, uint, tpt::fetch_type)
|
||||
WRAP_TEMPLATE(reduce_2d_cols, reduce_2d, uint, uint, uint, uint, uint, tpt::fetch_type)
|
||||
WRAP_TEMPLATE(reduce_2d_rows, tpt::reduce_2d, uint, uint, uint, uint, uint, tpt::fetch_type)
|
||||
WRAP_TEMPLATE(reduce_2d_cols, tpt::reduce_2d, uint, uint, uint, uint, uint, tpt::fetch_type)
|
||||
WRAP_BASE(gemm)
|
||||
WRAP_TEMPLATE(gemm_nn, gemm, uint, uint, uint, uint, uint, uint, uint, uint, tpt::fetch_type, tpt::fetch_type, uint, uint)
|
||||
WRAP_TEMPLATE(gemm_tn, gemm, uint, uint, uint, uint, uint, uint, uint, uint, tpt::fetch_type, tpt::fetch_type, uint, uint)
|
||||
WRAP_TEMPLATE(gemm_nt, gemm, uint, uint, uint, uint, uint, uint, uint, uint, tpt::fetch_type, tpt::fetch_type, uint, uint)
|
||||
WRAP_TEMPLATE(gemm_tt, gemm, uint, uint, uint, uint, uint, uint, uint, uint, tpt::fetch_type, tpt::fetch_type, uint, uint)
|
||||
|
||||
|
||||
WRAP_TEMPLATE(gemm_nn, tpt::gemm, uint, uint, uint, uint, uint, uint, uint, uint, tpt::fetch_type, tpt::fetch_type, uint, uint)
|
||||
WRAP_TEMPLATE(gemm_tn, tpt::gemm, uint, uint, uint, uint, uint, uint, uint, uint, tpt::fetch_type, tpt::fetch_type, uint, uint)
|
||||
WRAP_TEMPLATE(gemm_nt, tpt::gemm, uint, uint, uint, uint, uint, uint, uint, uint, tpt::fetch_type, tpt::fetch_type, uint, uint)
|
||||
WRAP_TEMPLATE(gemm_tt, tpt::gemm, uint, uint, uint, uint, uint, uint, uint, uint, tpt::fetch_type, tpt::fetch_type, uint, uint)
|
||||
}
|
||||
|
Reference in New Issue
Block a user