[python] basic tensorflow wrapper working

This commit is contained in:
Philippe Tillet
2019-08-26 16:53:49 -07:00
parent 0e0399f866
commit 4075949f80
26 changed files with 702 additions and 968 deletions

View File

@@ -16,7 +16,6 @@
#include <unordered_map>
#include <iostream>
#include <list>
#include <deque>
#include <valarray>
#if defined(_MSC_VER)
@@ -84,8 +83,7 @@ template <typename Type, typename Key> struct set_caster {
template <typename T>
static handle cast(T &&src, return_value_policy policy, handle parent) {
if (!std::is_lvalue_reference<T>::value)
policy = return_value_policy_override<Key>::policy(policy);
policy = return_value_policy_override<Key>::policy(policy);
pybind11::set s;
for (auto &&value : src) {
auto value_ = reinterpret_steal<object>(key_conv::cast(forward_like<T>(value), policy, parent));
@@ -95,7 +93,7 @@ template <typename Type, typename Key> struct set_caster {
return s.release();
}
PYBIND11_TYPE_CASTER(type, _("Set[") + key_conv::name + _("]"));
PYBIND11_TYPE_CASTER(type, _("Set[") + key_conv::name() + _("]"));
};
template <typename Type, typename Key, typename Value> struct map_caster {
@@ -121,12 +119,8 @@ template <typename Type, typename Key, typename Value> struct map_caster {
template <typename T>
static handle cast(T &&src, return_value_policy policy, handle parent) {
dict d;
return_value_policy policy_key = policy;
return_value_policy policy_value = policy;
if (!std::is_lvalue_reference<T>::value) {
policy_key = return_value_policy_override<Key>::policy(policy_key);
policy_value = return_value_policy_override<Value>::policy(policy_value);
}
return_value_policy policy_key = return_value_policy_override<Key>::policy(policy);
return_value_policy policy_value = return_value_policy_override<Value>::policy(policy);
for (auto &&kv : src) {
auto key = reinterpret_steal<object>(key_conv::cast(forward_like<T>(kv.first), policy_key, parent));
auto value = reinterpret_steal<object>(value_conv::cast(forward_like<T>(kv.second), policy_value, parent));
@@ -137,14 +131,14 @@ template <typename Type, typename Key, typename Value> struct map_caster {
return d.release();
}
PYBIND11_TYPE_CASTER(Type, _("Dict[") + key_conv::name + _(", ") + value_conv::name + _("]"));
PYBIND11_TYPE_CASTER(Type, _("Dict[") + key_conv::name() + _(", ") + value_conv::name() + _("]"));
};
template <typename Type, typename Value> struct list_caster {
using value_conv = make_caster<Value>;
bool load(handle src, bool convert) {
if (!isinstance<sequence>(src) || isinstance<str>(src))
if (!isinstance<sequence>(src))
return false;
auto s = reinterpret_borrow<sequence>(src);
value.clear();
@@ -167,8 +161,7 @@ private:
public:
template <typename T>
static handle cast(T &&src, return_value_policy policy, handle parent) {
if (!std::is_lvalue_reference<T>::value)
policy = return_value_policy_override<Value>::policy(policy);
policy = return_value_policy_override<Value>::policy(policy);
list l(src.size());
size_t index = 0;
for (auto &&value : src) {
@@ -180,15 +173,12 @@ public:
return l.release();
}
PYBIND11_TYPE_CASTER(Type, _("List[") + value_conv::name + _("]"));
PYBIND11_TYPE_CASTER(Type, _("List[") + value_conv::name() + _("]"));
};
template <typename Type, typename Alloc> struct type_caster<std::vector<Type, Alloc>>
: list_caster<std::vector<Type, Alloc>, Type> { };
template <typename Type, typename Alloc> struct type_caster<std::deque<Type, Alloc>>
: list_caster<std::deque<Type, Alloc>, Type> { };
template <typename Type, typename Alloc> struct type_caster<std::list<Type, Alloc>>
: list_caster<std::list<Type, Alloc>, Type> { };
@@ -209,9 +199,9 @@ private:
public:
bool load(handle src, bool convert) {
if (!isinstance<sequence>(src))
if (!isinstance<list>(src))
return false;
auto l = reinterpret_borrow<sequence>(src);
auto l = reinterpret_borrow<list>(src);
if (!require_size(l.size()))
return false;
size_t ctr = 0;
@@ -237,7 +227,7 @@ public:
return l.release();
}
PYBIND11_TYPE_CASTER(ArrayType, _("List[") + value_conv::name + _<Resizable>(_(""), _("[") + _<Size>() + _("]")) + _("]"));
PYBIND11_TYPE_CASTER(ArrayType, _("List[") + value_conv::name() + _<Resizable>(_(""), _("[") + _<Size>() + _("]")) + _("]"));
};
template <typename Type, size_t Size> struct type_caster<std::array<Type, Size>>
@@ -284,7 +274,7 @@ template<typename T> struct optional_caster {
return true;
}
PYBIND11_TYPE_CASTER(T, _("Optional[") + value_conv::name + _("]"));
PYBIND11_TYPE_CASTER(T, _("Optional[") + value_conv::name() + _("]"));
};
#if PYBIND11_HAS_OPTIONAL
@@ -364,7 +354,7 @@ struct variant_caster<V<Ts...>> {
}
using Type = V<Ts...>;
PYBIND11_TYPE_CASTER(Type, _("Union[") + detail::concat(make_caster<Ts>::name...) + _("]"));
PYBIND11_TYPE_CASTER(Type, _("Union[") + detail::concat(make_caster<Ts>::name()...) + _("]"));
};
#if PYBIND11_HAS_VARIANT