File indexing completed on 2025-01-18 10:06:15
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010 #pragma once
0011
0012 #include "pybind11.h"
0013 #include "detail/common.h"
0014 #include "complex.h"
0015 #include "gil_safe_call_once.h"
0016 #include "pytypes.h"
0017
0018 #include <algorithm>
0019 #include <array>
0020 #include <cstdint>
0021 #include <cstdlib>
0022 #include <cstring>
0023 #include <functional>
0024 #include <numeric>
0025 #include <sstream>
0026 #include <string>
0027 #include <type_traits>
0028 #include <typeindex>
0029 #include <utility>
0030 #include <vector>
0031
0032 #if defined(PYBIND11_NUMPY_1_ONLY) && !defined(PYBIND11_INTERNAL_NUMPY_1_ONLY_DETECTED)
0033 # error PYBIND11_NUMPY_1_ONLY must be defined before any pybind11 header is included.
0034 #endif
0035
0036
0037
0038
0039
0040
0041 static_assert(sizeof(::pybind11::ssize_t) == sizeof(Py_intptr_t), "ssize_t != Py_intptr_t");
0042 static_assert(std::is_signed<Py_intptr_t>::value, "Py_intptr_t must be signed");
0043
0044
0045 PYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
0046
0047 PYBIND11_WARNING_DISABLE_MSVC(4127)
0048
0049 class dtype;
0050 class array;
0051
0052 PYBIND11_NAMESPACE_BEGIN(detail)
0053
0054 template <>
0055 struct handle_type_name<dtype> {
0056 static constexpr auto name = const_name("numpy.dtype");
0057 };
0058
0059 template <>
0060 struct handle_type_name<array> {
0061 static constexpr auto name = const_name("numpy.ndarray");
0062 };
0063
0064 template <typename type, typename SFINAE = void>
0065 struct npy_format_descriptor;
0066
0067
0068 struct PyArrayDescr1_Proxy {
0069 PyObject_HEAD
0070 PyObject *typeobj;
0071 char kind;
0072 char type;
0073 char byteorder;
0074 char flags;
0075 int type_num;
0076 int elsize;
0077 int alignment;
0078 char *subarray;
0079 PyObject *fields;
0080 PyObject *names;
0081 };
0082
0083 #ifndef PYBIND11_NUMPY_1_ONLY
0084 struct PyArrayDescr_Proxy {
0085 PyObject_HEAD
0086 PyObject *typeobj;
0087 char kind;
0088 char type;
0089 char byteorder;
0090 char _former_flags;
0091 int type_num;
0092
0093 };
0094 #else
0095
0096 using PyArrayDescr_Proxy = PyArrayDescr1_Proxy;
0097 #endif
0098
0099
0100 struct PyArrayDescr2_Proxy {
0101 PyObject_HEAD
0102 PyObject *typeobj;
0103 char kind;
0104 char type;
0105 char byteorder;
0106 char _former_flags;
0107 int type_num;
0108 std::uint64_t flags;
0109 ssize_t elsize;
0110 ssize_t alignment;
0111 PyObject *metadata;
0112 Py_hash_t hash;
0113 void *reserved_null[2];
0114
0115 char *subarray;
0116 PyObject *fields;
0117 PyObject *names;
0118 };
0119
0120 struct PyArray_Proxy {
0121 PyObject_HEAD
0122 char *data;
0123 int nd;
0124 ssize_t *dimensions;
0125 ssize_t *strides;
0126 PyObject *base;
0127 PyObject *descr;
0128 int flags;
0129 };
0130
0131 struct PyVoidScalarObject_Proxy {
0132 PyObject_VAR_HEAD char *obval;
0133 PyArrayDescr_Proxy *descr;
0134 int flags;
0135 PyObject *base;
0136 };
0137
0138 struct numpy_type_info {
0139 PyObject *dtype_ptr;
0140 std::string format_str;
0141 };
0142
0143 struct numpy_internals {
0144 std::unordered_map<std::type_index, numpy_type_info> registered_dtypes;
0145
0146 numpy_type_info *get_type_info(const std::type_info &tinfo, bool throw_if_missing = true) {
0147 auto it = registered_dtypes.find(std::type_index(tinfo));
0148 if (it != registered_dtypes.end()) {
0149 return &(it->second);
0150 }
0151 if (throw_if_missing) {
0152 pybind11_fail(std::string("NumPy type info missing for ") + tinfo.name());
0153 }
0154 return nullptr;
0155 }
0156
0157 template <typename T>
0158 numpy_type_info *get_type_info(bool throw_if_missing = true) {
0159 return get_type_info(typeid(typename std::remove_cv<T>::type), throw_if_missing);
0160 }
0161 };
0162
0163 PYBIND11_NOINLINE void load_numpy_internals(numpy_internals *&ptr) {
0164 ptr = &get_or_create_shared_data<numpy_internals>("_numpy_internals");
0165 }
0166
0167 inline numpy_internals &get_numpy_internals() {
0168 static numpy_internals *ptr = nullptr;
0169 if (!ptr) {
0170 load_numpy_internals(ptr);
0171 }
0172 return *ptr;
0173 }
0174
0175 PYBIND11_NOINLINE module_ import_numpy_core_submodule(const char *submodule_name) {
0176 module_ numpy = module_::import("numpy");
0177 str version_string = numpy.attr("__version__");
0178
0179 module_ numpy_lib = module_::import("numpy.lib");
0180 object numpy_version = numpy_lib.attr("NumpyVersion")(version_string);
0181 int major_version = numpy_version.attr("major").cast<int>();
0182
0183 #ifdef PYBIND11_NUMPY_1_ONLY
0184 if (major_version >= 2) {
0185 throw std::runtime_error(
0186 "This extension was built with PYBIND11_NUMPY_1_ONLY defined, "
0187 "but NumPy 2 is used in this process. For NumPy2 compatibility, "
0188 "this extension needs to be rebuilt without the PYBIND11_NUMPY_1_ONLY define.");
0189 }
0190 #endif
0191
0192
0193 std::string numpy_core_path = major_version >= 2 ? "numpy._core" : "numpy.core";
0194 return module_::import((numpy_core_path + "." + submodule_name).c_str());
0195 }
0196
0197 template <typename T>
0198 struct same_size {
0199 template <typename U>
0200 using as = bool_constant<sizeof(T) == sizeof(U)>;
0201 };
0202
0203 template <typename Concrete>
0204 constexpr int platform_lookup() {
0205 return -1;
0206 }
0207
0208
0209 template <typename Concrete, typename T, typename... Ts, typename... Ints>
0210 constexpr int platform_lookup(int I, Ints... Is) {
0211 return sizeof(Concrete) == sizeof(T) ? I : platform_lookup<Concrete, Ts...>(Is...);
0212 }
0213
0214 struct npy_api {
0215 enum constants {
0216 NPY_ARRAY_C_CONTIGUOUS_ = 0x0001,
0217 NPY_ARRAY_F_CONTIGUOUS_ = 0x0002,
0218 NPY_ARRAY_OWNDATA_ = 0x0004,
0219 NPY_ARRAY_FORCECAST_ = 0x0010,
0220 NPY_ARRAY_ENSUREARRAY_ = 0x0040,
0221 NPY_ARRAY_ALIGNED_ = 0x0100,
0222 NPY_ARRAY_WRITEABLE_ = 0x0400,
0223 NPY_BOOL_ = 0,
0224 NPY_BYTE_,
0225 NPY_UBYTE_,
0226 NPY_SHORT_,
0227 NPY_USHORT_,
0228 NPY_INT_,
0229 NPY_UINT_,
0230 NPY_LONG_,
0231 NPY_ULONG_,
0232 NPY_LONGLONG_,
0233 NPY_ULONGLONG_,
0234 NPY_FLOAT_,
0235 NPY_DOUBLE_,
0236 NPY_LONGDOUBLE_,
0237 NPY_CFLOAT_,
0238 NPY_CDOUBLE_,
0239 NPY_CLONGDOUBLE_,
0240 NPY_OBJECT_ = 17,
0241 NPY_STRING_,
0242 NPY_UNICODE_,
0243 NPY_VOID_,
0244
0245 NPY_INT8_ = NPY_BYTE_,
0246 NPY_UINT8_ = NPY_UBYTE_,
0247 NPY_INT16_ = NPY_SHORT_,
0248 NPY_UINT16_ = NPY_USHORT_,
0249
0250
0251
0252 NPY_INT32_
0253 = platform_lookup<std::int32_t, long, int, short>(NPY_LONG_, NPY_INT_, NPY_SHORT_),
0254 NPY_UINT32_ = platform_lookup<std::uint32_t, unsigned long, unsigned int, unsigned short>(
0255 NPY_ULONG_, NPY_UINT_, NPY_USHORT_),
0256 NPY_INT64_
0257 = platform_lookup<std::int64_t, long, long long, int>(NPY_LONG_, NPY_LONGLONG_, NPY_INT_),
0258 NPY_UINT64_
0259 = platform_lookup<std::uint64_t, unsigned long, unsigned long long, unsigned int>(
0260 NPY_ULONG_, NPY_ULONGLONG_, NPY_UINT_),
0261 };
0262
0263 unsigned int PyArray_RUNTIME_VERSION_;
0264
0265 struct PyArray_Dims {
0266 Py_intptr_t *ptr;
0267 int len;
0268 };
0269
0270 static npy_api &get() {
0271 PYBIND11_CONSTINIT static gil_safe_call_once_and_store<npy_api> storage;
0272 return storage.call_once_and_store_result(lookup).get_stored();
0273 }
0274
0275 bool PyArray_Check_(PyObject *obj) const {
0276 return PyObject_TypeCheck(obj, PyArray_Type_) != 0;
0277 }
0278 bool PyArrayDescr_Check_(PyObject *obj) const {
0279 return PyObject_TypeCheck(obj, PyArrayDescr_Type_) != 0;
0280 }
0281
0282 unsigned int (*PyArray_GetNDArrayCFeatureVersion_)();
0283 PyObject *(*PyArray_DescrFromType_)(int);
0284 PyObject *(*PyArray_NewFromDescr_)(PyTypeObject *,
0285 PyObject *,
0286 int,
0287 Py_intptr_t const *,
0288 Py_intptr_t const *,
0289 void *,
0290 int,
0291 PyObject *);
0292
0293 PyObject *(*PyArray_DescrNewFromType_)(int);
0294 int (*PyArray_CopyInto_)(PyObject *, PyObject *);
0295 PyObject *(*PyArray_NewCopy_)(PyObject *, int);
0296 PyTypeObject *PyArray_Type_;
0297 PyTypeObject *PyVoidArrType_Type_;
0298 PyTypeObject *PyArrayDescr_Type_;
0299 PyObject *(*PyArray_DescrFromScalar_)(PyObject *);
0300 PyObject *(*PyArray_FromAny_)(PyObject *, PyObject *, int, int, int, PyObject *);
0301 int (*PyArray_DescrConverter_)(PyObject *, PyObject **);
0302 bool (*PyArray_EquivTypes_)(PyObject *, PyObject *);
0303 #ifdef PYBIND11_NUMPY_1_ONLY
0304 int (*PyArray_GetArrayParamsFromObject_)(PyObject *,
0305 PyObject *,
0306 unsigned char,
0307 PyObject **,
0308 int *,
0309 Py_intptr_t *,
0310 PyObject **,
0311 PyObject *);
0312 #endif
0313 PyObject *(*PyArray_Squeeze_)(PyObject *);
0314
0315 int (*PyArray_SetBaseObject_)(PyObject *, PyObject *);
0316 PyObject *(*PyArray_Resize_)(PyObject *, PyArray_Dims *, int, int);
0317 PyObject *(*PyArray_Newshape_)(PyObject *, PyArray_Dims *, int);
0318 PyObject *(*PyArray_View_)(PyObject *, PyObject *, PyObject *);
0319
0320 private:
0321 enum functions {
0322 API_PyArray_GetNDArrayCFeatureVersion = 211,
0323 API_PyArray_Type = 2,
0324 API_PyArrayDescr_Type = 3,
0325 API_PyVoidArrType_Type = 39,
0326 API_PyArray_DescrFromType = 45,
0327 API_PyArray_DescrFromScalar = 57,
0328 API_PyArray_FromAny = 69,
0329 API_PyArray_Resize = 80,
0330
0331 API_PyArray_CopyInto = 50,
0332 API_PyArray_NewCopy = 85,
0333 API_PyArray_NewFromDescr = 94,
0334 API_PyArray_DescrNewFromType = 96,
0335 API_PyArray_Newshape = 135,
0336 API_PyArray_Squeeze = 136,
0337 API_PyArray_View = 137,
0338 API_PyArray_DescrConverter = 174,
0339 API_PyArray_EquivTypes = 182,
0340 #ifdef PYBIND11_NUMPY_1_ONLY
0341 API_PyArray_GetArrayParamsFromObject = 278,
0342 #endif
0343 API_PyArray_SetBaseObject = 282
0344 };
0345
0346 static npy_api lookup() {
0347 module_ m = detail::import_numpy_core_submodule("multiarray");
0348 auto c = m.attr("_ARRAY_API");
0349 void **api_ptr = (void **) PyCapsule_GetPointer(c.ptr(), nullptr);
0350 if (api_ptr == nullptr) {
0351 raise_from(PyExc_SystemError, "FAILURE obtaining numpy _ARRAY_API pointer.");
0352 throw error_already_set();
0353 }
0354 npy_api api;
0355 #define DECL_NPY_API(Func) api.Func##_ = (decltype(api.Func##_)) api_ptr[API_##Func];
0356 DECL_NPY_API(PyArray_GetNDArrayCFeatureVersion);
0357 api.PyArray_RUNTIME_VERSION_ = api.PyArray_GetNDArrayCFeatureVersion_();
0358 if (api.PyArray_RUNTIME_VERSION_ < 0x7) {
0359 pybind11_fail("pybind11 numpy support requires numpy >= 1.7.0");
0360 }
0361 DECL_NPY_API(PyArray_Type);
0362 DECL_NPY_API(PyVoidArrType_Type);
0363 DECL_NPY_API(PyArrayDescr_Type);
0364 DECL_NPY_API(PyArray_DescrFromType);
0365 DECL_NPY_API(PyArray_DescrFromScalar);
0366 DECL_NPY_API(PyArray_FromAny);
0367 DECL_NPY_API(PyArray_Resize);
0368 DECL_NPY_API(PyArray_CopyInto);
0369 DECL_NPY_API(PyArray_NewCopy);
0370 DECL_NPY_API(PyArray_NewFromDescr);
0371 DECL_NPY_API(PyArray_DescrNewFromType);
0372 DECL_NPY_API(PyArray_Newshape);
0373 DECL_NPY_API(PyArray_Squeeze);
0374 DECL_NPY_API(PyArray_View);
0375 DECL_NPY_API(PyArray_DescrConverter);
0376 DECL_NPY_API(PyArray_EquivTypes);
0377 #ifdef PYBIND11_NUMPY_1_ONLY
0378 DECL_NPY_API(PyArray_GetArrayParamsFromObject);
0379 #endif
0380 DECL_NPY_API(PyArray_SetBaseObject);
0381
0382 #undef DECL_NPY_API
0383 return api;
0384 }
0385 };
0386
0387 inline PyArray_Proxy *array_proxy(void *ptr) { return reinterpret_cast<PyArray_Proxy *>(ptr); }
0388
0389 inline const PyArray_Proxy *array_proxy(const void *ptr) {
0390 return reinterpret_cast<const PyArray_Proxy *>(ptr);
0391 }
0392
0393 inline PyArrayDescr_Proxy *array_descriptor_proxy(PyObject *ptr) {
0394 return reinterpret_cast<PyArrayDescr_Proxy *>(ptr);
0395 }
0396
0397 inline const PyArrayDescr_Proxy *array_descriptor_proxy(const PyObject *ptr) {
0398 return reinterpret_cast<const PyArrayDescr_Proxy *>(ptr);
0399 }
0400
0401 inline const PyArrayDescr1_Proxy *array_descriptor1_proxy(const PyObject *ptr) {
0402 return reinterpret_cast<const PyArrayDescr1_Proxy *>(ptr);
0403 }
0404
0405 inline const PyArrayDescr2_Proxy *array_descriptor2_proxy(const PyObject *ptr) {
0406 return reinterpret_cast<const PyArrayDescr2_Proxy *>(ptr);
0407 }
0408
0409 inline bool check_flags(const void *ptr, int flag) {
0410 return (flag == (array_proxy(ptr)->flags & flag));
0411 }
0412
0413 template <typename T>
0414 struct is_std_array : std::false_type {};
0415 template <typename T, size_t N>
0416 struct is_std_array<std::array<T, N>> : std::true_type {};
0417 template <typename T>
0418 struct is_complex : std::false_type {};
0419 template <typename T>
0420 struct is_complex<std::complex<T>> : std::true_type {};
0421
0422 template <typename T>
0423 struct array_info_scalar {
0424 using type = T;
0425 static constexpr bool is_array = false;
0426 static constexpr bool is_empty = false;
0427 static constexpr auto extents = const_name("");
0428 static void append_extents(list & ) {}
0429 };
0430
0431
0432
0433 template <typename T>
0434 struct array_info : array_info_scalar<T> {};
0435 template <typename T, size_t N>
0436 struct array_info<std::array<T, N>> {
0437 using type = typename array_info<T>::type;
0438 static constexpr bool is_array = true;
0439 static constexpr bool is_empty = (N == 0) || array_info<T>::is_empty;
0440 static constexpr size_t extent = N;
0441
0442
0443 static void append_extents(list &shape) {
0444 shape.append(N);
0445 array_info<T>::append_extents(shape);
0446 }
0447
0448 static constexpr auto extents = const_name<array_info<T>::is_array>(
0449 ::pybind11::detail::concat(const_name<N>(), array_info<T>::extents), const_name<N>());
0450 };
0451
0452
0453 template <size_t N>
0454 struct array_info<char[N]> : array_info_scalar<char[N]> {};
0455 template <size_t N>
0456 struct array_info<std::array<char, N>> : array_info_scalar<std::array<char, N>> {};
0457 template <typename T, size_t N>
0458 struct array_info<T[N]> : array_info<std::array<T, N>> {};
0459 template <typename T>
0460 using remove_all_extents_t = typename array_info<T>::type;
0461
0462 template <typename T>
0463 using is_pod_struct
0464 = all_of<std::is_standard_layout<T>,
0465
0466 #if defined(__GLIBCXX__) \
0467 && (__GLIBCXX__ < 20150422 || __GLIBCXX__ == 20150426 || __GLIBCXX__ == 20150623 \
0468 || __GLIBCXX__ == 20150626 || __GLIBCXX__ == 20160803)
0469
0470
0471 std::is_trivially_destructible<T>,
0472 satisfies_any_of<T, std::has_trivial_copy_constructor, std::has_trivial_copy_assign>,
0473 #else
0474 std::is_trivially_copyable<T>,
0475 #endif
0476 satisfies_none_of<T,
0477 std::is_reference,
0478 std::is_array,
0479 is_std_array,
0480 std::is_arithmetic,
0481 is_complex,
0482 std::is_enum>>;
0483
0484
0485 template <typename T>
0486 using is_pod = all_of<std::is_standard_layout<T>, std::is_trivial<T>>;
0487
0488 template <ssize_t Dim = 0, typename Strides>
0489 ssize_t byte_offset_unsafe(const Strides &) {
0490 return 0;
0491 }
0492 template <ssize_t Dim = 0, typename Strides, typename... Ix>
0493 ssize_t byte_offset_unsafe(const Strides &strides, ssize_t i, Ix... index) {
0494 return i * strides[Dim] + byte_offset_unsafe<Dim + 1>(strides, index...);
0495 }
0496
0497
0498
0499
0500
0501
0502 template <typename T, ssize_t Dims>
0503 class unchecked_reference {
0504 protected:
0505 static constexpr bool Dynamic = Dims < 0;
0506 const unsigned char *data_;
0507
0508
0509 conditional_t<Dynamic, const ssize_t *, std::array<ssize_t, (size_t) Dims>> shape_, strides_;
0510 const ssize_t dims_;
0511
0512 friend class pybind11::array;
0513
0514 template <bool Dyn = Dynamic>
0515 unchecked_reference(const void *data,
0516 const ssize_t *shape,
0517 const ssize_t *strides,
0518 enable_if_t<!Dyn, ssize_t>)
0519 : data_{reinterpret_cast<const unsigned char *>(data)}, dims_{Dims} {
0520 for (size_t i = 0; i < (size_t) dims_; i++) {
0521 shape_[i] = shape[i];
0522 strides_[i] = strides[i];
0523 }
0524 }
0525
0526 template <bool Dyn = Dynamic>
0527 unchecked_reference(const void *data,
0528 const ssize_t *shape,
0529 const ssize_t *strides,
0530 enable_if_t<Dyn, ssize_t> dims)
0531 : data_{reinterpret_cast<const unsigned char *>(data)}, shape_{shape}, strides_{strides},
0532 dims_{dims} {}
0533
0534 public:
0535
0536
0537
0538
0539
0540 template <typename... Ix>
0541 const T &operator()(Ix... index) const {
0542 static_assert(ssize_t{sizeof...(Ix)} == Dims || Dynamic,
0543 "Invalid number of indices for unchecked array reference");
0544 return *reinterpret_cast<const T *>(data_
0545 + byte_offset_unsafe(strides_, ssize_t(index)...));
0546 }
0547
0548
0549
0550
0551 template <ssize_t D = Dims, typename = enable_if_t<D == 1 || Dynamic>>
0552 const T &operator[](ssize_t index) const {
0553 return operator()(index);
0554 }
0555
0556
0557 template <typename... Ix>
0558 const T *data(Ix... ix) const {
0559 return &operator()(ssize_t(ix)...);
0560 }
0561
0562
0563 constexpr static ssize_t itemsize() { return sizeof(T); }
0564
0565
0566 ssize_t shape(ssize_t dim) const { return shape_[(size_t) dim]; }
0567
0568
0569 ssize_t ndim() const { return dims_; }
0570
0571
0572
0573 template <bool Dyn = Dynamic>
0574 enable_if_t<!Dyn, ssize_t> size() const {
0575 return std::accumulate(
0576 shape_.begin(), shape_.end(), (ssize_t) 1, std::multiplies<ssize_t>());
0577 }
0578 template <bool Dyn = Dynamic>
0579 enable_if_t<Dyn, ssize_t> size() const {
0580 return std::accumulate(shape_, shape_ + ndim(), (ssize_t) 1, std::multiplies<ssize_t>());
0581 }
0582
0583
0584
0585
0586 ssize_t nbytes() const { return size() * itemsize(); }
0587 };
0588
0589 template <typename T, ssize_t Dims>
0590 class unchecked_mutable_reference : public unchecked_reference<T, Dims> {
0591 friend class pybind11::array;
0592 using ConstBase = unchecked_reference<T, Dims>;
0593 using ConstBase::ConstBase;
0594 using ConstBase::Dynamic;
0595
0596 public:
0597
0598 using ConstBase::operator();
0599 using ConstBase::operator[];
0600
0601
0602 template <typename... Ix>
0603 T &operator()(Ix... index) {
0604 static_assert(ssize_t{sizeof...(Ix)} == Dims || Dynamic,
0605 "Invalid number of indices for unchecked array reference");
0606 return const_cast<T &>(ConstBase::operator()(index...));
0607 }
0608
0609
0610
0611
0612
0613 template <ssize_t D = Dims, typename = enable_if_t<D == 1 || Dynamic>>
0614 T &operator[](ssize_t index) {
0615 return operator()(index);
0616 }
0617
0618
0619 template <typename... Ix>
0620 T *mutable_data(Ix... ix) {
0621 return &operator()(ssize_t(ix)...);
0622 }
0623 };
0624
0625 template <typename T, ssize_t Dim>
0626 struct type_caster<unchecked_reference<T, Dim>> {
0627 static_assert(Dim == 0 && Dim > 0 ,
0628 "unchecked array proxy object is not castable");
0629 };
0630 template <typename T, ssize_t Dim>
0631 struct type_caster<unchecked_mutable_reference<T, Dim>>
0632 : type_caster<unchecked_reference<T, Dim>> {};
0633
0634 PYBIND11_NAMESPACE_END(detail)
0635
0636 class dtype : public object {
0637 public:
0638 PYBIND11_OBJECT_DEFAULT(dtype, object, detail::npy_api::get().PyArrayDescr_Check_)
0639
0640 explicit dtype(const buffer_info &info) {
0641 dtype descr(_dtype_from_pep3118()(pybind11::str(info.format)));
0642
0643 m_ptr = descr.strip_padding(info.itemsize != 0 ? info.itemsize : descr.itemsize())
0644 .release()
0645 .ptr();
0646 }
0647
0648 explicit dtype(const pybind11::str &format) : dtype(from_args(format)) {}
0649
0650 explicit dtype(const std::string &format) : dtype(pybind11::str(format)) {}
0651
0652 explicit dtype(const char *format) : dtype(pybind11::str(format)) {}
0653
0654 dtype(list names, list formats, list offsets, ssize_t itemsize) {
0655 dict args;
0656 args["names"] = std::move(names);
0657 args["formats"] = std::move(formats);
0658 args["offsets"] = std::move(offsets);
0659 args["itemsize"] = pybind11::int_(itemsize);
0660 m_ptr = from_args(args).release().ptr();
0661 }
0662
0663
0664
0665 explicit dtype(int typenum)
0666 : object(detail::npy_api::get().PyArray_DescrFromType_(typenum), stolen_t{}) {
0667 if (m_ptr == nullptr) {
0668 throw error_already_set();
0669 }
0670 }
0671
0672
0673 static dtype from_args(const object &args) {
0674 PyObject *ptr = nullptr;
0675 if ((detail::npy_api::get().PyArray_DescrConverter_(args.ptr(), &ptr) == 0) || !ptr) {
0676 throw error_already_set();
0677 }
0678 return reinterpret_steal<dtype>(ptr);
0679 }
0680
0681
0682 template <typename T>
0683 static dtype of() {
0684 return detail::npy_format_descriptor<typename std::remove_cv<T>::type>::dtype();
0685 }
0686
0687
0688 #ifdef PYBIND11_NUMPY_1_ONLY
0689 ssize_t itemsize() const { return detail::array_descriptor_proxy(m_ptr)->elsize; }
0690 #else
0691 ssize_t itemsize() const {
0692 if (detail::npy_api::get().PyArray_RUNTIME_VERSION_ < 0x12) {
0693 return detail::array_descriptor1_proxy(m_ptr)->elsize;
0694 }
0695 return detail::array_descriptor2_proxy(m_ptr)->elsize;
0696 }
0697 #endif
0698
0699
0700 #ifdef PYBIND11_NUMPY_1_ONLY
0701 bool has_fields() const { return detail::array_descriptor_proxy(m_ptr)->names != nullptr; }
0702 #else
0703 bool has_fields() const {
0704 if (detail::npy_api::get().PyArray_RUNTIME_VERSION_ < 0x12) {
0705 return detail::array_descriptor1_proxy(m_ptr)->names != nullptr;
0706 }
0707 const auto *proxy = detail::array_descriptor2_proxy(m_ptr);
0708 if (proxy->type_num < 0 || proxy->type_num >= 2056) {
0709 return false;
0710 }
0711 return proxy->names != nullptr;
0712 }
0713 #endif
0714
0715
0716
0717 char kind() const { return detail::array_descriptor_proxy(m_ptr)->kind; }
0718
0719
0720
0721 char char_() const {
0722
0723
0724
0725 return detail::array_descriptor_proxy(m_ptr)->type;
0726 }
0727
0728
0729 int num() const {
0730
0731
0732
0733 return detail::array_descriptor_proxy(m_ptr)->type_num;
0734 }
0735
0736
0737 char byteorder() const { return detail::array_descriptor_proxy(m_ptr)->byteorder; }
0738
0739
0740 #ifdef PYBIND11_NUMPY_1_ONLY
0741 int alignment() const { return detail::array_descriptor_proxy(m_ptr)->alignment; }
0742 #else
0743 ssize_t alignment() const {
0744 if (detail::npy_api::get().PyArray_RUNTIME_VERSION_ < 0x12) {
0745 return detail::array_descriptor1_proxy(m_ptr)->alignment;
0746 }
0747 return detail::array_descriptor2_proxy(m_ptr)->alignment;
0748 }
0749 #endif
0750
0751
0752 #ifdef PYBIND11_NUMPY_1_ONLY
0753 char flags() const { return detail::array_descriptor_proxy(m_ptr)->flags; }
0754 #else
0755 std::uint64_t flags() const {
0756 if (detail::npy_api::get().PyArray_RUNTIME_VERSION_ < 0x12) {
0757 return (unsigned char) detail::array_descriptor1_proxy(m_ptr)->flags;
0758 }
0759 return detail::array_descriptor2_proxy(m_ptr)->flags;
0760 }
0761 #endif
0762
0763 private:
0764 static object &_dtype_from_pep3118() {
0765 PYBIND11_CONSTINIT static gil_safe_call_once_and_store<object> storage;
0766 return storage
0767 .call_once_and_store_result([]() {
0768 return detail::import_numpy_core_submodule("_internal")
0769 .attr("_dtype_from_pep3118");
0770 })
0771 .get_stored();
0772 }
0773
0774 dtype strip_padding(ssize_t itemsize) {
0775
0776
0777 if (!has_fields()) {
0778 return *this;
0779 }
0780
0781 struct field_descr {
0782 pybind11::str name;
0783 object format;
0784 pybind11::int_ offset;
0785 field_descr(pybind11::str &&name, object &&format, pybind11::int_ &&offset)
0786 : name{std::move(name)}, format{std::move(format)}, offset{std::move(offset)} {};
0787 };
0788 auto field_dict = attr("fields").cast<dict>();
0789 std::vector<field_descr> field_descriptors;
0790 field_descriptors.reserve(field_dict.size());
0791
0792 for (auto field : field_dict.attr("items")()) {
0793 auto spec = field.cast<tuple>();
0794 auto name = spec[0].cast<pybind11::str>();
0795 auto spec_fo = spec[1].cast<tuple>();
0796 auto format = spec_fo[0].cast<dtype>();
0797 auto offset = spec_fo[1].cast<pybind11::int_>();
0798 if ((len(name) == 0u) && format.kind() == 'V') {
0799 continue;
0800 }
0801 field_descriptors.emplace_back(
0802 std::move(name), format.strip_padding(format.itemsize()), std::move(offset));
0803 }
0804
0805 std::sort(field_descriptors.begin(),
0806 field_descriptors.end(),
0807 [](const field_descr &a, const field_descr &b) {
0808 return a.offset.cast<int>() < b.offset.cast<int>();
0809 });
0810
0811 list names, formats, offsets;
0812 for (auto &descr : field_descriptors) {
0813 names.append(std::move(descr.name));
0814 formats.append(std::move(descr.format));
0815 offsets.append(std::move(descr.offset));
0816 }
0817 return dtype(std::move(names), std::move(formats), std::move(offsets), itemsize);
0818 }
0819 };
0820
0821 class array : public buffer {
0822 public:
0823 PYBIND11_OBJECT_CVT(array, buffer, detail::npy_api::get().PyArray_Check_, raw_array)
0824
0825 enum {
0826 c_style = detail::npy_api::NPY_ARRAY_C_CONTIGUOUS_,
0827 f_style = detail::npy_api::NPY_ARRAY_F_CONTIGUOUS_,
0828 forcecast = detail::npy_api::NPY_ARRAY_FORCECAST_
0829 };
0830
0831 array() : array(0, static_cast<const double *>(nullptr)) {}
0832
0833 using ShapeContainer = detail::any_container<ssize_t>;
0834 using StridesContainer = detail::any_container<ssize_t>;
0835
0836
0837 array(const pybind11::dtype &dt,
0838 ShapeContainer shape,
0839 StridesContainer strides,
0840 const void *ptr = nullptr,
0841 handle base = handle()) {
0842
0843 if (strides->empty()) {
0844 *strides = detail::c_strides(*shape, dt.itemsize());
0845 }
0846
0847 auto ndim = shape->size();
0848 if (ndim != strides->size()) {
0849 pybind11_fail("NumPy: shape ndim doesn't match strides ndim");
0850 }
0851 auto descr = dt;
0852
0853 int flags = 0;
0854 if (base && ptr) {
0855 if (isinstance<array>(base)) {
0856
0857 flags = reinterpret_borrow<array>(base).flags()
0858 & ~detail::npy_api::NPY_ARRAY_OWNDATA_;
0859 } else {
0860
0861 flags = detail::npy_api::NPY_ARRAY_WRITEABLE_;
0862 }
0863 }
0864
0865 auto &api = detail::npy_api::get();
0866 auto tmp = reinterpret_steal<object>(api.PyArray_NewFromDescr_(
0867 api.PyArray_Type_,
0868 descr.release().ptr(),
0869 (int) ndim,
0870
0871 reinterpret_cast<Py_intptr_t *>(shape->data()),
0872 reinterpret_cast<Py_intptr_t *>(strides->data()),
0873 const_cast<void *>(ptr),
0874 flags,
0875 nullptr));
0876 if (!tmp) {
0877 throw error_already_set();
0878 }
0879 if (ptr) {
0880 if (base) {
0881 api.PyArray_SetBaseObject_(tmp.ptr(), base.inc_ref().ptr());
0882 } else {
0883 tmp = reinterpret_steal<object>(
0884 api.PyArray_NewCopy_(tmp.ptr(), -1 ));
0885 }
0886 }
0887 m_ptr = tmp.release().ptr();
0888 }
0889
0890 array(const pybind11::dtype &dt,
0891 ShapeContainer shape,
0892 const void *ptr = nullptr,
0893 handle base = handle())
0894 : array(dt, std::move(shape), {}, ptr, base) {}
0895
0896 template <typename T,
0897 typename
0898 = detail::enable_if_t<std::is_integral<T>::value && !std::is_same<bool, T>::value>>
0899 array(const pybind11::dtype &dt, T count, const void *ptr = nullptr, handle base = handle())
0900 : array(dt, {{count}}, ptr, base) {}
0901
0902 template <typename T>
0903 array(ShapeContainer shape, StridesContainer strides, const T *ptr, handle base = handle())
0904 : array(pybind11::dtype::of<T>(),
0905 std::move(shape),
0906 std::move(strides),
0907 reinterpret_cast<const void *>(ptr),
0908 base) {}
0909
0910 template <typename T>
0911 array(ShapeContainer shape, const T *ptr, handle base = handle())
0912 : array(std::move(shape), {}, ptr, base) {}
0913
0914 template <typename T>
0915 explicit array(ssize_t count, const T *ptr, handle base = handle())
0916 : array({count}, {}, ptr, base) {}
0917
0918 explicit array(const buffer_info &info, handle base = handle())
0919 : array(pybind11::dtype(info), info.shape, info.strides, info.ptr, base) {}
0920
0921
0922 pybind11::dtype dtype() const {
0923 return reinterpret_borrow<pybind11::dtype>(detail::array_proxy(m_ptr)->descr);
0924 }
0925
0926
0927 ssize_t size() const {
0928 return std::accumulate(shape(), shape() + ndim(), (ssize_t) 1, std::multiplies<ssize_t>());
0929 }
0930
0931
0932 ssize_t itemsize() const { return dtype().itemsize(); }
0933
0934
0935 ssize_t nbytes() const { return size() * itemsize(); }
0936
0937
0938 ssize_t ndim() const { return detail::array_proxy(m_ptr)->nd; }
0939
0940
0941 object base() const { return reinterpret_borrow<object>(detail::array_proxy(m_ptr)->base); }
0942
0943
0944 const ssize_t *shape() const { return detail::array_proxy(m_ptr)->dimensions; }
0945
0946
0947 ssize_t shape(ssize_t dim) const {
0948 if (dim >= ndim()) {
0949 fail_dim_check(dim, "invalid axis");
0950 }
0951 return shape()[dim];
0952 }
0953
0954
0955 const ssize_t *strides() const { return detail::array_proxy(m_ptr)->strides; }
0956
0957
0958 ssize_t strides(ssize_t dim) const {
0959 if (dim >= ndim()) {
0960 fail_dim_check(dim, "invalid axis");
0961 }
0962 return strides()[dim];
0963 }
0964
0965
0966 int flags() const { return detail::array_proxy(m_ptr)->flags; }
0967
0968
0969 bool writeable() const {
0970 return detail::check_flags(m_ptr, detail::npy_api::NPY_ARRAY_WRITEABLE_);
0971 }
0972
0973
0974 bool owndata() const {
0975 return detail::check_flags(m_ptr, detail::npy_api::NPY_ARRAY_OWNDATA_);
0976 }
0977
0978
0979
0980 template <typename... Ix>
0981 const void *data(Ix... index) const {
0982 return static_cast<const void *>(detail::array_proxy(m_ptr)->data + offset_at(index...));
0983 }
0984
0985
0986
0987
0988 template <typename... Ix>
0989 void *mutable_data(Ix... index) {
0990 check_writeable();
0991 return static_cast<void *>(detail::array_proxy(m_ptr)->data + offset_at(index...));
0992 }
0993
0994
0995
0996 template <typename... Ix>
0997 ssize_t offset_at(Ix... index) const {
0998 if ((ssize_t) sizeof...(index) > ndim()) {
0999 fail_dim_check(sizeof...(index), "too many indices for an array");
1000 }
1001 return byte_offset(ssize_t(index)...);
1002 }
1003
1004 ssize_t offset_at() const { return 0; }
1005
1006
1007
1008 template <typename... Ix>
1009 ssize_t index_at(Ix... index) const {
1010 return offset_at(index...) / itemsize();
1011 }
1012
1013
1014
1015
1016
1017
1018
1019 template <typename T, ssize_t Dims = -1>
1020 detail::unchecked_mutable_reference<T, Dims> mutable_unchecked() & {
1021 if (Dims >= 0 && ndim() != Dims) {
1022 throw std::domain_error("array has incorrect number of dimensions: "
1023 + std::to_string(ndim()) + "; expected "
1024 + std::to_string(Dims));
1025 }
1026 return detail::unchecked_mutable_reference<T, Dims>(
1027 mutable_data(), shape(), strides(), ndim());
1028 }
1029
1030
1031
1032
1033
1034
1035
1036
1037 template <typename T, ssize_t Dims = -1>
1038 detail::unchecked_reference<T, Dims> unchecked() const & {
1039 if (Dims >= 0 && ndim() != Dims) {
1040 throw std::domain_error("array has incorrect number of dimensions: "
1041 + std::to_string(ndim()) + "; expected "
1042 + std::to_string(Dims));
1043 }
1044 return detail::unchecked_reference<T, Dims>(data(), shape(), strides(), ndim());
1045 }
1046
1047
1048 array squeeze() {
1049 auto &api = detail::npy_api::get();
1050 return reinterpret_steal<array>(api.PyArray_Squeeze_(m_ptr));
1051 }
1052
1053
1054
1055
1056 void resize(ShapeContainer new_shape, bool refcheck = true) {
1057 detail::npy_api::PyArray_Dims d
1058 = {
1059 reinterpret_cast<Py_intptr_t *>(new_shape->data()),
1060 int(new_shape->size())};
1061
1062 auto new_array = reinterpret_steal<object>(
1063 detail::npy_api::get().PyArray_Resize_(m_ptr, &d, int(refcheck), -1));
1064 if (!new_array) {
1065 throw error_already_set();
1066 }
1067 if (isinstance<array>(new_array)) {
1068 *this = std::move(new_array);
1069 }
1070 }
1071
1072
1073 array reshape(ShapeContainer new_shape) {
1074 detail::npy_api::PyArray_Dims d
1075 = {reinterpret_cast<Py_intptr_t *>(new_shape->data()), int(new_shape->size())};
1076 auto new_array
1077 = reinterpret_steal<array>(detail::npy_api::get().PyArray_Newshape_(m_ptr, &d, 0));
1078 if (!new_array) {
1079 throw error_already_set();
1080 }
1081 return new_array;
1082 }
1083
1084
1085
1086
1087
1088
1089 array view(const std::string &dtype) {
1090 auto &api = detail::npy_api::get();
1091 auto new_view = reinterpret_steal<array>(api.PyArray_View_(
1092 m_ptr, dtype::from_args(pybind11::str(dtype)).release().ptr(), nullptr));
1093 if (!new_view) {
1094 throw error_already_set();
1095 }
1096 return new_view;
1097 }
1098
1099
1100
1101 static array ensure(handle h, int ExtraFlags = 0) {
1102 auto result = reinterpret_steal<array>(raw_array(h.ptr(), ExtraFlags));
1103 if (!result) {
1104 PyErr_Clear();
1105 }
1106 return result;
1107 }
1108
1109 protected:
1110 template <typename, typename>
1111 friend struct detail::npy_format_descriptor;
1112
1113 void fail_dim_check(ssize_t dim, const std::string &msg) const {
1114 throw index_error(msg + ": " + std::to_string(dim) + " (ndim = " + std::to_string(ndim())
1115 + ')');
1116 }
1117
1118 template <typename... Ix>
1119 ssize_t byte_offset(Ix... index) const {
1120 check_dimensions(index...);
1121 return detail::byte_offset_unsafe(strides(), ssize_t(index)...);
1122 }
1123
1124 void check_writeable() const {
1125 if (!writeable()) {
1126 throw std::domain_error("array is not writeable");
1127 }
1128 }
1129
1130 template <typename... Ix>
1131 void check_dimensions(Ix... index) const {
1132 check_dimensions_impl(ssize_t(0), shape(), ssize_t(index)...);
1133 }
1134
1135 void check_dimensions_impl(ssize_t, const ssize_t *) const {}
1136
1137 template <typename... Ix>
1138 void check_dimensions_impl(ssize_t axis, const ssize_t *shape, ssize_t i, Ix... index) const {
1139 if (i >= *shape) {
1140 throw index_error(std::string("index ") + std::to_string(i)
1141 + " is out of bounds for axis " + std::to_string(axis)
1142 + " with size " + std::to_string(*shape));
1143 }
1144 check_dimensions_impl(axis + 1, shape + 1, index...);
1145 }
1146
1147
1148 static PyObject *raw_array(PyObject *ptr, int ExtraFlags = 0) {
1149 if (ptr == nullptr) {
1150 set_error(PyExc_ValueError, "cannot create a pybind11::array from a nullptr");
1151 return nullptr;
1152 }
1153 return detail::npy_api::get().PyArray_FromAny_(
1154 ptr, nullptr, 0, 0, detail::npy_api::NPY_ARRAY_ENSUREARRAY_ | ExtraFlags, nullptr);
1155 }
1156 };
1157
1158 template <typename T, int ExtraFlags = array::forcecast>
1159 class array_t : public array {
1160 private:
1161 struct private_ctor {};
1162
1163 array_t(private_ctor,
1164 ShapeContainer &&shape,
1165 StridesContainer &&strides,
1166 const T *ptr,
1167 handle base)
1168 : array(std::move(shape), std::move(strides), ptr, base) {}
1169
1170 public:
1171 static_assert(!detail::array_info<T>::is_array, "Array types cannot be used with array_t");
1172
1173 using value_type = T;
1174
1175 array_t() : array(0, static_cast<const T *>(nullptr)) {}
1176 array_t(handle h, borrowed_t) : array(h, borrowed_t{}) {}
1177 array_t(handle h, stolen_t) : array(h, stolen_t{}) {}
1178
1179 PYBIND11_DEPRECATED("Use array_t<T>::ensure() instead")
1180 array_t(handle h, bool is_borrowed) : array(raw_array_t(h.ptr()), stolen_t{}) {
1181 if (!m_ptr) {
1182 PyErr_Clear();
1183 }
1184 if (!is_borrowed) {
1185 Py_XDECREF(h.ptr());
1186 }
1187 }
1188
1189
1190 array_t(const object &o) : array(raw_array_t(o.ptr()), stolen_t{}) {
1191 if (!m_ptr) {
1192 throw error_already_set();
1193 }
1194 }
1195
1196 explicit array_t(const buffer_info &info, handle base = handle()) : array(info, base) {}
1197
1198 array_t(ShapeContainer shape,
1199 StridesContainer strides,
1200 const T *ptr = nullptr,
1201 handle base = handle())
1202 : array(std::move(shape), std::move(strides), ptr, base) {}
1203
1204 explicit array_t(ShapeContainer shape, const T *ptr = nullptr, handle base = handle())
1205 : array_t(private_ctor{},
1206 std::move(shape),
1207 (ExtraFlags & f_style) != 0 ? detail::f_strides(*shape, itemsize())
1208 : detail::c_strides(*shape, itemsize()),
1209 ptr,
1210 base) {}
1211
1212 explicit array_t(ssize_t count, const T *ptr = nullptr, handle base = handle())
1213 : array({count}, {}, ptr, base) {}
1214
1215 constexpr ssize_t itemsize() const { return sizeof(T); }
1216
1217 template <typename... Ix>
1218 ssize_t index_at(Ix... index) const {
1219 return offset_at(index...) / itemsize();
1220 }
1221
1222 template <typename... Ix>
1223 const T *data(Ix... index) const {
1224 return static_cast<const T *>(array::data(index...));
1225 }
1226
1227 template <typename... Ix>
1228 T *mutable_data(Ix... index) {
1229 return static_cast<T *>(array::mutable_data(index...));
1230 }
1231
1232
1233 template <typename... Ix>
1234 const T &at(Ix... index) const {
1235 if ((ssize_t) sizeof...(index) != ndim()) {
1236 fail_dim_check(sizeof...(index), "index dimension mismatch");
1237 }
1238 return *(static_cast<const T *>(array::data())
1239 + byte_offset(ssize_t(index)...) / itemsize());
1240 }
1241
1242
1243 template <typename... Ix>
1244 T &mutable_at(Ix... index) {
1245 if ((ssize_t) sizeof...(index) != ndim()) {
1246 fail_dim_check(sizeof...(index), "index dimension mismatch");
1247 }
1248 return *(static_cast<T *>(array::mutable_data())
1249 + byte_offset(ssize_t(index)...) / itemsize());
1250 }
1251
1252
1253
1254
1255
1256
1257
1258 template <ssize_t Dims = -1>
1259 detail::unchecked_mutable_reference<T, Dims> mutable_unchecked() & {
1260 return array::mutable_unchecked<T, Dims>();
1261 }
1262
1263
1264
1265
1266
1267
1268
1269
1270 template <ssize_t Dims = -1>
1271 detail::unchecked_reference<T, Dims> unchecked() const & {
1272 return array::unchecked<T, Dims>();
1273 }
1274
1275
1276
1277 static array_t ensure(handle h) {
1278 auto result = reinterpret_steal<array_t>(raw_array_t(h.ptr()));
1279 if (!result) {
1280 PyErr_Clear();
1281 }
1282 return result;
1283 }
1284
1285 static bool check_(handle h) {
1286 const auto &api = detail::npy_api::get();
1287 return api.PyArray_Check_(h.ptr())
1288 && api.PyArray_EquivTypes_(detail::array_proxy(h.ptr())->descr,
1289 dtype::of<T>().ptr())
1290 && detail::check_flags(h.ptr(), ExtraFlags & (array::c_style | array::f_style));
1291 }
1292
1293 protected:
1294
1295 static PyObject *raw_array_t(PyObject *ptr) {
1296 if (ptr == nullptr) {
1297 set_error(PyExc_ValueError, "cannot create a pybind11::array_t from a nullptr");
1298 return nullptr;
1299 }
1300 return detail::npy_api::get().PyArray_FromAny_(ptr,
1301 dtype::of<T>().release().ptr(),
1302 0,
1303 0,
1304 detail::npy_api::NPY_ARRAY_ENSUREARRAY_
1305 | ExtraFlags,
1306 nullptr);
1307 }
1308 };
1309
1310 template <typename T>
1311 struct format_descriptor<T, detail::enable_if_t<detail::is_pod_struct<T>::value>> {
1312 static std::string format() {
1313 return detail::npy_format_descriptor<typename std::remove_cv<T>::type>::format();
1314 }
1315 };
1316
1317 template <size_t N>
1318 struct format_descriptor<char[N]> {
1319 static std::string format() { return std::to_string(N) + 's'; }
1320 };
1321 template <size_t N>
1322 struct format_descriptor<std::array<char, N>> {
1323 static std::string format() { return std::to_string(N) + 's'; }
1324 };
1325
1326 template <typename T>
1327 struct format_descriptor<T, detail::enable_if_t<std::is_enum<T>::value>> {
1328 static std::string format() {
1329 return format_descriptor<
1330 typename std::remove_cv<typename std::underlying_type<T>::type>::type>::format();
1331 }
1332 };
1333
1334 template <typename T>
1335 struct format_descriptor<T, detail::enable_if_t<detail::array_info<T>::is_array>> {
1336 static std::string format() {
1337 using namespace detail;
1338 static constexpr auto extents = const_name("(") + array_info<T>::extents + const_name(")");
1339 return extents.text + format_descriptor<remove_all_extents_t<T>>::format();
1340 }
1341 };
1342
1343 PYBIND11_NAMESPACE_BEGIN(detail)
1344 template <typename T, int ExtraFlags>
1345 struct pyobject_caster<array_t<T, ExtraFlags>> {
1346 using type = array_t<T, ExtraFlags>;
1347
1348 bool load(handle src, bool convert) {
1349 if (!convert && !type::check_(src)) {
1350 return false;
1351 }
1352 value = type::ensure(src);
1353 return static_cast<bool>(value);
1354 }
1355
1356 static handle cast(const handle &src, return_value_policy , handle ) {
1357 return src.inc_ref();
1358 }
1359 PYBIND11_TYPE_CASTER(type, handle_type_name<type>::name);
1360 };
1361
1362 template <typename T>
1363 struct compare_buffer_info<T, detail::enable_if_t<detail::is_pod_struct<T>::value>> {
1364 static bool compare(const buffer_info &b) {
1365 return npy_api::get().PyArray_EquivTypes_(dtype::of<T>().ptr(), dtype(b).ptr());
1366 }
1367 };
1368
1369 template <typename T, typename = void>
1370 struct npy_format_descriptor_name;
1371
1372 template <typename T>
1373 struct npy_format_descriptor_name<T, enable_if_t<std::is_integral<T>::value>> {
1374 static constexpr auto name = const_name<std::is_same<T, bool>::value>(
1375 const_name("bool"),
1376 const_name<std::is_signed<T>::value>("numpy.int", "numpy.uint")
1377 + const_name<sizeof(T) * 8>());
1378 };
1379
1380 template <typename T>
1381 struct npy_format_descriptor_name<T, enable_if_t<std::is_floating_point<T>::value>> {
1382 static constexpr auto name = const_name < std::is_same<T, float>::value
1383 || std::is_same<T, const float>::value
1384 || std::is_same<T, double>::value
1385 || std::is_same<T, const double>::value
1386 > (const_name("numpy.float") + const_name<sizeof(T) * 8>(),
1387 const_name("numpy.longdouble"));
1388 };
1389
1390 template <typename T>
1391 struct npy_format_descriptor_name<T, enable_if_t<is_complex<T>::value>> {
1392 static constexpr auto name = const_name < std::is_same<typename T::value_type, float>::value
1393 || std::is_same<typename T::value_type, const float>::value
1394 || std::is_same<typename T::value_type, double>::value
1395 || std::is_same<typename T::value_type, const double>::value
1396 > (const_name("numpy.complex")
1397 + const_name<sizeof(typename T::value_type) * 16>(),
1398 const_name("numpy.longcomplex"));
1399 };
1400
1401 template <typename T>
1402 struct npy_format_descriptor<
1403 T,
1404 enable_if_t<satisfies_any_of<T, std::is_arithmetic, is_complex>::value>>
1405 : npy_format_descriptor_name<T> {
1406 private:
1407
1408 constexpr static const int values[15] = {npy_api::NPY_BOOL_,
1409 npy_api::NPY_BYTE_,
1410 npy_api::NPY_UBYTE_,
1411 npy_api::NPY_INT16_,
1412 npy_api::NPY_UINT16_,
1413 npy_api::NPY_INT32_,
1414 npy_api::NPY_UINT32_,
1415 npy_api::NPY_INT64_,
1416 npy_api::NPY_UINT64_,
1417 npy_api::NPY_FLOAT_,
1418 npy_api::NPY_DOUBLE_,
1419 npy_api::NPY_LONGDOUBLE_,
1420 npy_api::NPY_CFLOAT_,
1421 npy_api::NPY_CDOUBLE_,
1422 npy_api::NPY_CLONGDOUBLE_};
1423
1424 public:
1425 static constexpr int value = values[detail::is_fmt_numeric<T>::index];
1426
1427 static pybind11::dtype dtype() { return pybind11::dtype( value); }
1428 };
1429
1430 template <typename T>
1431 struct npy_format_descriptor<T, enable_if_t<is_same_ignoring_cvref<T, PyObject *>::value>> {
1432 static constexpr auto name = const_name("object");
1433
1434 static constexpr int value = npy_api::NPY_OBJECT_;
1435
1436 static pybind11::dtype dtype() { return pybind11::dtype( value); }
1437 };
1438
1439 #define PYBIND11_DECL_CHAR_FMT \
1440 static constexpr auto name = const_name("S") + const_name<N>(); \
1441 static pybind11::dtype dtype() { \
1442 return pybind11::dtype(std::string("S") + std::to_string(N)); \
1443 }
1444 template <size_t N>
1445 struct npy_format_descriptor<char[N]> {
1446 PYBIND11_DECL_CHAR_FMT
1447 };
1448 template <size_t N>
1449 struct npy_format_descriptor<std::array<char, N>> {
1450 PYBIND11_DECL_CHAR_FMT
1451 };
1452 #undef PYBIND11_DECL_CHAR_FMT
1453
1454 template <typename T>
1455 struct npy_format_descriptor<T, enable_if_t<array_info<T>::is_array>> {
1456 private:
1457 using base_descr = npy_format_descriptor<typename array_info<T>::type>;
1458
1459 public:
1460 static_assert(!array_info<T>::is_empty, "Zero-sized arrays are not supported");
1461
1462 static constexpr auto name
1463 = const_name("(") + array_info<T>::extents + const_name(")") + base_descr::name;
1464 static pybind11::dtype dtype() {
1465 list shape;
1466 array_info<T>::append_extents(shape);
1467 return pybind11::dtype::from_args(
1468 pybind11::make_tuple(base_descr::dtype(), std::move(shape)));
1469 }
1470 };
1471
1472 template <typename T>
1473 struct npy_format_descriptor<T, enable_if_t<std::is_enum<T>::value>> {
1474 private:
1475 using base_descr = npy_format_descriptor<typename std::underlying_type<T>::type>;
1476
1477 public:
1478 static constexpr auto name = base_descr::name;
1479 static pybind11::dtype dtype() { return base_descr::dtype(); }
1480 };
1481
1482 struct field_descriptor {
1483 const char *name;
1484 ssize_t offset;
1485 ssize_t size;
1486 std::string format;
1487 dtype descr;
1488 };
1489
1490 PYBIND11_NOINLINE void register_structured_dtype(any_container<field_descriptor> fields,
1491 const std::type_info &tinfo,
1492 ssize_t itemsize,
1493 bool (*direct_converter)(PyObject *, void *&)) {
1494
1495 auto &numpy_internals = get_numpy_internals();
1496 if (numpy_internals.get_type_info(tinfo, false)) {
1497 pybind11_fail("NumPy: dtype is already registered");
1498 }
1499
1500
1501
1502 std::vector<field_descriptor> ordered_fields(std::move(fields));
1503 std::sort(
1504 ordered_fields.begin(),
1505 ordered_fields.end(),
1506 [](const field_descriptor &a, const field_descriptor &b) { return a.offset < b.offset; });
1507
1508 list names, formats, offsets;
1509 for (auto &field : ordered_fields) {
1510 if (!field.descr) {
1511 pybind11_fail(std::string("NumPy: unsupported field dtype: `") + field.name + "` @ "
1512 + tinfo.name());
1513 }
1514 names.append(pybind11::str(field.name));
1515 formats.append(field.descr);
1516 offsets.append(pybind11::int_(field.offset));
1517 }
1518 auto *dtype_ptr
1519 = pybind11::dtype(std::move(names), std::move(formats), std::move(offsets), itemsize)
1520 .release()
1521 .ptr();
1522
1523
1524
1525
1526
1527
1528
1529
1530 ssize_t offset = 0;
1531 std::ostringstream oss;
1532
1533
1534
1535
1536
1537 oss << "^T{";
1538 for (auto &field : ordered_fields) {
1539 if (field.offset > offset) {
1540 oss << (field.offset - offset) << 'x';
1541 }
1542 oss << field.format << ':' << field.name << ':';
1543 offset = field.offset + field.size;
1544 }
1545 if (itemsize > offset) {
1546 oss << (itemsize - offset) << 'x';
1547 }
1548 oss << '}';
1549 auto format_str = oss.str();
1550
1551
1552 auto &api = npy_api::get();
1553 auto arr = array(buffer_info(nullptr, itemsize, format_str, 1));
1554 if (!api.PyArray_EquivTypes_(dtype_ptr, arr.dtype().ptr())) {
1555 pybind11_fail("NumPy: invalid buffer descriptor!");
1556 }
1557
1558 auto tindex = std::type_index(tinfo);
1559 numpy_internals.registered_dtypes[tindex] = {dtype_ptr, std::move(format_str)};
1560 with_internals([tindex, &direct_converter](internals &internals) {
1561 internals.direct_conversions[tindex].push_back(direct_converter);
1562 });
1563 }
1564
1565 template <typename T, typename SFINAE>
1566 struct npy_format_descriptor {
1567 static_assert(is_pod_struct<T>::value,
1568 "Attempt to use a non-POD or unimplemented POD type as a numpy dtype");
1569
1570 static constexpr auto name = make_caster<T>::name;
1571
1572 static pybind11::dtype dtype() { return reinterpret_borrow<pybind11::dtype>(dtype_ptr()); }
1573
1574 static std::string format() {
1575 static auto format_str = get_numpy_internals().get_type_info<T>(true)->format_str;
1576 return format_str;
1577 }
1578
1579 static void register_dtype(any_container<field_descriptor> fields) {
1580 register_structured_dtype(std::move(fields),
1581 typeid(typename std::remove_cv<T>::type),
1582 sizeof(T),
1583 &direct_converter);
1584 }
1585
1586 private:
1587 static PyObject *dtype_ptr() {
1588 static PyObject *ptr = get_numpy_internals().get_type_info<T>(true)->dtype_ptr;
1589 return ptr;
1590 }
1591
1592 static bool direct_converter(PyObject *obj, void *&value) {
1593 auto &api = npy_api::get();
1594 if (!PyObject_TypeCheck(obj, api.PyVoidArrType_Type_)) {
1595 return false;
1596 }
1597 if (auto descr = reinterpret_steal<object>(api.PyArray_DescrFromScalar_(obj))) {
1598 if (api.PyArray_EquivTypes_(dtype_ptr(), descr.ptr())) {
1599 value = ((PyVoidScalarObject_Proxy *) obj)->obval;
1600 return true;
1601 }
1602 }
1603 return false;
1604 }
1605 };
1606
1607 #ifdef __CLION_IDE__
1608 # define PYBIND11_NUMPY_DTYPE(Type, ...) ((void) 0)
1609 # define PYBIND11_NUMPY_DTYPE_EX(Type, ...) ((void) 0)
1610 #else
1611
1612 # define PYBIND11_FIELD_DESCRIPTOR_EX(T, Field, Name) \
1613 ::pybind11::detail::field_descriptor { \
1614 Name, offsetof(T, Field), sizeof(decltype(std::declval<T>().Field)), \
1615 ::pybind11::format_descriptor<decltype(std::declval<T>().Field)>::format(), \
1616 ::pybind11::detail::npy_format_descriptor< \
1617 decltype(std::declval<T>().Field)>::dtype() \
1618 }
1619
1620
1621 # define PYBIND11_FIELD_DESCRIPTOR(T, Field) PYBIND11_FIELD_DESCRIPTOR_EX(T, Field, #Field)
1622
1623
1624
1625 # define PYBIND11_EVAL0(...) __VA_ARGS__
1626 # define PYBIND11_EVAL1(...) PYBIND11_EVAL0(PYBIND11_EVAL0(PYBIND11_EVAL0(__VA_ARGS__)))
1627 # define PYBIND11_EVAL2(...) PYBIND11_EVAL1(PYBIND11_EVAL1(PYBIND11_EVAL1(__VA_ARGS__)))
1628 # define PYBIND11_EVAL3(...) PYBIND11_EVAL2(PYBIND11_EVAL2(PYBIND11_EVAL2(__VA_ARGS__)))
1629 # define PYBIND11_EVAL4(...) PYBIND11_EVAL3(PYBIND11_EVAL3(PYBIND11_EVAL3(__VA_ARGS__)))
1630 # define PYBIND11_EVAL(...) PYBIND11_EVAL4(PYBIND11_EVAL4(PYBIND11_EVAL4(__VA_ARGS__)))
1631 # define PYBIND11_MAP_END(...)
1632 # define PYBIND11_MAP_OUT
1633 # define PYBIND11_MAP_COMMA ,
1634 # define PYBIND11_MAP_GET_END() 0, PYBIND11_MAP_END
1635 # define PYBIND11_MAP_NEXT0(test, next, ...) next PYBIND11_MAP_OUT
1636 # define PYBIND11_MAP_NEXT1(test, next) PYBIND11_MAP_NEXT0(test, next, 0)
1637 # define PYBIND11_MAP_NEXT(test, next) PYBIND11_MAP_NEXT1(PYBIND11_MAP_GET_END test, next)
1638 # if defined(_MSC_VER) \
1639 && !defined(__clang__)
1640 # define PYBIND11_MAP_LIST_NEXT1(test, next) \
1641 PYBIND11_EVAL0(PYBIND11_MAP_NEXT0(test, PYBIND11_MAP_COMMA next, 0))
1642 # else
1643 # define PYBIND11_MAP_LIST_NEXT1(test, next) \
1644 PYBIND11_MAP_NEXT0(test, PYBIND11_MAP_COMMA next, 0)
1645 # endif
1646 # define PYBIND11_MAP_LIST_NEXT(test, next) \
1647 PYBIND11_MAP_LIST_NEXT1(PYBIND11_MAP_GET_END test, next)
1648 # define PYBIND11_MAP_LIST0(f, t, x, peek, ...) \
1649 f(t, x) PYBIND11_MAP_LIST_NEXT(peek, PYBIND11_MAP_LIST1)(f, t, peek, __VA_ARGS__)
1650 # define PYBIND11_MAP_LIST1(f, t, x, peek, ...) \
1651 f(t, x) PYBIND11_MAP_LIST_NEXT(peek, PYBIND11_MAP_LIST0)(f, t, peek, __VA_ARGS__)
1652
1653 # define PYBIND11_MAP_LIST(f, t, ...) \
1654 PYBIND11_EVAL(PYBIND11_MAP_LIST1(f, t, __VA_ARGS__, (), 0))
1655
1656 # define PYBIND11_NUMPY_DTYPE(Type, ...) \
1657 ::pybind11::detail::npy_format_descriptor<Type>::register_dtype( \
1658 ::std::vector<::pybind11::detail::field_descriptor>{ \
1659 PYBIND11_MAP_LIST(PYBIND11_FIELD_DESCRIPTOR, Type, __VA_ARGS__)})
1660
1661 # if defined(_MSC_VER) && !defined(__clang__)
1662 # define PYBIND11_MAP2_LIST_NEXT1(test, next) \
1663 PYBIND11_EVAL0(PYBIND11_MAP_NEXT0(test, PYBIND11_MAP_COMMA next, 0))
1664 # else
1665 # define PYBIND11_MAP2_LIST_NEXT1(test, next) \
1666 PYBIND11_MAP_NEXT0(test, PYBIND11_MAP_COMMA next, 0)
1667 # endif
1668 # define PYBIND11_MAP2_LIST_NEXT(test, next) \
1669 PYBIND11_MAP2_LIST_NEXT1(PYBIND11_MAP_GET_END test, next)
1670 # define PYBIND11_MAP2_LIST0(f, t, x1, x2, peek, ...) \
1671 f(t, x1, x2) PYBIND11_MAP2_LIST_NEXT(peek, PYBIND11_MAP2_LIST1)(f, t, peek, __VA_ARGS__)
1672 # define PYBIND11_MAP2_LIST1(f, t, x1, x2, peek, ...) \
1673 f(t, x1, x2) PYBIND11_MAP2_LIST_NEXT(peek, PYBIND11_MAP2_LIST0)(f, t, peek, __VA_ARGS__)
1674
1675 # define PYBIND11_MAP2_LIST(f, t, ...) \
1676 PYBIND11_EVAL(PYBIND11_MAP2_LIST1(f, t, __VA_ARGS__, (), 0))
1677
1678 # define PYBIND11_NUMPY_DTYPE_EX(Type, ...) \
1679 ::pybind11::detail::npy_format_descriptor<Type>::register_dtype( \
1680 ::std::vector<::pybind11::detail::field_descriptor>{ \
1681 PYBIND11_MAP2_LIST(PYBIND11_FIELD_DESCRIPTOR_EX, Type, __VA_ARGS__)})
1682
1683 #endif
1684
1685 class common_iterator {
1686 public:
1687 using container_type = std::vector<ssize_t>;
1688 using value_type = container_type::value_type;
1689 using size_type = container_type::size_type;
1690
1691 common_iterator() : m_strides() {}
1692
1693 common_iterator(void *ptr, const container_type &strides, const container_type &shape)
1694 : p_ptr(reinterpret_cast<char *>(ptr)), m_strides(strides.size()) {
1695 m_strides.back() = static_cast<value_type>(strides.back());
1696 for (size_type i = m_strides.size() - 1; i != 0; --i) {
1697 size_type j = i - 1;
1698 auto s = static_cast<value_type>(shape[i]);
1699 m_strides[j] = strides[j] + m_strides[i] - strides[i] * s;
1700 }
1701 }
1702
1703 void increment(size_type dim) { p_ptr += m_strides[dim]; }
1704
1705 void *data() const { return p_ptr; }
1706
1707 private:
1708 char *p_ptr{nullptr};
1709 container_type m_strides;
1710 };
1711
1712 template <size_t N>
1713 class multi_array_iterator {
1714 public:
1715 using container_type = std::vector<ssize_t>;
1716
1717 multi_array_iterator(const std::array<buffer_info, N> &buffers, const container_type &shape)
1718 : m_shape(shape.size()), m_index(shape.size(), 0), m_common_iterator() {
1719
1720
1721 for (size_t i = 0; i < shape.size(); ++i) {
1722 m_shape[i] = shape[i];
1723 }
1724
1725 container_type strides(shape.size());
1726 for (size_t i = 0; i < N; ++i) {
1727 init_common_iterator(buffers[i], shape, m_common_iterator[i], strides);
1728 }
1729 }
1730
1731 multi_array_iterator &operator++() {
1732 for (size_t j = m_index.size(); j != 0; --j) {
1733 size_t i = j - 1;
1734 if (++m_index[i] != m_shape[i]) {
1735 increment_common_iterator(i);
1736 break;
1737 }
1738 m_index[i] = 0;
1739 }
1740 return *this;
1741 }
1742
1743 template <size_t K, class T = void>
1744 T *data() const {
1745 return reinterpret_cast<T *>(m_common_iterator[K].data());
1746 }
1747
1748 private:
1749 using common_iter = common_iterator;
1750
1751 void init_common_iterator(const buffer_info &buffer,
1752 const container_type &shape,
1753 common_iter &iterator,
1754 container_type &strides) {
1755 auto buffer_shape_iter = buffer.shape.rbegin();
1756 auto buffer_strides_iter = buffer.strides.rbegin();
1757 auto shape_iter = shape.rbegin();
1758 auto strides_iter = strides.rbegin();
1759
1760 while (buffer_shape_iter != buffer.shape.rend()) {
1761 if (*shape_iter == *buffer_shape_iter) {
1762 *strides_iter = *buffer_strides_iter;
1763 } else {
1764 *strides_iter = 0;
1765 }
1766
1767 ++buffer_shape_iter;
1768 ++buffer_strides_iter;
1769 ++shape_iter;
1770 ++strides_iter;
1771 }
1772
1773 std::fill(strides_iter, strides.rend(), 0);
1774 iterator = common_iter(buffer.ptr, strides, shape);
1775 }
1776
1777 void increment_common_iterator(size_t dim) {
1778 for (auto &iter : m_common_iterator) {
1779 iter.increment(dim);
1780 }
1781 }
1782
1783 container_type m_shape;
1784 container_type m_index;
1785 std::array<common_iter, N> m_common_iterator;
1786 };
1787
1788 enum class broadcast_trivial { non_trivial, c_trivial, f_trivial };
1789
1790
1791
1792
1793
1794 template <size_t N>
1795 broadcast_trivial
1796 broadcast(const std::array<buffer_info, N> &buffers, ssize_t &ndim, std::vector<ssize_t> &shape) {
1797 ndim = std::accumulate(
1798 buffers.begin(), buffers.end(), ssize_t(0), [](ssize_t res, const buffer_info &buf) {
1799 return std::max(res, buf.ndim);
1800 });
1801
1802 shape.clear();
1803 shape.resize((size_t) ndim, 1);
1804
1805
1806
1807 for (size_t i = 0; i < N; ++i) {
1808 auto res_iter = shape.rbegin();
1809 auto end = buffers[i].shape.rend();
1810 for (auto shape_iter = buffers[i].shape.rbegin(); shape_iter != end;
1811 ++shape_iter, ++res_iter) {
1812 const auto &dim_size_in = *shape_iter;
1813 auto &dim_size_out = *res_iter;
1814
1815
1816
1817 if (dim_size_out == 1) {
1818 dim_size_out = dim_size_in;
1819 } else if (dim_size_in != 1 && dim_size_in != dim_size_out) {
1820 pybind11_fail("pybind11::vectorize: incompatible size/dimension of inputs!");
1821 }
1822 }
1823 }
1824
1825 bool trivial_broadcast_c = true;
1826 bool trivial_broadcast_f = true;
1827 for (size_t i = 0; i < N && (trivial_broadcast_c || trivial_broadcast_f); ++i) {
1828 if (buffers[i].size == 1) {
1829 continue;
1830 }
1831
1832
1833 if (buffers[i].ndim != ndim) {
1834 return broadcast_trivial::non_trivial;
1835 }
1836
1837
1838 if (!std::equal(buffers[i].shape.cbegin(), buffers[i].shape.cend(), shape.cbegin())) {
1839 return broadcast_trivial::non_trivial;
1840 }
1841
1842
1843 if (trivial_broadcast_c) {
1844 ssize_t expect_stride = buffers[i].itemsize;
1845 auto end = buffers[i].shape.crend();
1846 for (auto shape_iter = buffers[i].shape.crbegin(),
1847 stride_iter = buffers[i].strides.crbegin();
1848 trivial_broadcast_c && shape_iter != end;
1849 ++shape_iter, ++stride_iter) {
1850 if (expect_stride == *stride_iter) {
1851 expect_stride *= *shape_iter;
1852 } else {
1853 trivial_broadcast_c = false;
1854 }
1855 }
1856 }
1857
1858
1859 if (trivial_broadcast_f) {
1860 ssize_t expect_stride = buffers[i].itemsize;
1861 auto end = buffers[i].shape.cend();
1862 for (auto shape_iter = buffers[i].shape.cbegin(),
1863 stride_iter = buffers[i].strides.cbegin();
1864 trivial_broadcast_f && shape_iter != end;
1865 ++shape_iter, ++stride_iter) {
1866 if (expect_stride == *stride_iter) {
1867 expect_stride *= *shape_iter;
1868 } else {
1869 trivial_broadcast_f = false;
1870 }
1871 }
1872 }
1873 }
1874
1875 return trivial_broadcast_c ? broadcast_trivial::c_trivial
1876 : trivial_broadcast_f ? broadcast_trivial::f_trivial
1877 : broadcast_trivial::non_trivial;
1878 }
1879
1880 template <typename T>
1881 struct vectorize_arg {
1882 static_assert(!std::is_rvalue_reference<T>::value,
1883 "Functions with rvalue reference arguments cannot be vectorized");
1884
1885 using call_type = remove_reference_t<T>;
1886
1887 static constexpr bool vectorize
1888 = satisfies_any_of<call_type, std::is_arithmetic, is_complex, is_pod>::value
1889 && satisfies_none_of<call_type,
1890 std::is_pointer,
1891 std::is_array,
1892 is_std_array,
1893 std::is_enum>::value
1894 && (!std::is_reference<T>::value
1895 || (std::is_lvalue_reference<T>::value && std::is_const<call_type>::value));
1896
1897 using type = conditional_t<vectorize, array_t<remove_cv_t<call_type>, array::forcecast>, T>;
1898 };
1899
1900
1901 template <typename Func, typename Return, typename... Args>
1902 struct vectorize_returned_array {
1903 using Type = array_t<Return>;
1904
1905 static Type create(broadcast_trivial trivial, const std::vector<ssize_t> &shape) {
1906 if (trivial == broadcast_trivial::f_trivial) {
1907 return array_t<Return, array::f_style>(shape);
1908 }
1909 return array_t<Return>(shape);
1910 }
1911
1912 static Return *mutable_data(Type &array) { return array.mutable_data(); }
1913
1914 static Return call(Func &f, Args &...args) { return f(args...); }
1915
1916 static void call(Return *out, size_t i, Func &f, Args &...args) { out[i] = f(args...); }
1917 };
1918
1919
1920 template <typename Func, typename... Args>
1921 struct vectorize_returned_array<Func, void, Args...> {
1922 using Type = none;
1923
1924 static Type create(broadcast_trivial, const std::vector<ssize_t> &) { return none(); }
1925
1926 static void *mutable_data(Type &) { return nullptr; }
1927
1928 static detail::void_type call(Func &f, Args &...args) {
1929 f(args...);
1930 return {};
1931 }
1932
1933 static void call(void *, size_t, Func &f, Args &...args) { f(args...); }
1934 };
1935
1936 template <typename Func, typename Return, typename... Args>
1937 struct vectorize_helper {
1938
1939
1940 #ifdef __CUDACC__
1941 public:
1942 #else
1943 private:
1944 #endif
1945
1946 static constexpr size_t N = sizeof...(Args);
1947 static constexpr size_t NVectorized = constexpr_sum(vectorize_arg<Args>::vectorize...);
1948 static_assert(
1949 NVectorized >= 1,
1950 "pybind11::vectorize(...) requires a function with at least one vectorizable argument");
1951
1952 public:
1953 template <typename T,
1954
1955 typename = detail::enable_if_t<
1956 !std::is_same<vectorize_helper, typename std::decay<T>::type>::value>>
1957 explicit vectorize_helper(T &&f) : f(std::forward<T>(f)) {}
1958
1959 object operator()(typename vectorize_arg<Args>::type... args) {
1960 return run(args...,
1961 make_index_sequence<N>(),
1962 select_indices<vectorize_arg<Args>::vectorize...>(),
1963 make_index_sequence<NVectorized>());
1964 }
1965
1966 private:
1967 remove_reference_t<Func> f;
1968
1969
1970
1971 using arg_call_types = std::tuple<typename vectorize_arg<Args>::call_type...>;
1972 template <size_t Index>
1973 using param_n_t = typename std::tuple_element<Index, arg_call_types>::type;
1974
1975 using returned_array = vectorize_returned_array<Func, Return, Args...>;
1976
1977
1978
1979
1980
1981
1982
1983
1984 template <size_t... Index, size_t... VIndex, size_t... BIndex>
1985 object run(typename vectorize_arg<Args>::type &...args,
1986 index_sequence<Index...> i_seq,
1987 index_sequence<VIndex...> vi_seq,
1988 index_sequence<BIndex...> bi_seq) {
1989
1990
1991
1992
1993 std::array<void *, N> params{{reinterpret_cast<void *>(&args)...}};
1994
1995
1996 std::array<buffer_info, NVectorized> buffers{
1997 {reinterpret_cast<array *>(params[VIndex])->request()...}};
1998
1999
2000 ssize_t nd = 0;
2001 std::vector<ssize_t> shape(0);
2002 auto trivial = broadcast(buffers, nd, shape);
2003 auto ndim = (size_t) nd;
2004
2005 size_t size
2006 = std::accumulate(shape.begin(), shape.end(), (size_t) 1, std::multiplies<size_t>());
2007
2008
2009
2010 if (size == 1 && ndim == 0) {
2011 PYBIND11_EXPAND_SIDE_EFFECTS(params[VIndex] = buffers[BIndex].ptr);
2012 return cast(
2013 returned_array::call(f, *reinterpret_cast<param_n_t<Index> *>(params[Index])...));
2014 }
2015
2016 auto result = returned_array::create(trivial, shape);
2017
2018 PYBIND11_WARNING_PUSH
2019 #ifdef PYBIND11_DETECTED_CLANG_WITH_MISLEADING_CALL_STD_MOVE_EXPLICITLY_WARNING
2020 PYBIND11_WARNING_DISABLE_CLANG("-Wreturn-std-move")
2021 #endif
2022
2023 if (size == 0) {
2024 return result;
2025 }
2026
2027
2028 auto *mutable_data = returned_array::mutable_data(result);
2029 if (trivial == broadcast_trivial::non_trivial) {
2030 apply_broadcast(buffers, params, mutable_data, size, shape, i_seq, vi_seq, bi_seq);
2031 } else {
2032 apply_trivial(buffers, params, mutable_data, size, i_seq, vi_seq, bi_seq);
2033 }
2034
2035 return result;
2036 PYBIND11_WARNING_POP
2037 }
2038
2039 template <size_t... Index, size_t... VIndex, size_t... BIndex>
2040 void apply_trivial(std::array<buffer_info, NVectorized> &buffers,
2041 std::array<void *, N> ¶ms,
2042 Return *out,
2043 size_t size,
2044 index_sequence<Index...>,
2045 index_sequence<VIndex...>,
2046 index_sequence<BIndex...>) {
2047
2048
2049
2050
2051 std::array<std::pair<unsigned char *&, const size_t>, NVectorized> vecparams{
2052 {std::pair<unsigned char *&, const size_t>(
2053 reinterpret_cast<unsigned char *&>(params[VIndex] = buffers[BIndex].ptr),
2054 buffers[BIndex].size == 1 ? 0 : sizeof(param_n_t<VIndex>))...}};
2055
2056 for (size_t i = 0; i < size; ++i) {
2057 returned_array::call(
2058 out, i, f, *reinterpret_cast<param_n_t<Index> *>(params[Index])...);
2059 for (auto &x : vecparams) {
2060 x.first += x.second;
2061 }
2062 }
2063 }
2064
2065 template <size_t... Index, size_t... VIndex, size_t... BIndex>
2066 void apply_broadcast(std::array<buffer_info, NVectorized> &buffers,
2067 std::array<void *, N> ¶ms,
2068 Return *out,
2069 size_t size,
2070 const std::vector<ssize_t> &output_shape,
2071 index_sequence<Index...>,
2072 index_sequence<VIndex...>,
2073 index_sequence<BIndex...>) {
2074
2075 multi_array_iterator<NVectorized> input_iter(buffers, output_shape);
2076
2077 for (size_t i = 0; i < size; ++i, ++input_iter) {
2078 PYBIND11_EXPAND_SIDE_EFFECTS((params[VIndex] = input_iter.template data<BIndex>()));
2079 returned_array::call(
2080 out, i, f, *reinterpret_cast<param_n_t<Index> *>(std::get<Index>(params))...);
2081 }
2082 }
2083 };
2084
2085 template <typename Func, typename Return, typename... Args>
2086 vectorize_helper<Func, Return, Args...> vectorize_extractor(const Func &f, Return (*)(Args...)) {
2087 return detail::vectorize_helper<Func, Return, Args...>(f);
2088 }
2089
2090 template <typename T, int Flags>
2091 struct handle_type_name<array_t<T, Flags>> {
2092 static constexpr auto name
2093 = const_name("numpy.ndarray[") + npy_format_descriptor<T>::name + const_name("]");
2094 };
2095
2096 PYBIND11_NAMESPACE_END(detail)
2097
2098
2099 template <typename Return, typename... Args>
2100 detail::vectorize_helper<Return (*)(Args...), Return, Args...> vectorize(Return (*f)(Args...)) {
2101 return detail::vectorize_helper<Return (*)(Args...), Return, Args...>(f);
2102 }
2103
2104
2105 template <typename Func, detail::enable_if_t<detail::is_lambda<Func>::value, int> = 0>
2106 auto vectorize(Func &&f)
2107 -> decltype(detail::vectorize_extractor(std::forward<Func>(f),
2108 (detail::function_signature_t<Func> *) nullptr)) {
2109 return detail::vectorize_extractor(std::forward<Func>(f),
2110 (detail::function_signature_t<Func> *) nullptr);
2111 }
2112
2113
2114 template <typename Return,
2115 typename Class,
2116 typename... Args,
2117 typename Helper = detail::vectorize_helper<
2118 decltype(std::mem_fn(std::declval<Return (Class::*)(Args...)>())),
2119 Return,
2120 Class *,
2121 Args...>>
2122 Helper vectorize(Return (Class::*f)(Args...)) {
2123 return Helper(std::mem_fn(f));
2124 }
2125
2126
2127 template <typename Return,
2128 typename Class,
2129 typename... Args,
2130 typename Helper = detail::vectorize_helper<
2131 decltype(std::mem_fn(std::declval<Return (Class::*)(Args...) const>())),
2132 Return,
2133 const Class *,
2134 Args...>>
2135 Helper vectorize(Return (Class::*f)(Args...) const) {
2136 return Helper(std::mem_fn(f));
2137 }
2138
2139 PYBIND11_NAMESPACE_END(PYBIND11_NAMESPACE)