[python] basic tensorflow wrapper working

This commit is contained in:
Philippe Tillet
2019-08-26 16:53:49 -07:00
parent 0e0399f866
commit 4075949f80
26 changed files with 702 additions and 968 deletions

View File

@@ -41,7 +41,7 @@ if(BUILD_PYTHON_MODULE)
# extra tensorflow ops (e.g., alloc_empty)
file(GLOB_RECURSE EXTRA_TF_OPS_SRC python/src/tensorflow/*.cc)
add_library(extra_tf_ops SHARED ${EXTRA_TF_OPS_SRC})
target_link_libraries(extra_tf_ops ${TF_LIBS})
target_link_libraries(extra_tf_ops triton ${TF_LIBS})
endif()

View File

@@ -153,8 +153,8 @@ perf_t do_bench(drv::stream* stream, bool AT, bool BT, int32_t M, int32_t N, int
opt.defines.push_back({"AT", {""}});
if(BT)
opt.defines.push_back({"BT", {""}});
opt.defines.push_back({"TM", {"128"}});
opt.defines.push_back({"TN", {"128"}});
opt.defines.push_back({"TM", {"32"}});
opt.defines.push_back({"TN", {"32"}});
opt.defines.push_back({"TK", {"32"}});
opt.num_warps = {1, 2, 4, 8};
rt::function function(src, opt);
@@ -208,7 +208,7 @@ int main() {
// shapes to benchmark
std::vector<config_t> configs = {
// {false, false, 8192, 512, 512},
{false, false, 128, 128, 128}
{false, true, 128, 128, 128}
// {false, true, 128, 128, 128},
// {false, false, 128, 128, 128},
// {true, false, 128, 128, 128},

View File

@@ -282,7 +282,10 @@ void grids::run(ir::module &mod) {
std::string str_d = std::to_string(d);
effective_num_warps *= params_.at(i).at("wpt.d" + str_d)->get_value();
}
assert(num_warps_ == effective_num_warps);
if(num_warps_ != effective_num_warps)
throw std::runtime_error("cannot create a kernel with this amount of warps");
}
/* Scan-line */
@@ -305,7 +308,9 @@ void grids::run(ir::module &mod) {
std::string str_d = std::to_string(d);
effective_num_threads *= params_.at(i).at("mts.d" + str_d)->get_value();
}
assert(num_threads == effective_num_threads);
if(num_threads != effective_num_threads)
throw std::runtime_error("cannot create a kernel with this amount of warps");
}
}

View File

@@ -160,7 +160,12 @@ function::caller function::autotune(driver::stream* stream, const grid_fn_ty& gr
// triton-ir code-gen
auto ir = make_ir(parser);
// binary code-gen
auto bin = make_bin(*ir, stream->context(), opt);
std::unique_ptr<driver::module> bin;
try{
bin = make_bin(*ir, stream->context(), opt);
}catch(const std::runtime_error& e) {
return;
}
// benchmark
ir::function *tmp = ir->get_function_list()[0];
caller call(tmp, std::move(bin), opt);
@@ -177,21 +182,21 @@ function::caller function::autotune(driver::stream* stream, const grid_fn_ty& gr
std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::context *context, const options_t& opt) {
std::unique_ptr<codegen::target> target = context->device()->make_target();
// create passes
codegen::analysis::grids tune(opt.num_warps);
codegen::analysis::grids grids(opt.num_warps);
codegen::analysis::shmem::info shmem_info;
codegen::analysis::shmem::liveness shmem_liveness(&shmem_info);
codegen::analysis::shmem::allocation shmem_allocation(&shmem_liveness, &shmem_info, &tune);
codegen::analysis::shmem::allocation shmem_allocation(&shmem_liveness, &shmem_info, &grids);
codegen::analysis::alignment_info alignment_info;
codegen::transform::shmem_barriers shmem_barriers(&shmem_allocation, &shmem_info);
codegen::transform::vectorize vectorize(&tune);
codegen::transform::vectorize vectorize(&grids);
codegen::transform::dce dce;
codegen::transform::peephole peephole;
codegen::transform::reassociate reassociate(&tune);
codegen::selection selection(&shmem_allocation, &tune, &shmem_info, &alignment_info, target.get());
codegen::transform::reassociate reassociate(&grids);
codegen::selection selection(&shmem_allocation, &grids, &shmem_info, &alignment_info, target.get());
// run passes
peephole.run(module);
dce.run(module);
tune.run(module);
grids.run(module);
reassociate.run(module);
peephole.run(module);
if(target->is_gpu()){

View File

@@ -22,6 +22,8 @@ void dot(TYPE * A __noalias __readonly __aligned(16),
int lda __multipleof(8),
int ldb __multipleof(8),
int ldc) {
/* prologue */
int ridx = get_program_id(0);
int ridy = get_program_id(1);
int rxa[TM] = ridx * TM + 0 ... TM;
@@ -88,8 +90,8 @@ class dot:
self.trans_b = trans_b
def __call__(self, a, b):
shape_a = tf.shape(a)
shape_b = tf.shape(b)
shape_a = triton.shape(a)
shape_b = triton.shape(b)
M = shape_a[0]
K = shape_a[1]
N = shape_b[0]
@@ -98,9 +100,9 @@ class dot:
ldc = N
c = triton.empty([M, N])
return self.dot(a, b, c, M, N, K, lda, ldb, ldc,
lambda opt: [cdiv(M, opt.D('TM')), cdiv(N, opt.D('TN')), 1],
lambda opt: [cdiv(M, opt.d('TM')), cdiv(N, opt.d('TN'))],
AT = self.trans_a, BT = self.trans_b, TYPE = tf.float16,
TM = [128], TN = [128], TK = [32])
TM = [32, 64, 128], TN = [32, 64, 128], TK = [32])
dot_tn = dot()
@@ -119,7 +121,7 @@ def run_dot():
result = sess.run([c], feed_dict = {a: ha,
b: hb})[0]
# Test
hresult = np.dot(ha.T, hb)
hresult = np.dot(ha.T, hb).T
dif = np.abs(result - hresult)
np.savetxt('dif.dat', dif, '%2.4f')
print(hresult)

View File

@@ -200,8 +200,7 @@ struct function_record {
/// Special data structure which (temporarily) holds metadata about a bound class
struct type_record {
PYBIND11_NOINLINE type_record()
: multiple_inheritance(false), dynamic_attr(false), buffer_protocol(false),
default_holder(true), module_local(false) { }
: multiple_inheritance(false), dynamic_attr(false), buffer_protocol(false), module_local(false) { }
/// Handle to the parent scope
handle scope;
@@ -215,14 +214,11 @@ struct type_record {
/// How large is the underlying C++ type?
size_t type_size = 0;
/// What is the alignment of the underlying C++ type?
size_t type_align = 0;
/// How large is the type's holder?
size_t holder_size = 0;
/// The global operator new can be overridden with a class-specific variant
void *(*operator_new)(size_t) = nullptr;
void *(*operator_new)(size_t) = ::operator new;
/// Function pointer to class_<..>::init_instance
void (*init_instance)(instance *, const void *) = nullptr;
@@ -282,7 +278,7 @@ struct type_record {
}
};
inline function_call::function_call(const function_record &f, handle p) :
inline function_call::function_call(function_record &f, handle p) :
func(f), parent(p) {
args.reserve(f.nargs);
args_convert.reserve(f.nargs);

View File

@@ -17,7 +17,6 @@
#include <array>
#include <limits>
#include <tuple>
#include <type_traits>
#if defined(PYBIND11_CPP17)
# if defined(__has_include)
@@ -204,10 +203,10 @@ PYBIND11_NOINLINE inline handle get_type_handle(const std::type_info &tp, bool t
}
struct value_and_holder {
instance *inst = nullptr;
size_t index = 0u;
const detail::type_info *type = nullptr;
void **vh = nullptr;
instance *inst;
size_t index;
const detail::type_info *type;
void **vh;
// Main constructor for a found value/holder:
value_and_holder(instance *i, const detail::type_info *type, size_t vpos, size_t index) :
@@ -216,7 +215,7 @@ struct value_and_holder {
{}
// Default constructor (used to signal a value-and-holder not found by get_value_and_holder())
value_and_holder() {}
value_and_holder() : inst{nullptr} {}
// Used for past-the-end iterator
value_and_holder(size_t index) : index{index} {}
@@ -270,8 +269,8 @@ public:
struct iterator {
private:
instance *inst = nullptr;
const type_vec *types = nullptr;
instance *inst;
const type_vec *types;
value_and_holder curr;
friend struct values_and_holders;
iterator(instance *inst, const type_vec *tinfo)
@@ -571,17 +570,7 @@ public:
// Lazy allocation for unallocated values:
if (vptr == nullptr) {
auto *type = v_h.type ? v_h.type : typeinfo;
if (type->operator_new) {
vptr = type->operator_new(type->type_size);
} else {
#if defined(PYBIND11_CPP17)
if (type->type_align > __STDCPP_DEFAULT_NEW_ALIGNMENT__)
vptr = ::operator new(type->type_size,
(std::align_val_t) type->type_align);
else
#endif
vptr = ::operator new(type->type_size);
}
vptr = type->operator_new(type->type_size);
}
value = vptr;
}
@@ -785,47 +774,11 @@ template <typename T1, typename T2> struct is_copy_constructible<std::pair<T1, T
: all_of<is_copy_constructible<T1>, is_copy_constructible<T2>> {};
#endif
NAMESPACE_END(detail)
// polymorphic_type_hook<itype>::get(src, tinfo) determines whether the object pointed
// to by `src` actually is an instance of some class derived from `itype`.
// If so, it sets `tinfo` to point to the std::type_info representing that derived
// type, and returns a pointer to the start of the most-derived object of that type
// (in which `src` is a subobject; this will be the same address as `src` in most
// single inheritance cases). If not, or if `src` is nullptr, it simply returns `src`
// and leaves `tinfo` at its default value of nullptr.
//
// The default polymorphic_type_hook just returns src. A specialization for polymorphic
// types determines the runtime type of the passed object and adjusts the this-pointer
// appropriately via dynamic_cast<void*>. This is what enables a C++ Animal* to appear
// to Python as a Dog (if Dog inherits from Animal, Animal is polymorphic, Dog is
// registered with pybind11, and this Animal is in fact a Dog).
//
// You may specialize polymorphic_type_hook yourself for types that want to appear
// polymorphic to Python but do not use C++ RTTI. (This is a not uncommon pattern
// in performance-sensitive applications, used most notably in LLVM.)
template <typename itype, typename SFINAE = void>
struct polymorphic_type_hook
{
static const void *get(const itype *src, const std::type_info*&) { return src; }
};
template <typename itype>
struct polymorphic_type_hook<itype, detail::enable_if_t<std::is_polymorphic<itype>::value>>
{
static const void *get(const itype *src, const std::type_info*& type) {
type = src ? &typeid(*src) : nullptr;
return dynamic_cast<const void*>(src);
}
};
NAMESPACE_BEGIN(detail)
/// Generic type caster for objects stored on the heap
template <typename type> class type_caster_base : public type_caster_generic {
using itype = intrinsic_t<type>;
public:
static constexpr auto name = _<type>();
static PYBIND11_DESCR name() { return type_descr(_<type>()); }
type_caster_base() : type_caster_base(typeid(type)) { }
explicit type_caster_base(const std::type_info &info) : type_caster_generic(info) { }
@@ -840,28 +793,32 @@ public:
return cast(&src, return_value_policy::move, parent);
}
// Returns a (pointer, type_info) pair taking care of necessary type lookup for a
// polymorphic type (using RTTI by default, but can be overridden by specializing
// polymorphic_type_hook). If the instance isn't derived, returns the base version.
// Returns a (pointer, type_info) pair taking care of necessary RTTI type lookup for a
// polymorphic type. If the instance isn't derived, returns the non-RTTI base version.
template <typename T = itype, enable_if_t<std::is_polymorphic<T>::value, int> = 0>
static std::pair<const void *, const type_info *> src_and_type(const itype *src) {
const void *vsrc = src;
auto &cast_type = typeid(itype);
const std::type_info *instance_type = nullptr;
const void *vsrc = polymorphic_type_hook<itype>::get(src, instance_type);
if (instance_type && !same_type(cast_type, *instance_type)) {
// This is a base pointer to a derived type. If the derived type is registered
// with pybind11, we want to make the full derived object available.
// In the typical case where itype is polymorphic, we get the correct
// derived pointer (which may be != base pointer) by a dynamic_cast to
// most derived type. If itype is not polymorphic, we won't get here
// except via a user-provided specialization of polymorphic_type_hook,
// and the user has promised that no this-pointer adjustment is
// required in that case, so it's OK to use static_cast.
if (const auto *tpi = get_type_info(*instance_type))
return {vsrc, tpi};
if (vsrc) {
instance_type = &typeid(*src);
if (!same_type(cast_type, *instance_type)) {
// This is a base pointer to a derived type; if it is a pybind11-registered type, we
// can get the correct derived pointer (which may be != base pointer) by a
// dynamic_cast to most derived type:
if (auto *tpi = get_type_info(*instance_type))
return {dynamic_cast<const void *>(src), const_cast<const type_info *>(tpi)};
}
}
// Otherwise we have either a nullptr, an `itype` pointer, or an unknown derived pointer, so
// don't do a cast
return type_caster_generic::src_and_type(src, cast_type, instance_type);
return type_caster_generic::src_and_type(vsrc, cast_type, instance_type);
}
// Non-polymorphic type, so no dynamic casting; just call the generic version directly
template <typename T = itype, enable_if_t<!std::is_polymorphic<T>::value, int> = 0>
static std::pair<const void *, const type_info *> src_and_type(const itype *src) {
return type_caster_generic::src_and_type(src, typeid(itype));
}
static handle cast(const itype *src, return_value_policy policy, handle parent) {
@@ -878,7 +835,7 @@ public:
nullptr, nullptr, holder);
}
template <typename T> using cast_op_type = detail::cast_op_type<T>;
template <typename T> using cast_op_type = cast_op_type<T>;
operator itype*() { return (type *) value; }
operator itype&() { if (!value) throw reference_cast_error(); return *((itype *) value); }
@@ -928,7 +885,7 @@ private:
"std::reference_wrapper<T> caster requires T to have a caster with an `T &` operator");
public:
bool load(handle src, bool convert) { return subcaster.load(src, convert); }
static constexpr auto name = caster_t::name;
static PYBIND11_DESCR name() { return caster_t::name(); }
static handle cast(const std::reference_wrapper<type> &src, return_value_policy policy, handle parent) {
// It is definitely wrong to take ownership of this pointer, so mask that rvp
if (policy == return_value_policy::take_ownership || policy == return_value_policy::automatic)
@@ -943,7 +900,7 @@ public:
protected: \
type value; \
public: \
static constexpr auto name = py_name; \
static PYBIND11_DESCR name() { return type_descr(py_name); } \
template <typename T_, enable_if_t<std::is_same<type, remove_cv_t<T_>>::value, int> = 0> \
static handle cast(T_ *src, return_value_policy policy, handle parent) { \
if (!src) return none().release(); \
@@ -1020,34 +977,20 @@ public:
return true;
}
template<typename U = T>
static typename std::enable_if<std::is_floating_point<U>::value, handle>::type
cast(U src, return_value_policy /* policy */, handle /* parent */) {
return PyFloat_FromDouble((double) src);
}
template<typename U = T>
static typename std::enable_if<!std::is_floating_point<U>::value && std::is_signed<U>::value && (sizeof(U) <= sizeof(long)), handle>::type
cast(U src, return_value_policy /* policy */, handle /* parent */) {
return PYBIND11_LONG_FROM_SIGNED((long) src);
}
template<typename U = T>
static typename std::enable_if<!std::is_floating_point<U>::value && std::is_unsigned<U>::value && (sizeof(U) <= sizeof(unsigned long)), handle>::type
cast(U src, return_value_policy /* policy */, handle /* parent */) {
return PYBIND11_LONG_FROM_UNSIGNED((unsigned long) src);
}
template<typename U = T>
static typename std::enable_if<!std::is_floating_point<U>::value && std::is_signed<U>::value && (sizeof(U) > sizeof(long)), handle>::type
cast(U src, return_value_policy /* policy */, handle /* parent */) {
return PyLong_FromLongLong((long long) src);
}
template<typename U = T>
static typename std::enable_if<!std::is_floating_point<U>::value && std::is_unsigned<U>::value && (sizeof(U) > sizeof(unsigned long)), handle>::type
cast(U src, return_value_policy /* policy */, handle /* parent */) {
return PyLong_FromUnsignedLongLong((unsigned long long) src);
static handle cast(T src, return_value_policy /* policy */, handle /* parent */) {
if (std::is_floating_point<T>::value) {
return PyFloat_FromDouble((double) src);
} else if (sizeof(T) <= sizeof(long)) {
if (std::is_signed<T>::value)
return PyLong_FromLong((long) src);
else
return PyLong_FromUnsignedLong((unsigned long) src);
} else {
if (std::is_signed<T>::value)
return PyLong_FromLongLong((long long) src);
else
return PyLong_FromUnsignedLongLong((unsigned long long) src);
}
}
PYBIND11_TYPE_CASTER(T, _<std::is_integral<T>::value>("int", "float"));
@@ -1106,7 +1049,7 @@ public:
template <typename T> using cast_op_type = void*&;
operator void *&() { return value; }
static constexpr auto name = _("capsule");
static PYBIND11_DESCR name() { return type_descr(_("capsule")); }
private:
void *value = nullptr;
};
@@ -1349,7 +1292,7 @@ public:
return one_char;
}
static constexpr auto name = _(PYBIND11_STRING_NAME);
static PYBIND11_DESCR name() { return type_descr(_(PYBIND11_STRING_NAME)); }
template <typename _T> using cast_op_type = pybind11::detail::cast_op_type<_T>;
};
@@ -1374,7 +1317,9 @@ public:
return cast_impl(std::forward<T>(src), policy, parent, indices{});
}
static constexpr auto name = _("Tuple[") + concat(make_caster<Ts>::name...) + _("]");
static PYBIND11_DESCR name() {
return type_descr(_("Tuple[") + detail::concat(make_caster<Ts>::name()...) + _("]"));
}
template <typename T> using cast_op_type = type;
@@ -1519,7 +1464,7 @@ struct move_only_holder_caster {
auto *ptr = holder_helper<holder_type>::get(src);
return type_caster_base<type>::cast_holder(ptr, std::addressof(src));
}
static constexpr auto name = type_caster_base<type>::name;
static PYBIND11_DESCR name() { return type_caster_base<type>::name(); }
};
template <typename type, typename deleter>
@@ -1550,10 +1495,10 @@ template <typename base, typename holder> struct is_holder_type :
template <typename base, typename deleter> struct is_holder_type<base, std::unique_ptr<base, deleter>> :
std::true_type {};
template <typename T> struct handle_type_name { static constexpr auto name = _<T>(); };
template <> struct handle_type_name<bytes> { static constexpr auto name = _(PYBIND11_BYTES_NAME); };
template <> struct handle_type_name<args> { static constexpr auto name = _("*args"); };
template <> struct handle_type_name<kwargs> { static constexpr auto name = _("**kwargs"); };
template <typename T> struct handle_type_name { static PYBIND11_DESCR name() { return _<T>(); } };
template <> struct handle_type_name<bytes> { static PYBIND11_DESCR name() { return _(PYBIND11_BYTES_NAME); } };
template <> struct handle_type_name<args> { static PYBIND11_DESCR name() { return _("*args"); } };
template <> struct handle_type_name<kwargs> { static PYBIND11_DESCR name() { return _("**kwargs"); } };
template <typename type>
struct pyobject_caster {
@@ -1571,7 +1516,7 @@ struct pyobject_caster {
static handle cast(const handle &src, return_value_policy /* policy */, handle /* parent */) {
return src.inc_ref();
}
PYBIND11_TYPE_CASTER(type, handle_type_name<type>::name);
PYBIND11_TYPE_CASTER(type, handle_type_name<type>::name());
};
template <typename T>
@@ -1611,8 +1556,7 @@ template <typename T> using move_never = none_of<move_always<T>, move_if_unrefer
// everything else returns a reference/pointer to a local variable.
template <typename type> using cast_is_temporary_value_reference = bool_constant<
(std::is_reference<type>::value || std::is_pointer<type>::value) &&
!std::is_base_of<type_caster_generic, make_caster<type>>::value &&
!std::is_same<intrinsic_t<type>, void>::value
!std::is_base_of<type_caster_generic, make_caster<type>>::value
>;
// When a value returned from a C++ function is being cast back to Python, we almost always want to
@@ -1625,9 +1569,8 @@ template <typename Return, typename SFINAE = void> struct return_value_policy_ov
template <typename Return> struct return_value_policy_override<Return,
detail::enable_if_t<std::is_base_of<type_caster_generic, make_caster<Return>>::value, void>> {
static return_value_policy policy(return_value_policy p) {
return !std::is_lvalue_reference<Return>::value &&
!std::is_pointer<Return>::value
? return_value_policy::move : p;
return !std::is_lvalue_reference<Return>::value && !std::is_pointer<Return>::value
? return_value_policy::move : p;
}
};
@@ -1855,7 +1798,7 @@ struct function_record;
/// Internal data associated with a single function call
struct function_call {
function_call(const function_record &f, handle p); // Implementation in attr.h
function_call(function_record &f, handle p); // Implementation in attr.h
/// The function data:
const function_record &func;
@@ -1897,7 +1840,7 @@ public:
static constexpr bool has_kwargs = kwargs_pos < 0;
static constexpr bool has_args = args_pos < 0;
static constexpr auto arg_names = concat(type_descr(make_caster<Args>::name)...);
static PYBIND11_DESCR arg_names() { return detail::concat(make_caster<Args>::name()...); }
bool load_args(function_call &call) {
return load_impl_sequence(call, indices{});
@@ -2116,13 +2059,9 @@ object object_api<Derived>::call(Args &&...args) const {
NAMESPACE_END(detail)
#define PYBIND11_MAKE_OPAQUE(...) \
#define PYBIND11_MAKE_OPAQUE(Type) \
namespace pybind11 { namespace detail { \
template<> class type_caster<__VA_ARGS__> : public type_caster_base<__VA_ARGS__> { }; \
template<> class type_caster<Type> : public type_caster_base<Type> { }; \
}}
/// Lets you pass a type containing a `,` through a macro parameter without needing a separate
/// typedef, e.g.: `PYBIND11_OVERLOAD(PYBIND11_TYPE(ReturnType<A, B>), PYBIND11_TYPE(Parent<C, D>), f, arg)`
#define PYBIND11_TYPE(...) __VA_ARGS__
NAMESPACE_END(PYBIND11_NAMESPACE)

View File

@@ -25,13 +25,9 @@ template <typename T> struct format_descriptor<std::complex<T>, detail::enable_i
static std::string format() { return std::string(value); }
};
#ifndef PYBIND11_CPP17
template <typename T> constexpr const char format_descriptor<
std::complex<T>, detail::enable_if_t<std::is_floating_point<T>::value>>::value[3];
#endif
NAMESPACE_BEGIN(detail)
template <typename T> struct is_fmt_numeric<std::complex<T>, detail::enable_if_t<std::is_floating_point<T>::value>> {

View File

@@ -10,7 +10,6 @@
#pragma once
#include "../attr.h"
#include "../options.h"
NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
NAMESPACE_BEGIN(detail)
@@ -290,9 +289,13 @@ extern "C" inline int pybind11_object_init(PyObject *self, PyObject *, PyObject
inline void add_patient(PyObject *nurse, PyObject *patient) {
auto &internals = get_internals();
auto instance = reinterpret_cast<detail::instance *>(nurse);
auto &current_patients = internals.patients[nurse];
instance->has_patients = true;
for (auto &p : current_patients)
if (p == patient)
return;
Py_INCREF(patient);
internals.patients[nurse].push_back(patient);
current_patients.push_back(patient);
}
inline void clear_patients(PyObject *self) {
@@ -469,7 +472,7 @@ extern "C" inline int pybind11_getbuffer(PyObject *obj, Py_buffer *view, int fla
if (tinfo && tinfo->get_buffer)
break;
}
if (view == nullptr || !tinfo || !tinfo->get_buffer) {
if (view == nullptr || obj == nullptr || !tinfo || !tinfo->get_buffer) {
if (view)
view->obj = nullptr;
PyErr_SetString(PyExc_BufferError, "pybind11_getbuffer(): Internal error");

View File

@@ -93,8 +93,8 @@
#endif
#define PYBIND11_VERSION_MAJOR 2
#define PYBIND11_VERSION_MINOR 3
#define PYBIND11_VERSION_PATCH 0
#define PYBIND11_VERSION_MINOR 2
#define PYBIND11_VERSION_PATCH 4
/// Include Python header, disable linking to pythonX_d.lib on Windows in debug mode
#if defined(_MSC_VER)
@@ -159,8 +159,6 @@
#define PYBIND11_BYTES_SIZE PyBytes_Size
#define PYBIND11_LONG_CHECK(o) PyLong_Check(o)
#define PYBIND11_LONG_AS_LONGLONG(o) PyLong_AsLongLong(o)
#define PYBIND11_LONG_FROM_SIGNED(o) PyLong_FromSsize_t((ssize_t) o)
#define PYBIND11_LONG_FROM_UNSIGNED(o) PyLong_FromSize_t((size_t) o)
#define PYBIND11_BYTES_NAME "bytes"
#define PYBIND11_STRING_NAME "str"
#define PYBIND11_SLICE_OBJECT PyObject
@@ -183,8 +181,6 @@
#define PYBIND11_BYTES_SIZE PyString_Size
#define PYBIND11_LONG_CHECK(o) (PyInt_Check(o) || PyLong_Check(o))
#define PYBIND11_LONG_AS_LONGLONG(o) (PyInt_Check(o) ? (long long) PyLong_AsLong(o) : PyLong_AsLongLong(o))
#define PYBIND11_LONG_FROM_SIGNED(o) PyInt_FromSsize_t((ssize_t) o) // Returns long if needed.
#define PYBIND11_LONG_FROM_UNSIGNED(o) PyInt_FromSize_t((size_t) o) // Returns long if needed.
#define PYBIND11_BYTES_NAME "str"
#define PYBIND11_STRING_NAME "unicode"
#define PYBIND11_SLICE_OBJECT PySliceObject
@@ -212,31 +208,6 @@ extern "C" {
#define PYBIND11_TOSTRING(x) PYBIND11_STRINGIFY(x)
#define PYBIND11_CONCAT(first, second) first##second
#define PYBIND11_CHECK_PYTHON_VERSION \
{ \
const char *compiled_ver = PYBIND11_TOSTRING(PY_MAJOR_VERSION) \
"." PYBIND11_TOSTRING(PY_MINOR_VERSION); \
const char *runtime_ver = Py_GetVersion(); \
size_t len = std::strlen(compiled_ver); \
if (std::strncmp(runtime_ver, compiled_ver, len) != 0 \
|| (runtime_ver[len] >= '0' && runtime_ver[len] <= '9')) { \
PyErr_Format(PyExc_ImportError, \
"Python version mismatch: module was compiled for Python %s, " \
"but the interpreter version is incompatible: %s.", \
compiled_ver, runtime_ver); \
return nullptr; \
} \
}
#define PYBIND11_CATCH_INIT_EXCEPTIONS \
catch (pybind11::error_already_set &e) { \
PyErr_SetString(PyExc_ImportError, e.what()); \
return nullptr; \
} catch (const std::exception &e) { \
PyErr_SetString(PyExc_ImportError, e.what()); \
return nullptr; \
} \
/** \rst
***Deprecated in favor of PYBIND11_MODULE***
@@ -256,10 +227,27 @@ extern "C" {
PYBIND11_DEPRECATED("PYBIND11_PLUGIN is deprecated, use PYBIND11_MODULE") \
static PyObject *pybind11_init(); \
PYBIND11_PLUGIN_IMPL(name) { \
PYBIND11_CHECK_PYTHON_VERSION \
int major, minor; \
if (sscanf(Py_GetVersion(), "%i.%i", &major, &minor) != 2) { \
PyErr_SetString(PyExc_ImportError, "Can't parse Python version."); \
return nullptr; \
} else if (major != PY_MAJOR_VERSION || minor != PY_MINOR_VERSION) { \
PyErr_Format(PyExc_ImportError, \
"Python version mismatch: module was compiled for " \
"version %i.%i, while the interpreter is running " \
"version %i.%i.", PY_MAJOR_VERSION, PY_MINOR_VERSION, \
major, minor); \
return nullptr; \
} \
try { \
return pybind11_init(); \
} PYBIND11_CATCH_INIT_EXCEPTIONS \
} catch (pybind11::error_already_set &e) { \
PyErr_SetString(PyExc_ImportError, e.what()); \
return nullptr; \
} catch (const std::exception &e) { \
PyErr_SetString(PyExc_ImportError, e.what()); \
return nullptr; \
} \
} \
PyObject *pybind11_init()
@@ -283,12 +271,29 @@ extern "C" {
#define PYBIND11_MODULE(name, variable) \
static void PYBIND11_CONCAT(pybind11_init_, name)(pybind11::module &); \
PYBIND11_PLUGIN_IMPL(name) { \
PYBIND11_CHECK_PYTHON_VERSION \
int major, minor; \
if (sscanf(Py_GetVersion(), "%i.%i", &major, &minor) != 2) { \
PyErr_SetString(PyExc_ImportError, "Can't parse Python version."); \
return nullptr; \
} else if (major != PY_MAJOR_VERSION || minor != PY_MINOR_VERSION) { \
PyErr_Format(PyExc_ImportError, \
"Python version mismatch: module was compiled for " \
"version %i.%i, while the interpreter is running " \
"version %i.%i.", PY_MAJOR_VERSION, PY_MINOR_VERSION, \
major, minor); \
return nullptr; \
} \
auto m = pybind11::module(PYBIND11_TOSTRING(name)); \
try { \
PYBIND11_CONCAT(pybind11_init_, name)(m); \
return m.ptr(); \
} PYBIND11_CATCH_INIT_EXCEPTIONS \
} catch (pybind11::error_already_set &e) { \
PyErr_SetString(PyExc_ImportError, e.what()); \
return nullptr; \
} catch (const std::exception &e) { \
PyErr_SetString(PyExc_ImportError, e.what()); \
return nullptr; \
} \
} \
void PYBIND11_CONCAT(pybind11_init_, name)(pybind11::module &variable)
@@ -386,7 +391,7 @@ struct instance {
void *simple_value_holder[1 + instance_simple_holder_in_ptrs()];
nonsimple_values_and_holders nonsimple;
};
/// Weak references
/// Weak references (needed for keep alive):
PyObject *weakrefs;
/// If true, the pointer is owned which means we're free to manage it with a holder.
bool owned : 1;
@@ -403,10 +408,10 @@ struct instance {
* (which is typically the size of two pointers), or when multiple inheritance is used on the
* python side. Non-simple layout allocates the required amount of memory to have multiple
* bound C++ classes as parents. Under this layout, `nonsimple.values_and_holders` is set to a
* pointer to allocated space of the required space to hold a sequence of value pointers and
* pointer to allocated space of the required space to hold a a sequence of value pointers and
* holders followed `status`, a set of bit flags (1 byte each), i.e.
* [val1*][holder1][val2*][holder2]...[bb...] where each [block] is rounded up to a multiple of
* `sizeof(void *)`. `nonsimple.status` is, for convenience, a pointer to the
* `sizeof(void *)`. `nonsimple.holder_constructed` is, for convenience, a pointer to the
* beginning of the [bb...] block (but not independently allocated).
*
* Status bits indicate whether the associated holder is constructed (&
@@ -579,11 +584,6 @@ template <typename T, typename... Us> using deferred_t = typename deferred_type<
template <typename Base, typename Derived> using is_strict_base_of = bool_constant<
std::is_base_of<Base, Derived>::value && !std::is_same<Base, Derived>::value>;
/// Like is_base_of, but also requires that the base type is accessible (i.e. that a Derived pointer
/// can be converted to a Base pointer)
template <typename Base, typename Derived> using is_accessible_base_of = bool_constant<
std::is_base_of<Base, Derived>::value && std::is_convertible<Derived *, Base *>::value>;
template <template<typename...> class Base>
struct is_template_base_of_impl {
template <typename... Us> static std::true_type check(Base<Us...> *);
@@ -702,13 +702,9 @@ template <typename T> struct format_descriptor<T, detail::enable_if_t<std::is_ar
static std::string format() { return std::string(1, c); }
};
#if !defined(PYBIND11_CPP17)
template <typename T> constexpr const char format_descriptor<
T, detail::enable_if_t<std::is_arithmetic<T>::value>>::value[2];
#endif
/// RAII wrapper that temporarily clears any Python error state
struct error_scope {
PyObject *type, *value, *trace;

View File

@@ -1,5 +1,6 @@
/*
pybind11/detail/descr.h: Helper type for concatenating type signatures at compile time
pybind11/detail/descr.h: Helper type for concatenating type signatures
either at runtime (C++11) or compile time (C++14)
Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
@@ -14,87 +15,171 @@
NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
NAMESPACE_BEGIN(detail)
#if !defined(_MSC_VER)
# define PYBIND11_DESCR_CONSTEXPR static constexpr
#else
# define PYBIND11_DESCR_CONSTEXPR const
#endif
/* Concatenate type signatures at compile time using C++14 */
#if defined(PYBIND11_CPP14) && !defined(_MSC_VER)
#define PYBIND11_CONSTEXPR_DESCR
/* Concatenate type signatures at compile time */
template <size_t N, typename... Ts>
struct descr {
char text[N + 1];
template <size_t Size1, size_t Size2> class descr {
template <size_t Size1_, size_t Size2_> friend class descr;
public:
constexpr descr(char const (&text) [Size1+1], const std::type_info * const (&types)[Size2+1])
: descr(text, types,
make_index_sequence<Size1>(),
make_index_sequence<Size2>()) { }
constexpr descr() : text{'\0'} { }
constexpr descr(char const (&s)[N+1]) : descr(s, make_index_sequence<N>()) { }
constexpr const char *text() const { return m_text; }
constexpr const std::type_info * const * types() const { return m_types; }
template <size_t... Is>
constexpr descr(char const (&s)[N+1], index_sequence<Is...>) : text{s[Is]..., '\0'} { }
template <typename... Chars>
constexpr descr(char c, Chars... cs) : text{c, static_cast<char>(cs)..., '\0'} { }
static constexpr std::array<const std::type_info *, sizeof...(Ts) + 1> types() {
return {{&typeid(Ts)..., nullptr}};
template <size_t OtherSize1, size_t OtherSize2>
constexpr descr<Size1 + OtherSize1, Size2 + OtherSize2> operator+(const descr<OtherSize1, OtherSize2> &other) const {
return concat(other,
make_index_sequence<Size1>(),
make_index_sequence<Size2>(),
make_index_sequence<OtherSize1>(),
make_index_sequence<OtherSize2>());
}
protected:
template <size_t... Indices1, size_t... Indices2>
constexpr descr(
char const (&text) [Size1+1],
const std::type_info * const (&types) [Size2+1],
index_sequence<Indices1...>, index_sequence<Indices2...>)
: m_text{text[Indices1]..., '\0'},
m_types{types[Indices2]..., nullptr } {}
template <size_t OtherSize1, size_t OtherSize2, size_t... Indices1,
size_t... Indices2, size_t... OtherIndices1, size_t... OtherIndices2>
constexpr descr<Size1 + OtherSize1, Size2 + OtherSize2>
concat(const descr<OtherSize1, OtherSize2> &other,
index_sequence<Indices1...>, index_sequence<Indices2...>,
index_sequence<OtherIndices1...>, index_sequence<OtherIndices2...>) const {
return descr<Size1 + OtherSize1, Size2 + OtherSize2>(
{ m_text[Indices1]..., other.m_text[OtherIndices1]..., '\0' },
{ m_types[Indices2]..., other.m_types[OtherIndices2]..., nullptr }
);
}
protected:
char m_text[Size1 + 1];
const std::type_info * m_types[Size2 + 1];
};
template <size_t N1, size_t N2, typename... Ts1, typename... Ts2, size_t... Is1, size_t... Is2>
constexpr descr<N1 + N2, Ts1..., Ts2...> plus_impl(const descr<N1, Ts1...> &a, const descr<N2, Ts2...> &b,
index_sequence<Is1...>, index_sequence<Is2...>) {
return {a.text[Is1]..., b.text[Is2]...};
template <size_t Size> constexpr descr<Size - 1, 0> _(char const(&text)[Size]) {
return descr<Size - 1, 0>(text, { nullptr });
}
template <size_t N1, size_t N2, typename... Ts1, typename... Ts2>
constexpr descr<N1 + N2, Ts1..., Ts2...> operator+(const descr<N1, Ts1...> &a, const descr<N2, Ts2...> &b) {
return plus_impl(a, b, make_index_sequence<N1>(), make_index_sequence<N2>());
}
template <size_t N>
constexpr descr<N - 1> _(char const(&text)[N]) { return descr<N - 1>(text); }
constexpr descr<0> _(char const(&)[1]) { return {}; }
template <size_t Rem, size_t... Digits> struct int_to_str : int_to_str<Rem/10, Rem%10, Digits...> { };
template <size_t...Digits> struct int_to_str<0, Digits...> {
static constexpr auto digits = descr<sizeof...(Digits)>(('0' + Digits)...);
static constexpr auto digits = descr<sizeof...(Digits), 0>({ ('0' + Digits)..., '\0' }, { nullptr });
};
// Ternary description (like std::conditional)
template <bool B, size_t N1, size_t N2>
constexpr enable_if_t<B, descr<N1 - 1>> _(char const(&text1)[N1], char const(&)[N2]) {
template <bool B, size_t Size1, size_t Size2>
constexpr enable_if_t<B, descr<Size1 - 1, 0>> _(char const(&text1)[Size1], char const(&)[Size2]) {
return _(text1);
}
template <bool B, size_t N1, size_t N2>
constexpr enable_if_t<!B, descr<N2 - 1>> _(char const(&)[N1], char const(&text2)[N2]) {
template <bool B, size_t Size1, size_t Size2>
constexpr enable_if_t<!B, descr<Size2 - 1, 0>> _(char const(&)[Size1], char const(&text2)[Size2]) {
return _(text2);
}
template <bool B, typename T1, typename T2>
constexpr enable_if_t<B, T1> _(const T1 &d, const T2 &) { return d; }
template <bool B, typename T1, typename T2>
constexpr enable_if_t<!B, T2> _(const T1 &, const T2 &d) { return d; }
template <bool B, size_t SizeA1, size_t SizeA2, size_t SizeB1, size_t SizeB2>
constexpr enable_if_t<B, descr<SizeA1, SizeA2>> _(descr<SizeA1, SizeA2> d, descr<SizeB1, SizeB2>) { return d; }
template <bool B, size_t SizeA1, size_t SizeA2, size_t SizeB1, size_t SizeB2>
constexpr enable_if_t<!B, descr<SizeB1, SizeB2>> _(descr<SizeA1, SizeA2>, descr<SizeB1, SizeB2> d) { return d; }
template <size_t Size> auto constexpr _() -> decltype(int_to_str<Size / 10, Size % 10>::digits) {
return int_to_str<Size / 10, Size % 10>::digits;
}
template <typename Type> constexpr descr<1, Type> _() { return {'%'}; }
constexpr descr<0> concat() { return {}; }
template <size_t N, typename... Ts>
constexpr descr<N, Ts...> concat(const descr<N, Ts...> &descr) { return descr; }
template <size_t N, typename... Ts, typename... Args>
constexpr auto concat(const descr<N, Ts...> &d, const Args &...args)
-> decltype(std::declval<descr<N + 2, Ts...>>() + concat(args...)) {
return d + _(", ") + concat(args...);
template <typename Type> constexpr descr<1, 1> _() {
return descr<1, 1>({ '%', '\0' }, { &typeid(Type), nullptr });
}
template <size_t N, typename... Ts>
constexpr descr<N + 2, Ts...> type_descr(const descr<N, Ts...> &descr) {
return _("{") + descr + _("}");
inline constexpr descr<0, 0> concat() { return _(""); }
template <size_t Size1, size_t Size2, typename... Args> auto constexpr concat(descr<Size1, Size2> descr) { return descr; }
template <size_t Size1, size_t Size2, typename... Args> auto constexpr concat(descr<Size1, Size2> descr, Args&&... args) { return descr + _(", ") + concat(args...); }
template <size_t Size1, size_t Size2> auto constexpr type_descr(descr<Size1, Size2> descr) { return _("{") + descr + _("}"); }
#define PYBIND11_DESCR constexpr auto
#else /* Simpler C++11 implementation based on run-time memory allocation and copying */
class descr {
public:
PYBIND11_NOINLINE descr(const char *text, const std::type_info * const * types) {
size_t nChars = len(text), nTypes = len(types);
m_text = new char[nChars];
m_types = new const std::type_info *[nTypes];
memcpy(m_text, text, nChars * sizeof(char));
memcpy(m_types, types, nTypes * sizeof(const std::type_info *));
}
PYBIND11_NOINLINE descr operator+(descr &&d2) && {
descr r;
size_t nChars1 = len(m_text), nTypes1 = len(m_types);
size_t nChars2 = len(d2.m_text), nTypes2 = len(d2.m_types);
r.m_text = new char[nChars1 + nChars2 - 1];
r.m_types = new const std::type_info *[nTypes1 + nTypes2 - 1];
memcpy(r.m_text, m_text, (nChars1-1) * sizeof(char));
memcpy(r.m_text + nChars1 - 1, d2.m_text, nChars2 * sizeof(char));
memcpy(r.m_types, m_types, (nTypes1-1) * sizeof(std::type_info *));
memcpy(r.m_types + nTypes1 - 1, d2.m_types, nTypes2 * sizeof(std::type_info *));
delete[] m_text; delete[] m_types;
delete[] d2.m_text; delete[] d2.m_types;
return r;
}
char *text() { return m_text; }
const std::type_info * * types() { return m_types; }
protected:
PYBIND11_NOINLINE descr() { }
template <typename T> static size_t len(const T *ptr) { // return length including null termination
const T *it = ptr;
while (*it++ != (T) 0)
;
return static_cast<size_t>(it - ptr);
}
const std::type_info **m_types = nullptr;
char *m_text = nullptr;
};
/* The 'PYBIND11_NOINLINE inline' combinations below are intentional to get the desired linkage while producing as little object code as possible */
PYBIND11_NOINLINE inline descr _(const char *text) {
const std::type_info *types[1] = { nullptr };
return descr(text, types);
}
template <bool B> PYBIND11_NOINLINE enable_if_t<B, descr> _(const char *text1, const char *) { return _(text1); }
template <bool B> PYBIND11_NOINLINE enable_if_t<!B, descr> _(char const *, const char *text2) { return _(text2); }
template <bool B> PYBIND11_NOINLINE enable_if_t<B, descr> _(descr d, descr) { return d; }
template <bool B> PYBIND11_NOINLINE enable_if_t<!B, descr> _(descr, descr d) { return d; }
template <typename Type> PYBIND11_NOINLINE descr _() {
const std::type_info *types[2] = { &typeid(Type), nullptr };
return descr("%", types);
}
template <size_t Size> PYBIND11_NOINLINE descr _() {
const std::type_info *types[1] = { nullptr };
return descr(std::to_string(Size).c_str(), types);
}
PYBIND11_NOINLINE inline descr concat() { return _(""); }
PYBIND11_NOINLINE inline descr concat(descr &&d) { return d; }
template <typename... Args> PYBIND11_NOINLINE descr concat(descr &&d, Args&&... args) { return std::move(d) + _(", ") + concat(std::forward<Args>(args)...); }
PYBIND11_NOINLINE inline descr type_descr(descr&& d) { return _("{") + std::move(d) + _("}"); }
#define PYBIND11_DESCR ::pybind11::detail::descr
#endif
NAMESPACE_END(detail)
NAMESPACE_END(PYBIND11_NAMESPACE)

View File

@@ -24,7 +24,7 @@ public:
template <typename> using cast_op_type = value_and_holder &;
operator value_and_holder &() { return *value; }
static constexpr auto name = _<value_and_holder>();
static PYBIND11_DESCR name() { return type_descr(_<value_and_holder>()); }
private:
value_and_holder *value = nullptr;

View File

@@ -23,7 +23,7 @@ inline PyObject *make_object_base_type(PyTypeObject *metaclass);
#if PY_VERSION_HEX >= 0x03070000
# define PYBIND11_TLS_KEY_INIT(var) Py_tss_t *var = nullptr
# define PYBIND11_TLS_GET_VALUE(key) PyThread_tss_get((key))
# define PYBIND11_TLS_REPLACE_VALUE(key, value) PyThread_tss_set((key), (value))
# define PYBIND11_TLS_REPLACE_VALUE(key, value) PyThread_tss_set((key), (tstate))
# define PYBIND11_TLS_DELETE_VALUE(key) PyThread_tss_set((key), nullptr)
#else
// Usually an int but a long on Cygwin64 with Python 3.x
@@ -116,7 +116,7 @@ struct internals {
struct type_info {
PyTypeObject *type;
const std::type_info *cpptype;
size_t type_size, type_align, holder_size_in_ptrs;
size_t type_size, holder_size_in_ptrs;
void *(*operator_new)(size_t);
void (*init_instance)(instance *, const void *);
void (*dealloc)(value_and_holder &v_h);
@@ -138,13 +138,7 @@ struct type_info {
};
/// Tracks the `internals` and `type_info` ABI version independent of the main library version
#define PYBIND11_INTERNALS_VERSION 3
#if defined(_DEBUG)
# define PYBIND11_BUILD_TYPE "_debug"
#else
# define PYBIND11_BUILD_TYPE ""
#endif
#define PYBIND11_INTERNALS_VERSION 2
#if defined(WITH_THREAD)
# define PYBIND11_INTERNALS_KIND ""
@@ -153,10 +147,10 @@ struct type_info {
#endif
#define PYBIND11_INTERNALS_ID "__pybind11_internals_v" \
PYBIND11_TOSTRING(PYBIND11_INTERNALS_VERSION) PYBIND11_INTERNALS_KIND PYBIND11_BUILD_TYPE "__"
PYBIND11_TOSTRING(PYBIND11_INTERNALS_VERSION) PYBIND11_INTERNALS_KIND "__"
#define PYBIND11_MODULE_LOCAL_ID "__pybind11_module_local_v" \
PYBIND11_TOSTRING(PYBIND11_INTERNALS_VERSION) PYBIND11_INTERNALS_KIND PYBIND11_BUILD_TYPE "__"
PYBIND11_TOSTRING(PYBIND11_INTERNALS_VERSION) PYBIND11_INTERNALS_KIND "__"
/// Each module locally stores a pointer to the `internals` data. The data
/// itself is shared among modules with the same `PYBIND11_INTERNALS_ID`.

View File

@@ -16,8 +16,6 @@
#include <cxxabi.h>
#endif
#include "common.h"
NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
NAMESPACE_BEGIN(detail)
/// Erase all occurrences of a substring

View File

@@ -17,11 +17,6 @@
# pragma GCC diagnostic push
# pragma GCC diagnostic ignored "-Wconversion"
# pragma GCC diagnostic ignored "-Wdeprecated-declarations"
# ifdef __clang__
// Eigen generates a bunch of implicit-copy-constructor-is-deprecated warnings with -Wdeprecated
// under Clang, so disable that warning here:
# pragma GCC diagnostic ignored "-Wdeprecated"
# endif
# if __GNUC__ >= 7
# pragma GCC diagnostic ignored "-Wint-in-bool-context"
# endif
@@ -186,26 +181,28 @@ template <typename Type_> struct EigenProps {
}
}
static constexpr bool show_writeable = is_eigen_dense_map<Type>::value && is_eigen_mutable_map<Type>::value;
static constexpr bool show_order = is_eigen_dense_map<Type>::value;
static constexpr bool show_c_contiguous = show_order && requires_row_major;
static constexpr bool show_f_contiguous = !show_c_contiguous && show_order && requires_col_major;
static PYBIND11_DESCR descriptor() {
constexpr bool show_writeable = is_eigen_dense_map<Type>::value && is_eigen_mutable_map<Type>::value;
constexpr bool show_order = is_eigen_dense_map<Type>::value;
constexpr bool show_c_contiguous = show_order && requires_row_major;
constexpr bool show_f_contiguous = !show_c_contiguous && show_order && requires_col_major;
static constexpr auto descriptor =
_("numpy.ndarray[") + npy_format_descriptor<Scalar>::name +
_("[") + _<fixed_rows>(_<(size_t) rows>(), _("m")) +
_(", ") + _<fixed_cols>(_<(size_t) cols>(), _("n")) +
_("]") +
// For a reference type (e.g. Ref<MatrixXd>) we have other constraints that might need to be
// satisfied: writeable=True (for a mutable reference), and, depending on the map's stride
// options, possibly f_contiguous or c_contiguous. We include them in the descriptor output
// to provide some hint as to why a TypeError is occurring (otherwise it can be confusing to
// see that a function accepts a 'numpy.ndarray[float64[3,2]]' and an error message that you
// *gave* a numpy.ndarray of the right type and dimensions.
_<show_writeable>(", flags.writeable", "") +
_<show_c_contiguous>(", flags.c_contiguous", "") +
_<show_f_contiguous>(", flags.f_contiguous", "") +
_("]");
return type_descr(_("numpy.ndarray[") + npy_format_descriptor<Scalar>::name() +
_("[") + _<fixed_rows>(_<(size_t) rows>(), _("m")) +
_(", ") + _<fixed_cols>(_<(size_t) cols>(), _("n")) +
_("]") +
// For a reference type (e.g. Ref<MatrixXd>) we have other constraints that might need to be
// satisfied: writeable=True (for a mutable reference), and, depending on the map's stride
// options, possibly f_contiguous or c_contiguous. We include them in the descriptor output
// to provide some hint as to why a TypeError is occurring (otherwise it can be confusing to
// see that a function accepts a 'numpy.ndarray[float64[3,2]]' and an error message that you
// *gave* a numpy.ndarray of the right type and dimensions.
_<show_writeable>(", flags.writeable", "") +
_<show_c_contiguous>(", flags.c_contiguous", "") +
_<show_f_contiguous>(", flags.f_contiguous", "") +
_("]")
);
}
};
// Casts an Eigen type to numpy array. If given a base, the numpy array references the src data,
@@ -342,7 +339,7 @@ public:
return cast_impl(src, policy, parent);
}
static constexpr auto name = props::descriptor;
static PYBIND11_DESCR name() { return props::descriptor(); }
operator Type*() { return &value; }
operator Type&() { return value; }
@@ -382,7 +379,7 @@ public:
}
}
static constexpr auto name = props::descriptor;
static PYBIND11_DESCR name() { return props::descriptor(); }
// Explicitly delete these: support python -> C++ conversion on these (i.e. these can be return
// types but not bound arguments). We still provide them (with an explicitly delete) so that
@@ -527,7 +524,7 @@ public:
}
static handle cast(const Type *src, return_value_policy policy, handle parent) { return cast(*src, policy, parent); }
static constexpr auto name = props::descriptor;
static PYBIND11_DESCR name() { return props::descriptor(); }
// Explicitly delete these: support python -> C++ conversion on these (i.e. these can be return
// types but not bound arguments). We still provide them (with an explicitly delete) so that
@@ -594,7 +591,7 @@ struct type_caster<Type, enable_if_t<is_eigen_sparse<Type>::value>> {
}
PYBIND11_TYPE_CASTER(Type, _<(Type::IsRowMajor) != 0>("scipy.sparse.csr_matrix[", "scipy.sparse.csc_matrix[")
+ npy_format_descriptor<Scalar>::name + _("]"));
+ npy_format_descriptor<Scalar>::name() + _("]"));
};
NAMESPACE_END(detail)

View File

@@ -90,14 +90,8 @@ NAMESPACE_END(detail)
Initialize the Python interpreter. No other pybind11 or CPython API functions can be
called before this is done; with the exception of `PYBIND11_EMBEDDED_MODULE`. The
optional parameter can be used to skip the registration of signal handlers (see the
`Python documentation`_ for details). Calling this function again after the interpreter
Python documentation for details). Calling this function again after the interpreter
has already been initialized is a fatal error.
If initializing the Python interpreter fails, then the program is terminated. (This
is controlled by the CPython runtime and is an exception to pybind11's normal behavior
of throwing exceptions on errors.)
.. _Python documentation: https://docs.python.org/3/c-api/init.html#c.Py_InitializeEx
\endrst */
inline void initialize_interpreter(bool init_signal_handlers = true) {
if (Py_IsInitialized())

View File

@@ -54,20 +54,9 @@ public:
}
}
// ensure GIL is held during functor destruction
struct func_handle {
function f;
func_handle(function&& f_) : f(std::move(f_)) {}
func_handle(const func_handle&) = default;
~func_handle() {
gil_scoped_acquire acq;
function kill_f(std::move(f));
}
};
value = [hfunc = func_handle(std::move(func))](Args... args) -> Return {
value = [func](Args... args) -> Return {
gil_scoped_acquire acq;
object retval(hfunc.f(std::forward<Args>(args)...));
object retval(func(std::forward<Args>(args)...));
/* Visual studio 2015 parser issue: need parentheses around this expression */
return (retval.template cast<Return>());
};
@@ -86,8 +75,10 @@ public:
return cpp_function(std::forward<Func>(f_), policy).release();
}
PYBIND11_TYPE_CASTER(type, _("Callable[[") + concat(make_caster<Args>::name...) + _("], ")
+ make_caster<retval_type>::name + _("]"));
PYBIND11_TYPE_CASTER(type, _("Callable[[") +
argument_loader<Args...>::arg_names() + _("], ") +
make_caster<retval_type>::name() +
_("]"));
};
NAMESPACE_END(detail)

View File

@@ -25,8 +25,7 @@ class pythonbuf : public std::streambuf {
private:
using traits_type = std::streambuf::traits_type;
const size_t buf_size;
std::unique_ptr<char[]> d_buffer;
char d_buffer[1024];
object pywrite;
object pyflush;
@@ -43,11 +42,8 @@ private:
// This subtraction cannot be negative, so dropping the sign
str line(pbase(), static_cast<size_t>(pptr() - pbase()));
{
gil_scoped_acquire tmp;
pywrite(line);
pyflush();
}
pywrite(line);
pyflush();
setp(pbase(), epptr());
}
@@ -55,13 +51,10 @@ private:
}
public:
pythonbuf(object pyostream, size_t buffer_size = 1024)
: buf_size(buffer_size),
d_buffer(new char[buf_size]),
pywrite(pyostream.attr("write")),
pythonbuf(object pyostream)
: pywrite(pyostream.attr("write")),
pyflush(pyostream.attr("flush")) {
setp(d_buffer.get(), d_buffer.get() + buf_size - 1);
setp(d_buffer, d_buffer + sizeof(d_buffer) - 1);
}
/// Sync before destroy
@@ -201,7 +194,7 @@ inline class_<detail::OstreamRedirect> add_ostream_redirect(module m, std::strin
return class_<detail::OstreamRedirect>(m, name.c_str(), module_local())
.def(init<bool,bool>(), arg("stdout")=true, arg("stderr")=true)
.def("__enter__", &detail::OstreamRedirect::enter)
.def("__exit__", [](detail::OstreamRedirect &self_, args) { self_.exit(); });
.def("__exit__", [](detail::OstreamRedirect &self, args) { self.exit(); });
}
NAMESPACE_END(PYBIND11_NAMESPACE)

View File

@@ -18,9 +18,9 @@
#include <cstring>
#include <sstream>
#include <string>
#include <initializer_list>
#include <functional>
#include <utility>
#include <vector>
#include <typeindex>
#if defined(_MSC_VER)
@@ -250,7 +250,7 @@ template <typename T> struct array_info_scalar {
typedef T type;
static constexpr bool is_array = false;
static constexpr bool is_empty = false;
static constexpr auto extents = _("");
static PYBIND11_DESCR extents() { return _(""); }
static void append_extents(list& /* shape */) { }
};
// Computes underlying type and a comma-separated list of extents for array
@@ -269,9 +269,15 @@ template <typename T, size_t N> struct array_info<std::array<T, N>> {
array_info<T>::append_extents(shape);
}
static constexpr auto extents = _<array_info<T>::is_array>(
concat(_<N>(), array_info<T>::extents), _<N>()
);
template<typename T2 = T, enable_if_t<!array_info<T2>::is_array, int> = 0>
static PYBIND11_DESCR extents() {
return _<N>();
}
template<typename T2 = T, enable_if_t<array_info<T2>::is_array, int> = 0>
static PYBIND11_DESCR extents() {
return concat(_<N>(), array_info<T>::extents());
}
};
// For numpy we have special handling for arrays of characters, so we don't include
// the size in the array extents.
@@ -440,7 +446,7 @@ public:
/// This is essentially the same as calling numpy.dtype(args) in Python.
static dtype from_args(object args) {
PyObject *ptr = nullptr;
if (!detail::npy_api::get().PyArray_DescrConverter_(args.ptr(), &ptr) || !ptr)
if (!detail::npy_api::get().PyArray_DescrConverter_(args.release().ptr(), &ptr) || !ptr)
throw error_already_set();
return reinterpret_steal<dtype>(ptr);
}
@@ -855,14 +861,14 @@ public:
// Reference to element at a given index
template<typename... Ix> const T& at(Ix... index) const {
if ((ssize_t) sizeof...(index) != ndim())
if (sizeof...(index) != ndim())
fail_dim_check(sizeof...(index), "index dimension mismatch");
return *(static_cast<const T*>(array::data()) + byte_offset(ssize_t(index)...) / itemsize());
}
// Mutable reference to element at a given index
template<typename... Ix> T& mutable_at(Ix... index) {
if ((ssize_t) sizeof...(index) != ndim())
if (sizeof...(index) != ndim())
fail_dim_check(sizeof...(index), "index dimension mismatch");
return *(static_cast<T*>(array::mutable_data()) + byte_offset(ssize_t(index)...) / itemsize());
}
@@ -942,8 +948,8 @@ template <typename T>
struct format_descriptor<T, detail::enable_if_t<detail::array_info<T>::is_array>> {
static std::string format() {
using namespace detail;
static constexpr auto extents = _("(") + array_info<T>::extents + _(")");
return extents.text + format_descriptor<remove_all_extents_t<T>>::format();
PYBIND11_DESCR extents = _("(") + array_info<T>::extents() + _(")");
return extents.text() + format_descriptor<remove_all_extents_t<T>>::format();
}
};
@@ -962,7 +968,7 @@ struct pyobject_caster<array_t<T, ExtraFlags>> {
static handle cast(const handle &src, return_value_policy /* policy */, handle /* parent */) {
return src.inc_ref();
}
PYBIND11_TYPE_CASTER(type, handle_type_name<type>::name);
PYBIND11_TYPE_CASTER(type, handle_type_name<type>::name());
};
template <typename T>
@@ -972,34 +978,7 @@ struct compare_buffer_info<T, detail::enable_if_t<detail::is_pod_struct<T>::valu
}
};
template <typename T, typename = void>
struct npy_format_descriptor_name;
template <typename T>
struct npy_format_descriptor_name<T, enable_if_t<std::is_integral<T>::value>> {
static constexpr auto name = _<std::is_same<T, bool>::value>(
_("bool"), _<std::is_signed<T>::value>("int", "uint") + _<sizeof(T)*8>()
);
};
template <typename T>
struct npy_format_descriptor_name<T, enable_if_t<std::is_floating_point<T>::value>> {
static constexpr auto name = _<std::is_same<T, float>::value || std::is_same<T, double>::value>(
_("float") + _<sizeof(T)*8>(), _("longdouble")
);
};
template <typename T>
struct npy_format_descriptor_name<T, enable_if_t<is_complex<T>::value>> {
static constexpr auto name = _<std::is_same<typename T::value_type, float>::value
|| std::is_same<typename T::value_type, double>::value>(
_("complex") + _<sizeof(typename T::value_type)*16>(), _("longcomplex")
);
};
template <typename T>
struct npy_format_descriptor<T, enable_if_t<satisfies_any_of<T, std::is_arithmetic, is_complex>::value>>
: npy_format_descriptor_name<T> {
template <typename T> struct npy_format_descriptor<T, enable_if_t<satisfies_any_of<T, std::is_arithmetic, is_complex>::value>> {
private:
// NB: the order here must match the one in common.h
constexpr static const int values[15] = {
@@ -1018,10 +997,25 @@ public:
return reinterpret_borrow<pybind11::dtype>(ptr);
pybind11_fail("Unsupported buffer format!");
}
template <typename T2 = T, enable_if_t<std::is_integral<T2>::value, int> = 0>
static PYBIND11_DESCR name() {
return _<std::is_same<T, bool>::value>(_("bool"),
_<std::is_signed<T>::value>("int", "uint") + _<sizeof(T)*8>());
}
template <typename T2 = T, enable_if_t<std::is_floating_point<T2>::value, int> = 0>
static PYBIND11_DESCR name() {
return _<std::is_same<T, float>::value || std::is_same<T, double>::value>(
_("float") + _<sizeof(T)*8>(), _("longdouble"));
}
template <typename T2 = T, enable_if_t<is_complex<T2>::value, int> = 0>
static PYBIND11_DESCR name() {
return _<std::is_same<typename T2::value_type, float>::value || std::is_same<typename T2::value_type, double>::value>(
_("complex") + _<sizeof(typename T2::value_type)*16>(), _("longcomplex"));
}
};
#define PYBIND11_DECL_CHAR_FMT \
static constexpr auto name = _("S") + _<N>(); \
static PYBIND11_DESCR name() { return _("S") + _<N>(); } \
static pybind11::dtype dtype() { return pybind11::dtype(std::string("S") + std::to_string(N)); }
template <size_t N> struct npy_format_descriptor<char[N]> { PYBIND11_DECL_CHAR_FMT };
template <size_t N> struct npy_format_descriptor<std::array<char, N>> { PYBIND11_DECL_CHAR_FMT };
@@ -1033,7 +1027,7 @@ private:
public:
static_assert(!array_info<T>::is_empty, "Zero-sized arrays are not supported");
static constexpr auto name = _("(") + array_info<T>::extents + _(")") + base_descr::name;
static PYBIND11_DESCR name() { return _("(") + array_info<T>::extents() + _(")") + base_descr::name(); }
static pybind11::dtype dtype() {
list shape;
array_info<T>::append_extents(shape);
@@ -1045,7 +1039,7 @@ template<typename T> struct npy_format_descriptor<T, enable_if_t<std::is_enum<T>
private:
using base_descr = npy_format_descriptor<typename std::underlying_type<T>::type>;
public:
static constexpr auto name = base_descr::name;
static PYBIND11_DESCR name() { return base_descr::name(); }
static pybind11::dtype dtype() { return base_descr::dtype(); }
};
@@ -1058,7 +1052,7 @@ struct field_descriptor {
};
inline PYBIND11_NOINLINE void register_structured_dtype(
any_container<field_descriptor> fields,
const std::initializer_list<field_descriptor>& fields,
const std::type_info& tinfo, ssize_t itemsize,
bool (*direct_converter)(PyObject *, void *&)) {
@@ -1067,7 +1061,7 @@ inline PYBIND11_NOINLINE void register_structured_dtype(
pybind11_fail("NumPy: dtype is already registered");
list names, formats, offsets;
for (auto field : *fields) {
for (auto field : fields) {
if (!field.descr)
pybind11_fail(std::string("NumPy: unsupported field dtype: `") +
field.name + "` @ " + tinfo.name());
@@ -1084,7 +1078,7 @@ inline PYBIND11_NOINLINE void register_structured_dtype(
// - https://github.com/numpy/numpy/pull/7798
// Because of this, we won't use numpy's logic to generate buffer format
// strings and will just do it ourselves.
std::vector<field_descriptor> ordered_fields(std::move(fields));
std::vector<field_descriptor> ordered_fields(fields);
std::sort(ordered_fields.begin(), ordered_fields.end(),
[](const field_descriptor &a, const field_descriptor &b) { return a.offset < b.offset; });
ssize_t offset = 0;
@@ -1120,7 +1114,7 @@ inline PYBIND11_NOINLINE void register_structured_dtype(
template <typename T, typename SFINAE> struct npy_format_descriptor {
static_assert(is_pod_struct<T>::value, "Attempt to use a non-POD or unimplemented POD type as a numpy dtype");
static constexpr auto name = make_caster<T>::name;
static PYBIND11_DESCR name() { return make_caster<T>::name(); }
static pybind11::dtype dtype() {
return reinterpret_borrow<pybind11::dtype>(dtype_ptr());
@@ -1131,8 +1125,8 @@ template <typename T, typename SFINAE> struct npy_format_descriptor {
return format_str;
}
static void register_dtype(any_container<field_descriptor> fields) {
register_structured_dtype(std::move(fields), typeid(typename std::remove_cv<T>::type),
static void register_dtype(const std::initializer_list<field_descriptor>& fields) {
register_structured_dtype(fields, typeid(typename std::remove_cv<T>::type),
sizeof(T), &direct_converter);
}
@@ -1205,8 +1199,7 @@ private:
#define PYBIND11_NUMPY_DTYPE(Type, ...) \
::pybind11::detail::npy_format_descriptor<Type>::register_dtype \
(::std::vector<::pybind11::detail::field_descriptor> \
{PYBIND11_MAP_LIST (PYBIND11_FIELD_DESCRIPTOR, Type, __VA_ARGS__)})
({PYBIND11_MAP_LIST (PYBIND11_FIELD_DESCRIPTOR, Type, __VA_ARGS__)})
#ifdef _MSC_VER
#define PYBIND11_MAP2_LIST_NEXT1(test, next) \
@@ -1227,8 +1220,7 @@ private:
#define PYBIND11_NUMPY_DTYPE_EX(Type, ...) \
::pybind11::detail::npy_format_descriptor<Type>::register_dtype \
(::std::vector<::pybind11::detail::field_descriptor> \
{PYBIND11_MAP2_LIST (PYBIND11_FIELD_DESCRIPTOR_EX, Type, __VA_ARGS__)})
({PYBIND11_MAP2_LIST (PYBIND11_FIELD_DESCRIPTOR_EX, Type, __VA_ARGS__)})
#endif // __CLION_IDE__
@@ -1466,10 +1458,7 @@ public:
private:
remove_reference_t<Func> f;
// Internal compiler error in MSVC 19.16.27025.1 (Visual Studio 2017 15.9.4), when compiling with "/permissive-" flag
// when arg_call_types is manually inlined.
using arg_call_types = std::tuple<typename vectorize_arg<Args>::call_type...>;
template <size_t Index> using param_n_t = typename std::tuple_element<Index, arg_call_types>::type;
template <size_t Index> using param_n_t = typename pack_element<Index, typename vectorize_arg<Args>::call_type...>::type;
// Runs a vectorized function given arguments tuple and three index sequences:
// - Index is the full set of 0 ... (N-1) argument indices;
@@ -1509,7 +1498,7 @@ private:
if (trivial == broadcast_trivial::f_trivial) result = array_t<Return, array::f_style>(shape);
else result = array_t<Return>(shape);
if (size == 0) return std::move(result);
if (size == 0) return result;
/* Call the function */
if (trivial == broadcast_trivial::non_trivial)
@@ -1517,7 +1506,7 @@ private:
else
apply_trivial(buffers, params, result.mutable_data(), size, i_seq, vi_seq, bi_seq);
return std::move(result);
return result;
}
template <size_t... Index, size_t... VIndex, size_t... BIndex>
@@ -1570,7 +1559,9 @@ vectorize_extractor(const Func &f, Return (*) (Args ...)) {
}
template <typename T, int Flags> struct handle_type_name<array_t<T, Flags>> {
static constexpr auto name = _("numpy.ndarray[") + npy_format_descriptor<T>::name + _("]");
static PYBIND11_DESCR name() {
return _("numpy.ndarray[") + npy_format_descriptor<T>::name() + _("]");
}
};
NAMESPACE_END(detail)

View File

@@ -10,17 +10,7 @@
#pragma once
#if defined(__INTEL_COMPILER)
# pragma warning push
# pragma warning disable 68 // integer conversion resulted in a change of sign
# pragma warning disable 186 // pointless comparison of unsigned integer with zero
# pragma warning disable 878 // incompatible exception specifications
# pragma warning disable 1334 // the "template" keyword used for syntactic disambiguation may only be used within a template
# pragma warning disable 1682 // implicit conversion of a 64-bit integral type to a smaller integral type (potential portability problem)
# pragma warning disable 1786 // function "strdup" was declared deprecated
# pragma warning disable 1875 // offsetof applied to non-POD (Plain Old Data) types is nonstandard
# pragma warning disable 2196 // warning #2196: routine is both "inline" and "noinline"
#elif defined(_MSC_VER)
#if defined(_MSC_VER)
# pragma warning(push)
# pragma warning(disable: 4100) // warning C4100: Unreferenced formal parameter
# pragma warning(disable: 4127) // warning C4127: Conditional expression is constant
@@ -29,6 +19,15 @@
# pragma warning(disable: 4996) // warning C4996: The POSIX name for this item is deprecated. Instead, use the ISO C and C++ conformant name
# pragma warning(disable: 4702) // warning C4702: unreachable code
# pragma warning(disable: 4522) // warning C4522: multiple assignment operators specified
#elif defined(__INTEL_COMPILER)
# pragma warning(push)
# pragma warning(disable: 68) // integer conversion resulted in a change of sign
# pragma warning(disable: 186) // pointless comparison of unsigned integer with zero
# pragma warning(disable: 878) // incompatible exception specifications
# pragma warning(disable: 1334) // the "template" keyword used for syntactic disambiguation may only be used within a template
# pragma warning(disable: 1682) // implicit conversion of a 64-bit integral type to a smaller integral type (potential portability problem)
# pragma warning(disable: 1875) // offsetof applied to non-POD (Plain Old Data) types is nonstandard
# pragma warning(disable: 2196) // warning #2196: routine is both "inline" and "noinline"
#elif defined(__GNUG__) && !defined(__clang__)
# pragma GCC diagnostic push
# pragma GCC diagnostic ignored "-Wunused-but-set-parameter"
@@ -41,11 +40,6 @@
# endif
#endif
#if defined(__GNUG__) && !defined(__clang__)
#include <cxxabi.h>
#endif
#include "attr.h"
#include "options.h"
#include "detail/class.h"
@@ -57,7 +51,6 @@ NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
class cpp_function : public function {
public:
cpp_function() { }
cpp_function(std::nullptr_t) { }
/// Construct a cpp_function from a vanilla function pointer
template <typename Return, typename... Args, typename... Extra>
@@ -100,6 +93,7 @@ protected:
template <typename Func, typename Return, typename... Args, typename... Extra>
void initialize(Func &&f, Return (*)(Args...), const Extra&... extra) {
using namespace detail;
struct capture { remove_reference_t<Func> f; };
/* Store the function including any extra state it might have (e.g. a lambda capture object) */
@@ -170,11 +164,10 @@ protected:
process_attributes<Extra...>::init(extra..., rec);
/* Generate a readable signature describing the function's arguments and return value types */
static constexpr auto signature = _("(") + cast_in::arg_names + _(") -> ") + cast_out::name;
PYBIND11_DESCR_CONSTEXPR auto types = decltype(signature)::types();
PYBIND11_DESCR signature = _("(") + cast_in::arg_names() + _(") -> ") + cast_out::name();
/* Register the function with Python from generic (non-templated) code */
initialize_generic(rec, signature.text, types.data(), sizeof...(Args));
initialize_generic(rec, signature.text(), signature.types(), sizeof...(Args));
if (cast_in::has_args) rec->has_args = true;
if (cast_in::has_kwargs) rec->has_kwargs = true;
@@ -224,30 +217,34 @@ protected:
/* Generate a proper function signature */
std::string signature;
size_t type_index = 0, arg_index = 0;
for (auto *pc = text; *pc != '\0'; ++pc) {
const auto c = *pc;
size_t type_depth = 0, char_index = 0, type_index = 0, arg_index = 0;
while (true) {
char c = text[char_index++];
if (c == '\0')
break;
if (c == '{') {
// Write arg name for everything except *args and **kwargs.
if (*(pc + 1) == '*')
continue;
if (arg_index < rec->args.size() && rec->args[arg_index].name) {
signature += rec->args[arg_index].name;
} else if (arg_index == 0 && rec->is_method) {
signature += "self";
} else {
signature += "arg" + std::to_string(arg_index - (rec->is_method ? 1 : 0));
// Write arg name for everything except *args, **kwargs and return type.
if (type_depth == 0 && text[char_index] != '*' && arg_index < args) {
if (!rec->args.empty() && rec->args[arg_index].name) {
signature += rec->args[arg_index].name;
} else if (arg_index == 0 && rec->is_method) {
signature += "self";
} else {
signature += "arg" + std::to_string(arg_index - (rec->is_method ? 1 : 0));
}
signature += ": ";
}
signature += ": ";
++type_depth;
} else if (c == '}') {
// Write default value if available.
if (arg_index < rec->args.size() && rec->args[arg_index].descr) {
signature += " = ";
signature += rec->args[arg_index].descr;
--type_depth;
if (type_depth == 0) {
if (arg_index < rec->args.size() && rec->args[arg_index].descr) {
signature += "=";
signature += rec->args[arg_index].descr;
}
arg_index++;
}
arg_index++;
} else if (c == '%') {
const std::type_info *t = types[type_index++];
if (!t)
@@ -272,9 +269,14 @@ protected:
signature += c;
}
}
if (arg_index != args || types[type_index] != nullptr)
if (type_depth != 0 || types[type_index] != nullptr)
pybind11_fail("Internal error while parsing type signature (2)");
#if !defined(PYBIND11_CONSTEXPR_DESCR)
delete[] types;
delete[] text;
#endif
#if PY_MAJOR_VERSION < 3
if (strcmp(rec->name, "__next__") == 0) {
std::free(rec->name);
@@ -426,8 +428,8 @@ protected:
using namespace detail;
/* Iterator over the list of potentially admissible overloads */
const function_record *overloads = (function_record *) PyCapsule_GetPointer(self, nullptr),
*it = overloads;
function_record *overloads = (function_record *) PyCapsule_GetPointer(self, nullptr),
*it = overloads;
/* Need to know how many arguments + keyword arguments there are to pick the right overload */
const size_t n_args_in = (size_t) PyTuple_GET_SIZE(args_in);
@@ -483,7 +485,7 @@ protected:
result other than PYBIND11_TRY_NEXT_OVERLOAD.
*/
const function_record &func = *it;
function_record &func = *it;
size_t pos_args = func.nargs; // Number of positional arguments that we need
if (func.has_args) --pos_args; // (but don't count py::args
if (func.has_kwargs) --pos_args; // or py::kwargs)
@@ -515,7 +517,7 @@ protected:
// 1. Copy any position arguments given.
bool bad_arg = false;
for (; args_copied < args_to_copy; ++args_copied) {
const argument_record *arg_rec = args_copied < func.args.size() ? &func.args[args_copied] : nullptr;
argument_record *arg_rec = args_copied < func.args.size() ? &func.args[args_copied] : nullptr;
if (kwargs_in && arg_rec && arg_rec->name && PyDict_GetItemString(kwargs_in, arg_rec->name)) {
bad_arg = true;
break;
@@ -656,22 +658,13 @@ protected:
result = PYBIND11_TRY_NEXT_OVERLOAD;
}
if (result.ptr() != PYBIND11_TRY_NEXT_OVERLOAD) {
// The error reporting logic below expects 'it' to be valid, as it would be
// if we'd encountered this failure in the first-pass loop.
if (!result)
it = &call.func;
if (result.ptr() != PYBIND11_TRY_NEXT_OVERLOAD)
break;
}
}
}
} catch (error_already_set &e) {
e.restore();
return nullptr;
#if defined(__GNUG__) && !defined(__clang__)
} catch ( abi::__forced_unwind& ) {
throw;
#endif
} catch (...) {
/* When an exception is caught, give each registered exception
translator a chance to translate it to a Python exception
@@ -718,7 +711,7 @@ protected:
" arguments. The following argument types are supported:\n";
int ctr = 0;
for (const function_record *it2 = overloads; it2 != nullptr; it2 = it2->next) {
for (function_record *it2 = overloads; it2 != nullptr; it2 = it2->next) {
msg += " "+ std::to_string(++ctr) + ". ";
bool wrote_sig = false;
@@ -906,7 +899,6 @@ protected:
tinfo->type = (PyTypeObject *) m_ptr;
tinfo->cpptype = rec.type;
tinfo->type_size = rec.type_size;
tinfo->type_align = rec.type_align;
tinfo->operator_new = rec.operator_new;
tinfo->holder_size_in_ptrs = size_in_ptrs(rec.holder_size);
tinfo->init_instance = rec.init_instance;
@@ -969,18 +961,18 @@ protected:
tinfo->get_buffer_data = get_buffer_data;
}
// rec_func must be set for either fget or fset.
void def_property_static_impl(const char *name,
handle fget, handle fset,
detail::function_record *rec_func) {
const auto is_static = rec_func && !(rec_func->is_method && rec_func->scope);
const auto has_doc = rec_func && rec_func->doc && pybind11::options::show_user_defined_docstrings();
detail::function_record *rec_fget) {
const auto is_static = !(rec_fget->is_method && rec_fget->scope);
const auto has_doc = rec_fget->doc && pybind11::options::show_user_defined_docstrings();
auto property = handle((PyObject *) (is_static ? get_internals().static_property_type
: &PyProperty_Type));
attr(name) = property(fget.ptr() ? fget : none(),
fset.ptr() ? fset : none(),
/*deleter*/none(),
pybind11::str(has_doc ? rec_func->doc : ""));
pybind11::str(has_doc ? rec_fget->doc : ""));
}
};
@@ -998,21 +990,11 @@ template <typename T> struct has_operator_delete_size<T, void_t<decltype(static_
: std::true_type { };
/// Call class-specific delete if it exists or global otherwise. Can also be an overload set.
template <typename T, enable_if_t<has_operator_delete<T>::value, int> = 0>
void call_operator_delete(T *p, size_t, size_t) { T::operator delete(p); }
void call_operator_delete(T *p, size_t) { T::operator delete(p); }
template <typename T, enable_if_t<!has_operator_delete<T>::value && has_operator_delete_size<T>::value, int> = 0>
void call_operator_delete(T *p, size_t s, size_t) { T::operator delete(p, s); }
void call_operator_delete(T *p, size_t s) { T::operator delete(p, s); }
inline void call_operator_delete(void *p, size_t s, size_t a) {
(void)s; (void)a;
#if defined(PYBIND11_CPP17)
if (a > __STDCPP_DEFAULT_NEW_ALIGNMENT__)
::operator delete(p, s, std::align_val_t(a));
else
::operator delete(p, s);
#else
::operator delete(p);
#endif
}
inline void call_operator_delete(void *p, size_t) { ::operator delete(p); }
NAMESPACE_END(detail)
@@ -1022,18 +1004,10 @@ template <typename /*Derived*/, typename F>
auto method_adaptor(F &&f) -> decltype(std::forward<F>(f)) { return std::forward<F>(f); }
template <typename Derived, typename Return, typename Class, typename... Args>
auto method_adaptor(Return (Class::*pmf)(Args...)) -> Return (Derived::*)(Args...) {
static_assert(detail::is_accessible_base_of<Class, Derived>::value,
"Cannot bind an inaccessible base class method; use a lambda definition instead");
return pmf;
}
auto method_adaptor(Return (Class::*pmf)(Args...)) -> Return (Derived::*)(Args...) { return pmf; }
template <typename Derived, typename Return, typename Class, typename... Args>
auto method_adaptor(Return (Class::*pmf)(Args...) const) -> Return (Derived::*)(Args...) const {
static_assert(detail::is_accessible_base_of<Class, Derived>::value,
"Cannot bind an inaccessible base class method; use a lambda definition instead");
return pmf;
}
auto method_adaptor(Return (Class::*pmf)(Args...) const) -> Return (Derived::*)(Args...) const { return pmf; }
template <typename type_, typename... options>
class class_ : public detail::generic_type {
@@ -1075,11 +1049,10 @@ public:
record.name = name;
record.type = &typeid(type);
record.type_size = sizeof(conditional_t<has_alias, type_alias, type>);
record.type_align = alignof(conditional_t<has_alias, type_alias, type>&);
record.holder_size = sizeof(holder_type);
record.init_instance = init_instance;
record.dealloc = dealloc;
record.default_holder = detail::is_instantiation<std::unique_ptr, holder_type>::value;
record.default_holder = std::is_same<holder_type, std::unique_ptr<type>>::value;
set_operator_new<type>(&record);
@@ -1121,7 +1094,7 @@ public:
"def_static(...) called with a non-static member function pointer");
cpp_function cf(std::forward<Func>(f), name(name_), scope(*this),
sibling(getattr(*this, name_, none())), extra...);
attr(cf.name()) = staticmethod(cf);
attr(cf.name()) = cf;
return *this;
}
@@ -1185,7 +1158,7 @@ public:
template <typename C, typename D, typename... Extra>
class_ &def_readwrite(const char *name, D C::*pm, const Extra&... extra) {
static_assert(std::is_same<C, type>::value || std::is_base_of<C, type>::value, "def_readwrite() requires a class member (or base class member)");
static_assert(std::is_base_of<C, type>::value, "def_readwrite() requires a class member (or base class member)");
cpp_function fget([pm](const type &c) -> const D &{ return c.*pm; }, is_method(*this)),
fset([pm](type &c, const D &value) { c.*pm = value; }, is_method(*this));
def_property(name, fget, fset, return_value_policy::reference_internal, extra...);
@@ -1194,7 +1167,7 @@ public:
template <typename C, typename D, typename... Extra>
class_ &def_readonly(const char *name, const D C::*pm, const Extra& ...extra) {
static_assert(std::is_same<C, type>::value || std::is_base_of<C, type>::value, "def_readonly() requires a class member (or base class member)");
static_assert(std::is_base_of<C, type>::value, "def_readonly() requires a class member (or base class member)");
cpp_function fget([pm](const type &c) -> const D &{ return c.*pm; }, is_method(*this));
def_property_readonly(name, fget, return_value_policy::reference_internal, extra...);
return *this;
@@ -1225,7 +1198,7 @@ public:
/// Uses cpp_function's return_value_policy by default
template <typename... Extra>
class_ &def_property_readonly(const char *name, const cpp_function &fget, const Extra& ...extra) {
return def_property(name, fget, nullptr, extra...);
return def_property(name, fget, cpp_function(), extra...);
}
/// Uses return_value_policy::reference by default
@@ -1237,7 +1210,7 @@ public:
/// Uses cpp_function's return_value_policy by default
template <typename... Extra>
class_ &def_property_readonly_static(const char *name, const cpp_function &fget, const Extra& ...extra) {
return def_property_static(name, fget, nullptr, extra...);
return def_property_static(name, fget, cpp_function(), extra...);
}
/// Uses return_value_policy::reference_internal by default
@@ -1266,28 +1239,22 @@ public:
/// Uses cpp_function's return_value_policy by default
template <typename... Extra>
class_ &def_property_static(const char *name, const cpp_function &fget, const cpp_function &fset, const Extra& ...extra) {
static_assert( 0 == detail::constexpr_sum(std::is_base_of<arg, Extra>::value...),
"Argument annotations are not allowed for properties");
auto rec_fget = get_function_record(fget), rec_fset = get_function_record(fset);
auto *rec_active = rec_fget;
if (rec_fget) {
char *doc_prev = rec_fget->doc; /* 'extra' field may include a property-specific documentation string */
detail::process_attributes<Extra...>::init(extra..., rec_fget);
if (rec_fget->doc && rec_fget->doc != doc_prev) {
free(doc_prev);
rec_fget->doc = strdup(rec_fget->doc);
}
char *doc_prev = rec_fget->doc; /* 'extra' field may include a property-specific documentation string */
detail::process_attributes<Extra...>::init(extra..., rec_fget);
if (rec_fget->doc && rec_fget->doc != doc_prev) {
free(doc_prev);
rec_fget->doc = strdup(rec_fget->doc);
}
if (rec_fset) {
char *doc_prev = rec_fset->doc;
doc_prev = rec_fset->doc;
detail::process_attributes<Extra...>::init(extra..., rec_fset);
if (rec_fset->doc && rec_fset->doc != doc_prev) {
free(doc_prev);
rec_fset->doc = strdup(rec_fset->doc);
}
if (! rec_active) rec_active = rec_fset;
}
def_property_static_impl(name, fget, fset, rec_active);
def_property_static_impl(name, fget, fset, rec_fget);
return *this;
}
@@ -1353,10 +1320,7 @@ private:
v_h.set_holder_constructed(false);
}
else {
detail::call_operator_delete(v_h.value_ptr<type>(),
v_h.type->type_size,
v_h.type->type_align
);
detail::call_operator_delete(v_h.value_ptr<type>(), v_h.type->type_size);
}
v_h.value_ptr() = nullptr;
}
@@ -1392,190 +1356,93 @@ detail::initimpl::pickle_factory<GetState, SetState> pickle(GetState &&g, SetSta
return {std::forward<GetState>(g), std::forward<SetState>(s)};
}
NAMESPACE_BEGIN(detail)
struct enum_base {
enum_base(handle base, handle parent) : m_base(base), m_parent(parent) { }
PYBIND11_NOINLINE void init(bool is_arithmetic, bool is_convertible) {
m_base.attr("__entries") = dict();
auto property = handle((PyObject *) &PyProperty_Type);
auto static_property = handle((PyObject *) get_internals().static_property_type);
m_base.attr("__repr__") = cpp_function(
[](handle arg) -> str {
handle type = arg.get_type();
object type_name = type.attr("__name__");
dict entries = type.attr("__entries");
for (const auto &kv : entries) {
object other = kv.second[int_(0)];
if (other.equal(arg))
return pybind11::str("{}.{}").format(type_name, kv.first);
}
return pybind11::str("{}.???").format(type_name);
}, is_method(m_base)
);
m_base.attr("name") = property(cpp_function(
[](handle arg) -> str {
dict entries = arg.get_type().attr("__entries");
for (const auto &kv : entries) {
if (handle(kv.second[int_(0)]).equal(arg))
return pybind11::str(kv.first);
}
return "???";
}, is_method(m_base)
));
m_base.attr("__doc__") = static_property(cpp_function(
[](handle arg) -> std::string {
std::string docstring;
dict entries = arg.attr("__entries");
if (((PyTypeObject *) arg.ptr())->tp_doc)
docstring += std::string(((PyTypeObject *) arg.ptr())->tp_doc) + "\n\n";
docstring += "Members:";
for (const auto &kv : entries) {
auto key = std::string(pybind11::str(kv.first));
auto comment = kv.second[int_(1)];
docstring += "\n\n " + key;
if (!comment.is_none())
docstring += " : " + (std::string) pybind11::str(comment);
}
return docstring;
}
), none(), none(), "");
m_base.attr("__members__") = static_property(cpp_function(
[](handle arg) -> dict {
dict entries = arg.attr("__entries"), m;
for (const auto &kv : entries)
m[kv.first] = kv.second[int_(0)];
return m;
}), none(), none(), ""
);
#define PYBIND11_ENUM_OP_STRICT(op, expr, strict_behavior) \
m_base.attr(op) = cpp_function( \
[](object a, object b) { \
if (!a.get_type().is(b.get_type())) \
strict_behavior; \
return expr; \
}, \
is_method(m_base))
#define PYBIND11_ENUM_OP_CONV(op, expr) \
m_base.attr(op) = cpp_function( \
[](object a_, object b_) { \
int_ a(a_), b(b_); \
return expr; \
}, \
is_method(m_base))
if (is_convertible) {
PYBIND11_ENUM_OP_CONV("__eq__", !b.is_none() && a.equal(b));
PYBIND11_ENUM_OP_CONV("__ne__", b.is_none() || !a.equal(b));
if (is_arithmetic) {
PYBIND11_ENUM_OP_CONV("__lt__", a < b);
PYBIND11_ENUM_OP_CONV("__gt__", a > b);
PYBIND11_ENUM_OP_CONV("__le__", a <= b);
PYBIND11_ENUM_OP_CONV("__ge__", a >= b);
PYBIND11_ENUM_OP_CONV("__and__", a & b);
PYBIND11_ENUM_OP_CONV("__rand__", a & b);
PYBIND11_ENUM_OP_CONV("__or__", a | b);
PYBIND11_ENUM_OP_CONV("__ror__", a | b);
PYBIND11_ENUM_OP_CONV("__xor__", a ^ b);
PYBIND11_ENUM_OP_CONV("__rxor__", a ^ b);
}
} else {
PYBIND11_ENUM_OP_STRICT("__eq__", int_(a).equal(int_(b)), return false);
PYBIND11_ENUM_OP_STRICT("__ne__", !int_(a).equal(int_(b)), return true);
if (is_arithmetic) {
#define PYBIND11_THROW throw type_error("Expected an enumeration of matching type!");
PYBIND11_ENUM_OP_STRICT("__lt__", int_(a) < int_(b), PYBIND11_THROW);
PYBIND11_ENUM_OP_STRICT("__gt__", int_(a) > int_(b), PYBIND11_THROW);
PYBIND11_ENUM_OP_STRICT("__le__", int_(a) <= int_(b), PYBIND11_THROW);
PYBIND11_ENUM_OP_STRICT("__ge__", int_(a) >= int_(b), PYBIND11_THROW);
#undef PYBIND11_THROW
}
}
#undef PYBIND11_ENUM_OP_CONV
#undef PYBIND11_ENUM_OP_STRICT
object getstate = cpp_function(
[](object arg) { return int_(arg); }, is_method(m_base));
m_base.attr("__getstate__") = getstate;
m_base.attr("__hash__") = getstate;
}
PYBIND11_NOINLINE void value(char const* name_, object value, const char *doc = nullptr) {
dict entries = m_base.attr("__entries");
str name(name_);
if (entries.contains(name)) {
std::string type_name = (std::string) str(m_base.attr("__name__"));
throw value_error(type_name + ": element \"" + std::string(name_) + "\" already exists!");
}
entries[name] = std::make_pair(value, doc);
m_base.attr(name) = value;
}
PYBIND11_NOINLINE void export_values() {
dict entries = m_base.attr("__entries");
for (const auto &kv : entries)
m_parent.attr(kv.first) = kv.second[int_(0)];
}
handle m_base;
handle m_parent;
};
NAMESPACE_END(detail)
/// Binds C++ enumerations and enumeration classes to Python
template <typename Type> class enum_ : public class_<Type> {
public:
using Base = class_<Type>;
using Base::def;
using Base::attr;
using Base::def_property_readonly;
using Base::def_property_readonly_static;
using class_<Type>::def;
using class_<Type>::def_property_readonly_static;
using Scalar = typename std::underlying_type<Type>::type;
template <typename... Extra>
enum_(const handle &scope, const char *name, const Extra&... extra)
: class_<Type>(scope, name, extra...), m_base(*this, scope) {
constexpr bool is_arithmetic = detail::any_of<std::is_same<arithmetic, Extra>...>::value;
constexpr bool is_convertible = std::is_convertible<Type, Scalar>::value;
m_base.init(is_arithmetic, is_convertible);
: class_<Type>(scope, name, extra...), m_entries(), m_parent(scope) {
constexpr bool is_arithmetic = detail::any_of<std::is_same<arithmetic, Extra>...>::value;
auto m_entries_ptr = m_entries.inc_ref().ptr();
def("__repr__", [name, m_entries_ptr](Type value) -> pybind11::str {
for (const auto &kv : reinterpret_borrow<dict>(m_entries_ptr)) {
if (pybind11::cast<Type>(kv.second) == value)
return pybind11::str("{}.{}").format(name, kv.first);
}
return pybind11::str("{}.???").format(name);
});
def_property_readonly_static("__members__", [m_entries_ptr](object /* self */) {
dict m;
for (const auto &kv : reinterpret_borrow<dict>(m_entries_ptr))
m[kv.first] = kv.second;
return m;
}, return_value_policy::copy);
def(init([](Scalar i) { return static_cast<Type>(i); }));
def("__int__", [](Type value) { return (Scalar) value; });
#if PY_MAJOR_VERSION < 3
def("__long__", [](Type value) { return (Scalar) value; });
#endif
cpp_function setstate(
[](Type &value, Scalar arg) { value = static_cast<Type>(arg); },
is_method(*this));
attr("__setstate__") = setstate;
def("__eq__", [](const Type &value, Type *value2) { return value2 && value == *value2; });
def("__ne__", [](const Type &value, Type *value2) { return !value2 || value != *value2; });
if (is_arithmetic) {
def("__lt__", [](const Type &value, Type *value2) { return value2 && value < *value2; });
def("__gt__", [](const Type &value, Type *value2) { return value2 && value > *value2; });
def("__le__", [](const Type &value, Type *value2) { return value2 && value <= *value2; });
def("__ge__", [](const Type &value, Type *value2) { return value2 && value >= *value2; });
}
if (std::is_convertible<Type, Scalar>::value) {
// Don't provide comparison with the underlying type if the enum isn't convertible,
// i.e. if Type is a scoped enum, mirroring the C++ behaviour. (NB: we explicitly
// convert Type to Scalar below anyway because this needs to compile).
def("__eq__", [](const Type &value, Scalar value2) { return (Scalar) value == value2; });
def("__ne__", [](const Type &value, Scalar value2) { return (Scalar) value != value2; });
if (is_arithmetic) {
def("__lt__", [](const Type &value, Scalar value2) { return (Scalar) value < value2; });
def("__gt__", [](const Type &value, Scalar value2) { return (Scalar) value > value2; });
def("__le__", [](const Type &value, Scalar value2) { return (Scalar) value <= value2; });
def("__ge__", [](const Type &value, Scalar value2) { return (Scalar) value >= value2; });
def("__invert__", [](const Type &value) { return ~((Scalar) value); });
def("__and__", [](const Type &value, Scalar value2) { return (Scalar) value & value2; });
def("__or__", [](const Type &value, Scalar value2) { return (Scalar) value | value2; });
def("__xor__", [](const Type &value, Scalar value2) { return (Scalar) value ^ value2; });
def("__rand__", [](const Type &value, Scalar value2) { return (Scalar) value & value2; });
def("__ror__", [](const Type &value, Scalar value2) { return (Scalar) value | value2; });
def("__rxor__", [](const Type &value, Scalar value2) { return (Scalar) value ^ value2; });
def("__and__", [](const Type &value, const Type &value2) { return (Scalar) value & (Scalar) value2; });
def("__or__", [](const Type &value, const Type &value2) { return (Scalar) value | (Scalar) value2; });
def("__xor__", [](const Type &value, const Type &value2) { return (Scalar) value ^ (Scalar) value2; });
}
}
def("__hash__", [](const Type &value) { return (Scalar) value; });
// Pickling and unpickling -- needed for use with the 'multiprocessing' module
def(pickle([](const Type &value) { return pybind11::make_tuple((Scalar) value); },
[](tuple t) { return static_cast<Type>(t[0].cast<Scalar>()); }));
}
/// Export enumeration entries into the parent scope
enum_& export_values() {
m_base.export_values();
for (const auto &kv : m_entries)
m_parent.attr(kv.first) = kv.second;
return *this;
}
/// Add an enumeration entry
enum_& value(char const* name, Type value, const char *doc = nullptr) {
m_base.value(name, pybind11::cast(value, return_value_policy::copy), doc);
enum_& value(char const* name, Type value) {
auto v = pybind11::cast(value, return_value_policy::copy);
this->attr(name) = v;
m_entries[pybind11::str(name)] = v;
return *this;
}
private:
detail::enum_base m_base;
dict m_entries;
handle m_parent;
};
NAMESPACE_BEGIN(detail)
@@ -1882,15 +1749,6 @@ public:
auto const &internals = detail::get_internals();
tstate = (PyThreadState *) PYBIND11_TLS_GET_VALUE(internals.tstate);
if (!tstate) {
/* Check if the GIL was acquired using the PyGILState_* API instead (e.g. if
calling from a Python thread). Since we use a different key, this ensures
we don't create a new thread state and deadlock in PyEval_AcquireThread
below. Note we don't save this state with internals.tstate, since we don't
create it we would fail to clear it (its reference count should be > 0). */
tstate = PyGILState_GetThisThreadState();
}
if (!tstate) {
tstate = PyThreadState_New(internals.istate);
#if !defined(NDEBUG)
@@ -1998,12 +1856,12 @@ class gil_scoped_release { };
#endif
error_already_set::~error_already_set() {
if (m_type) {
if (type) {
error_scope scope;
gil_scoped_acquire gil;
m_type.release().dec_ref();
m_value.release().dec_ref();
m_trace.release().dec_ref();
type.release().dec_ref();
value.release().dec_ref();
trace.release().dec_ref();
}
}
@@ -2064,14 +1922,6 @@ inline function get_type_overload(const void *this_ptr, const detail::type_info
return overload;
}
/** \rst
Try to retrieve a python method by the provided name from the instance pointed to by the this_ptr.
:this_ptr: The pointer to the object the overload should be retrieved for. This should be the first
non-trampoline class encountered in the inheritance chain.
:name: The name of the overloaded Python method to retrieve.
:return: The Python method by this name from the object or an empty function wrapper.
\endrst */
template <class T> function get_overload(const T *this_ptr, const char *name) {
auto tinfo = detail::get_type_info(typeid(T));
return tinfo ? get_type_overload(this_ptr, tinfo, name) : function();
@@ -2090,73 +1940,26 @@ template <class T> function get_overload(const T *this_ptr, const char *name) {
} \
}
/** \rst
Macro to populate the virtual method in the trampoline class. This macro tries to look up a method named 'fn'
from the Python side, deals with the :ref:`gil` and necessary argument conversions to call this method and return
the appropriate type. See :ref:`overriding_virtuals` for more information. This macro should be used when the method
name in C is not the same as the method name in Python. For example with `__str__`.
.. code-block:: cpp
std::string toString() override {
PYBIND11_OVERLOAD_NAME(
std::string, // Return type (ret_type)
Animal, // Parent class (cname)
toString, // Name of function in C++ (name)
"__str__", // Name of method in Python (fn)
);
}
\endrst */
#define PYBIND11_OVERLOAD_NAME(ret_type, cname, name, fn, ...) \
PYBIND11_OVERLOAD_INT(PYBIND11_TYPE(ret_type), PYBIND11_TYPE(cname), name, __VA_ARGS__) \
PYBIND11_OVERLOAD_INT(ret_type, cname, name, __VA_ARGS__) \
return cname::fn(__VA_ARGS__)
/** \rst
Macro for pure virtual functions, this function is identical to :c:macro:`PYBIND11_OVERLOAD_NAME`, except that it
throws if no overload can be found.
\endrst */
#define PYBIND11_OVERLOAD_PURE_NAME(ret_type, cname, name, fn, ...) \
PYBIND11_OVERLOAD_INT(PYBIND11_TYPE(ret_type), PYBIND11_TYPE(cname), name, __VA_ARGS__) \
pybind11::pybind11_fail("Tried to call pure virtual function \"" PYBIND11_STRINGIFY(cname) "::" name "\"");
PYBIND11_OVERLOAD_INT(ret_type, cname, name, __VA_ARGS__) \
pybind11::pybind11_fail("Tried to call pure virtual function \"" #cname "::" name "\"");
/** \rst
Macro to populate the virtual method in the trampoline class. This macro tries to look up the method
from the Python side, deals with the :ref:`gil` and necessary argument conversions to call this method and return
the appropriate type. This macro should be used if the method name in C and in Python are identical.
See :ref:`overriding_virtuals` for more information.
.. code-block:: cpp
class PyAnimal : public Animal {
public:
// Inherit the constructors
using Animal::Animal;
// Trampoline (need one for each virtual function)
std::string go(int n_times) override {
PYBIND11_OVERLOAD_PURE(
std::string, // Return type (ret_type)
Animal, // Parent class (cname)
go, // Name of function in C++ (must match Python name) (fn)
n_times // Argument(s) (...)
);
}
};
\endrst */
#define PYBIND11_OVERLOAD(ret_type, cname, fn, ...) \
PYBIND11_OVERLOAD_NAME(PYBIND11_TYPE(ret_type), PYBIND11_TYPE(cname), #fn, fn, __VA_ARGS__)
PYBIND11_OVERLOAD_NAME(ret_type, cname, #fn, fn, __VA_ARGS__)
/** \rst
Macro for pure virtual functions, this function is identical to :c:macro:`PYBIND11_OVERLOAD`, except that it throws
if no overload can be found.
\endrst */
#define PYBIND11_OVERLOAD_PURE(ret_type, cname, fn, ...) \
PYBIND11_OVERLOAD_PURE_NAME(PYBIND11_TYPE(ret_type), PYBIND11_TYPE(cname), #fn, fn, __VA_ARGS__)
PYBIND11_OVERLOAD_PURE_NAME(ret_type, cname, #fn, fn, __VA_ARGS__)
NAMESPACE_END(PYBIND11_NAMESPACE)
#if defined(_MSC_VER) && !defined(__INTEL_COMPILER)
#if defined(_MSC_VER)
# pragma warning(pop)
#elif defined(__INTEL_COMPILER)
/* Leave ignored warnings on */
#elif defined(__GNUG__) && !defined(__clang__)
# pragma GCC diagnostic pop
#endif

View File

@@ -114,35 +114,6 @@ public:
bool is(object_api const& other) const { return derived().ptr() == other.derived().ptr(); }
/// Equivalent to ``obj is None`` in Python.
bool is_none() const { return derived().ptr() == Py_None; }
/// Equivalent to obj == other in Python
bool equal(object_api const &other) const { return rich_compare(other, Py_EQ); }
bool not_equal(object_api const &other) const { return rich_compare(other, Py_NE); }
bool operator<(object_api const &other) const { return rich_compare(other, Py_LT); }
bool operator<=(object_api const &other) const { return rich_compare(other, Py_LE); }
bool operator>(object_api const &other) const { return rich_compare(other, Py_GT); }
bool operator>=(object_api const &other) const { return rich_compare(other, Py_GE); }
object operator-() const;
object operator~() const;
object operator+(object_api const &other) const;
object operator+=(object_api const &other) const;
object operator-(object_api const &other) const;
object operator-=(object_api const &other) const;
object operator*(object_api const &other) const;
object operator*=(object_api const &other) const;
object operator/(object_api const &other) const;
object operator/=(object_api const &other) const;
object operator|(object_api const &other) const;
object operator|=(object_api const &other) const;
object operator&(object_api const &other) const;
object operator&=(object_api const &other) const;
object operator^(object_api const &other) const;
object operator^=(object_api const &other) const;
object operator<<(object_api const &other) const;
object operator<<=(object_api const &other) const;
object operator>>(object_api const &other) const;
object operator>>=(object_api const &other) const;
PYBIND11_DEPRECATED("Use py::str(obj) instead")
pybind11::str str() const;
@@ -153,9 +124,6 @@ public:
int ref_count() const { return static_cast<int>(Py_REFCNT(derived().ptr())); }
/// Return a handle to the Python type object underlying the instance
handle get_type() const;
private:
bool rich_compare(object_api const &other, int value) const;
};
NAMESPACE_END(detail)
@@ -324,18 +292,15 @@ public:
/// Constructs a new exception from the current Python error indicator, if any. The current
/// Python error indicator will be cleared.
error_already_set() : std::runtime_error(detail::error_string()) {
PyErr_Fetch(&m_type.ptr(), &m_value.ptr(), &m_trace.ptr());
PyErr_Fetch(&type.ptr(), &value.ptr(), &trace.ptr());
}
error_already_set(const error_already_set &) = default;
error_already_set(error_already_set &&) = default;
inline ~error_already_set();
/// Give the currently-held error back to Python, if any. If there is currently a Python error
/// already set it is cleared first. After this call, the current object no longer stores the
/// error variables (but the `.what()` string is still available).
void restore() { PyErr_Restore(m_type.release().ptr(), m_value.release().ptr(), m_trace.release().ptr()); }
void restore() { PyErr_Restore(type.release().ptr(), value.release().ptr(), trace.release().ptr()); }
// Does nothing; provided for backwards compatibility.
PYBIND11_DEPRECATED("Use of error_already_set.clear() is deprecated")
@@ -344,14 +309,10 @@ public:
/// Check if the currently trapped error type matches the given Python exception class (or a
/// subclass thereof). May also be passed a tuple to search for any exception class matches in
/// the given tuple.
bool matches(handle exc) const { return PyErr_GivenExceptionMatches(m_type.ptr(), exc.ptr()); }
const object& type() const { return m_type; }
const object& value() const { return m_value; }
const object& trace() const { return m_trace; }
bool matches(handle ex) const { return PyErr_GivenExceptionMatches(ex.ptr(), type.ptr()); }
private:
object m_type, m_value, m_trace;
object type, value, trace;
};
/** \defgroup python_builtins _
@@ -392,14 +353,6 @@ inline bool hasattr(handle obj, const char *name) {
return PyObject_HasAttrString(obj.ptr(), name) == 1;
}
inline void delattr(handle obj, handle name) {
if (PyObject_DelAttr(obj.ptr(), name.ptr()) != 0) { throw error_already_set(); }
}
inline void delattr(handle obj, const char *name) {
if (PyObject_DelAttrString(obj.ptr(), name) != 0) { throw error_already_set(); }
}
inline object getattr(handle obj, handle name) {
PyObject *result = PyObject_GetAttr(obj.ptr(), name.ptr());
if (!result) { throw error_already_set(); }
@@ -471,6 +424,7 @@ object object_or_cast(T &&o);
// Match a PyObject*, which we want to convert directly to handle via its converting constructor
inline handle object_or_cast(PyObject *ptr) { return ptr; }
template <typename Policy>
class accessor : public object_api<accessor<Policy>> {
using key_type = typename Policy::key_type;
@@ -708,7 +662,7 @@ protected:
private:
handle obj;
PyObject *key = nullptr, *value = nullptr;
PyObject *key, *value;
ssize_t pos = -1;
};
NAMESPACE_END(iterator_policies)
@@ -736,14 +690,9 @@ inline bool PyIterable_Check(PyObject *obj) {
}
inline bool PyNone_Check(PyObject *o) { return o == Py_None; }
#if PY_MAJOR_VERSION >= 3
inline bool PyEllipsis_Check(PyObject *o) { return o == Py_Ellipsis; }
#endif
inline bool PyUnicode_Check_Permissive(PyObject *o) { return PyUnicode_Check(o) || PYBIND11_BYTES_CHECK(o); }
inline bool PyStaticMethod_Check(PyObject *o) { return o->ob_type == &PyStaticMethod_Type; }
class kwargs_proxy : public handle {
public:
explicit kwargs_proxy(handle h) : handle(h) { }
@@ -1015,14 +964,6 @@ public:
none() : object(Py_None, borrowed_t{}) { }
};
#if PY_MAJOR_VERSION >= 3
class ellipsis : public object {
public:
PYBIND11_OBJECT(ellipsis, object, detail::PyEllipsis_Check)
ellipsis() : object(Py_Ellipsis, borrowed_t{}) { }
};
#endif
class bool_ : public object {
public:
PYBIND11_OBJECT_CVT(bool_, object, PyBool_Check, raw_bool)
@@ -1133,13 +1074,6 @@ public:
(ssize_t *) stop, (ssize_t *) step,
(ssize_t *) slicelength) == 0;
}
bool compute(ssize_t length, ssize_t *start, ssize_t *stop, ssize_t *step,
ssize_t *slicelength) const {
return PySlice_GetIndicesEx((PYBIND11_SLICE_OBJECT *) m_ptr,
length, start,
stop, step,
slicelength) == 0;
}
};
class capsule : public object {
@@ -1203,7 +1137,6 @@ public:
}
size_t size() const { return (size_t) PyTuple_Size(m_ptr); }
detail::tuple_accessor operator[](size_t index) const { return {*this, index}; }
detail::item_accessor operator[](handle h) const { return object::operator[](h); }
detail::tuple_iterator begin() const { return {*this, 0}; }
detail::tuple_iterator end() const { return {*this, PyTuple_GET_SIZE(m_ptr)}; }
};
@@ -1241,7 +1174,6 @@ public:
PYBIND11_OBJECT_DEFAULT(sequence, object, PySequence_Check)
size_t size() const { return (size_t) PySequence_Size(m_ptr); }
detail::sequence_accessor operator[](size_t index) const { return {*this, index}; }
detail::item_accessor operator[](handle h) const { return object::operator[](h); }
detail::sequence_iterator begin() const { return {*this, 0}; }
detail::sequence_iterator end() const { return {*this, PySequence_Size(m_ptr)}; }
};
@@ -1254,7 +1186,6 @@ public:
}
size_t size() const { return (size_t) PyList_Size(m_ptr); }
detail::list_accessor operator[](size_t index) const { return {*this, index}; }
detail::item_accessor operator[](handle h) const { return object::operator[](h); }
detail::list_iterator begin() const { return {*this, 0}; }
detail::list_iterator end() const { return {*this, PyList_GET_SIZE(m_ptr)}; }
template <typename T> void append(T &&val) const {
@@ -1290,11 +1221,6 @@ public:
bool is_cpp_function() const { return (bool) cpp_function(); }
};
class staticmethod : public object {
public:
PYBIND11_OBJECT_CVT(staticmethod, object, detail::PyStaticMethod_Check, PyStaticMethod_New)
};
class buffer : public object {
public:
PYBIND11_OBJECT_DEFAULT(buffer, object, PyObject_CheckBuffer)
@@ -1353,21 +1279,6 @@ inline size_t len(handle h) {
return (size_t) result;
}
inline size_t len_hint(handle h) {
#if PY_VERSION_HEX >= 0x03040000
ssize_t result = PyObject_LengthHint(h.ptr(), 0);
#else
ssize_t result = PyObject_Length(h.ptr());
#endif
if (result < 0) {
// Sometimes a length can't be determined at all (eg generators)
// In which case simply return 0
PyErr_Clear();
return 0;
}
return (size_t) result;
}
inline str repr(handle h) {
PyObject *str_value = PyObject_Repr(h.ptr());
if (!str_value) throw error_already_set();
@@ -1417,55 +1328,5 @@ str_attr_accessor object_api<D>::doc() const { return attr("__doc__"); }
template <typename D>
handle object_api<D>::get_type() const { return (PyObject *) Py_TYPE(derived().ptr()); }
template <typename D>
bool object_api<D>::rich_compare(object_api const &other, int value) const {
int rv = PyObject_RichCompareBool(derived().ptr(), other.derived().ptr(), value);
if (rv == -1)
throw error_already_set();
return rv == 1;
}
#define PYBIND11_MATH_OPERATOR_UNARY(op, fn) \
template <typename D> object object_api<D>::op() const { \
object result = reinterpret_steal<object>(fn(derived().ptr())); \
if (!result.ptr()) \
throw error_already_set(); \
return result; \
}
#define PYBIND11_MATH_OPERATOR_BINARY(op, fn) \
template <typename D> \
object object_api<D>::op(object_api const &other) const { \
object result = reinterpret_steal<object>( \
fn(derived().ptr(), other.derived().ptr())); \
if (!result.ptr()) \
throw error_already_set(); \
return result; \
}
PYBIND11_MATH_OPERATOR_UNARY (operator~, PyNumber_Invert)
PYBIND11_MATH_OPERATOR_UNARY (operator-, PyNumber_Negative)
PYBIND11_MATH_OPERATOR_BINARY(operator+, PyNumber_Add)
PYBIND11_MATH_OPERATOR_BINARY(operator+=, PyNumber_InPlaceAdd)
PYBIND11_MATH_OPERATOR_BINARY(operator-, PyNumber_Subtract)
PYBIND11_MATH_OPERATOR_BINARY(operator-=, PyNumber_InPlaceSubtract)
PYBIND11_MATH_OPERATOR_BINARY(operator*, PyNumber_Multiply)
PYBIND11_MATH_OPERATOR_BINARY(operator*=, PyNumber_InPlaceMultiply)
PYBIND11_MATH_OPERATOR_BINARY(operator/, PyNumber_TrueDivide)
PYBIND11_MATH_OPERATOR_BINARY(operator/=, PyNumber_InPlaceTrueDivide)
PYBIND11_MATH_OPERATOR_BINARY(operator|, PyNumber_Or)
PYBIND11_MATH_OPERATOR_BINARY(operator|=, PyNumber_InPlaceOr)
PYBIND11_MATH_OPERATOR_BINARY(operator&, PyNumber_And)
PYBIND11_MATH_OPERATOR_BINARY(operator&=, PyNumber_InPlaceAnd)
PYBIND11_MATH_OPERATOR_BINARY(operator^, PyNumber_Xor)
PYBIND11_MATH_OPERATOR_BINARY(operator^=, PyNumber_InPlaceXor)
PYBIND11_MATH_OPERATOR_BINARY(operator<<, PyNumber_Lshift)
PYBIND11_MATH_OPERATOR_BINARY(operator<<=, PyNumber_InPlaceLshift)
PYBIND11_MATH_OPERATOR_BINARY(operator>>, PyNumber_Rshift)
PYBIND11_MATH_OPERATOR_BINARY(operator>>=, PyNumber_InPlaceRshift)
#undef PYBIND11_MATH_OPERATOR_UNARY
#undef PYBIND11_MATH_OPERATOR_BINARY
NAMESPACE_END(detail)
NAMESPACE_END(PYBIND11_NAMESPACE)

View File

@@ -16,7 +16,6 @@
#include <unordered_map>
#include <iostream>
#include <list>
#include <deque>
#include <valarray>
#if defined(_MSC_VER)
@@ -84,8 +83,7 @@ template <typename Type, typename Key> struct set_caster {
template <typename T>
static handle cast(T &&src, return_value_policy policy, handle parent) {
if (!std::is_lvalue_reference<T>::value)
policy = return_value_policy_override<Key>::policy(policy);
policy = return_value_policy_override<Key>::policy(policy);
pybind11::set s;
for (auto &&value : src) {
auto value_ = reinterpret_steal<object>(key_conv::cast(forward_like<T>(value), policy, parent));
@@ -95,7 +93,7 @@ template <typename Type, typename Key> struct set_caster {
return s.release();
}
PYBIND11_TYPE_CASTER(type, _("Set[") + key_conv::name + _("]"));
PYBIND11_TYPE_CASTER(type, _("Set[") + key_conv::name() + _("]"));
};
template <typename Type, typename Key, typename Value> struct map_caster {
@@ -121,12 +119,8 @@ template <typename Type, typename Key, typename Value> struct map_caster {
template <typename T>
static handle cast(T &&src, return_value_policy policy, handle parent) {
dict d;
return_value_policy policy_key = policy;
return_value_policy policy_value = policy;
if (!std::is_lvalue_reference<T>::value) {
policy_key = return_value_policy_override<Key>::policy(policy_key);
policy_value = return_value_policy_override<Value>::policy(policy_value);
}
return_value_policy policy_key = return_value_policy_override<Key>::policy(policy);
return_value_policy policy_value = return_value_policy_override<Value>::policy(policy);
for (auto &&kv : src) {
auto key = reinterpret_steal<object>(key_conv::cast(forward_like<T>(kv.first), policy_key, parent));
auto value = reinterpret_steal<object>(value_conv::cast(forward_like<T>(kv.second), policy_value, parent));
@@ -137,14 +131,14 @@ template <typename Type, typename Key, typename Value> struct map_caster {
return d.release();
}
PYBIND11_TYPE_CASTER(Type, _("Dict[") + key_conv::name + _(", ") + value_conv::name + _("]"));
PYBIND11_TYPE_CASTER(Type, _("Dict[") + key_conv::name() + _(", ") + value_conv::name() + _("]"));
};
template <typename Type, typename Value> struct list_caster {
using value_conv = make_caster<Value>;
bool load(handle src, bool convert) {
if (!isinstance<sequence>(src) || isinstance<str>(src))
if (!isinstance<sequence>(src))
return false;
auto s = reinterpret_borrow<sequence>(src);
value.clear();
@@ -167,8 +161,7 @@ private:
public:
template <typename T>
static handle cast(T &&src, return_value_policy policy, handle parent) {
if (!std::is_lvalue_reference<T>::value)
policy = return_value_policy_override<Value>::policy(policy);
policy = return_value_policy_override<Value>::policy(policy);
list l(src.size());
size_t index = 0;
for (auto &&value : src) {
@@ -180,15 +173,12 @@ public:
return l.release();
}
PYBIND11_TYPE_CASTER(Type, _("List[") + value_conv::name + _("]"));
PYBIND11_TYPE_CASTER(Type, _("List[") + value_conv::name() + _("]"));
};
template <typename Type, typename Alloc> struct type_caster<std::vector<Type, Alloc>>
: list_caster<std::vector<Type, Alloc>, Type> { };
template <typename Type, typename Alloc> struct type_caster<std::deque<Type, Alloc>>
: list_caster<std::deque<Type, Alloc>, Type> { };
template <typename Type, typename Alloc> struct type_caster<std::list<Type, Alloc>>
: list_caster<std::list<Type, Alloc>, Type> { };
@@ -209,9 +199,9 @@ private:
public:
bool load(handle src, bool convert) {
if (!isinstance<sequence>(src))
if (!isinstance<list>(src))
return false;
auto l = reinterpret_borrow<sequence>(src);
auto l = reinterpret_borrow<list>(src);
if (!require_size(l.size()))
return false;
size_t ctr = 0;
@@ -237,7 +227,7 @@ public:
return l.release();
}
PYBIND11_TYPE_CASTER(ArrayType, _("List[") + value_conv::name + _<Resizable>(_(""), _("[") + _<Size>() + _("]")) + _("]"));
PYBIND11_TYPE_CASTER(ArrayType, _("List[") + value_conv::name() + _<Resizable>(_(""), _("[") + _<Size>() + _("]")) + _("]"));
};
template <typename Type, size_t Size> struct type_caster<std::array<Type, Size>>
@@ -284,7 +274,7 @@ template<typename T> struct optional_caster {
return true;
}
PYBIND11_TYPE_CASTER(T, _("Optional[") + value_conv::name + _("]"));
PYBIND11_TYPE_CASTER(T, _("Optional[") + value_conv::name() + _("]"));
};
#if PYBIND11_HAS_OPTIONAL
@@ -364,7 +354,7 @@ struct variant_caster<V<Ts...>> {
}
using Type = V<Ts...>;
PYBIND11_TYPE_CASTER(Type, _("Union[") + detail::concat(make_caster<Ts>::name...) + _("]"));
PYBIND11_TYPE_CASTER(Type, _("Union[") + detail::concat(make_caster<Ts>::name()...) + _("]"));
};
#if PYBIND11_HAS_VARIANT

View File

@@ -122,7 +122,7 @@ void vector_modifiers(enable_if_t<is_copy_constructible<typename Vector::value_t
cl.def(init([](iterable it) {
auto v = std::unique_ptr<Vector>(new Vector());
v->reserve(len_hint(it));
v->reserve(len(it));
for (handle h : it)
v->push_back(h.cast<T>());
return v.release();
@@ -136,28 +136,6 @@ void vector_modifiers(enable_if_t<is_copy_constructible<typename Vector::value_t
"Extend the list by appending all the items in the given list"
);
cl.def("extend",
[](Vector &v, iterable it) {
const size_t old_size = v.size();
v.reserve(old_size + len_hint(it));
try {
for (handle h : it) {
v.push_back(h.cast<T>());
}
} catch (const cast_error &) {
v.erase(v.begin() + static_cast<typename Vector::difference_type>(old_size), v.end());
try {
v.shrink_to_fit();
} catch (const std::exception &) {
// Do nothing
}
throw;
}
},
arg("L"),
"Extend the list by appending all the items in the given list"
);
cl.def("insert",
[](Vector &v, SizeType i, const T &x) {
if (i > v.size())
@@ -601,15 +579,6 @@ class_<Map, holder_type> bind_map(handle scope, const std::string &name, Args&&.
return_value_policy::reference_internal // ref + keepalive
);
cl.def("__contains__",
[](Map &m, const KeyType &k) -> bool {
auto it = m.find(k);
if (it == m.end())
return false;
return true;
}
);
// Assignment provided only if the type is copyable
detail::map_assignment<Map, Class_>(cl);

View File

@@ -4,7 +4,6 @@
#include <string>
#include <regex>
#include <algorithm>
#include "tensorflow/core/framework/tensor.h"
#include "triton/codegen/selection/selection.h"
#include "triton/runtime/function.h"
#include "triton/lang/code_gen.h"
@@ -22,15 +21,15 @@ using namespace triton;
namespace rt = triton::runtime;
typedef std::vector<tensorflow::Tensor> tf_grid_t;
typedef std::function<tf_grid_t(const rt::function::options_t& opt)> tf_grid_fn_ty;
/* TF triton op properties */
std::map<size_t, tf_grid_fn_ty> id_grid_map;
std::map<size_t, rt::function::grid_fn_ty> id_grid_map;
std::map<size_t, rt::function*> id_fn_map;
std::map<size_t, int64_t> i64scalar_map;
void register_grid(size_t id,
const tf_grid_fn_ty& grid_fn) {
const rt::function::grid_fn_ty& grid_fn) {
id_grid_map[id] = grid_fn;
}
@@ -43,6 +42,17 @@ size_t register_fn(const std::string& src,
return id;
}
size_t make_scalar_id() {
return i64scalar_map.size();
}
bool has_scalar(size_t id) {
return i64scalar_map.find(id) != i64scalar_map.end();
}
int64_t retrieve_scalar(size_t id) {
return i64scalar_map.at(id);
}
/* TF source-code generation */
@@ -112,13 +122,6 @@ void gen_make_handles(std::ostream &os, const std::vector<ir::argument*>& args)
}
void gen_make_launch_function(std::ostream &os, const std::vector<ir::argument*>& args) {
os << " rt::function::grid_fn_ty grid_fn = [&](const rt::function::options_t& opt) {" << std::endl;
os << " auto tmp = id_grid_map.at(id_)(opt);" << std::endl;
os << " rt::grid_t result;" << std::endl;
os << " for(auto& x: tmp) { result.push_back(x.scalar<int>()()); }" << std::endl;
os << " return result; }; " << std::endl;
os << " (*id_fn_map.at(id_))({";
for(unsigned i = 0; i < args.size() ; i++){
ir::argument *arg = args[i];
@@ -129,7 +132,7 @@ void gen_make_launch_function(std::ostream &os, const std::vector<ir::argument*>
os << ", ";
os << name;
}
os << "}, grid_fn, stream); \n";
os << "}, id_grid_map.at(id_), stream); \n";
}
void gen_register_kernel_builder(std::ostream &os, const std::string &name,
@@ -239,9 +242,7 @@ using GPUDevice = Eigen::GpuDevice;
namespace rt = triton::runtime;
namespace drv = triton::driver;
typedef std::vector<tensorflow::Tensor> tf_grid_t;
typedef std::function<tf_grid_t(const rt::function::options_t& opt)> tf_grid_fn_ty;
extern std::map<size_t, tf_grid_fn_ty> id_grid_map;
extern std::map<size_t, rt::function::grid_fn_ty> id_grid_map;
extern std::map<size_t, rt::function*> id_fn_map;
@@ -307,7 +308,8 @@ PYBIND11_MODULE(libtriton, m) {
// bindings for triton classes
pybind11::class_<options_t>(m, "options")
.def(pybind11::init<>())
.def("D", &options_t::D<int>);
.def("d", &options_t::D<int>)
.def_readonly("num_warps", &options_t::num_warps);
pybind11::class_<options_space_t>(m, "options_space")
.def(pybind11::init<>())
@@ -317,4 +319,6 @@ PYBIND11_MODULE(libtriton, m) {
// hooks into triton constructs since frameworks may not use pybind11
m.def("register_grid", &register_grid);
m.def("register_fn", &register_fn);
m.def("make_scalar_id", &make_scalar_id);
m.def("retrieve_scalar", &retrieve_scalar);
}

View File

@@ -0,0 +1,37 @@
#include <map>
#include "tensorflow/core/framework/op_kernel.h"
using namespace tensorflow;
extern std::map<size_t, int64_t> i64scalar_map;
class RegisterScalarOp : public OpKernel {
public:
explicit RegisterScalarOp(OpKernelConstruction* context)
: OpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("id", &id_));
}
void Compute(OpKernelContext* context) override {
// fetch input
const Tensor& x = context->input(0);
const int32* x_data = (const int32*)x.tensor_data().data();
const int32 x_rank = x.dims();
OP_REQUIRES(context, x_rank == 0, errors::InvalidArgument("Input must be a scalar"));
i64scalar_map[id_] = *x_data;
context->set_output(0, x);
}
private:
int id_;
};
REGISTER_KERNEL_BUILDER(Name("RegisterScalar")
.HostMemory("x")
.Device(DEVICE_CPU), RegisterScalarOp);
REGISTER_OP("RegisterScalar")
.Input("x: int32")
.Output("y: int32")
.Attr("id: int")
;

View File

@@ -104,9 +104,77 @@ def _cvt_to_def_str(obj):
tf.float64: 'double'}[obj]
return str(obj)
class op:
def _make_tensorflow_op(self, src, outputs, options):
class scalar:
def __init__(self, x):
self.id = libtriton.make_scalar_id()
self.handle = extra_ops.register_scalar(x, id=self.id)
self.assume_initialized = False
def set_assume_initialized(self):
self.assume_initialized = True
def unset_assume_initialized(self):
self.assume_initialized = False
def get_value(self):
if self.assume_initialized:
return libtriton.retrieve_scalar(self.id)
else:
return self.handle
def __add__(self, other):
return self.get_value() + other
def __radd__(self, other):
return other + self.get_value()
def __sub__(self, other):
return self.get_value() - other
def __rsub(self, other):
return other - self.get_value()
def __mul__(self, other):
return self.get_value() * other
def __rmul(self, other):
return other * self.get_value()
def __floordiv__(self, other):
return self.get_value() // other
def __rfloordiv__(self, other):
return other // self.get_value()
def __div__(self, other):
return self.get_value() / other
def __rdiv__(self, other):
return other / self.get_value()
def __truediv__(self, other):
self.get_value().__truediv__(other)
def __rtruediv__(self, other):
other.__truediv__(self.get_value())
def __neg__(self):
return -self.get_value()
class lazy_shape:
def __init__(self, shape):
self.shape = shape
def __getitem__(self, key):
return scalar(self.shape[key])
def shape(A) :
return lazy_shape(tf.shape(A))
def _make_tensorflow_op(src, outputs, options):
src, name = make_bindings(src, outputs, options)
cache_path = make_cache_path(src)
cpp, so = write_bindings(src, cache_path)
@@ -114,15 +182,18 @@ class op:
result = tf.load_op_library(so)
return result.__dict__[name]
class op:
def __init__(self, src, outputs):
self.fw_ops = dict()
self.src = src
self.outputs = outputs
pass
def D(self, name):
pass
def __del__(self):
libtriton.unregister_grid(self.id)
def __call__(self, *args, **kwargs):
# recompilation key
key = zip(kwargs.keys(), kwargs.values())
@@ -141,11 +212,23 @@ class op:
opt.num_warps = [1, 2, 4, 8]
# register framework op
id = libtriton.register_fn(self.src, opt)
self.fw_ops[key] = (self._make_tensorflow_op(self.src, self.outputs, opt), id)
self.fw_ops[key] = (_make_tensorflow_op(self.src, self.outputs, opt), id)
# retrieve framework op
op, id = self.fw_ops[key]
libtriton.register_grid(id, args[-1])
op_args = args[:-1]
op, id = self.fw_ops[key]
# create grid function
scalars = [x for x in args[:-1] if isinstance(x, scalar)]
def grid(opt):
for x in scalars:
x.set_assume_initialized()
result = args[-1](opt)
for x in scalars:
x.unset_assume_initialized()
return result
# register grid function
self.grid = grid
libtriton.register_grid(id, self.grid)
# create operands
op_args = [x.handle if isinstance(x, scalar) else x for x in args[:-1]]
return op(*op_args, id=id)
@@ -158,4 +241,6 @@ def make_tensorflow_op(src, outputs, grids):
return result.__dict__[name]
def empty(shapes):
return extra_ops.alloc_empty(tf.stack(shapes))
args = [x.handle if isinstance(x, scalar) else x for x in shapes]
args = tf.stack(args)
return extra_ops.alloc_empty(args)