[python] upgraded pybind11 ; forcing torch tensors to be contiguous()
This commit is contained in:
@@ -202,9 +202,11 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
|
||||
codegen::selection selection(&shmem_allocation, &grids, &shmem_info, &alignment_info, target.get());
|
||||
|
||||
|
||||
|
||||
// run passes
|
||||
peephole.run(module);
|
||||
dce.run(module);
|
||||
// ir::print(module, std::cout);
|
||||
alignment_info.run(module);
|
||||
grids.run(module);
|
||||
// ir::print(module, std::cout);
|
||||
|
@@ -29,11 +29,22 @@ def run_tf():
|
||||
|
||||
def run_torch():
|
||||
import torch as th
|
||||
th.manual_seed(0)
|
||||
M, N, K = 128, 128, 128
|
||||
a = th.randn(M, K).cuda()
|
||||
b = th.randn(K, N).cuda()
|
||||
th_c = th.matmul(a, b)
|
||||
tr_c = triton.ops.dot(a, b)
|
||||
print(tr_c)
|
||||
b.requires_grad_(True)
|
||||
#th_c = th.matmul(a, th.t(b))
|
||||
#th_d = th.matmul(th.t(th_c), b)
|
||||
tr_c = triton.ops.dot(a, b, False, True)
|
||||
#tr_d = triton.ops.dot(tr_c, b, True, False)
|
||||
y = th.sum(tr_c)
|
||||
#print('backprop', y)
|
||||
y.backward()
|
||||
#print('backward done')
|
||||
print(b.grad)
|
||||
#th_d.backward()
|
||||
#print(a.grad)
|
||||
|
||||
|
||||
run_torch()
|
@@ -35,7 +35,6 @@ void register_grid(size_t id,
|
||||
|
||||
void delete_grid(size_t id) {
|
||||
id_grid_map.erase(id);
|
||||
std::cout << "deleted " << id_grid_map.size() << std::endl;
|
||||
}
|
||||
|
||||
void register_fn(size_t id,
|
||||
@@ -46,7 +45,6 @@ void register_fn(size_t id,
|
||||
|
||||
void delete_fn(size_t id) {
|
||||
id_fn_map.erase(id);
|
||||
std::cout << "deleted " << id_fn_map.size() << std::endl;
|
||||
}
|
||||
|
||||
void cleanup() {
|
||||
@@ -415,10 +413,12 @@ void gen_torch_make_handles(std::ostream &os,
|
||||
ir::type* ty = arg->get_type();
|
||||
if(!ty->is_pointer_ty())
|
||||
os << " " << to_c_ty(ty) << " arg_" << name << " = " << name << ";" << std::endl;
|
||||
else
|
||||
else{
|
||||
os << " CHECK_INPUT(" << name << ");" << std::endl;
|
||||
os << " drv::cu_buffer arg_" + name + "(ctx, " + name + ".storage().size(), (CUdeviceptr)" + name + ".storage().data(), false);" << std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void gen_torch_make_launch_function(std::ostream &os, const std::vector<ir::argument*>& args) {
|
||||
os << " (*id_fn_map.at(id))({";
|
||||
@@ -435,6 +435,10 @@ void gen_torch_make_launch_function(std::ostream &os, const std::vector<ir::argu
|
||||
}
|
||||
|
||||
void gen_torch_ret(std::ostream &os, const std::vector<std::string>& outputs) {
|
||||
if(outputs.size() == 1){
|
||||
os << " return " << outputs[0] << ";" << std::endl;
|
||||
return;
|
||||
}
|
||||
os << " return {";
|
||||
for(size_t i = 0; i < outputs.size(); i++){
|
||||
if(i > 0)
|
||||
@@ -467,6 +471,10 @@ std::tuple<std::string,
|
||||
#include "ATen/cuda/CUDAContext.h"
|
||||
#include "ATen/cuda/detail/CUDAHooks.h"
|
||||
|
||||
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
|
||||
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
|
||||
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
|
||||
|
||||
namespace rt = triton::runtime;
|
||||
namespace drv = triton::driver;
|
||||
|
||||
|
@@ -200,7 +200,8 @@ struct function_record {
|
||||
/// Special data structure which (temporarily) holds metadata about a bound class
|
||||
struct type_record {
|
||||
PYBIND11_NOINLINE type_record()
|
||||
: multiple_inheritance(false), dynamic_attr(false), buffer_protocol(false), module_local(false) { }
|
||||
: multiple_inheritance(false), dynamic_attr(false), buffer_protocol(false),
|
||||
default_holder(true), module_local(false) { }
|
||||
|
||||
/// Handle to the parent scope
|
||||
handle scope;
|
||||
@@ -214,11 +215,14 @@ struct type_record {
|
||||
/// How large is the underlying C++ type?
|
||||
size_t type_size = 0;
|
||||
|
||||
/// What is the alignment of the underlying C++ type?
|
||||
size_t type_align = 0;
|
||||
|
||||
/// How large is the type's holder?
|
||||
size_t holder_size = 0;
|
||||
|
||||
/// The global operator new can be overridden with a class-specific variant
|
||||
void *(*operator_new)(size_t) = ::operator new;
|
||||
void *(*operator_new)(size_t) = nullptr;
|
||||
|
||||
/// Function pointer to class_<..>::init_instance
|
||||
void (*init_instance)(instance *, const void *) = nullptr;
|
||||
@@ -278,7 +282,7 @@ struct type_record {
|
||||
}
|
||||
};
|
||||
|
||||
inline function_call::function_call(function_record &f, handle p) :
|
||||
inline function_call::function_call(const function_record &f, handle p) :
|
||||
func(f), parent(p) {
|
||||
args.reserve(f.nargs);
|
||||
args_convert.reserve(f.nargs);
|
||||
|
@@ -17,6 +17,7 @@
|
||||
#include <array>
|
||||
#include <limits>
|
||||
#include <tuple>
|
||||
#include <type_traits>
|
||||
|
||||
#if defined(PYBIND11_CPP17)
|
||||
# if defined(__has_include)
|
||||
@@ -203,10 +204,10 @@ PYBIND11_NOINLINE inline handle get_type_handle(const std::type_info &tp, bool t
|
||||
}
|
||||
|
||||
struct value_and_holder {
|
||||
instance *inst;
|
||||
size_t index;
|
||||
const detail::type_info *type;
|
||||
void **vh;
|
||||
instance *inst = nullptr;
|
||||
size_t index = 0u;
|
||||
const detail::type_info *type = nullptr;
|
||||
void **vh = nullptr;
|
||||
|
||||
// Main constructor for a found value/holder:
|
||||
value_and_holder(instance *i, const detail::type_info *type, size_t vpos, size_t index) :
|
||||
@@ -215,7 +216,7 @@ struct value_and_holder {
|
||||
{}
|
||||
|
||||
// Default constructor (used to signal a value-and-holder not found by get_value_and_holder())
|
||||
value_and_holder() : inst{nullptr} {}
|
||||
value_and_holder() {}
|
||||
|
||||
// Used for past-the-end iterator
|
||||
value_and_holder(size_t index) : index{index} {}
|
||||
@@ -269,8 +270,8 @@ public:
|
||||
|
||||
struct iterator {
|
||||
private:
|
||||
instance *inst;
|
||||
const type_vec *types;
|
||||
instance *inst = nullptr;
|
||||
const type_vec *types = nullptr;
|
||||
value_and_holder curr;
|
||||
friend struct values_and_holders;
|
||||
iterator(instance *inst, const type_vec *tinfo)
|
||||
@@ -570,7 +571,17 @@ public:
|
||||
// Lazy allocation for unallocated values:
|
||||
if (vptr == nullptr) {
|
||||
auto *type = v_h.type ? v_h.type : typeinfo;
|
||||
if (type->operator_new) {
|
||||
vptr = type->operator_new(type->type_size);
|
||||
} else {
|
||||
#if defined(PYBIND11_CPP17)
|
||||
if (type->type_align > __STDCPP_DEFAULT_NEW_ALIGNMENT__)
|
||||
vptr = ::operator new(type->type_size,
|
||||
(std::align_val_t) type->type_align);
|
||||
else
|
||||
#endif
|
||||
vptr = ::operator new(type->type_size);
|
||||
}
|
||||
}
|
||||
value = vptr;
|
||||
}
|
||||
@@ -774,11 +785,47 @@ template <typename T1, typename T2> struct is_copy_constructible<std::pair<T1, T
|
||||
: all_of<is_copy_constructible<T1>, is_copy_constructible<T2>> {};
|
||||
#endif
|
||||
|
||||
NAMESPACE_END(detail)
|
||||
|
||||
// polymorphic_type_hook<itype>::get(src, tinfo) determines whether the object pointed
|
||||
// to by `src` actually is an instance of some class derived from `itype`.
|
||||
// If so, it sets `tinfo` to point to the std::type_info representing that derived
|
||||
// type, and returns a pointer to the start of the most-derived object of that type
|
||||
// (in which `src` is a subobject; this will be the same address as `src` in most
|
||||
// single inheritance cases). If not, or if `src` is nullptr, it simply returns `src`
|
||||
// and leaves `tinfo` at its default value of nullptr.
|
||||
//
|
||||
// The default polymorphic_type_hook just returns src. A specialization for polymorphic
|
||||
// types determines the runtime type of the passed object and adjusts the this-pointer
|
||||
// appropriately via dynamic_cast<void*>. This is what enables a C++ Animal* to appear
|
||||
// to Python as a Dog (if Dog inherits from Animal, Animal is polymorphic, Dog is
|
||||
// registered with pybind11, and this Animal is in fact a Dog).
|
||||
//
|
||||
// You may specialize polymorphic_type_hook yourself for types that want to appear
|
||||
// polymorphic to Python but do not use C++ RTTI. (This is a not uncommon pattern
|
||||
// in performance-sensitive applications, used most notably in LLVM.)
|
||||
template <typename itype, typename SFINAE = void>
|
||||
struct polymorphic_type_hook
|
||||
{
|
||||
static const void *get(const itype *src, const std::type_info*&) { return src; }
|
||||
};
|
||||
template <typename itype>
|
||||
struct polymorphic_type_hook<itype, detail::enable_if_t<std::is_polymorphic<itype>::value>>
|
||||
{
|
||||
static const void *get(const itype *src, const std::type_info*& type) {
|
||||
type = src ? &typeid(*src) : nullptr;
|
||||
return dynamic_cast<const void*>(src);
|
||||
}
|
||||
};
|
||||
|
||||
NAMESPACE_BEGIN(detail)
|
||||
|
||||
/// Generic type caster for objects stored on the heap
|
||||
template <typename type> class type_caster_base : public type_caster_generic {
|
||||
using itype = intrinsic_t<type>;
|
||||
|
||||
public:
|
||||
static PYBIND11_DESCR name() { return type_descr(_<type>()); }
|
||||
static constexpr auto name = _<type>();
|
||||
|
||||
type_caster_base() : type_caster_base(typeid(type)) { }
|
||||
explicit type_caster_base(const std::type_info &info) : type_caster_generic(info) { }
|
||||
@@ -793,32 +840,28 @@ public:
|
||||
return cast(&src, return_value_policy::move, parent);
|
||||
}
|
||||
|
||||
// Returns a (pointer, type_info) pair taking care of necessary RTTI type lookup for a
|
||||
// polymorphic type. If the instance isn't derived, returns the non-RTTI base version.
|
||||
template <typename T = itype, enable_if_t<std::is_polymorphic<T>::value, int> = 0>
|
||||
// Returns a (pointer, type_info) pair taking care of necessary type lookup for a
|
||||
// polymorphic type (using RTTI by default, but can be overridden by specializing
|
||||
// polymorphic_type_hook). If the instance isn't derived, returns the base version.
|
||||
static std::pair<const void *, const type_info *> src_and_type(const itype *src) {
|
||||
const void *vsrc = src;
|
||||
auto &cast_type = typeid(itype);
|
||||
const std::type_info *instance_type = nullptr;
|
||||
if (vsrc) {
|
||||
instance_type = &typeid(*src);
|
||||
if (!same_type(cast_type, *instance_type)) {
|
||||
// This is a base pointer to a derived type; if it is a pybind11-registered type, we
|
||||
// can get the correct derived pointer (which may be != base pointer) by a
|
||||
// dynamic_cast to most derived type:
|
||||
if (auto *tpi = get_type_info(*instance_type))
|
||||
return {dynamic_cast<const void *>(src), const_cast<const type_info *>(tpi)};
|
||||
}
|
||||
const void *vsrc = polymorphic_type_hook<itype>::get(src, instance_type);
|
||||
if (instance_type && !same_type(cast_type, *instance_type)) {
|
||||
// This is a base pointer to a derived type. If the derived type is registered
|
||||
// with pybind11, we want to make the full derived object available.
|
||||
// In the typical case where itype is polymorphic, we get the correct
|
||||
// derived pointer (which may be != base pointer) by a dynamic_cast to
|
||||
// most derived type. If itype is not polymorphic, we won't get here
|
||||
// except via a user-provided specialization of polymorphic_type_hook,
|
||||
// and the user has promised that no this-pointer adjustment is
|
||||
// required in that case, so it's OK to use static_cast.
|
||||
if (const auto *tpi = get_type_info(*instance_type))
|
||||
return {vsrc, tpi};
|
||||
}
|
||||
// Otherwise we have either a nullptr, an `itype` pointer, or an unknown derived pointer, so
|
||||
// don't do a cast
|
||||
return type_caster_generic::src_and_type(vsrc, cast_type, instance_type);
|
||||
}
|
||||
|
||||
// Non-polymorphic type, so no dynamic casting; just call the generic version directly
|
||||
template <typename T = itype, enable_if_t<!std::is_polymorphic<T>::value, int> = 0>
|
||||
static std::pair<const void *, const type_info *> src_and_type(const itype *src) {
|
||||
return type_caster_generic::src_and_type(src, typeid(itype));
|
||||
return type_caster_generic::src_and_type(src, cast_type, instance_type);
|
||||
}
|
||||
|
||||
static handle cast(const itype *src, return_value_policy policy, handle parent) {
|
||||
@@ -835,7 +878,7 @@ public:
|
||||
nullptr, nullptr, holder);
|
||||
}
|
||||
|
||||
template <typename T> using cast_op_type = cast_op_type<T>;
|
||||
template <typename T> using cast_op_type = detail::cast_op_type<T>;
|
||||
|
||||
operator itype*() { return (type *) value; }
|
||||
operator itype&() { if (!value) throw reference_cast_error(); return *((itype *) value); }
|
||||
@@ -885,7 +928,7 @@ private:
|
||||
"std::reference_wrapper<T> caster requires T to have a caster with an `T &` operator");
|
||||
public:
|
||||
bool load(handle src, bool convert) { return subcaster.load(src, convert); }
|
||||
static PYBIND11_DESCR name() { return caster_t::name(); }
|
||||
static constexpr auto name = caster_t::name;
|
||||
static handle cast(const std::reference_wrapper<type> &src, return_value_policy policy, handle parent) {
|
||||
// It is definitely wrong to take ownership of this pointer, so mask that rvp
|
||||
if (policy == return_value_policy::take_ownership || policy == return_value_policy::automatic)
|
||||
@@ -900,7 +943,7 @@ public:
|
||||
protected: \
|
||||
type value; \
|
||||
public: \
|
||||
static PYBIND11_DESCR name() { return type_descr(py_name); } \
|
||||
static constexpr auto name = py_name; \
|
||||
template <typename T_, enable_if_t<std::is_same<type, remove_cv_t<T_>>::value, int> = 0> \
|
||||
static handle cast(T_ *src, return_value_policy policy, handle parent) { \
|
||||
if (!src) return none().release(); \
|
||||
@@ -977,20 +1020,34 @@ public:
|
||||
return true;
|
||||
}
|
||||
|
||||
static handle cast(T src, return_value_policy /* policy */, handle /* parent */) {
|
||||
if (std::is_floating_point<T>::value) {
|
||||
template<typename U = T>
|
||||
static typename std::enable_if<std::is_floating_point<U>::value, handle>::type
|
||||
cast(U src, return_value_policy /* policy */, handle /* parent */) {
|
||||
return PyFloat_FromDouble((double) src);
|
||||
} else if (sizeof(T) <= sizeof(long)) {
|
||||
if (std::is_signed<T>::value)
|
||||
return PyLong_FromLong((long) src);
|
||||
else
|
||||
return PyLong_FromUnsignedLong((unsigned long) src);
|
||||
} else {
|
||||
if (std::is_signed<T>::value)
|
||||
return PyLong_FromLongLong((long long) src);
|
||||
else
|
||||
return PyLong_FromUnsignedLongLong((unsigned long long) src);
|
||||
}
|
||||
|
||||
template<typename U = T>
|
||||
static typename std::enable_if<!std::is_floating_point<U>::value && std::is_signed<U>::value && (sizeof(U) <= sizeof(long)), handle>::type
|
||||
cast(U src, return_value_policy /* policy */, handle /* parent */) {
|
||||
return PYBIND11_LONG_FROM_SIGNED((long) src);
|
||||
}
|
||||
|
||||
template<typename U = T>
|
||||
static typename std::enable_if<!std::is_floating_point<U>::value && std::is_unsigned<U>::value && (sizeof(U) <= sizeof(unsigned long)), handle>::type
|
||||
cast(U src, return_value_policy /* policy */, handle /* parent */) {
|
||||
return PYBIND11_LONG_FROM_UNSIGNED((unsigned long) src);
|
||||
}
|
||||
|
||||
template<typename U = T>
|
||||
static typename std::enable_if<!std::is_floating_point<U>::value && std::is_signed<U>::value && (sizeof(U) > sizeof(long)), handle>::type
|
||||
cast(U src, return_value_policy /* policy */, handle /* parent */) {
|
||||
return PyLong_FromLongLong((long long) src);
|
||||
}
|
||||
|
||||
template<typename U = T>
|
||||
static typename std::enable_if<!std::is_floating_point<U>::value && std::is_unsigned<U>::value && (sizeof(U) > sizeof(unsigned long)), handle>::type
|
||||
cast(U src, return_value_policy /* policy */, handle /* parent */) {
|
||||
return PyLong_FromUnsignedLongLong((unsigned long long) src);
|
||||
}
|
||||
|
||||
PYBIND11_TYPE_CASTER(T, _<std::is_integral<T>::value>("int", "float"));
|
||||
@@ -1049,7 +1106,7 @@ public:
|
||||
|
||||
template <typename T> using cast_op_type = void*&;
|
||||
operator void *&() { return value; }
|
||||
static PYBIND11_DESCR name() { return type_descr(_("capsule")); }
|
||||
static constexpr auto name = _("capsule");
|
||||
private:
|
||||
void *value = nullptr;
|
||||
};
|
||||
@@ -1292,7 +1349,7 @@ public:
|
||||
return one_char;
|
||||
}
|
||||
|
||||
static PYBIND11_DESCR name() { return type_descr(_(PYBIND11_STRING_NAME)); }
|
||||
static constexpr auto name = _(PYBIND11_STRING_NAME);
|
||||
template <typename _T> using cast_op_type = pybind11::detail::cast_op_type<_T>;
|
||||
};
|
||||
|
||||
@@ -1317,9 +1374,7 @@ public:
|
||||
return cast_impl(std::forward<T>(src), policy, parent, indices{});
|
||||
}
|
||||
|
||||
static PYBIND11_DESCR name() {
|
||||
return type_descr(_("Tuple[") + detail::concat(make_caster<Ts>::name()...) + _("]"));
|
||||
}
|
||||
static constexpr auto name = _("Tuple[") + concat(make_caster<Ts>::name...) + _("]");
|
||||
|
||||
template <typename T> using cast_op_type = type;
|
||||
|
||||
@@ -1464,7 +1519,7 @@ struct move_only_holder_caster {
|
||||
auto *ptr = holder_helper<holder_type>::get(src);
|
||||
return type_caster_base<type>::cast_holder(ptr, std::addressof(src));
|
||||
}
|
||||
static PYBIND11_DESCR name() { return type_caster_base<type>::name(); }
|
||||
static constexpr auto name = type_caster_base<type>::name;
|
||||
};
|
||||
|
||||
template <typename type, typename deleter>
|
||||
@@ -1495,10 +1550,10 @@ template <typename base, typename holder> struct is_holder_type :
|
||||
template <typename base, typename deleter> struct is_holder_type<base, std::unique_ptr<base, deleter>> :
|
||||
std::true_type {};
|
||||
|
||||
template <typename T> struct handle_type_name { static PYBIND11_DESCR name() { return _<T>(); } };
|
||||
template <> struct handle_type_name<bytes> { static PYBIND11_DESCR name() { return _(PYBIND11_BYTES_NAME); } };
|
||||
template <> struct handle_type_name<args> { static PYBIND11_DESCR name() { return _("*args"); } };
|
||||
template <> struct handle_type_name<kwargs> { static PYBIND11_DESCR name() { return _("**kwargs"); } };
|
||||
template <typename T> struct handle_type_name { static constexpr auto name = _<T>(); };
|
||||
template <> struct handle_type_name<bytes> { static constexpr auto name = _(PYBIND11_BYTES_NAME); };
|
||||
template <> struct handle_type_name<args> { static constexpr auto name = _("*args"); };
|
||||
template <> struct handle_type_name<kwargs> { static constexpr auto name = _("**kwargs"); };
|
||||
|
||||
template <typename type>
|
||||
struct pyobject_caster {
|
||||
@@ -1516,7 +1571,7 @@ struct pyobject_caster {
|
||||
static handle cast(const handle &src, return_value_policy /* policy */, handle /* parent */) {
|
||||
return src.inc_ref();
|
||||
}
|
||||
PYBIND11_TYPE_CASTER(type, handle_type_name<type>::name());
|
||||
PYBIND11_TYPE_CASTER(type, handle_type_name<type>::name);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
@@ -1556,7 +1611,8 @@ template <typename T> using move_never = none_of<move_always<T>, move_if_unrefer
|
||||
// everything else returns a reference/pointer to a local variable.
|
||||
template <typename type> using cast_is_temporary_value_reference = bool_constant<
|
||||
(std::is_reference<type>::value || std::is_pointer<type>::value) &&
|
||||
!std::is_base_of<type_caster_generic, make_caster<type>>::value
|
||||
!std::is_base_of<type_caster_generic, make_caster<type>>::value &&
|
||||
!std::is_same<intrinsic_t<type>, void>::value
|
||||
>;
|
||||
|
||||
// When a value returned from a C++ function is being cast back to Python, we almost always want to
|
||||
@@ -1569,7 +1625,8 @@ template <typename Return, typename SFINAE = void> struct return_value_policy_ov
|
||||
template <typename Return> struct return_value_policy_override<Return,
|
||||
detail::enable_if_t<std::is_base_of<type_caster_generic, make_caster<Return>>::value, void>> {
|
||||
static return_value_policy policy(return_value_policy p) {
|
||||
return !std::is_lvalue_reference<Return>::value && !std::is_pointer<Return>::value
|
||||
return !std::is_lvalue_reference<Return>::value &&
|
||||
!std::is_pointer<Return>::value
|
||||
? return_value_policy::move : p;
|
||||
}
|
||||
};
|
||||
@@ -1798,7 +1855,7 @@ struct function_record;
|
||||
|
||||
/// Internal data associated with a single function call
|
||||
struct function_call {
|
||||
function_call(function_record &f, handle p); // Implementation in attr.h
|
||||
function_call(const function_record &f, handle p); // Implementation in attr.h
|
||||
|
||||
/// The function data:
|
||||
const function_record &func;
|
||||
@@ -1840,7 +1897,7 @@ public:
|
||||
static constexpr bool has_kwargs = kwargs_pos < 0;
|
||||
static constexpr bool has_args = args_pos < 0;
|
||||
|
||||
static PYBIND11_DESCR arg_names() { return detail::concat(make_caster<Args>::name()...); }
|
||||
static constexpr auto arg_names = concat(type_descr(make_caster<Args>::name)...);
|
||||
|
||||
bool load_args(function_call &call) {
|
||||
return load_impl_sequence(call, indices{});
|
||||
@@ -2059,9 +2116,13 @@ object object_api<Derived>::call(Args &&...args) const {
|
||||
|
||||
NAMESPACE_END(detail)
|
||||
|
||||
#define PYBIND11_MAKE_OPAQUE(Type) \
|
||||
#define PYBIND11_MAKE_OPAQUE(...) \
|
||||
namespace pybind11 { namespace detail { \
|
||||
template<> class type_caster<Type> : public type_caster_base<Type> { }; \
|
||||
template<> class type_caster<__VA_ARGS__> : public type_caster_base<__VA_ARGS__> { }; \
|
||||
}}
|
||||
|
||||
/// Lets you pass a type containing a `,` through a macro parameter without needing a separate
|
||||
/// typedef, e.g.: `PYBIND11_OVERLOAD(PYBIND11_TYPE(ReturnType<A, B>), PYBIND11_TYPE(Parent<C, D>), f, arg)`
|
||||
#define PYBIND11_TYPE(...) __VA_ARGS__
|
||||
|
||||
NAMESPACE_END(PYBIND11_NAMESPACE)
|
||||
|
@@ -25,9 +25,13 @@ template <typename T> struct format_descriptor<std::complex<T>, detail::enable_i
|
||||
static std::string format() { return std::string(value); }
|
||||
};
|
||||
|
||||
#ifndef PYBIND11_CPP17
|
||||
|
||||
template <typename T> constexpr const char format_descriptor<
|
||||
std::complex<T>, detail::enable_if_t<std::is_floating_point<T>::value>>::value[3];
|
||||
|
||||
#endif
|
||||
|
||||
NAMESPACE_BEGIN(detail)
|
||||
|
||||
template <typename T> struct is_fmt_numeric<std::complex<T>, detail::enable_if_t<std::is_floating_point<T>::value>> {
|
||||
|
@@ -10,6 +10,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "../attr.h"
|
||||
#include "../options.h"
|
||||
|
||||
NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
|
||||
NAMESPACE_BEGIN(detail)
|
||||
@@ -289,13 +290,9 @@ extern "C" inline int pybind11_object_init(PyObject *self, PyObject *, PyObject
|
||||
inline void add_patient(PyObject *nurse, PyObject *patient) {
|
||||
auto &internals = get_internals();
|
||||
auto instance = reinterpret_cast<detail::instance *>(nurse);
|
||||
auto ¤t_patients = internals.patients[nurse];
|
||||
instance->has_patients = true;
|
||||
for (auto &p : current_patients)
|
||||
if (p == patient)
|
||||
return;
|
||||
Py_INCREF(patient);
|
||||
current_patients.push_back(patient);
|
||||
internals.patients[nurse].push_back(patient);
|
||||
}
|
||||
|
||||
inline void clear_patients(PyObject *self) {
|
||||
@@ -472,7 +469,7 @@ extern "C" inline int pybind11_getbuffer(PyObject *obj, Py_buffer *view, int fla
|
||||
if (tinfo && tinfo->get_buffer)
|
||||
break;
|
||||
}
|
||||
if (view == nullptr || obj == nullptr || !tinfo || !tinfo->get_buffer) {
|
||||
if (view == nullptr || !tinfo || !tinfo->get_buffer) {
|
||||
if (view)
|
||||
view->obj = nullptr;
|
||||
PyErr_SetString(PyExc_BufferError, "pybind11_getbuffer(): Internal error");
|
||||
|
@@ -93,8 +93,8 @@
|
||||
#endif
|
||||
|
||||
#define PYBIND11_VERSION_MAJOR 2
|
||||
#define PYBIND11_VERSION_MINOR 2
|
||||
#define PYBIND11_VERSION_PATCH 4
|
||||
#define PYBIND11_VERSION_MINOR 3
|
||||
#define PYBIND11_VERSION_PATCH 0
|
||||
|
||||
/// Include Python header, disable linking to pythonX_d.lib on Windows in debug mode
|
||||
#if defined(_MSC_VER)
|
||||
@@ -159,6 +159,8 @@
|
||||
#define PYBIND11_BYTES_SIZE PyBytes_Size
|
||||
#define PYBIND11_LONG_CHECK(o) PyLong_Check(o)
|
||||
#define PYBIND11_LONG_AS_LONGLONG(o) PyLong_AsLongLong(o)
|
||||
#define PYBIND11_LONG_FROM_SIGNED(o) PyLong_FromSsize_t((ssize_t) o)
|
||||
#define PYBIND11_LONG_FROM_UNSIGNED(o) PyLong_FromSize_t((size_t) o)
|
||||
#define PYBIND11_BYTES_NAME "bytes"
|
||||
#define PYBIND11_STRING_NAME "str"
|
||||
#define PYBIND11_SLICE_OBJECT PyObject
|
||||
@@ -181,6 +183,8 @@
|
||||
#define PYBIND11_BYTES_SIZE PyString_Size
|
||||
#define PYBIND11_LONG_CHECK(o) (PyInt_Check(o) || PyLong_Check(o))
|
||||
#define PYBIND11_LONG_AS_LONGLONG(o) (PyInt_Check(o) ? (long long) PyLong_AsLong(o) : PyLong_AsLongLong(o))
|
||||
#define PYBIND11_LONG_FROM_SIGNED(o) PyInt_FromSsize_t((ssize_t) o) // Returns long if needed.
|
||||
#define PYBIND11_LONG_FROM_UNSIGNED(o) PyInt_FromSize_t((size_t) o) // Returns long if needed.
|
||||
#define PYBIND11_BYTES_NAME "str"
|
||||
#define PYBIND11_STRING_NAME "unicode"
|
||||
#define PYBIND11_SLICE_OBJECT PySliceObject
|
||||
@@ -208,6 +212,31 @@ extern "C" {
|
||||
#define PYBIND11_TOSTRING(x) PYBIND11_STRINGIFY(x)
|
||||
#define PYBIND11_CONCAT(first, second) first##second
|
||||
|
||||
#define PYBIND11_CHECK_PYTHON_VERSION \
|
||||
{ \
|
||||
const char *compiled_ver = PYBIND11_TOSTRING(PY_MAJOR_VERSION) \
|
||||
"." PYBIND11_TOSTRING(PY_MINOR_VERSION); \
|
||||
const char *runtime_ver = Py_GetVersion(); \
|
||||
size_t len = std::strlen(compiled_ver); \
|
||||
if (std::strncmp(runtime_ver, compiled_ver, len) != 0 \
|
||||
|| (runtime_ver[len] >= '0' && runtime_ver[len] <= '9')) { \
|
||||
PyErr_Format(PyExc_ImportError, \
|
||||
"Python version mismatch: module was compiled for Python %s, " \
|
||||
"but the interpreter version is incompatible: %s.", \
|
||||
compiled_ver, runtime_ver); \
|
||||
return nullptr; \
|
||||
} \
|
||||
}
|
||||
|
||||
#define PYBIND11_CATCH_INIT_EXCEPTIONS \
|
||||
catch (pybind11::error_already_set &e) { \
|
||||
PyErr_SetString(PyExc_ImportError, e.what()); \
|
||||
return nullptr; \
|
||||
} catch (const std::exception &e) { \
|
||||
PyErr_SetString(PyExc_ImportError, e.what()); \
|
||||
return nullptr; \
|
||||
} \
|
||||
|
||||
/** \rst
|
||||
***Deprecated in favor of PYBIND11_MODULE***
|
||||
|
||||
@@ -227,27 +256,10 @@ extern "C" {
|
||||
PYBIND11_DEPRECATED("PYBIND11_PLUGIN is deprecated, use PYBIND11_MODULE") \
|
||||
static PyObject *pybind11_init(); \
|
||||
PYBIND11_PLUGIN_IMPL(name) { \
|
||||
int major, minor; \
|
||||
if (sscanf(Py_GetVersion(), "%i.%i", &major, &minor) != 2) { \
|
||||
PyErr_SetString(PyExc_ImportError, "Can't parse Python version."); \
|
||||
return nullptr; \
|
||||
} else if (major != PY_MAJOR_VERSION || minor != PY_MINOR_VERSION) { \
|
||||
PyErr_Format(PyExc_ImportError, \
|
||||
"Python version mismatch: module was compiled for " \
|
||||
"version %i.%i, while the interpreter is running " \
|
||||
"version %i.%i.", PY_MAJOR_VERSION, PY_MINOR_VERSION, \
|
||||
major, minor); \
|
||||
return nullptr; \
|
||||
} \
|
||||
PYBIND11_CHECK_PYTHON_VERSION \
|
||||
try { \
|
||||
return pybind11_init(); \
|
||||
} catch (pybind11::error_already_set &e) { \
|
||||
PyErr_SetString(PyExc_ImportError, e.what()); \
|
||||
return nullptr; \
|
||||
} catch (const std::exception &e) { \
|
||||
PyErr_SetString(PyExc_ImportError, e.what()); \
|
||||
return nullptr; \
|
||||
} \
|
||||
} PYBIND11_CATCH_INIT_EXCEPTIONS \
|
||||
} \
|
||||
PyObject *pybind11_init()
|
||||
|
||||
@@ -271,29 +283,12 @@ extern "C" {
|
||||
#define PYBIND11_MODULE(name, variable) \
|
||||
static void PYBIND11_CONCAT(pybind11_init_, name)(pybind11::module &); \
|
||||
PYBIND11_PLUGIN_IMPL(name) { \
|
||||
int major, minor; \
|
||||
if (sscanf(Py_GetVersion(), "%i.%i", &major, &minor) != 2) { \
|
||||
PyErr_SetString(PyExc_ImportError, "Can't parse Python version."); \
|
||||
return nullptr; \
|
||||
} else if (major != PY_MAJOR_VERSION || minor != PY_MINOR_VERSION) { \
|
||||
PyErr_Format(PyExc_ImportError, \
|
||||
"Python version mismatch: module was compiled for " \
|
||||
"version %i.%i, while the interpreter is running " \
|
||||
"version %i.%i.", PY_MAJOR_VERSION, PY_MINOR_VERSION, \
|
||||
major, minor); \
|
||||
return nullptr; \
|
||||
} \
|
||||
PYBIND11_CHECK_PYTHON_VERSION \
|
||||
auto m = pybind11::module(PYBIND11_TOSTRING(name)); \
|
||||
try { \
|
||||
PYBIND11_CONCAT(pybind11_init_, name)(m); \
|
||||
return m.ptr(); \
|
||||
} catch (pybind11::error_already_set &e) { \
|
||||
PyErr_SetString(PyExc_ImportError, e.what()); \
|
||||
return nullptr; \
|
||||
} catch (const std::exception &e) { \
|
||||
PyErr_SetString(PyExc_ImportError, e.what()); \
|
||||
return nullptr; \
|
||||
} \
|
||||
} PYBIND11_CATCH_INIT_EXCEPTIONS \
|
||||
} \
|
||||
void PYBIND11_CONCAT(pybind11_init_, name)(pybind11::module &variable)
|
||||
|
||||
@@ -391,7 +386,7 @@ struct instance {
|
||||
void *simple_value_holder[1 + instance_simple_holder_in_ptrs()];
|
||||
nonsimple_values_and_holders nonsimple;
|
||||
};
|
||||
/// Weak references (needed for keep alive):
|
||||
/// Weak references
|
||||
PyObject *weakrefs;
|
||||
/// If true, the pointer is owned which means we're free to manage it with a holder.
|
||||
bool owned : 1;
|
||||
@@ -408,10 +403,10 @@ struct instance {
|
||||
* (which is typically the size of two pointers), or when multiple inheritance is used on the
|
||||
* python side. Non-simple layout allocates the required amount of memory to have multiple
|
||||
* bound C++ classes as parents. Under this layout, `nonsimple.values_and_holders` is set to a
|
||||
* pointer to allocated space of the required space to hold a a sequence of value pointers and
|
||||
* pointer to allocated space of the required space to hold a sequence of value pointers and
|
||||
* holders followed `status`, a set of bit flags (1 byte each), i.e.
|
||||
* [val1*][holder1][val2*][holder2]...[bb...] where each [block] is rounded up to a multiple of
|
||||
* `sizeof(void *)`. `nonsimple.holder_constructed` is, for convenience, a pointer to the
|
||||
* `sizeof(void *)`. `nonsimple.status` is, for convenience, a pointer to the
|
||||
* beginning of the [bb...] block (but not independently allocated).
|
||||
*
|
||||
* Status bits indicate whether the associated holder is constructed (&
|
||||
@@ -584,6 +579,11 @@ template <typename T, typename... Us> using deferred_t = typename deferred_type<
|
||||
template <typename Base, typename Derived> using is_strict_base_of = bool_constant<
|
||||
std::is_base_of<Base, Derived>::value && !std::is_same<Base, Derived>::value>;
|
||||
|
||||
/// Like is_base_of, but also requires that the base type is accessible (i.e. that a Derived pointer
|
||||
/// can be converted to a Base pointer)
|
||||
template <typename Base, typename Derived> using is_accessible_base_of = bool_constant<
|
||||
std::is_base_of<Base, Derived>::value && std::is_convertible<Derived *, Base *>::value>;
|
||||
|
||||
template <template<typename...> class Base>
|
||||
struct is_template_base_of_impl {
|
||||
template <typename... Us> static std::true_type check(Base<Us...> *);
|
||||
@@ -702,9 +702,13 @@ template <typename T> struct format_descriptor<T, detail::enable_if_t<std::is_ar
|
||||
static std::string format() { return std::string(1, c); }
|
||||
};
|
||||
|
||||
#if !defined(PYBIND11_CPP17)
|
||||
|
||||
template <typename T> constexpr const char format_descriptor<
|
||||
T, detail::enable_if_t<std::is_arithmetic<T>::value>>::value[2];
|
||||
|
||||
#endif
|
||||
|
||||
/// RAII wrapper that temporarily clears any Python error state
|
||||
struct error_scope {
|
||||
PyObject *type, *value, *trace;
|
||||
|
@@ -1,6 +1,5 @@
|
||||
/*
|
||||
pybind11/detail/descr.h: Helper type for concatenating type signatures
|
||||
either at runtime (C++11) or compile time (C++14)
|
||||
pybind11/detail/descr.h: Helper type for concatenating type signatures at compile time
|
||||
|
||||
Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
|
||||
|
||||
@@ -15,171 +14,87 @@
|
||||
NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
|
||||
NAMESPACE_BEGIN(detail)
|
||||
|
||||
/* Concatenate type signatures at compile time using C++14 */
|
||||
#if defined(PYBIND11_CPP14) && !defined(_MSC_VER)
|
||||
#define PYBIND11_CONSTEXPR_DESCR
|
||||
#if !defined(_MSC_VER)
|
||||
# define PYBIND11_DESCR_CONSTEXPR static constexpr
|
||||
#else
|
||||
# define PYBIND11_DESCR_CONSTEXPR const
|
||||
#endif
|
||||
|
||||
template <size_t Size1, size_t Size2> class descr {
|
||||
template <size_t Size1_, size_t Size2_> friend class descr;
|
||||
public:
|
||||
constexpr descr(char const (&text) [Size1+1], const std::type_info * const (&types)[Size2+1])
|
||||
: descr(text, types,
|
||||
make_index_sequence<Size1>(),
|
||||
make_index_sequence<Size2>()) { }
|
||||
/* Concatenate type signatures at compile time */
|
||||
template <size_t N, typename... Ts>
|
||||
struct descr {
|
||||
char text[N + 1];
|
||||
|
||||
constexpr const char *text() const { return m_text; }
|
||||
constexpr const std::type_info * const * types() const { return m_types; }
|
||||
constexpr descr() : text{'\0'} { }
|
||||
constexpr descr(char const (&s)[N+1]) : descr(s, make_index_sequence<N>()) { }
|
||||
|
||||
template <size_t OtherSize1, size_t OtherSize2>
|
||||
constexpr descr<Size1 + OtherSize1, Size2 + OtherSize2> operator+(const descr<OtherSize1, OtherSize2> &other) const {
|
||||
return concat(other,
|
||||
make_index_sequence<Size1>(),
|
||||
make_index_sequence<Size2>(),
|
||||
make_index_sequence<OtherSize1>(),
|
||||
make_index_sequence<OtherSize2>());
|
||||
template <size_t... Is>
|
||||
constexpr descr(char const (&s)[N+1], index_sequence<Is...>) : text{s[Is]..., '\0'} { }
|
||||
|
||||
template <typename... Chars>
|
||||
constexpr descr(char c, Chars... cs) : text{c, static_cast<char>(cs)..., '\0'} { }
|
||||
|
||||
static constexpr std::array<const std::type_info *, sizeof...(Ts) + 1> types() {
|
||||
return {{&typeid(Ts)..., nullptr}};
|
||||
}
|
||||
|
||||
protected:
|
||||
template <size_t... Indices1, size_t... Indices2>
|
||||
constexpr descr(
|
||||
char const (&text) [Size1+1],
|
||||
const std::type_info * const (&types) [Size2+1],
|
||||
index_sequence<Indices1...>, index_sequence<Indices2...>)
|
||||
: m_text{text[Indices1]..., '\0'},
|
||||
m_types{types[Indices2]..., nullptr } {}
|
||||
|
||||
template <size_t OtherSize1, size_t OtherSize2, size_t... Indices1,
|
||||
size_t... Indices2, size_t... OtherIndices1, size_t... OtherIndices2>
|
||||
constexpr descr<Size1 + OtherSize1, Size2 + OtherSize2>
|
||||
concat(const descr<OtherSize1, OtherSize2> &other,
|
||||
index_sequence<Indices1...>, index_sequence<Indices2...>,
|
||||
index_sequence<OtherIndices1...>, index_sequence<OtherIndices2...>) const {
|
||||
return descr<Size1 + OtherSize1, Size2 + OtherSize2>(
|
||||
{ m_text[Indices1]..., other.m_text[OtherIndices1]..., '\0' },
|
||||
{ m_types[Indices2]..., other.m_types[OtherIndices2]..., nullptr }
|
||||
);
|
||||
}
|
||||
|
||||
protected:
|
||||
char m_text[Size1 + 1];
|
||||
const std::type_info * m_types[Size2 + 1];
|
||||
};
|
||||
|
||||
template <size_t Size> constexpr descr<Size - 1, 0> _(char const(&text)[Size]) {
|
||||
return descr<Size - 1, 0>(text, { nullptr });
|
||||
template <size_t N1, size_t N2, typename... Ts1, typename... Ts2, size_t... Is1, size_t... Is2>
|
||||
constexpr descr<N1 + N2, Ts1..., Ts2...> plus_impl(const descr<N1, Ts1...> &a, const descr<N2, Ts2...> &b,
|
||||
index_sequence<Is1...>, index_sequence<Is2...>) {
|
||||
return {a.text[Is1]..., b.text[Is2]...};
|
||||
}
|
||||
|
||||
template <size_t N1, size_t N2, typename... Ts1, typename... Ts2>
|
||||
constexpr descr<N1 + N2, Ts1..., Ts2...> operator+(const descr<N1, Ts1...> &a, const descr<N2, Ts2...> &b) {
|
||||
return plus_impl(a, b, make_index_sequence<N1>(), make_index_sequence<N2>());
|
||||
}
|
||||
|
||||
template <size_t N>
|
||||
constexpr descr<N - 1> _(char const(&text)[N]) { return descr<N - 1>(text); }
|
||||
constexpr descr<0> _(char const(&)[1]) { return {}; }
|
||||
|
||||
template <size_t Rem, size_t... Digits> struct int_to_str : int_to_str<Rem/10, Rem%10, Digits...> { };
|
||||
template <size_t...Digits> struct int_to_str<0, Digits...> {
|
||||
static constexpr auto digits = descr<sizeof...(Digits), 0>({ ('0' + Digits)..., '\0' }, { nullptr });
|
||||
static constexpr auto digits = descr<sizeof...(Digits)>(('0' + Digits)...);
|
||||
};
|
||||
|
||||
// Ternary description (like std::conditional)
|
||||
template <bool B, size_t Size1, size_t Size2>
|
||||
constexpr enable_if_t<B, descr<Size1 - 1, 0>> _(char const(&text1)[Size1], char const(&)[Size2]) {
|
||||
template <bool B, size_t N1, size_t N2>
|
||||
constexpr enable_if_t<B, descr<N1 - 1>> _(char const(&text1)[N1], char const(&)[N2]) {
|
||||
return _(text1);
|
||||
}
|
||||
template <bool B, size_t Size1, size_t Size2>
|
||||
constexpr enable_if_t<!B, descr<Size2 - 1, 0>> _(char const(&)[Size1], char const(&text2)[Size2]) {
|
||||
template <bool B, size_t N1, size_t N2>
|
||||
constexpr enable_if_t<!B, descr<N2 - 1>> _(char const(&)[N1], char const(&text2)[N2]) {
|
||||
return _(text2);
|
||||
}
|
||||
template <bool B, size_t SizeA1, size_t SizeA2, size_t SizeB1, size_t SizeB2>
|
||||
constexpr enable_if_t<B, descr<SizeA1, SizeA2>> _(descr<SizeA1, SizeA2> d, descr<SizeB1, SizeB2>) { return d; }
|
||||
template <bool B, size_t SizeA1, size_t SizeA2, size_t SizeB1, size_t SizeB2>
|
||||
constexpr enable_if_t<!B, descr<SizeB1, SizeB2>> _(descr<SizeA1, SizeA2>, descr<SizeB1, SizeB2> d) { return d; }
|
||||
|
||||
template <bool B, typename T1, typename T2>
|
||||
constexpr enable_if_t<B, T1> _(const T1 &d, const T2 &) { return d; }
|
||||
template <bool B, typename T1, typename T2>
|
||||
constexpr enable_if_t<!B, T2> _(const T1 &, const T2 &d) { return d; }
|
||||
|
||||
template <size_t Size> auto constexpr _() -> decltype(int_to_str<Size / 10, Size % 10>::digits) {
|
||||
return int_to_str<Size / 10, Size % 10>::digits;
|
||||
}
|
||||
|
||||
template <typename Type> constexpr descr<1, 1> _() {
|
||||
return descr<1, 1>({ '%', '\0' }, { &typeid(Type), nullptr });
|
||||
template <typename Type> constexpr descr<1, Type> _() { return {'%'}; }
|
||||
|
||||
constexpr descr<0> concat() { return {}; }
|
||||
|
||||
template <size_t N, typename... Ts>
|
||||
constexpr descr<N, Ts...> concat(const descr<N, Ts...> &descr) { return descr; }
|
||||
|
||||
template <size_t N, typename... Ts, typename... Args>
|
||||
constexpr auto concat(const descr<N, Ts...> &d, const Args &...args)
|
||||
-> decltype(std::declval<descr<N + 2, Ts...>>() + concat(args...)) {
|
||||
return d + _(", ") + concat(args...);
|
||||
}
|
||||
|
||||
inline constexpr descr<0, 0> concat() { return _(""); }
|
||||
template <size_t Size1, size_t Size2, typename... Args> auto constexpr concat(descr<Size1, Size2> descr) { return descr; }
|
||||
template <size_t Size1, size_t Size2, typename... Args> auto constexpr concat(descr<Size1, Size2> descr, Args&&... args) { return descr + _(", ") + concat(args...); }
|
||||
template <size_t Size1, size_t Size2> auto constexpr type_descr(descr<Size1, Size2> descr) { return _("{") + descr + _("}"); }
|
||||
|
||||
#define PYBIND11_DESCR constexpr auto
|
||||
|
||||
#else /* Simpler C++11 implementation based on run-time memory allocation and copying */
|
||||
|
||||
class descr {
|
||||
public:
|
||||
PYBIND11_NOINLINE descr(const char *text, const std::type_info * const * types) {
|
||||
size_t nChars = len(text), nTypes = len(types);
|
||||
m_text = new char[nChars];
|
||||
m_types = new const std::type_info *[nTypes];
|
||||
memcpy(m_text, text, nChars * sizeof(char));
|
||||
memcpy(m_types, types, nTypes * sizeof(const std::type_info *));
|
||||
template <size_t N, typename... Ts>
|
||||
constexpr descr<N + 2, Ts...> type_descr(const descr<N, Ts...> &descr) {
|
||||
return _("{") + descr + _("}");
|
||||
}
|
||||
|
||||
PYBIND11_NOINLINE descr operator+(descr &&d2) && {
|
||||
descr r;
|
||||
|
||||
size_t nChars1 = len(m_text), nTypes1 = len(m_types);
|
||||
size_t nChars2 = len(d2.m_text), nTypes2 = len(d2.m_types);
|
||||
|
||||
r.m_text = new char[nChars1 + nChars2 - 1];
|
||||
r.m_types = new const std::type_info *[nTypes1 + nTypes2 - 1];
|
||||
memcpy(r.m_text, m_text, (nChars1-1) * sizeof(char));
|
||||
memcpy(r.m_text + nChars1 - 1, d2.m_text, nChars2 * sizeof(char));
|
||||
memcpy(r.m_types, m_types, (nTypes1-1) * sizeof(std::type_info *));
|
||||
memcpy(r.m_types + nTypes1 - 1, d2.m_types, nTypes2 * sizeof(std::type_info *));
|
||||
|
||||
delete[] m_text; delete[] m_types;
|
||||
delete[] d2.m_text; delete[] d2.m_types;
|
||||
|
||||
return r;
|
||||
}
|
||||
|
||||
char *text() { return m_text; }
|
||||
const std::type_info * * types() { return m_types; }
|
||||
|
||||
protected:
|
||||
PYBIND11_NOINLINE descr() { }
|
||||
|
||||
template <typename T> static size_t len(const T *ptr) { // return length including null termination
|
||||
const T *it = ptr;
|
||||
while (*it++ != (T) 0)
|
||||
;
|
||||
return static_cast<size_t>(it - ptr);
|
||||
}
|
||||
|
||||
const std::type_info **m_types = nullptr;
|
||||
char *m_text = nullptr;
|
||||
};
|
||||
|
||||
/* The 'PYBIND11_NOINLINE inline' combinations below are intentional to get the desired linkage while producing as little object code as possible */
|
||||
|
||||
PYBIND11_NOINLINE inline descr _(const char *text) {
|
||||
const std::type_info *types[1] = { nullptr };
|
||||
return descr(text, types);
|
||||
}
|
||||
|
||||
template <bool B> PYBIND11_NOINLINE enable_if_t<B, descr> _(const char *text1, const char *) { return _(text1); }
|
||||
template <bool B> PYBIND11_NOINLINE enable_if_t<!B, descr> _(char const *, const char *text2) { return _(text2); }
|
||||
template <bool B> PYBIND11_NOINLINE enable_if_t<B, descr> _(descr d, descr) { return d; }
|
||||
template <bool B> PYBIND11_NOINLINE enable_if_t<!B, descr> _(descr, descr d) { return d; }
|
||||
|
||||
template <typename Type> PYBIND11_NOINLINE descr _() {
|
||||
const std::type_info *types[2] = { &typeid(Type), nullptr };
|
||||
return descr("%", types);
|
||||
}
|
||||
|
||||
template <size_t Size> PYBIND11_NOINLINE descr _() {
|
||||
const std::type_info *types[1] = { nullptr };
|
||||
return descr(std::to_string(Size).c_str(), types);
|
||||
}
|
||||
|
||||
PYBIND11_NOINLINE inline descr concat() { return _(""); }
|
||||
PYBIND11_NOINLINE inline descr concat(descr &&d) { return d; }
|
||||
template <typename... Args> PYBIND11_NOINLINE descr concat(descr &&d, Args&&... args) { return std::move(d) + _(", ") + concat(std::forward<Args>(args)...); }
|
||||
PYBIND11_NOINLINE inline descr type_descr(descr&& d) { return _("{") + std::move(d) + _("}"); }
|
||||
|
||||
#define PYBIND11_DESCR ::pybind11::detail::descr
|
||||
#endif
|
||||
|
||||
NAMESPACE_END(detail)
|
||||
NAMESPACE_END(PYBIND11_NAMESPACE)
|
||||
|
@@ -24,7 +24,7 @@ public:
|
||||
|
||||
template <typename> using cast_op_type = value_and_holder &;
|
||||
operator value_and_holder &() { return *value; }
|
||||
static PYBIND11_DESCR name() { return type_descr(_<value_and_holder>()); }
|
||||
static constexpr auto name = _<value_and_holder>();
|
||||
|
||||
private:
|
||||
value_and_holder *value = nullptr;
|
||||
|
@@ -23,7 +23,7 @@ inline PyObject *make_object_base_type(PyTypeObject *metaclass);
|
||||
#if PY_VERSION_HEX >= 0x03070000
|
||||
# define PYBIND11_TLS_KEY_INIT(var) Py_tss_t *var = nullptr
|
||||
# define PYBIND11_TLS_GET_VALUE(key) PyThread_tss_get((key))
|
||||
# define PYBIND11_TLS_REPLACE_VALUE(key, value) PyThread_tss_set((key), (tstate))
|
||||
# define PYBIND11_TLS_REPLACE_VALUE(key, value) PyThread_tss_set((key), (value))
|
||||
# define PYBIND11_TLS_DELETE_VALUE(key) PyThread_tss_set((key), nullptr)
|
||||
#else
|
||||
// Usually an int but a long on Cygwin64 with Python 3.x
|
||||
@@ -116,7 +116,7 @@ struct internals {
|
||||
struct type_info {
|
||||
PyTypeObject *type;
|
||||
const std::type_info *cpptype;
|
||||
size_t type_size, holder_size_in_ptrs;
|
||||
size_t type_size, type_align, holder_size_in_ptrs;
|
||||
void *(*operator_new)(size_t);
|
||||
void (*init_instance)(instance *, const void *);
|
||||
void (*dealloc)(value_and_holder &v_h);
|
||||
@@ -138,7 +138,13 @@ struct type_info {
|
||||
};
|
||||
|
||||
/// Tracks the `internals` and `type_info` ABI version independent of the main library version
|
||||
#define PYBIND11_INTERNALS_VERSION 2
|
||||
#define PYBIND11_INTERNALS_VERSION 3
|
||||
|
||||
#if defined(_DEBUG)
|
||||
# define PYBIND11_BUILD_TYPE "_debug"
|
||||
#else
|
||||
# define PYBIND11_BUILD_TYPE ""
|
||||
#endif
|
||||
|
||||
#if defined(WITH_THREAD)
|
||||
# define PYBIND11_INTERNALS_KIND ""
|
||||
@@ -147,10 +153,10 @@ struct type_info {
|
||||
#endif
|
||||
|
||||
#define PYBIND11_INTERNALS_ID "__pybind11_internals_v" \
|
||||
PYBIND11_TOSTRING(PYBIND11_INTERNALS_VERSION) PYBIND11_INTERNALS_KIND "__"
|
||||
PYBIND11_TOSTRING(PYBIND11_INTERNALS_VERSION) PYBIND11_INTERNALS_KIND PYBIND11_BUILD_TYPE "__"
|
||||
|
||||
#define PYBIND11_MODULE_LOCAL_ID "__pybind11_module_local_v" \
|
||||
PYBIND11_TOSTRING(PYBIND11_INTERNALS_VERSION) PYBIND11_INTERNALS_KIND "__"
|
||||
PYBIND11_TOSTRING(PYBIND11_INTERNALS_VERSION) PYBIND11_INTERNALS_KIND PYBIND11_BUILD_TYPE "__"
|
||||
|
||||
/// Each module locally stores a pointer to the `internals` data. The data
|
||||
/// itself is shared among modules with the same `PYBIND11_INTERNALS_ID`.
|
||||
|
@@ -16,6 +16,8 @@
|
||||
#include <cxxabi.h>
|
||||
#endif
|
||||
|
||||
#include "common.h"
|
||||
|
||||
NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
|
||||
NAMESPACE_BEGIN(detail)
|
||||
/// Erase all occurrences of a substring
|
||||
|
@@ -17,6 +17,11 @@
|
||||
# pragma GCC diagnostic push
|
||||
# pragma GCC diagnostic ignored "-Wconversion"
|
||||
# pragma GCC diagnostic ignored "-Wdeprecated-declarations"
|
||||
# ifdef __clang__
|
||||
// Eigen generates a bunch of implicit-copy-constructor-is-deprecated warnings with -Wdeprecated
|
||||
// under Clang, so disable that warning here:
|
||||
# pragma GCC diagnostic ignored "-Wdeprecated"
|
||||
# endif
|
||||
# if __GNUC__ >= 7
|
||||
# pragma GCC diagnostic ignored "-Wint-in-bool-context"
|
||||
# endif
|
||||
@@ -181,13 +186,13 @@ template <typename Type_> struct EigenProps {
|
||||
}
|
||||
}
|
||||
|
||||
static PYBIND11_DESCR descriptor() {
|
||||
constexpr bool show_writeable = is_eigen_dense_map<Type>::value && is_eigen_mutable_map<Type>::value;
|
||||
constexpr bool show_order = is_eigen_dense_map<Type>::value;
|
||||
constexpr bool show_c_contiguous = show_order && requires_row_major;
|
||||
constexpr bool show_f_contiguous = !show_c_contiguous && show_order && requires_col_major;
|
||||
static constexpr bool show_writeable = is_eigen_dense_map<Type>::value && is_eigen_mutable_map<Type>::value;
|
||||
static constexpr bool show_order = is_eigen_dense_map<Type>::value;
|
||||
static constexpr bool show_c_contiguous = show_order && requires_row_major;
|
||||
static constexpr bool show_f_contiguous = !show_c_contiguous && show_order && requires_col_major;
|
||||
|
||||
return type_descr(_("numpy.ndarray[") + npy_format_descriptor<Scalar>::name() +
|
||||
static constexpr auto descriptor =
|
||||
_("numpy.ndarray[") + npy_format_descriptor<Scalar>::name +
|
||||
_("[") + _<fixed_rows>(_<(size_t) rows>(), _("m")) +
|
||||
_(", ") + _<fixed_cols>(_<(size_t) cols>(), _("n")) +
|
||||
_("]") +
|
||||
@@ -200,9 +205,7 @@ template <typename Type_> struct EigenProps {
|
||||
_<show_writeable>(", flags.writeable", "") +
|
||||
_<show_c_contiguous>(", flags.c_contiguous", "") +
|
||||
_<show_f_contiguous>(", flags.f_contiguous", "") +
|
||||
_("]")
|
||||
);
|
||||
}
|
||||
_("]");
|
||||
};
|
||||
|
||||
// Casts an Eigen type to numpy array. If given a base, the numpy array references the src data,
|
||||
@@ -339,7 +342,7 @@ public:
|
||||
return cast_impl(src, policy, parent);
|
||||
}
|
||||
|
||||
static PYBIND11_DESCR name() { return props::descriptor(); }
|
||||
static constexpr auto name = props::descriptor;
|
||||
|
||||
operator Type*() { return &value; }
|
||||
operator Type&() { return value; }
|
||||
@@ -379,7 +382,7 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
static PYBIND11_DESCR name() { return props::descriptor(); }
|
||||
static constexpr auto name = props::descriptor;
|
||||
|
||||
// Explicitly delete these: support python -> C++ conversion on these (i.e. these can be return
|
||||
// types but not bound arguments). We still provide them (with an explicitly delete) so that
|
||||
@@ -524,7 +527,7 @@ public:
|
||||
}
|
||||
static handle cast(const Type *src, return_value_policy policy, handle parent) { return cast(*src, policy, parent); }
|
||||
|
||||
static PYBIND11_DESCR name() { return props::descriptor(); }
|
||||
static constexpr auto name = props::descriptor;
|
||||
|
||||
// Explicitly delete these: support python -> C++ conversion on these (i.e. these can be return
|
||||
// types but not bound arguments). We still provide them (with an explicitly delete) so that
|
||||
@@ -591,7 +594,7 @@ struct type_caster<Type, enable_if_t<is_eigen_sparse<Type>::value>> {
|
||||
}
|
||||
|
||||
PYBIND11_TYPE_CASTER(Type, _<(Type::IsRowMajor) != 0>("scipy.sparse.csr_matrix[", "scipy.sparse.csc_matrix[")
|
||||
+ npy_format_descriptor<Scalar>::name() + _("]"));
|
||||
+ npy_format_descriptor<Scalar>::name + _("]"));
|
||||
};
|
||||
|
||||
NAMESPACE_END(detail)
|
||||
|
@@ -90,8 +90,14 @@ NAMESPACE_END(detail)
|
||||
Initialize the Python interpreter. No other pybind11 or CPython API functions can be
|
||||
called before this is done; with the exception of `PYBIND11_EMBEDDED_MODULE`. The
|
||||
optional parameter can be used to skip the registration of signal handlers (see the
|
||||
Python documentation for details). Calling this function again after the interpreter
|
||||
`Python documentation`_ for details). Calling this function again after the interpreter
|
||||
has already been initialized is a fatal error.
|
||||
|
||||
If initializing the Python interpreter fails, then the program is terminated. (This
|
||||
is controlled by the CPython runtime and is an exception to pybind11's normal behavior
|
||||
of throwing exceptions on errors.)
|
||||
|
||||
.. _Python documentation: https://docs.python.org/3/c-api/init.html#c.Py_InitializeEx
|
||||
\endrst */
|
||||
inline void initialize_interpreter(bool init_signal_handlers = true) {
|
||||
if (Py_IsInitialized())
|
||||
|
@@ -54,9 +54,20 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
value = [func](Args... args) -> Return {
|
||||
// ensure GIL is held during functor destruction
|
||||
struct func_handle {
|
||||
function f;
|
||||
func_handle(function&& f_) : f(std::move(f_)) {}
|
||||
func_handle(const func_handle&) = default;
|
||||
~func_handle() {
|
||||
gil_scoped_acquire acq;
|
||||
object retval(func(std::forward<Args>(args)...));
|
||||
function kill_f(std::move(f));
|
||||
}
|
||||
};
|
||||
|
||||
value = [hfunc = func_handle(std::move(func))](Args... args) -> Return {
|
||||
gil_scoped_acquire acq;
|
||||
object retval(hfunc.f(std::forward<Args>(args)...));
|
||||
/* Visual studio 2015 parser issue: need parentheses around this expression */
|
||||
return (retval.template cast<Return>());
|
||||
};
|
||||
@@ -75,10 +86,8 @@ public:
|
||||
return cpp_function(std::forward<Func>(f_), policy).release();
|
||||
}
|
||||
|
||||
PYBIND11_TYPE_CASTER(type, _("Callable[[") +
|
||||
argument_loader<Args...>::arg_names() + _("], ") +
|
||||
make_caster<retval_type>::name() +
|
||||
_("]"));
|
||||
PYBIND11_TYPE_CASTER(type, _("Callable[[") + concat(make_caster<Args>::name...) + _("], ")
|
||||
+ make_caster<retval_type>::name + _("]"));
|
||||
};
|
||||
|
||||
NAMESPACE_END(detail)
|
||||
|
@@ -25,7 +25,8 @@ class pythonbuf : public std::streambuf {
|
||||
private:
|
||||
using traits_type = std::streambuf::traits_type;
|
||||
|
||||
char d_buffer[1024];
|
||||
const size_t buf_size;
|
||||
std::unique_ptr<char[]> d_buffer;
|
||||
object pywrite;
|
||||
object pyflush;
|
||||
|
||||
@@ -42,8 +43,11 @@ private:
|
||||
// This subtraction cannot be negative, so dropping the sign
|
||||
str line(pbase(), static_cast<size_t>(pptr() - pbase()));
|
||||
|
||||
{
|
||||
gil_scoped_acquire tmp;
|
||||
pywrite(line);
|
||||
pyflush();
|
||||
}
|
||||
|
||||
setp(pbase(), epptr());
|
||||
}
|
||||
@@ -51,10 +55,13 @@ private:
|
||||
}
|
||||
|
||||
public:
|
||||
pythonbuf(object pyostream)
|
||||
: pywrite(pyostream.attr("write")),
|
||||
|
||||
pythonbuf(object pyostream, size_t buffer_size = 1024)
|
||||
: buf_size(buffer_size),
|
||||
d_buffer(new char[buf_size]),
|
||||
pywrite(pyostream.attr("write")),
|
||||
pyflush(pyostream.attr("flush")) {
|
||||
setp(d_buffer, d_buffer + sizeof(d_buffer) - 1);
|
||||
setp(d_buffer.get(), d_buffer.get() + buf_size - 1);
|
||||
}
|
||||
|
||||
/// Sync before destroy
|
||||
@@ -194,7 +201,7 @@ inline class_<detail::OstreamRedirect> add_ostream_redirect(module m, std::strin
|
||||
return class_<detail::OstreamRedirect>(m, name.c_str(), module_local())
|
||||
.def(init<bool,bool>(), arg("stdout")=true, arg("stderr")=true)
|
||||
.def("__enter__", &detail::OstreamRedirect::enter)
|
||||
.def("__exit__", [](detail::OstreamRedirect &self, args) { self.exit(); });
|
||||
.def("__exit__", [](detail::OstreamRedirect &self_, args) { self_.exit(); });
|
||||
}
|
||||
|
||||
NAMESPACE_END(PYBIND11_NAMESPACE)
|
||||
|
@@ -18,9 +18,9 @@
|
||||
#include <cstring>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <initializer_list>
|
||||
#include <functional>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include <typeindex>
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
@@ -250,7 +250,7 @@ template <typename T> struct array_info_scalar {
|
||||
typedef T type;
|
||||
static constexpr bool is_array = false;
|
||||
static constexpr bool is_empty = false;
|
||||
static PYBIND11_DESCR extents() { return _(""); }
|
||||
static constexpr auto extents = _("");
|
||||
static void append_extents(list& /* shape */) { }
|
||||
};
|
||||
// Computes underlying type and a comma-separated list of extents for array
|
||||
@@ -269,15 +269,9 @@ template <typename T, size_t N> struct array_info<std::array<T, N>> {
|
||||
array_info<T>::append_extents(shape);
|
||||
}
|
||||
|
||||
template<typename T2 = T, enable_if_t<!array_info<T2>::is_array, int> = 0>
|
||||
static PYBIND11_DESCR extents() {
|
||||
return _<N>();
|
||||
}
|
||||
|
||||
template<typename T2 = T, enable_if_t<array_info<T2>::is_array, int> = 0>
|
||||
static PYBIND11_DESCR extents() {
|
||||
return concat(_<N>(), array_info<T>::extents());
|
||||
}
|
||||
static constexpr auto extents = _<array_info<T>::is_array>(
|
||||
concat(_<N>(), array_info<T>::extents), _<N>()
|
||||
);
|
||||
};
|
||||
// For numpy we have special handling for arrays of characters, so we don't include
|
||||
// the size in the array extents.
|
||||
@@ -446,7 +440,7 @@ public:
|
||||
/// This is essentially the same as calling numpy.dtype(args) in Python.
|
||||
static dtype from_args(object args) {
|
||||
PyObject *ptr = nullptr;
|
||||
if (!detail::npy_api::get().PyArray_DescrConverter_(args.release().ptr(), &ptr) || !ptr)
|
||||
if (!detail::npy_api::get().PyArray_DescrConverter_(args.ptr(), &ptr) || !ptr)
|
||||
throw error_already_set();
|
||||
return reinterpret_steal<dtype>(ptr);
|
||||
}
|
||||
@@ -861,14 +855,14 @@ public:
|
||||
|
||||
// Reference to element at a given index
|
||||
template<typename... Ix> const T& at(Ix... index) const {
|
||||
if (sizeof...(index) != ndim())
|
||||
if ((ssize_t) sizeof...(index) != ndim())
|
||||
fail_dim_check(sizeof...(index), "index dimension mismatch");
|
||||
return *(static_cast<const T*>(array::data()) + byte_offset(ssize_t(index)...) / itemsize());
|
||||
}
|
||||
|
||||
// Mutable reference to element at a given index
|
||||
template<typename... Ix> T& mutable_at(Ix... index) {
|
||||
if (sizeof...(index) != ndim())
|
||||
if ((ssize_t) sizeof...(index) != ndim())
|
||||
fail_dim_check(sizeof...(index), "index dimension mismatch");
|
||||
return *(static_cast<T*>(array::mutable_data()) + byte_offset(ssize_t(index)...) / itemsize());
|
||||
}
|
||||
@@ -948,8 +942,8 @@ template <typename T>
|
||||
struct format_descriptor<T, detail::enable_if_t<detail::array_info<T>::is_array>> {
|
||||
static std::string format() {
|
||||
using namespace detail;
|
||||
PYBIND11_DESCR extents = _("(") + array_info<T>::extents() + _(")");
|
||||
return extents.text() + format_descriptor<remove_all_extents_t<T>>::format();
|
||||
static constexpr auto extents = _("(") + array_info<T>::extents + _(")");
|
||||
return extents.text + format_descriptor<remove_all_extents_t<T>>::format();
|
||||
}
|
||||
};
|
||||
|
||||
@@ -968,7 +962,7 @@ struct pyobject_caster<array_t<T, ExtraFlags>> {
|
||||
static handle cast(const handle &src, return_value_policy /* policy */, handle /* parent */) {
|
||||
return src.inc_ref();
|
||||
}
|
||||
PYBIND11_TYPE_CASTER(type, handle_type_name<type>::name());
|
||||
PYBIND11_TYPE_CASTER(type, handle_type_name<type>::name);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
@@ -978,7 +972,34 @@ struct compare_buffer_info<T, detail::enable_if_t<detail::is_pod_struct<T>::valu
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T> struct npy_format_descriptor<T, enable_if_t<satisfies_any_of<T, std::is_arithmetic, is_complex>::value>> {
|
||||
template <typename T, typename = void>
|
||||
struct npy_format_descriptor_name;
|
||||
|
||||
template <typename T>
|
||||
struct npy_format_descriptor_name<T, enable_if_t<std::is_integral<T>::value>> {
|
||||
static constexpr auto name = _<std::is_same<T, bool>::value>(
|
||||
_("bool"), _<std::is_signed<T>::value>("int", "uint") + _<sizeof(T)*8>()
|
||||
);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct npy_format_descriptor_name<T, enable_if_t<std::is_floating_point<T>::value>> {
|
||||
static constexpr auto name = _<std::is_same<T, float>::value || std::is_same<T, double>::value>(
|
||||
_("float") + _<sizeof(T)*8>(), _("longdouble")
|
||||
);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct npy_format_descriptor_name<T, enable_if_t<is_complex<T>::value>> {
|
||||
static constexpr auto name = _<std::is_same<typename T::value_type, float>::value
|
||||
|| std::is_same<typename T::value_type, double>::value>(
|
||||
_("complex") + _<sizeof(typename T::value_type)*16>(), _("longcomplex")
|
||||
);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct npy_format_descriptor<T, enable_if_t<satisfies_any_of<T, std::is_arithmetic, is_complex>::value>>
|
||||
: npy_format_descriptor_name<T> {
|
||||
private:
|
||||
// NB: the order here must match the one in common.h
|
||||
constexpr static const int values[15] = {
|
||||
@@ -997,25 +1018,10 @@ public:
|
||||
return reinterpret_borrow<pybind11::dtype>(ptr);
|
||||
pybind11_fail("Unsupported buffer format!");
|
||||
}
|
||||
template <typename T2 = T, enable_if_t<std::is_integral<T2>::value, int> = 0>
|
||||
static PYBIND11_DESCR name() {
|
||||
return _<std::is_same<T, bool>::value>(_("bool"),
|
||||
_<std::is_signed<T>::value>("int", "uint") + _<sizeof(T)*8>());
|
||||
}
|
||||
template <typename T2 = T, enable_if_t<std::is_floating_point<T2>::value, int> = 0>
|
||||
static PYBIND11_DESCR name() {
|
||||
return _<std::is_same<T, float>::value || std::is_same<T, double>::value>(
|
||||
_("float") + _<sizeof(T)*8>(), _("longdouble"));
|
||||
}
|
||||
template <typename T2 = T, enable_if_t<is_complex<T2>::value, int> = 0>
|
||||
static PYBIND11_DESCR name() {
|
||||
return _<std::is_same<typename T2::value_type, float>::value || std::is_same<typename T2::value_type, double>::value>(
|
||||
_("complex") + _<sizeof(typename T2::value_type)*16>(), _("longcomplex"));
|
||||
}
|
||||
};
|
||||
|
||||
#define PYBIND11_DECL_CHAR_FMT \
|
||||
static PYBIND11_DESCR name() { return _("S") + _<N>(); } \
|
||||
static constexpr auto name = _("S") + _<N>(); \
|
||||
static pybind11::dtype dtype() { return pybind11::dtype(std::string("S") + std::to_string(N)); }
|
||||
template <size_t N> struct npy_format_descriptor<char[N]> { PYBIND11_DECL_CHAR_FMT };
|
||||
template <size_t N> struct npy_format_descriptor<std::array<char, N>> { PYBIND11_DECL_CHAR_FMT };
|
||||
@@ -1027,7 +1033,7 @@ private:
|
||||
public:
|
||||
static_assert(!array_info<T>::is_empty, "Zero-sized arrays are not supported");
|
||||
|
||||
static PYBIND11_DESCR name() { return _("(") + array_info<T>::extents() + _(")") + base_descr::name(); }
|
||||
static constexpr auto name = _("(") + array_info<T>::extents + _(")") + base_descr::name;
|
||||
static pybind11::dtype dtype() {
|
||||
list shape;
|
||||
array_info<T>::append_extents(shape);
|
||||
@@ -1039,7 +1045,7 @@ template<typename T> struct npy_format_descriptor<T, enable_if_t<std::is_enum<T>
|
||||
private:
|
||||
using base_descr = npy_format_descriptor<typename std::underlying_type<T>::type>;
|
||||
public:
|
||||
static PYBIND11_DESCR name() { return base_descr::name(); }
|
||||
static constexpr auto name = base_descr::name;
|
||||
static pybind11::dtype dtype() { return base_descr::dtype(); }
|
||||
};
|
||||
|
||||
@@ -1052,7 +1058,7 @@ struct field_descriptor {
|
||||
};
|
||||
|
||||
inline PYBIND11_NOINLINE void register_structured_dtype(
|
||||
const std::initializer_list<field_descriptor>& fields,
|
||||
any_container<field_descriptor> fields,
|
||||
const std::type_info& tinfo, ssize_t itemsize,
|
||||
bool (*direct_converter)(PyObject *, void *&)) {
|
||||
|
||||
@@ -1061,7 +1067,7 @@ inline PYBIND11_NOINLINE void register_structured_dtype(
|
||||
pybind11_fail("NumPy: dtype is already registered");
|
||||
|
||||
list names, formats, offsets;
|
||||
for (auto field : fields) {
|
||||
for (auto field : *fields) {
|
||||
if (!field.descr)
|
||||
pybind11_fail(std::string("NumPy: unsupported field dtype: `") +
|
||||
field.name + "` @ " + tinfo.name());
|
||||
@@ -1078,7 +1084,7 @@ inline PYBIND11_NOINLINE void register_structured_dtype(
|
||||
// - https://github.com/numpy/numpy/pull/7798
|
||||
// Because of this, we won't use numpy's logic to generate buffer format
|
||||
// strings and will just do it ourselves.
|
||||
std::vector<field_descriptor> ordered_fields(fields);
|
||||
std::vector<field_descriptor> ordered_fields(std::move(fields));
|
||||
std::sort(ordered_fields.begin(), ordered_fields.end(),
|
||||
[](const field_descriptor &a, const field_descriptor &b) { return a.offset < b.offset; });
|
||||
ssize_t offset = 0;
|
||||
@@ -1114,7 +1120,7 @@ inline PYBIND11_NOINLINE void register_structured_dtype(
|
||||
template <typename T, typename SFINAE> struct npy_format_descriptor {
|
||||
static_assert(is_pod_struct<T>::value, "Attempt to use a non-POD or unimplemented POD type as a numpy dtype");
|
||||
|
||||
static PYBIND11_DESCR name() { return make_caster<T>::name(); }
|
||||
static constexpr auto name = make_caster<T>::name;
|
||||
|
||||
static pybind11::dtype dtype() {
|
||||
return reinterpret_borrow<pybind11::dtype>(dtype_ptr());
|
||||
@@ -1125,8 +1131,8 @@ template <typename T, typename SFINAE> struct npy_format_descriptor {
|
||||
return format_str;
|
||||
}
|
||||
|
||||
static void register_dtype(const std::initializer_list<field_descriptor>& fields) {
|
||||
register_structured_dtype(fields, typeid(typename std::remove_cv<T>::type),
|
||||
static void register_dtype(any_container<field_descriptor> fields) {
|
||||
register_structured_dtype(std::move(fields), typeid(typename std::remove_cv<T>::type),
|
||||
sizeof(T), &direct_converter);
|
||||
}
|
||||
|
||||
@@ -1199,7 +1205,8 @@ private:
|
||||
|
||||
#define PYBIND11_NUMPY_DTYPE(Type, ...) \
|
||||
::pybind11::detail::npy_format_descriptor<Type>::register_dtype \
|
||||
({PYBIND11_MAP_LIST (PYBIND11_FIELD_DESCRIPTOR, Type, __VA_ARGS__)})
|
||||
(::std::vector<::pybind11::detail::field_descriptor> \
|
||||
{PYBIND11_MAP_LIST (PYBIND11_FIELD_DESCRIPTOR, Type, __VA_ARGS__)})
|
||||
|
||||
#ifdef _MSC_VER
|
||||
#define PYBIND11_MAP2_LIST_NEXT1(test, next) \
|
||||
@@ -1220,7 +1227,8 @@ private:
|
||||
|
||||
#define PYBIND11_NUMPY_DTYPE_EX(Type, ...) \
|
||||
::pybind11::detail::npy_format_descriptor<Type>::register_dtype \
|
||||
({PYBIND11_MAP2_LIST (PYBIND11_FIELD_DESCRIPTOR_EX, Type, __VA_ARGS__)})
|
||||
(::std::vector<::pybind11::detail::field_descriptor> \
|
||||
{PYBIND11_MAP2_LIST (PYBIND11_FIELD_DESCRIPTOR_EX, Type, __VA_ARGS__)})
|
||||
|
||||
#endif // __CLION_IDE__
|
||||
|
||||
@@ -1458,7 +1466,10 @@ public:
|
||||
private:
|
||||
remove_reference_t<Func> f;
|
||||
|
||||
template <size_t Index> using param_n_t = typename pack_element<Index, typename vectorize_arg<Args>::call_type...>::type;
|
||||
// Internal compiler error in MSVC 19.16.27025.1 (Visual Studio 2017 15.9.4), when compiling with "/permissive-" flag
|
||||
// when arg_call_types is manually inlined.
|
||||
using arg_call_types = std::tuple<typename vectorize_arg<Args>::call_type...>;
|
||||
template <size_t Index> using param_n_t = typename std::tuple_element<Index, arg_call_types>::type;
|
||||
|
||||
// Runs a vectorized function given arguments tuple and three index sequences:
|
||||
// - Index is the full set of 0 ... (N-1) argument indices;
|
||||
@@ -1498,7 +1509,7 @@ private:
|
||||
if (trivial == broadcast_trivial::f_trivial) result = array_t<Return, array::f_style>(shape);
|
||||
else result = array_t<Return>(shape);
|
||||
|
||||
if (size == 0) return result;
|
||||
if (size == 0) return std::move(result);
|
||||
|
||||
/* Call the function */
|
||||
if (trivial == broadcast_trivial::non_trivial)
|
||||
@@ -1506,7 +1517,7 @@ private:
|
||||
else
|
||||
apply_trivial(buffers, params, result.mutable_data(), size, i_seq, vi_seq, bi_seq);
|
||||
|
||||
return result;
|
||||
return std::move(result);
|
||||
}
|
||||
|
||||
template <size_t... Index, size_t... VIndex, size_t... BIndex>
|
||||
@@ -1559,9 +1570,7 @@ vectorize_extractor(const Func &f, Return (*) (Args ...)) {
|
||||
}
|
||||
|
||||
template <typename T, int Flags> struct handle_type_name<array_t<T, Flags>> {
|
||||
static PYBIND11_DESCR name() {
|
||||
return _("numpy.ndarray[") + npy_format_descriptor<T>::name() + _("]");
|
||||
}
|
||||
static constexpr auto name = _("numpy.ndarray[") + npy_format_descriptor<T>::name + _("]");
|
||||
};
|
||||
|
||||
NAMESPACE_END(detail)
|
||||
|
@@ -10,7 +10,17 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
#if defined(__INTEL_COMPILER)
|
||||
# pragma warning push
|
||||
# pragma warning disable 68 // integer conversion resulted in a change of sign
|
||||
# pragma warning disable 186 // pointless comparison of unsigned integer with zero
|
||||
# pragma warning disable 878 // incompatible exception specifications
|
||||
# pragma warning disable 1334 // the "template" keyword used for syntactic disambiguation may only be used within a template
|
||||
# pragma warning disable 1682 // implicit conversion of a 64-bit integral type to a smaller integral type (potential portability problem)
|
||||
# pragma warning disable 1786 // function "strdup" was declared deprecated
|
||||
# pragma warning disable 1875 // offsetof applied to non-POD (Plain Old Data) types is nonstandard
|
||||
# pragma warning disable 2196 // warning #2196: routine is both "inline" and "noinline"
|
||||
#elif defined(_MSC_VER)
|
||||
# pragma warning(push)
|
||||
# pragma warning(disable: 4100) // warning C4100: Unreferenced formal parameter
|
||||
# pragma warning(disable: 4127) // warning C4127: Conditional expression is constant
|
||||
@@ -19,15 +29,6 @@
|
||||
# pragma warning(disable: 4996) // warning C4996: The POSIX name for this item is deprecated. Instead, use the ISO C and C++ conformant name
|
||||
# pragma warning(disable: 4702) // warning C4702: unreachable code
|
||||
# pragma warning(disable: 4522) // warning C4522: multiple assignment operators specified
|
||||
#elif defined(__INTEL_COMPILER)
|
||||
# pragma warning(push)
|
||||
# pragma warning(disable: 68) // integer conversion resulted in a change of sign
|
||||
# pragma warning(disable: 186) // pointless comparison of unsigned integer with zero
|
||||
# pragma warning(disable: 878) // incompatible exception specifications
|
||||
# pragma warning(disable: 1334) // the "template" keyword used for syntactic disambiguation may only be used within a template
|
||||
# pragma warning(disable: 1682) // implicit conversion of a 64-bit integral type to a smaller integral type (potential portability problem)
|
||||
# pragma warning(disable: 1875) // offsetof applied to non-POD (Plain Old Data) types is nonstandard
|
||||
# pragma warning(disable: 2196) // warning #2196: routine is both "inline" and "noinline"
|
||||
#elif defined(__GNUG__) && !defined(__clang__)
|
||||
# pragma GCC diagnostic push
|
||||
# pragma GCC diagnostic ignored "-Wunused-but-set-parameter"
|
||||
@@ -40,6 +41,11 @@
|
||||
# endif
|
||||
#endif
|
||||
|
||||
#if defined(__GNUG__) && !defined(__clang__)
|
||||
#include <cxxabi.h>
|
||||
#endif
|
||||
|
||||
|
||||
#include "attr.h"
|
||||
#include "options.h"
|
||||
#include "detail/class.h"
|
||||
@@ -51,6 +57,7 @@ NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
|
||||
class cpp_function : public function {
|
||||
public:
|
||||
cpp_function() { }
|
||||
cpp_function(std::nullptr_t) { }
|
||||
|
||||
/// Construct a cpp_function from a vanilla function pointer
|
||||
template <typename Return, typename... Args, typename... Extra>
|
||||
@@ -93,7 +100,6 @@ protected:
|
||||
template <typename Func, typename Return, typename... Args, typename... Extra>
|
||||
void initialize(Func &&f, Return (*)(Args...), const Extra&... extra) {
|
||||
using namespace detail;
|
||||
|
||||
struct capture { remove_reference_t<Func> f; };
|
||||
|
||||
/* Store the function including any extra state it might have (e.g. a lambda capture object) */
|
||||
@@ -164,10 +170,11 @@ protected:
|
||||
process_attributes<Extra...>::init(extra..., rec);
|
||||
|
||||
/* Generate a readable signature describing the function's arguments and return value types */
|
||||
PYBIND11_DESCR signature = _("(") + cast_in::arg_names() + _(") -> ") + cast_out::name();
|
||||
static constexpr auto signature = _("(") + cast_in::arg_names + _(") -> ") + cast_out::name;
|
||||
PYBIND11_DESCR_CONSTEXPR auto types = decltype(signature)::types();
|
||||
|
||||
/* Register the function with Python from generic (non-templated) code */
|
||||
initialize_generic(rec, signature.text(), signature.types(), sizeof...(Args));
|
||||
initialize_generic(rec, signature.text, types.data(), sizeof...(Args));
|
||||
|
||||
if (cast_in::has_args) rec->has_args = true;
|
||||
if (cast_in::has_kwargs) rec->has_kwargs = true;
|
||||
@@ -217,16 +224,16 @@ protected:
|
||||
|
||||
/* Generate a proper function signature */
|
||||
std::string signature;
|
||||
size_t type_depth = 0, char_index = 0, type_index = 0, arg_index = 0;
|
||||
while (true) {
|
||||
char c = text[char_index++];
|
||||
if (c == '\0')
|
||||
break;
|
||||
size_t type_index = 0, arg_index = 0;
|
||||
for (auto *pc = text; *pc != '\0'; ++pc) {
|
||||
const auto c = *pc;
|
||||
|
||||
if (c == '{') {
|
||||
// Write arg name for everything except *args, **kwargs and return type.
|
||||
if (type_depth == 0 && text[char_index] != '*' && arg_index < args) {
|
||||
if (!rec->args.empty() && rec->args[arg_index].name) {
|
||||
// Write arg name for everything except *args and **kwargs.
|
||||
if (*(pc + 1) == '*')
|
||||
continue;
|
||||
|
||||
if (arg_index < rec->args.size() && rec->args[arg_index].name) {
|
||||
signature += rec->args[arg_index].name;
|
||||
} else if (arg_index == 0 && rec->is_method) {
|
||||
signature += "self";
|
||||
@@ -234,17 +241,13 @@ protected:
|
||||
signature += "arg" + std::to_string(arg_index - (rec->is_method ? 1 : 0));
|
||||
}
|
||||
signature += ": ";
|
||||
}
|
||||
++type_depth;
|
||||
} else if (c == '}') {
|
||||
--type_depth;
|
||||
if (type_depth == 0) {
|
||||
// Write default value if available.
|
||||
if (arg_index < rec->args.size() && rec->args[arg_index].descr) {
|
||||
signature += " = ";
|
||||
signature += rec->args[arg_index].descr;
|
||||
}
|
||||
arg_index++;
|
||||
}
|
||||
} else if (c == '%') {
|
||||
const std::type_info *t = types[type_index++];
|
||||
if (!t)
|
||||
@@ -269,14 +272,9 @@ protected:
|
||||
signature += c;
|
||||
}
|
||||
}
|
||||
if (type_depth != 0 || types[type_index] != nullptr)
|
||||
if (arg_index != args || types[type_index] != nullptr)
|
||||
pybind11_fail("Internal error while parsing type signature (2)");
|
||||
|
||||
#if !defined(PYBIND11_CONSTEXPR_DESCR)
|
||||
delete[] types;
|
||||
delete[] text;
|
||||
#endif
|
||||
|
||||
#if PY_MAJOR_VERSION < 3
|
||||
if (strcmp(rec->name, "__next__") == 0) {
|
||||
std::free(rec->name);
|
||||
@@ -428,7 +426,7 @@ protected:
|
||||
using namespace detail;
|
||||
|
||||
/* Iterator over the list of potentially admissible overloads */
|
||||
function_record *overloads = (function_record *) PyCapsule_GetPointer(self, nullptr),
|
||||
const function_record *overloads = (function_record *) PyCapsule_GetPointer(self, nullptr),
|
||||
*it = overloads;
|
||||
|
||||
/* Need to know how many arguments + keyword arguments there are to pick the right overload */
|
||||
@@ -485,7 +483,7 @@ protected:
|
||||
result other than PYBIND11_TRY_NEXT_OVERLOAD.
|
||||
*/
|
||||
|
||||
function_record &func = *it;
|
||||
const function_record &func = *it;
|
||||
size_t pos_args = func.nargs; // Number of positional arguments that we need
|
||||
if (func.has_args) --pos_args; // (but don't count py::args
|
||||
if (func.has_kwargs) --pos_args; // or py::kwargs)
|
||||
@@ -517,7 +515,7 @@ protected:
|
||||
// 1. Copy any position arguments given.
|
||||
bool bad_arg = false;
|
||||
for (; args_copied < args_to_copy; ++args_copied) {
|
||||
argument_record *arg_rec = args_copied < func.args.size() ? &func.args[args_copied] : nullptr;
|
||||
const argument_record *arg_rec = args_copied < func.args.size() ? &func.args[args_copied] : nullptr;
|
||||
if (kwargs_in && arg_rec && arg_rec->name && PyDict_GetItemString(kwargs_in, arg_rec->name)) {
|
||||
bad_arg = true;
|
||||
break;
|
||||
@@ -658,13 +656,22 @@ protected:
|
||||
result = PYBIND11_TRY_NEXT_OVERLOAD;
|
||||
}
|
||||
|
||||
if (result.ptr() != PYBIND11_TRY_NEXT_OVERLOAD)
|
||||
if (result.ptr() != PYBIND11_TRY_NEXT_OVERLOAD) {
|
||||
// The error reporting logic below expects 'it' to be valid, as it would be
|
||||
// if we'd encountered this failure in the first-pass loop.
|
||||
if (!result)
|
||||
it = &call.func;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (error_already_set &e) {
|
||||
e.restore();
|
||||
return nullptr;
|
||||
#if defined(__GNUG__) && !defined(__clang__)
|
||||
} catch ( abi::__forced_unwind& ) {
|
||||
throw;
|
||||
#endif
|
||||
} catch (...) {
|
||||
/* When an exception is caught, give each registered exception
|
||||
translator a chance to translate it to a Python exception
|
||||
@@ -711,7 +718,7 @@ protected:
|
||||
" arguments. The following argument types are supported:\n";
|
||||
|
||||
int ctr = 0;
|
||||
for (function_record *it2 = overloads; it2 != nullptr; it2 = it2->next) {
|
||||
for (const function_record *it2 = overloads; it2 != nullptr; it2 = it2->next) {
|
||||
msg += " "+ std::to_string(++ctr) + ". ";
|
||||
|
||||
bool wrote_sig = false;
|
||||
@@ -899,6 +906,7 @@ protected:
|
||||
tinfo->type = (PyTypeObject *) m_ptr;
|
||||
tinfo->cpptype = rec.type;
|
||||
tinfo->type_size = rec.type_size;
|
||||
tinfo->type_align = rec.type_align;
|
||||
tinfo->operator_new = rec.operator_new;
|
||||
tinfo->holder_size_in_ptrs = size_in_ptrs(rec.holder_size);
|
||||
tinfo->init_instance = rec.init_instance;
|
||||
@@ -961,18 +969,18 @@ protected:
|
||||
tinfo->get_buffer_data = get_buffer_data;
|
||||
}
|
||||
|
||||
// rec_func must be set for either fget or fset.
|
||||
void def_property_static_impl(const char *name,
|
||||
handle fget, handle fset,
|
||||
detail::function_record *rec_fget) {
|
||||
const auto is_static = !(rec_fget->is_method && rec_fget->scope);
|
||||
const auto has_doc = rec_fget->doc && pybind11::options::show_user_defined_docstrings();
|
||||
|
||||
detail::function_record *rec_func) {
|
||||
const auto is_static = rec_func && !(rec_func->is_method && rec_func->scope);
|
||||
const auto has_doc = rec_func && rec_func->doc && pybind11::options::show_user_defined_docstrings();
|
||||
auto property = handle((PyObject *) (is_static ? get_internals().static_property_type
|
||||
: &PyProperty_Type));
|
||||
attr(name) = property(fget.ptr() ? fget : none(),
|
||||
fset.ptr() ? fset : none(),
|
||||
/*deleter*/none(),
|
||||
pybind11::str(has_doc ? rec_fget->doc : ""));
|
||||
pybind11::str(has_doc ? rec_func->doc : ""));
|
||||
}
|
||||
};
|
||||
|
||||
@@ -990,11 +998,21 @@ template <typename T> struct has_operator_delete_size<T, void_t<decltype(static_
|
||||
: std::true_type { };
|
||||
/// Call class-specific delete if it exists or global otherwise. Can also be an overload set.
|
||||
template <typename T, enable_if_t<has_operator_delete<T>::value, int> = 0>
|
||||
void call_operator_delete(T *p, size_t) { T::operator delete(p); }
|
||||
void call_operator_delete(T *p, size_t, size_t) { T::operator delete(p); }
|
||||
template <typename T, enable_if_t<!has_operator_delete<T>::value && has_operator_delete_size<T>::value, int> = 0>
|
||||
void call_operator_delete(T *p, size_t s) { T::operator delete(p, s); }
|
||||
void call_operator_delete(T *p, size_t s, size_t) { T::operator delete(p, s); }
|
||||
|
||||
inline void call_operator_delete(void *p, size_t) { ::operator delete(p); }
|
||||
inline void call_operator_delete(void *p, size_t s, size_t a) {
|
||||
(void)s; (void)a;
|
||||
#if defined(PYBIND11_CPP17)
|
||||
if (a > __STDCPP_DEFAULT_NEW_ALIGNMENT__)
|
||||
::operator delete(p, s, std::align_val_t(a));
|
||||
else
|
||||
::operator delete(p, s);
|
||||
#else
|
||||
::operator delete(p);
|
||||
#endif
|
||||
}
|
||||
|
||||
NAMESPACE_END(detail)
|
||||
|
||||
@@ -1004,10 +1022,18 @@ template <typename /*Derived*/, typename F>
|
||||
auto method_adaptor(F &&f) -> decltype(std::forward<F>(f)) { return std::forward<F>(f); }
|
||||
|
||||
template <typename Derived, typename Return, typename Class, typename... Args>
|
||||
auto method_adaptor(Return (Class::*pmf)(Args...)) -> Return (Derived::*)(Args...) { return pmf; }
|
||||
auto method_adaptor(Return (Class::*pmf)(Args...)) -> Return (Derived::*)(Args...) {
|
||||
static_assert(detail::is_accessible_base_of<Class, Derived>::value,
|
||||
"Cannot bind an inaccessible base class method; use a lambda definition instead");
|
||||
return pmf;
|
||||
}
|
||||
|
||||
template <typename Derived, typename Return, typename Class, typename... Args>
|
||||
auto method_adaptor(Return (Class::*pmf)(Args...) const) -> Return (Derived::*)(Args...) const { return pmf; }
|
||||
auto method_adaptor(Return (Class::*pmf)(Args...) const) -> Return (Derived::*)(Args...) const {
|
||||
static_assert(detail::is_accessible_base_of<Class, Derived>::value,
|
||||
"Cannot bind an inaccessible base class method; use a lambda definition instead");
|
||||
return pmf;
|
||||
}
|
||||
|
||||
template <typename type_, typename... options>
|
||||
class class_ : public detail::generic_type {
|
||||
@@ -1049,10 +1075,11 @@ public:
|
||||
record.name = name;
|
||||
record.type = &typeid(type);
|
||||
record.type_size = sizeof(conditional_t<has_alias, type_alias, type>);
|
||||
record.type_align = alignof(conditional_t<has_alias, type_alias, type>&);
|
||||
record.holder_size = sizeof(holder_type);
|
||||
record.init_instance = init_instance;
|
||||
record.dealloc = dealloc;
|
||||
record.default_holder = std::is_same<holder_type, std::unique_ptr<type>>::value;
|
||||
record.default_holder = detail::is_instantiation<std::unique_ptr, holder_type>::value;
|
||||
|
||||
set_operator_new<type>(&record);
|
||||
|
||||
@@ -1094,7 +1121,7 @@ public:
|
||||
"def_static(...) called with a non-static member function pointer");
|
||||
cpp_function cf(std::forward<Func>(f), name(name_), scope(*this),
|
||||
sibling(getattr(*this, name_, none())), extra...);
|
||||
attr(cf.name()) = cf;
|
||||
attr(cf.name()) = staticmethod(cf);
|
||||
return *this;
|
||||
}
|
||||
|
||||
@@ -1158,7 +1185,7 @@ public:
|
||||
|
||||
template <typename C, typename D, typename... Extra>
|
||||
class_ &def_readwrite(const char *name, D C::*pm, const Extra&... extra) {
|
||||
static_assert(std::is_base_of<C, type>::value, "def_readwrite() requires a class member (or base class member)");
|
||||
static_assert(std::is_same<C, type>::value || std::is_base_of<C, type>::value, "def_readwrite() requires a class member (or base class member)");
|
||||
cpp_function fget([pm](const type &c) -> const D &{ return c.*pm; }, is_method(*this)),
|
||||
fset([pm](type &c, const D &value) { c.*pm = value; }, is_method(*this));
|
||||
def_property(name, fget, fset, return_value_policy::reference_internal, extra...);
|
||||
@@ -1167,7 +1194,7 @@ public:
|
||||
|
||||
template <typename C, typename D, typename... Extra>
|
||||
class_ &def_readonly(const char *name, const D C::*pm, const Extra& ...extra) {
|
||||
static_assert(std::is_base_of<C, type>::value, "def_readonly() requires a class member (or base class member)");
|
||||
static_assert(std::is_same<C, type>::value || std::is_base_of<C, type>::value, "def_readonly() requires a class member (or base class member)");
|
||||
cpp_function fget([pm](const type &c) -> const D &{ return c.*pm; }, is_method(*this));
|
||||
def_property_readonly(name, fget, return_value_policy::reference_internal, extra...);
|
||||
return *this;
|
||||
@@ -1198,7 +1225,7 @@ public:
|
||||
/// Uses cpp_function's return_value_policy by default
|
||||
template <typename... Extra>
|
||||
class_ &def_property_readonly(const char *name, const cpp_function &fget, const Extra& ...extra) {
|
||||
return def_property(name, fget, cpp_function(), extra...);
|
||||
return def_property(name, fget, nullptr, extra...);
|
||||
}
|
||||
|
||||
/// Uses return_value_policy::reference by default
|
||||
@@ -1210,7 +1237,7 @@ public:
|
||||
/// Uses cpp_function's return_value_policy by default
|
||||
template <typename... Extra>
|
||||
class_ &def_property_readonly_static(const char *name, const cpp_function &fget, const Extra& ...extra) {
|
||||
return def_property_static(name, fget, cpp_function(), extra...);
|
||||
return def_property_static(name, fget, nullptr, extra...);
|
||||
}
|
||||
|
||||
/// Uses return_value_policy::reference_internal by default
|
||||
@@ -1239,22 +1266,28 @@ public:
|
||||
/// Uses cpp_function's return_value_policy by default
|
||||
template <typename... Extra>
|
||||
class_ &def_property_static(const char *name, const cpp_function &fget, const cpp_function &fset, const Extra& ...extra) {
|
||||
static_assert( 0 == detail::constexpr_sum(std::is_base_of<arg, Extra>::value...),
|
||||
"Argument annotations are not allowed for properties");
|
||||
auto rec_fget = get_function_record(fget), rec_fset = get_function_record(fset);
|
||||
auto *rec_active = rec_fget;
|
||||
if (rec_fget) {
|
||||
char *doc_prev = rec_fget->doc; /* 'extra' field may include a property-specific documentation string */
|
||||
detail::process_attributes<Extra...>::init(extra..., rec_fget);
|
||||
if (rec_fget->doc && rec_fget->doc != doc_prev) {
|
||||
free(doc_prev);
|
||||
rec_fget->doc = strdup(rec_fget->doc);
|
||||
}
|
||||
}
|
||||
if (rec_fset) {
|
||||
doc_prev = rec_fset->doc;
|
||||
char *doc_prev = rec_fset->doc;
|
||||
detail::process_attributes<Extra...>::init(extra..., rec_fset);
|
||||
if (rec_fset->doc && rec_fset->doc != doc_prev) {
|
||||
free(doc_prev);
|
||||
rec_fset->doc = strdup(rec_fset->doc);
|
||||
}
|
||||
if (! rec_active) rec_active = rec_fset;
|
||||
}
|
||||
def_property_static_impl(name, fget, fset, rec_fget);
|
||||
def_property_static_impl(name, fget, fset, rec_active);
|
||||
return *this;
|
||||
}
|
||||
|
||||
@@ -1320,7 +1353,10 @@ private:
|
||||
v_h.set_holder_constructed(false);
|
||||
}
|
||||
else {
|
||||
detail::call_operator_delete(v_h.value_ptr<type>(), v_h.type->type_size);
|
||||
detail::call_operator_delete(v_h.value_ptr<type>(),
|
||||
v_h.type->type_size,
|
||||
v_h.type->type_align
|
||||
);
|
||||
}
|
||||
v_h.value_ptr() = nullptr;
|
||||
}
|
||||
@@ -1356,93 +1392,190 @@ detail::initimpl::pickle_factory<GetState, SetState> pickle(GetState &&g, SetSta
|
||||
return {std::forward<GetState>(g), std::forward<SetState>(s)};
|
||||
}
|
||||
|
||||
NAMESPACE_BEGIN(detail)
|
||||
struct enum_base {
|
||||
enum_base(handle base, handle parent) : m_base(base), m_parent(parent) { }
|
||||
|
||||
PYBIND11_NOINLINE void init(bool is_arithmetic, bool is_convertible) {
|
||||
m_base.attr("__entries") = dict();
|
||||
auto property = handle((PyObject *) &PyProperty_Type);
|
||||
auto static_property = handle((PyObject *) get_internals().static_property_type);
|
||||
|
||||
m_base.attr("__repr__") = cpp_function(
|
||||
[](handle arg) -> str {
|
||||
handle type = arg.get_type();
|
||||
object type_name = type.attr("__name__");
|
||||
dict entries = type.attr("__entries");
|
||||
for (const auto &kv : entries) {
|
||||
object other = kv.second[int_(0)];
|
||||
if (other.equal(arg))
|
||||
return pybind11::str("{}.{}").format(type_name, kv.first);
|
||||
}
|
||||
return pybind11::str("{}.???").format(type_name);
|
||||
}, is_method(m_base)
|
||||
);
|
||||
|
||||
m_base.attr("name") = property(cpp_function(
|
||||
[](handle arg) -> str {
|
||||
dict entries = arg.get_type().attr("__entries");
|
||||
for (const auto &kv : entries) {
|
||||
if (handle(kv.second[int_(0)]).equal(arg))
|
||||
return pybind11::str(kv.first);
|
||||
}
|
||||
return "???";
|
||||
}, is_method(m_base)
|
||||
));
|
||||
|
||||
m_base.attr("__doc__") = static_property(cpp_function(
|
||||
[](handle arg) -> std::string {
|
||||
std::string docstring;
|
||||
dict entries = arg.attr("__entries");
|
||||
if (((PyTypeObject *) arg.ptr())->tp_doc)
|
||||
docstring += std::string(((PyTypeObject *) arg.ptr())->tp_doc) + "\n\n";
|
||||
docstring += "Members:";
|
||||
for (const auto &kv : entries) {
|
||||
auto key = std::string(pybind11::str(kv.first));
|
||||
auto comment = kv.second[int_(1)];
|
||||
docstring += "\n\n " + key;
|
||||
if (!comment.is_none())
|
||||
docstring += " : " + (std::string) pybind11::str(comment);
|
||||
}
|
||||
return docstring;
|
||||
}
|
||||
), none(), none(), "");
|
||||
|
||||
m_base.attr("__members__") = static_property(cpp_function(
|
||||
[](handle arg) -> dict {
|
||||
dict entries = arg.attr("__entries"), m;
|
||||
for (const auto &kv : entries)
|
||||
m[kv.first] = kv.second[int_(0)];
|
||||
return m;
|
||||
}), none(), none(), ""
|
||||
);
|
||||
|
||||
#define PYBIND11_ENUM_OP_STRICT(op, expr, strict_behavior) \
|
||||
m_base.attr(op) = cpp_function( \
|
||||
[](object a, object b) { \
|
||||
if (!a.get_type().is(b.get_type())) \
|
||||
strict_behavior; \
|
||||
return expr; \
|
||||
}, \
|
||||
is_method(m_base))
|
||||
|
||||
#define PYBIND11_ENUM_OP_CONV(op, expr) \
|
||||
m_base.attr(op) = cpp_function( \
|
||||
[](object a_, object b_) { \
|
||||
int_ a(a_), b(b_); \
|
||||
return expr; \
|
||||
}, \
|
||||
is_method(m_base))
|
||||
|
||||
if (is_convertible) {
|
||||
PYBIND11_ENUM_OP_CONV("__eq__", !b.is_none() && a.equal(b));
|
||||
PYBIND11_ENUM_OP_CONV("__ne__", b.is_none() || !a.equal(b));
|
||||
|
||||
if (is_arithmetic) {
|
||||
PYBIND11_ENUM_OP_CONV("__lt__", a < b);
|
||||
PYBIND11_ENUM_OP_CONV("__gt__", a > b);
|
||||
PYBIND11_ENUM_OP_CONV("__le__", a <= b);
|
||||
PYBIND11_ENUM_OP_CONV("__ge__", a >= b);
|
||||
PYBIND11_ENUM_OP_CONV("__and__", a & b);
|
||||
PYBIND11_ENUM_OP_CONV("__rand__", a & b);
|
||||
PYBIND11_ENUM_OP_CONV("__or__", a | b);
|
||||
PYBIND11_ENUM_OP_CONV("__ror__", a | b);
|
||||
PYBIND11_ENUM_OP_CONV("__xor__", a ^ b);
|
||||
PYBIND11_ENUM_OP_CONV("__rxor__", a ^ b);
|
||||
}
|
||||
} else {
|
||||
PYBIND11_ENUM_OP_STRICT("__eq__", int_(a).equal(int_(b)), return false);
|
||||
PYBIND11_ENUM_OP_STRICT("__ne__", !int_(a).equal(int_(b)), return true);
|
||||
|
||||
if (is_arithmetic) {
|
||||
#define PYBIND11_THROW throw type_error("Expected an enumeration of matching type!");
|
||||
PYBIND11_ENUM_OP_STRICT("__lt__", int_(a) < int_(b), PYBIND11_THROW);
|
||||
PYBIND11_ENUM_OP_STRICT("__gt__", int_(a) > int_(b), PYBIND11_THROW);
|
||||
PYBIND11_ENUM_OP_STRICT("__le__", int_(a) <= int_(b), PYBIND11_THROW);
|
||||
PYBIND11_ENUM_OP_STRICT("__ge__", int_(a) >= int_(b), PYBIND11_THROW);
|
||||
#undef PYBIND11_THROW
|
||||
}
|
||||
}
|
||||
|
||||
#undef PYBIND11_ENUM_OP_CONV
|
||||
#undef PYBIND11_ENUM_OP_STRICT
|
||||
|
||||
object getstate = cpp_function(
|
||||
[](object arg) { return int_(arg); }, is_method(m_base));
|
||||
|
||||
m_base.attr("__getstate__") = getstate;
|
||||
m_base.attr("__hash__") = getstate;
|
||||
}
|
||||
|
||||
PYBIND11_NOINLINE void value(char const* name_, object value, const char *doc = nullptr) {
|
||||
dict entries = m_base.attr("__entries");
|
||||
str name(name_);
|
||||
if (entries.contains(name)) {
|
||||
std::string type_name = (std::string) str(m_base.attr("__name__"));
|
||||
throw value_error(type_name + ": element \"" + std::string(name_) + "\" already exists!");
|
||||
}
|
||||
|
||||
entries[name] = std::make_pair(value, doc);
|
||||
m_base.attr(name) = value;
|
||||
}
|
||||
|
||||
PYBIND11_NOINLINE void export_values() {
|
||||
dict entries = m_base.attr("__entries");
|
||||
for (const auto &kv : entries)
|
||||
m_parent.attr(kv.first) = kv.second[int_(0)];
|
||||
}
|
||||
|
||||
handle m_base;
|
||||
handle m_parent;
|
||||
};
|
||||
|
||||
NAMESPACE_END(detail)
|
||||
|
||||
/// Binds C++ enumerations and enumeration classes to Python
|
||||
template <typename Type> class enum_ : public class_<Type> {
|
||||
public:
|
||||
using class_<Type>::def;
|
||||
using class_<Type>::def_property_readonly_static;
|
||||
using Base = class_<Type>;
|
||||
using Base::def;
|
||||
using Base::attr;
|
||||
using Base::def_property_readonly;
|
||||
using Base::def_property_readonly_static;
|
||||
using Scalar = typename std::underlying_type<Type>::type;
|
||||
|
||||
template <typename... Extra>
|
||||
enum_(const handle &scope, const char *name, const Extra&... extra)
|
||||
: class_<Type>(scope, name, extra...), m_entries(), m_parent(scope) {
|
||||
|
||||
: class_<Type>(scope, name, extra...), m_base(*this, scope) {
|
||||
constexpr bool is_arithmetic = detail::any_of<std::is_same<arithmetic, Extra>...>::value;
|
||||
constexpr bool is_convertible = std::is_convertible<Type, Scalar>::value;
|
||||
m_base.init(is_arithmetic, is_convertible);
|
||||
|
||||
auto m_entries_ptr = m_entries.inc_ref().ptr();
|
||||
def("__repr__", [name, m_entries_ptr](Type value) -> pybind11::str {
|
||||
for (const auto &kv : reinterpret_borrow<dict>(m_entries_ptr)) {
|
||||
if (pybind11::cast<Type>(kv.second) == value)
|
||||
return pybind11::str("{}.{}").format(name, kv.first);
|
||||
}
|
||||
return pybind11::str("{}.???").format(name);
|
||||
});
|
||||
def_property_readonly_static("__members__", [m_entries_ptr](object /* self */) {
|
||||
dict m;
|
||||
for (const auto &kv : reinterpret_borrow<dict>(m_entries_ptr))
|
||||
m[kv.first] = kv.second;
|
||||
return m;
|
||||
}, return_value_policy::copy);
|
||||
def(init([](Scalar i) { return static_cast<Type>(i); }));
|
||||
def("__int__", [](Type value) { return (Scalar) value; });
|
||||
#if PY_MAJOR_VERSION < 3
|
||||
def("__long__", [](Type value) { return (Scalar) value; });
|
||||
#endif
|
||||
def("__eq__", [](const Type &value, Type *value2) { return value2 && value == *value2; });
|
||||
def("__ne__", [](const Type &value, Type *value2) { return !value2 || value != *value2; });
|
||||
if (is_arithmetic) {
|
||||
def("__lt__", [](const Type &value, Type *value2) { return value2 && value < *value2; });
|
||||
def("__gt__", [](const Type &value, Type *value2) { return value2 && value > *value2; });
|
||||
def("__le__", [](const Type &value, Type *value2) { return value2 && value <= *value2; });
|
||||
def("__ge__", [](const Type &value, Type *value2) { return value2 && value >= *value2; });
|
||||
}
|
||||
if (std::is_convertible<Type, Scalar>::value) {
|
||||
// Don't provide comparison with the underlying type if the enum isn't convertible,
|
||||
// i.e. if Type is a scoped enum, mirroring the C++ behaviour. (NB: we explicitly
|
||||
// convert Type to Scalar below anyway because this needs to compile).
|
||||
def("__eq__", [](const Type &value, Scalar value2) { return (Scalar) value == value2; });
|
||||
def("__ne__", [](const Type &value, Scalar value2) { return (Scalar) value != value2; });
|
||||
if (is_arithmetic) {
|
||||
def("__lt__", [](const Type &value, Scalar value2) { return (Scalar) value < value2; });
|
||||
def("__gt__", [](const Type &value, Scalar value2) { return (Scalar) value > value2; });
|
||||
def("__le__", [](const Type &value, Scalar value2) { return (Scalar) value <= value2; });
|
||||
def("__ge__", [](const Type &value, Scalar value2) { return (Scalar) value >= value2; });
|
||||
def("__invert__", [](const Type &value) { return ~((Scalar) value); });
|
||||
def("__and__", [](const Type &value, Scalar value2) { return (Scalar) value & value2; });
|
||||
def("__or__", [](const Type &value, Scalar value2) { return (Scalar) value | value2; });
|
||||
def("__xor__", [](const Type &value, Scalar value2) { return (Scalar) value ^ value2; });
|
||||
def("__rand__", [](const Type &value, Scalar value2) { return (Scalar) value & value2; });
|
||||
def("__ror__", [](const Type &value, Scalar value2) { return (Scalar) value | value2; });
|
||||
def("__rxor__", [](const Type &value, Scalar value2) { return (Scalar) value ^ value2; });
|
||||
def("__and__", [](const Type &value, const Type &value2) { return (Scalar) value & (Scalar) value2; });
|
||||
def("__or__", [](const Type &value, const Type &value2) { return (Scalar) value | (Scalar) value2; });
|
||||
def("__xor__", [](const Type &value, const Type &value2) { return (Scalar) value ^ (Scalar) value2; });
|
||||
}
|
||||
}
|
||||
def("__hash__", [](const Type &value) { return (Scalar) value; });
|
||||
// Pickling and unpickling -- needed for use with the 'multiprocessing' module
|
||||
def(pickle([](const Type &value) { return pybind11::make_tuple((Scalar) value); },
|
||||
[](tuple t) { return static_cast<Type>(t[0].cast<Scalar>()); }));
|
||||
cpp_function setstate(
|
||||
[](Type &value, Scalar arg) { value = static_cast<Type>(arg); },
|
||||
is_method(*this));
|
||||
attr("__setstate__") = setstate;
|
||||
}
|
||||
|
||||
/// Export enumeration entries into the parent scope
|
||||
enum_& export_values() {
|
||||
for (const auto &kv : m_entries)
|
||||
m_parent.attr(kv.first) = kv.second;
|
||||
m_base.export_values();
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Add an enumeration entry
|
||||
enum_& value(char const* name, Type value) {
|
||||
auto v = pybind11::cast(value, return_value_policy::copy);
|
||||
this->attr(name) = v;
|
||||
m_entries[pybind11::str(name)] = v;
|
||||
enum_& value(char const* name, Type value, const char *doc = nullptr) {
|
||||
m_base.value(name, pybind11::cast(value, return_value_policy::copy), doc);
|
||||
return *this;
|
||||
}
|
||||
|
||||
private:
|
||||
dict m_entries;
|
||||
handle m_parent;
|
||||
detail::enum_base m_base;
|
||||
};
|
||||
|
||||
NAMESPACE_BEGIN(detail)
|
||||
@@ -1749,6 +1882,15 @@ public:
|
||||
auto const &internals = detail::get_internals();
|
||||
tstate = (PyThreadState *) PYBIND11_TLS_GET_VALUE(internals.tstate);
|
||||
|
||||
if (!tstate) {
|
||||
/* Check if the GIL was acquired using the PyGILState_* API instead (e.g. if
|
||||
calling from a Python thread). Since we use a different key, this ensures
|
||||
we don't create a new thread state and deadlock in PyEval_AcquireThread
|
||||
below. Note we don't save this state with internals.tstate, since we don't
|
||||
create it we would fail to clear it (its reference count should be > 0). */
|
||||
tstate = PyGILState_GetThisThreadState();
|
||||
}
|
||||
|
||||
if (!tstate) {
|
||||
tstate = PyThreadState_New(internals.istate);
|
||||
#if !defined(NDEBUG)
|
||||
@@ -1856,12 +1998,12 @@ class gil_scoped_release { };
|
||||
#endif
|
||||
|
||||
error_already_set::~error_already_set() {
|
||||
if (type) {
|
||||
if (m_type) {
|
||||
error_scope scope;
|
||||
gil_scoped_acquire gil;
|
||||
type.release().dec_ref();
|
||||
value.release().dec_ref();
|
||||
trace.release().dec_ref();
|
||||
m_type.release().dec_ref();
|
||||
m_value.release().dec_ref();
|
||||
m_trace.release().dec_ref();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1922,6 +2064,14 @@ inline function get_type_overload(const void *this_ptr, const detail::type_info
|
||||
return overload;
|
||||
}
|
||||
|
||||
/** \rst
|
||||
Try to retrieve a python method by the provided name from the instance pointed to by the this_ptr.
|
||||
|
||||
:this_ptr: The pointer to the object the overload should be retrieved for. This should be the first
|
||||
non-trampoline class encountered in the inheritance chain.
|
||||
:name: The name of the overloaded Python method to retrieve.
|
||||
:return: The Python method by this name from the object or an empty function wrapper.
|
||||
\endrst */
|
||||
template <class T> function get_overload(const T *this_ptr, const char *name) {
|
||||
auto tinfo = detail::get_type_info(typeid(T));
|
||||
return tinfo ? get_type_overload(this_ptr, tinfo, name) : function();
|
||||
@@ -1940,26 +2090,73 @@ template <class T> function get_overload(const T *this_ptr, const char *name) {
|
||||
} \
|
||||
}
|
||||
|
||||
/** \rst
|
||||
Macro to populate the virtual method in the trampoline class. This macro tries to look up a method named 'fn'
|
||||
from the Python side, deals with the :ref:`gil` and necessary argument conversions to call this method and return
|
||||
the appropriate type. See :ref:`overriding_virtuals` for more information. This macro should be used when the method
|
||||
name in C is not the same as the method name in Python. For example with `__str__`.
|
||||
|
||||
.. code-block:: cpp
|
||||
|
||||
std::string toString() override {
|
||||
PYBIND11_OVERLOAD_NAME(
|
||||
std::string, // Return type (ret_type)
|
||||
Animal, // Parent class (cname)
|
||||
toString, // Name of function in C++ (name)
|
||||
"__str__", // Name of method in Python (fn)
|
||||
);
|
||||
}
|
||||
\endrst */
|
||||
#define PYBIND11_OVERLOAD_NAME(ret_type, cname, name, fn, ...) \
|
||||
PYBIND11_OVERLOAD_INT(ret_type, cname, name, __VA_ARGS__) \
|
||||
PYBIND11_OVERLOAD_INT(PYBIND11_TYPE(ret_type), PYBIND11_TYPE(cname), name, __VA_ARGS__) \
|
||||
return cname::fn(__VA_ARGS__)
|
||||
|
||||
/** \rst
|
||||
Macro for pure virtual functions, this function is identical to :c:macro:`PYBIND11_OVERLOAD_NAME`, except that it
|
||||
throws if no overload can be found.
|
||||
\endrst */
|
||||
#define PYBIND11_OVERLOAD_PURE_NAME(ret_type, cname, name, fn, ...) \
|
||||
PYBIND11_OVERLOAD_INT(ret_type, cname, name, __VA_ARGS__) \
|
||||
pybind11::pybind11_fail("Tried to call pure virtual function \"" #cname "::" name "\"");
|
||||
PYBIND11_OVERLOAD_INT(PYBIND11_TYPE(ret_type), PYBIND11_TYPE(cname), name, __VA_ARGS__) \
|
||||
pybind11::pybind11_fail("Tried to call pure virtual function \"" PYBIND11_STRINGIFY(cname) "::" name "\"");
|
||||
|
||||
/** \rst
|
||||
Macro to populate the virtual method in the trampoline class. This macro tries to look up the method
|
||||
from the Python side, deals with the :ref:`gil` and necessary argument conversions to call this method and return
|
||||
the appropriate type. This macro should be used if the method name in C and in Python are identical.
|
||||
See :ref:`overriding_virtuals` for more information.
|
||||
|
||||
.. code-block:: cpp
|
||||
|
||||
class PyAnimal : public Animal {
|
||||
public:
|
||||
// Inherit the constructors
|
||||
using Animal::Animal;
|
||||
|
||||
// Trampoline (need one for each virtual function)
|
||||
std::string go(int n_times) override {
|
||||
PYBIND11_OVERLOAD_PURE(
|
||||
std::string, // Return type (ret_type)
|
||||
Animal, // Parent class (cname)
|
||||
go, // Name of function in C++ (must match Python name) (fn)
|
||||
n_times // Argument(s) (...)
|
||||
);
|
||||
}
|
||||
};
|
||||
\endrst */
|
||||
#define PYBIND11_OVERLOAD(ret_type, cname, fn, ...) \
|
||||
PYBIND11_OVERLOAD_NAME(ret_type, cname, #fn, fn, __VA_ARGS__)
|
||||
PYBIND11_OVERLOAD_NAME(PYBIND11_TYPE(ret_type), PYBIND11_TYPE(cname), #fn, fn, __VA_ARGS__)
|
||||
|
||||
/** \rst
|
||||
Macro for pure virtual functions, this function is identical to :c:macro:`PYBIND11_OVERLOAD`, except that it throws
|
||||
if no overload can be found.
|
||||
\endrst */
|
||||
#define PYBIND11_OVERLOAD_PURE(ret_type, cname, fn, ...) \
|
||||
PYBIND11_OVERLOAD_PURE_NAME(ret_type, cname, #fn, fn, __VA_ARGS__)
|
||||
PYBIND11_OVERLOAD_PURE_NAME(PYBIND11_TYPE(ret_type), PYBIND11_TYPE(cname), #fn, fn, __VA_ARGS__)
|
||||
|
||||
NAMESPACE_END(PYBIND11_NAMESPACE)
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
#if defined(_MSC_VER) && !defined(__INTEL_COMPILER)
|
||||
# pragma warning(pop)
|
||||
#elif defined(__INTEL_COMPILER)
|
||||
/* Leave ignored warnings on */
|
||||
#elif defined(__GNUG__) && !defined(__clang__)
|
||||
# pragma GCC diagnostic pop
|
||||
#endif
|
||||
|
@@ -114,6 +114,35 @@ public:
|
||||
bool is(object_api const& other) const { return derived().ptr() == other.derived().ptr(); }
|
||||
/// Equivalent to ``obj is None`` in Python.
|
||||
bool is_none() const { return derived().ptr() == Py_None; }
|
||||
/// Equivalent to obj == other in Python
|
||||
bool equal(object_api const &other) const { return rich_compare(other, Py_EQ); }
|
||||
bool not_equal(object_api const &other) const { return rich_compare(other, Py_NE); }
|
||||
bool operator<(object_api const &other) const { return rich_compare(other, Py_LT); }
|
||||
bool operator<=(object_api const &other) const { return rich_compare(other, Py_LE); }
|
||||
bool operator>(object_api const &other) const { return rich_compare(other, Py_GT); }
|
||||
bool operator>=(object_api const &other) const { return rich_compare(other, Py_GE); }
|
||||
|
||||
object operator-() const;
|
||||
object operator~() const;
|
||||
object operator+(object_api const &other) const;
|
||||
object operator+=(object_api const &other) const;
|
||||
object operator-(object_api const &other) const;
|
||||
object operator-=(object_api const &other) const;
|
||||
object operator*(object_api const &other) const;
|
||||
object operator*=(object_api const &other) const;
|
||||
object operator/(object_api const &other) const;
|
||||
object operator/=(object_api const &other) const;
|
||||
object operator|(object_api const &other) const;
|
||||
object operator|=(object_api const &other) const;
|
||||
object operator&(object_api const &other) const;
|
||||
object operator&=(object_api const &other) const;
|
||||
object operator^(object_api const &other) const;
|
||||
object operator^=(object_api const &other) const;
|
||||
object operator<<(object_api const &other) const;
|
||||
object operator<<=(object_api const &other) const;
|
||||
object operator>>(object_api const &other) const;
|
||||
object operator>>=(object_api const &other) const;
|
||||
|
||||
PYBIND11_DEPRECATED("Use py::str(obj) instead")
|
||||
pybind11::str str() const;
|
||||
|
||||
@@ -124,6 +153,9 @@ public:
|
||||
int ref_count() const { return static_cast<int>(Py_REFCNT(derived().ptr())); }
|
||||
/// Return a handle to the Python type object underlying the instance
|
||||
handle get_type() const;
|
||||
|
||||
private:
|
||||
bool rich_compare(object_api const &other, int value) const;
|
||||
};
|
||||
|
||||
NAMESPACE_END(detail)
|
||||
@@ -292,15 +324,18 @@ public:
|
||||
/// Constructs a new exception from the current Python error indicator, if any. The current
|
||||
/// Python error indicator will be cleared.
|
||||
error_already_set() : std::runtime_error(detail::error_string()) {
|
||||
PyErr_Fetch(&type.ptr(), &value.ptr(), &trace.ptr());
|
||||
PyErr_Fetch(&m_type.ptr(), &m_value.ptr(), &m_trace.ptr());
|
||||
}
|
||||
|
||||
error_already_set(const error_already_set &) = default;
|
||||
error_already_set(error_already_set &&) = default;
|
||||
|
||||
inline ~error_already_set();
|
||||
|
||||
/// Give the currently-held error back to Python, if any. If there is currently a Python error
|
||||
/// already set it is cleared first. After this call, the current object no longer stores the
|
||||
/// error variables (but the `.what()` string is still available).
|
||||
void restore() { PyErr_Restore(type.release().ptr(), value.release().ptr(), trace.release().ptr()); }
|
||||
void restore() { PyErr_Restore(m_type.release().ptr(), m_value.release().ptr(), m_trace.release().ptr()); }
|
||||
|
||||
// Does nothing; provided for backwards compatibility.
|
||||
PYBIND11_DEPRECATED("Use of error_already_set.clear() is deprecated")
|
||||
@@ -309,10 +344,14 @@ public:
|
||||
/// Check if the currently trapped error type matches the given Python exception class (or a
|
||||
/// subclass thereof). May also be passed a tuple to search for any exception class matches in
|
||||
/// the given tuple.
|
||||
bool matches(handle ex) const { return PyErr_GivenExceptionMatches(ex.ptr(), type.ptr()); }
|
||||
bool matches(handle exc) const { return PyErr_GivenExceptionMatches(m_type.ptr(), exc.ptr()); }
|
||||
|
||||
const object& type() const { return m_type; }
|
||||
const object& value() const { return m_value; }
|
||||
const object& trace() const { return m_trace; }
|
||||
|
||||
private:
|
||||
object type, value, trace;
|
||||
object m_type, m_value, m_trace;
|
||||
};
|
||||
|
||||
/** \defgroup python_builtins _
|
||||
@@ -353,6 +392,14 @@ inline bool hasattr(handle obj, const char *name) {
|
||||
return PyObject_HasAttrString(obj.ptr(), name) == 1;
|
||||
}
|
||||
|
||||
inline void delattr(handle obj, handle name) {
|
||||
if (PyObject_DelAttr(obj.ptr(), name.ptr()) != 0) { throw error_already_set(); }
|
||||
}
|
||||
|
||||
inline void delattr(handle obj, const char *name) {
|
||||
if (PyObject_DelAttrString(obj.ptr(), name) != 0) { throw error_already_set(); }
|
||||
}
|
||||
|
||||
inline object getattr(handle obj, handle name) {
|
||||
PyObject *result = PyObject_GetAttr(obj.ptr(), name.ptr());
|
||||
if (!result) { throw error_already_set(); }
|
||||
@@ -424,7 +471,6 @@ object object_or_cast(T &&o);
|
||||
// Match a PyObject*, which we want to convert directly to handle via its converting constructor
|
||||
inline handle object_or_cast(PyObject *ptr) { return ptr; }
|
||||
|
||||
|
||||
template <typename Policy>
|
||||
class accessor : public object_api<accessor<Policy>> {
|
||||
using key_type = typename Policy::key_type;
|
||||
@@ -662,7 +708,7 @@ protected:
|
||||
|
||||
private:
|
||||
handle obj;
|
||||
PyObject *key, *value;
|
||||
PyObject *key = nullptr, *value = nullptr;
|
||||
ssize_t pos = -1;
|
||||
};
|
||||
NAMESPACE_END(iterator_policies)
|
||||
@@ -690,9 +736,14 @@ inline bool PyIterable_Check(PyObject *obj) {
|
||||
}
|
||||
|
||||
inline bool PyNone_Check(PyObject *o) { return o == Py_None; }
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
inline bool PyEllipsis_Check(PyObject *o) { return o == Py_Ellipsis; }
|
||||
#endif
|
||||
|
||||
inline bool PyUnicode_Check_Permissive(PyObject *o) { return PyUnicode_Check(o) || PYBIND11_BYTES_CHECK(o); }
|
||||
|
||||
inline bool PyStaticMethod_Check(PyObject *o) { return o->ob_type == &PyStaticMethod_Type; }
|
||||
|
||||
class kwargs_proxy : public handle {
|
||||
public:
|
||||
explicit kwargs_proxy(handle h) : handle(h) { }
|
||||
@@ -964,6 +1015,14 @@ public:
|
||||
none() : object(Py_None, borrowed_t{}) { }
|
||||
};
|
||||
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
class ellipsis : public object {
|
||||
public:
|
||||
PYBIND11_OBJECT(ellipsis, object, detail::PyEllipsis_Check)
|
||||
ellipsis() : object(Py_Ellipsis, borrowed_t{}) { }
|
||||
};
|
||||
#endif
|
||||
|
||||
class bool_ : public object {
|
||||
public:
|
||||
PYBIND11_OBJECT_CVT(bool_, object, PyBool_Check, raw_bool)
|
||||
@@ -1074,6 +1133,13 @@ public:
|
||||
(ssize_t *) stop, (ssize_t *) step,
|
||||
(ssize_t *) slicelength) == 0;
|
||||
}
|
||||
bool compute(ssize_t length, ssize_t *start, ssize_t *stop, ssize_t *step,
|
||||
ssize_t *slicelength) const {
|
||||
return PySlice_GetIndicesEx((PYBIND11_SLICE_OBJECT *) m_ptr,
|
||||
length, start,
|
||||
stop, step,
|
||||
slicelength) == 0;
|
||||
}
|
||||
};
|
||||
|
||||
class capsule : public object {
|
||||
@@ -1137,6 +1203,7 @@ public:
|
||||
}
|
||||
size_t size() const { return (size_t) PyTuple_Size(m_ptr); }
|
||||
detail::tuple_accessor operator[](size_t index) const { return {*this, index}; }
|
||||
detail::item_accessor operator[](handle h) const { return object::operator[](h); }
|
||||
detail::tuple_iterator begin() const { return {*this, 0}; }
|
||||
detail::tuple_iterator end() const { return {*this, PyTuple_GET_SIZE(m_ptr)}; }
|
||||
};
|
||||
@@ -1174,6 +1241,7 @@ public:
|
||||
PYBIND11_OBJECT_DEFAULT(sequence, object, PySequence_Check)
|
||||
size_t size() const { return (size_t) PySequence_Size(m_ptr); }
|
||||
detail::sequence_accessor operator[](size_t index) const { return {*this, index}; }
|
||||
detail::item_accessor operator[](handle h) const { return object::operator[](h); }
|
||||
detail::sequence_iterator begin() const { return {*this, 0}; }
|
||||
detail::sequence_iterator end() const { return {*this, PySequence_Size(m_ptr)}; }
|
||||
};
|
||||
@@ -1186,6 +1254,7 @@ public:
|
||||
}
|
||||
size_t size() const { return (size_t) PyList_Size(m_ptr); }
|
||||
detail::list_accessor operator[](size_t index) const { return {*this, index}; }
|
||||
detail::item_accessor operator[](handle h) const { return object::operator[](h); }
|
||||
detail::list_iterator begin() const { return {*this, 0}; }
|
||||
detail::list_iterator end() const { return {*this, PyList_GET_SIZE(m_ptr)}; }
|
||||
template <typename T> void append(T &&val) const {
|
||||
@@ -1221,6 +1290,11 @@ public:
|
||||
bool is_cpp_function() const { return (bool) cpp_function(); }
|
||||
};
|
||||
|
||||
class staticmethod : public object {
|
||||
public:
|
||||
PYBIND11_OBJECT_CVT(staticmethod, object, detail::PyStaticMethod_Check, PyStaticMethod_New)
|
||||
};
|
||||
|
||||
class buffer : public object {
|
||||
public:
|
||||
PYBIND11_OBJECT_DEFAULT(buffer, object, PyObject_CheckBuffer)
|
||||
@@ -1279,6 +1353,21 @@ inline size_t len(handle h) {
|
||||
return (size_t) result;
|
||||
}
|
||||
|
||||
inline size_t len_hint(handle h) {
|
||||
#if PY_VERSION_HEX >= 0x03040000
|
||||
ssize_t result = PyObject_LengthHint(h.ptr(), 0);
|
||||
#else
|
||||
ssize_t result = PyObject_Length(h.ptr());
|
||||
#endif
|
||||
if (result < 0) {
|
||||
// Sometimes a length can't be determined at all (eg generators)
|
||||
// In which case simply return 0
|
||||
PyErr_Clear();
|
||||
return 0;
|
||||
}
|
||||
return (size_t) result;
|
||||
}
|
||||
|
||||
inline str repr(handle h) {
|
||||
PyObject *str_value = PyObject_Repr(h.ptr());
|
||||
if (!str_value) throw error_already_set();
|
||||
@@ -1328,5 +1417,55 @@ str_attr_accessor object_api<D>::doc() const { return attr("__doc__"); }
|
||||
template <typename D>
|
||||
handle object_api<D>::get_type() const { return (PyObject *) Py_TYPE(derived().ptr()); }
|
||||
|
||||
template <typename D>
|
||||
bool object_api<D>::rich_compare(object_api const &other, int value) const {
|
||||
int rv = PyObject_RichCompareBool(derived().ptr(), other.derived().ptr(), value);
|
||||
if (rv == -1)
|
||||
throw error_already_set();
|
||||
return rv == 1;
|
||||
}
|
||||
|
||||
#define PYBIND11_MATH_OPERATOR_UNARY(op, fn) \
|
||||
template <typename D> object object_api<D>::op() const { \
|
||||
object result = reinterpret_steal<object>(fn(derived().ptr())); \
|
||||
if (!result.ptr()) \
|
||||
throw error_already_set(); \
|
||||
return result; \
|
||||
}
|
||||
|
||||
#define PYBIND11_MATH_OPERATOR_BINARY(op, fn) \
|
||||
template <typename D> \
|
||||
object object_api<D>::op(object_api const &other) const { \
|
||||
object result = reinterpret_steal<object>( \
|
||||
fn(derived().ptr(), other.derived().ptr())); \
|
||||
if (!result.ptr()) \
|
||||
throw error_already_set(); \
|
||||
return result; \
|
||||
}
|
||||
|
||||
PYBIND11_MATH_OPERATOR_UNARY (operator~, PyNumber_Invert)
|
||||
PYBIND11_MATH_OPERATOR_UNARY (operator-, PyNumber_Negative)
|
||||
PYBIND11_MATH_OPERATOR_BINARY(operator+, PyNumber_Add)
|
||||
PYBIND11_MATH_OPERATOR_BINARY(operator+=, PyNumber_InPlaceAdd)
|
||||
PYBIND11_MATH_OPERATOR_BINARY(operator-, PyNumber_Subtract)
|
||||
PYBIND11_MATH_OPERATOR_BINARY(operator-=, PyNumber_InPlaceSubtract)
|
||||
PYBIND11_MATH_OPERATOR_BINARY(operator*, PyNumber_Multiply)
|
||||
PYBIND11_MATH_OPERATOR_BINARY(operator*=, PyNumber_InPlaceMultiply)
|
||||
PYBIND11_MATH_OPERATOR_BINARY(operator/, PyNumber_TrueDivide)
|
||||
PYBIND11_MATH_OPERATOR_BINARY(operator/=, PyNumber_InPlaceTrueDivide)
|
||||
PYBIND11_MATH_OPERATOR_BINARY(operator|, PyNumber_Or)
|
||||
PYBIND11_MATH_OPERATOR_BINARY(operator|=, PyNumber_InPlaceOr)
|
||||
PYBIND11_MATH_OPERATOR_BINARY(operator&, PyNumber_And)
|
||||
PYBIND11_MATH_OPERATOR_BINARY(operator&=, PyNumber_InPlaceAnd)
|
||||
PYBIND11_MATH_OPERATOR_BINARY(operator^, PyNumber_Xor)
|
||||
PYBIND11_MATH_OPERATOR_BINARY(operator^=, PyNumber_InPlaceXor)
|
||||
PYBIND11_MATH_OPERATOR_BINARY(operator<<, PyNumber_Lshift)
|
||||
PYBIND11_MATH_OPERATOR_BINARY(operator<<=, PyNumber_InPlaceLshift)
|
||||
PYBIND11_MATH_OPERATOR_BINARY(operator>>, PyNumber_Rshift)
|
||||
PYBIND11_MATH_OPERATOR_BINARY(operator>>=, PyNumber_InPlaceRshift)
|
||||
|
||||
#undef PYBIND11_MATH_OPERATOR_UNARY
|
||||
#undef PYBIND11_MATH_OPERATOR_BINARY
|
||||
|
||||
NAMESPACE_END(detail)
|
||||
NAMESPACE_END(PYBIND11_NAMESPACE)
|
||||
|
@@ -16,6 +16,7 @@
|
||||
#include <unordered_map>
|
||||
#include <iostream>
|
||||
#include <list>
|
||||
#include <deque>
|
||||
#include <valarray>
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
@@ -83,6 +84,7 @@ template <typename Type, typename Key> struct set_caster {
|
||||
|
||||
template <typename T>
|
||||
static handle cast(T &&src, return_value_policy policy, handle parent) {
|
||||
if (!std::is_lvalue_reference<T>::value)
|
||||
policy = return_value_policy_override<Key>::policy(policy);
|
||||
pybind11::set s;
|
||||
for (auto &&value : src) {
|
||||
@@ -93,7 +95,7 @@ template <typename Type, typename Key> struct set_caster {
|
||||
return s.release();
|
||||
}
|
||||
|
||||
PYBIND11_TYPE_CASTER(type, _("Set[") + key_conv::name() + _("]"));
|
||||
PYBIND11_TYPE_CASTER(type, _("Set[") + key_conv::name + _("]"));
|
||||
};
|
||||
|
||||
template <typename Type, typename Key, typename Value> struct map_caster {
|
||||
@@ -119,8 +121,12 @@ template <typename Type, typename Key, typename Value> struct map_caster {
|
||||
template <typename T>
|
||||
static handle cast(T &&src, return_value_policy policy, handle parent) {
|
||||
dict d;
|
||||
return_value_policy policy_key = return_value_policy_override<Key>::policy(policy);
|
||||
return_value_policy policy_value = return_value_policy_override<Value>::policy(policy);
|
||||
return_value_policy policy_key = policy;
|
||||
return_value_policy policy_value = policy;
|
||||
if (!std::is_lvalue_reference<T>::value) {
|
||||
policy_key = return_value_policy_override<Key>::policy(policy_key);
|
||||
policy_value = return_value_policy_override<Value>::policy(policy_value);
|
||||
}
|
||||
for (auto &&kv : src) {
|
||||
auto key = reinterpret_steal<object>(key_conv::cast(forward_like<T>(kv.first), policy_key, parent));
|
||||
auto value = reinterpret_steal<object>(value_conv::cast(forward_like<T>(kv.second), policy_value, parent));
|
||||
@@ -131,14 +137,14 @@ template <typename Type, typename Key, typename Value> struct map_caster {
|
||||
return d.release();
|
||||
}
|
||||
|
||||
PYBIND11_TYPE_CASTER(Type, _("Dict[") + key_conv::name() + _(", ") + value_conv::name() + _("]"));
|
||||
PYBIND11_TYPE_CASTER(Type, _("Dict[") + key_conv::name + _(", ") + value_conv::name + _("]"));
|
||||
};
|
||||
|
||||
template <typename Type, typename Value> struct list_caster {
|
||||
using value_conv = make_caster<Value>;
|
||||
|
||||
bool load(handle src, bool convert) {
|
||||
if (!isinstance<sequence>(src))
|
||||
if (!isinstance<sequence>(src) || isinstance<str>(src))
|
||||
return false;
|
||||
auto s = reinterpret_borrow<sequence>(src);
|
||||
value.clear();
|
||||
@@ -161,6 +167,7 @@ private:
|
||||
public:
|
||||
template <typename T>
|
||||
static handle cast(T &&src, return_value_policy policy, handle parent) {
|
||||
if (!std::is_lvalue_reference<T>::value)
|
||||
policy = return_value_policy_override<Value>::policy(policy);
|
||||
list l(src.size());
|
||||
size_t index = 0;
|
||||
@@ -173,12 +180,15 @@ public:
|
||||
return l.release();
|
||||
}
|
||||
|
||||
PYBIND11_TYPE_CASTER(Type, _("List[") + value_conv::name() + _("]"));
|
||||
PYBIND11_TYPE_CASTER(Type, _("List[") + value_conv::name + _("]"));
|
||||
};
|
||||
|
||||
template <typename Type, typename Alloc> struct type_caster<std::vector<Type, Alloc>>
|
||||
: list_caster<std::vector<Type, Alloc>, Type> { };
|
||||
|
||||
template <typename Type, typename Alloc> struct type_caster<std::deque<Type, Alloc>>
|
||||
: list_caster<std::deque<Type, Alloc>, Type> { };
|
||||
|
||||
template <typename Type, typename Alloc> struct type_caster<std::list<Type, Alloc>>
|
||||
: list_caster<std::list<Type, Alloc>, Type> { };
|
||||
|
||||
@@ -199,9 +209,9 @@ private:
|
||||
|
||||
public:
|
||||
bool load(handle src, bool convert) {
|
||||
if (!isinstance<list>(src))
|
||||
if (!isinstance<sequence>(src))
|
||||
return false;
|
||||
auto l = reinterpret_borrow<list>(src);
|
||||
auto l = reinterpret_borrow<sequence>(src);
|
||||
if (!require_size(l.size()))
|
||||
return false;
|
||||
size_t ctr = 0;
|
||||
@@ -227,7 +237,7 @@ public:
|
||||
return l.release();
|
||||
}
|
||||
|
||||
PYBIND11_TYPE_CASTER(ArrayType, _("List[") + value_conv::name() + _<Resizable>(_(""), _("[") + _<Size>() + _("]")) + _("]"));
|
||||
PYBIND11_TYPE_CASTER(ArrayType, _("List[") + value_conv::name + _<Resizable>(_(""), _("[") + _<Size>() + _("]")) + _("]"));
|
||||
};
|
||||
|
||||
template <typename Type, size_t Size> struct type_caster<std::array<Type, Size>>
|
||||
@@ -274,7 +284,7 @@ template<typename T> struct optional_caster {
|
||||
return true;
|
||||
}
|
||||
|
||||
PYBIND11_TYPE_CASTER(T, _("Optional[") + value_conv::name() + _("]"));
|
||||
PYBIND11_TYPE_CASTER(T, _("Optional[") + value_conv::name + _("]"));
|
||||
};
|
||||
|
||||
#if PYBIND11_HAS_OPTIONAL
|
||||
@@ -354,7 +364,7 @@ struct variant_caster<V<Ts...>> {
|
||||
}
|
||||
|
||||
using Type = V<Ts...>;
|
||||
PYBIND11_TYPE_CASTER(Type, _("Union[") + detail::concat(make_caster<Ts>::name()...) + _("]"));
|
||||
PYBIND11_TYPE_CASTER(Type, _("Union[") + detail::concat(make_caster<Ts>::name...) + _("]"));
|
||||
};
|
||||
|
||||
#if PYBIND11_HAS_VARIANT
|
||||
|
@@ -122,7 +122,7 @@ void vector_modifiers(enable_if_t<is_copy_constructible<typename Vector::value_t
|
||||
|
||||
cl.def(init([](iterable it) {
|
||||
auto v = std::unique_ptr<Vector>(new Vector());
|
||||
v->reserve(len(it));
|
||||
v->reserve(len_hint(it));
|
||||
for (handle h : it)
|
||||
v->push_back(h.cast<T>());
|
||||
return v.release();
|
||||
@@ -136,6 +136,28 @@ void vector_modifiers(enable_if_t<is_copy_constructible<typename Vector::value_t
|
||||
"Extend the list by appending all the items in the given list"
|
||||
);
|
||||
|
||||
cl.def("extend",
|
||||
[](Vector &v, iterable it) {
|
||||
const size_t old_size = v.size();
|
||||
v.reserve(old_size + len_hint(it));
|
||||
try {
|
||||
for (handle h : it) {
|
||||
v.push_back(h.cast<T>());
|
||||
}
|
||||
} catch (const cast_error &) {
|
||||
v.erase(v.begin() + static_cast<typename Vector::difference_type>(old_size), v.end());
|
||||
try {
|
||||
v.shrink_to_fit();
|
||||
} catch (const std::exception &) {
|
||||
// Do nothing
|
||||
}
|
||||
throw;
|
||||
}
|
||||
},
|
||||
arg("L"),
|
||||
"Extend the list by appending all the items in the given list"
|
||||
);
|
||||
|
||||
cl.def("insert",
|
||||
[](Vector &v, SizeType i, const T &x) {
|
||||
if (i > v.size())
|
||||
@@ -579,6 +601,15 @@ class_<Map, holder_type> bind_map(handle scope, const std::string &name, Args&&.
|
||||
return_value_policy::reference_internal // ref + keepalive
|
||||
);
|
||||
|
||||
cl.def("__contains__",
|
||||
[](Map &m, const KeyType &k) -> bool {
|
||||
auto it = m.find(k);
|
||||
if (it == m.end())
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
);
|
||||
|
||||
// Assignment provided only if the type is copyable
|
||||
detail::map_assignment<Map, Class_>(cl);
|
||||
|
||||
|
@@ -217,6 +217,7 @@ class kernel:
|
||||
if fw.has_tensorflow():
|
||||
return self.fw_op(*op_args, id=op_id)
|
||||
elif fw.has_torch():
|
||||
return self.fw_op(op_id, *op_args)
|
||||
args = [x.contiguous() if isinstance(x, fw.torch.Tensor) else x for x in op_args]
|
||||
return self.fw_op(op_id, *args)
|
||||
else:
|
||||
assert False
|
@@ -20,8 +20,8 @@ void dot(TYPE * A, TYPE * B, TYPE * C,
|
||||
TYPE* pa[SHAPE_A] = A + rka[BROADCAST_AK] * STRIDE_AK + rxa[BROADCAST_AM] * STRIDE_AM;
|
||||
TYPE* pb[SHAPE_B] = B + rkb[BROADCAST_BK] * STRIDE_BK + ryb[BROADCAST_BN] * STRIDE_BN;
|
||||
// prefetches operands
|
||||
TYPE a[SHAPE_A] = *pa;
|
||||
TYPE b[SHAPE_B] = *pb;
|
||||
TYPE a[SHAPE_A] = (*pa);
|
||||
TYPE b[SHAPE_B] = (*pb);
|
||||
// reduction loop
|
||||
for(int k = K; k > 0; k-= TK){
|
||||
c += USE_A @ USE_B;
|
||||
@@ -80,16 +80,19 @@ void dot(TYPE * A, TYPE * B, TYPE * C,
|
||||
'SHAPE_B' : 'TN, TK' if transpose_b else 'TK, TN'}
|
||||
return _dot.kernel(a, b, c, M, N, Ka, lda, ldb, ldc, grid,
|
||||
AT = transpose_a, BT = transpose_b, TYPE = dtype,
|
||||
TM = [64, 128], TN = [64, 128], TK = [8], **macros)
|
||||
TM = [128], TN = [128], TK = [8], **macros)
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, a, b, transpose_a = False, transpose_b = False):
|
||||
ctx.save_for_backward(a, b, transpose_a, transpose_b)
|
||||
ctx.save_for_backward(a, b)
|
||||
ctx.t_a = transpose_a
|
||||
ctx.t_b = transpose_b
|
||||
return _dot._call(a, b, transpose_a, transpose_b)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dy):
|
||||
a, b, t_a, t_b = ctx.saved_tensors
|
||||
a, b = ctx.saved_tensors
|
||||
t_a, t_b = ctx.t_a, ctx.t_b
|
||||
if not t_a and not t_b:
|
||||
da = _dot._call(dy, b, False, True)
|
||||
db = _dot._call(a, dy, True, False)
|
||||
@@ -104,6 +107,6 @@ void dot(TYPE * A, TYPE * B, TYPE * C,
|
||||
db = _dot._call(dy, a, True, True)
|
||||
else:
|
||||
assert False
|
||||
return [da, db, None, None, None, None, None, None, None]
|
||||
return da, db, None, None, None, None, None, None, None
|
||||
|
||||
dot = _dot.apply
|
Reference in New Issue
Block a user