File indexing completed on 2025-01-18 10:06:13
0001
0002
0003
0004
0005
0006
0007
0008 #pragma once
0009
0010 #include <pybind11/numpy.h>
0011
0012 #include "common.h"
0013
0014 #if defined(__GNUC__) && !defined(__clang__) && !defined(__INTEL_COMPILER)
0015 static_assert(__GNUC__ > 5, "Eigen Tensor support in pybind11 requires GCC > 5.0");
0016 #endif
0017
0018
0019 PYBIND11_WARNING_PUSH
0020 PYBIND11_WARNING_DISABLE_MSVC(4554)
0021 PYBIND11_WARNING_DISABLE_MSVC(4127)
0022 #if defined(__MINGW32__)
0023 PYBIND11_WARNING_DISABLE_GCC("-Wmaybe-uninitialized")
0024 #endif
0025
0026 #include <unsupported/Eigen/CXX11/Tensor>
0027
0028 PYBIND11_WARNING_POP
0029
0030 static_assert(EIGEN_VERSION_AT_LEAST(3, 3, 0),
0031 "Eigen Tensor support in pybind11 requires Eigen >= 3.3.0");
0032
0033 PYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
0034
0035 PYBIND11_WARNING_DISABLE_MSVC(4127)
0036
0037 PYBIND11_NAMESPACE_BEGIN(detail)
0038
0039 inline bool is_tensor_aligned(const void *data) {
0040 return (reinterpret_cast<std::size_t>(data) % EIGEN_DEFAULT_ALIGN_BYTES) == 0;
0041 }
0042
0043 template <typename T>
0044 constexpr int compute_array_flag_from_tensor() {
0045 static_assert((static_cast<int>(T::Layout) == static_cast<int>(Eigen::RowMajor))
0046 || (static_cast<int>(T::Layout) == static_cast<int>(Eigen::ColMajor)),
0047 "Layout must be row or column major");
0048 return (static_cast<int>(T::Layout) == static_cast<int>(Eigen::RowMajor)) ? array::c_style
0049 : array::f_style;
0050 }
0051
0052 template <typename T>
0053 struct eigen_tensor_helper {};
0054
0055 template <typename Scalar_, int NumIndices_, int Options_, typename IndexType>
0056 struct eigen_tensor_helper<Eigen::Tensor<Scalar_, NumIndices_, Options_, IndexType>> {
0057 using Type = Eigen::Tensor<Scalar_, NumIndices_, Options_, IndexType>;
0058 using ValidType = void;
0059
0060 static Eigen::DSizes<typename Type::Index, Type::NumIndices> get_shape(const Type &f) {
0061 return f.dimensions();
0062 }
0063
0064 static constexpr bool
0065 is_correct_shape(const Eigen::DSizes<typename Type::Index, Type::NumIndices> & ) {
0066 return true;
0067 }
0068
0069 template <typename T>
0070 struct helper {};
0071
0072 template <size_t... Is>
0073 struct helper<index_sequence<Is...>> {
0074 static constexpr auto value = ::pybind11::detail::concat(const_name(((void) Is, "?"))...);
0075 };
0076
0077 static constexpr auto dimensions_descriptor
0078 = helper<decltype(make_index_sequence<Type::NumIndices>())>::value;
0079
0080 template <typename... Args>
0081 static Type *alloc(Args &&...args) {
0082 return new Type(std::forward<Args>(args)...);
0083 }
0084
0085 static void free(Type *tensor) { delete tensor; }
0086 };
0087
0088 template <typename Scalar_, typename std::ptrdiff_t... Indices, int Options_, typename IndexType>
0089 struct eigen_tensor_helper<
0090 Eigen::TensorFixedSize<Scalar_, Eigen::Sizes<Indices...>, Options_, IndexType>> {
0091 using Type = Eigen::TensorFixedSize<Scalar_, Eigen::Sizes<Indices...>, Options_, IndexType>;
0092 using ValidType = void;
0093
0094 static constexpr Eigen::DSizes<typename Type::Index, Type::NumIndices>
0095 get_shape(const Type & ) {
0096 return get_shape();
0097 }
0098
0099 static constexpr Eigen::DSizes<typename Type::Index, Type::NumIndices> get_shape() {
0100 return Eigen::DSizes<typename Type::Index, Type::NumIndices>(Indices...);
0101 }
0102
0103 static bool
0104 is_correct_shape(const Eigen::DSizes<typename Type::Index, Type::NumIndices> &shape) {
0105 return get_shape() == shape;
0106 }
0107
0108 static constexpr auto dimensions_descriptor
0109 = ::pybind11::detail::concat(const_name<Indices>()...);
0110
0111 template <typename... Args>
0112 static Type *alloc(Args &&...args) {
0113 Eigen::aligned_allocator<Type> allocator;
0114 return ::new (allocator.allocate(1)) Type(std::forward<Args>(args)...);
0115 }
0116
0117 static void free(Type *tensor) {
0118 Eigen::aligned_allocator<Type> allocator;
0119 tensor->~Type();
0120 allocator.deallocate(tensor, 1);
0121 }
0122 };
0123
0124 template <typename Type, bool ShowDetails, bool NeedsWriteable = false>
0125 struct get_tensor_descriptor {
0126 static constexpr auto details
0127 = const_name<NeedsWriteable>(", flags.writeable", "")
0128 + const_name<static_cast<int>(Type::Layout) == static_cast<int>(Eigen::RowMajor)>(
0129 ", flags.c_contiguous", ", flags.f_contiguous");
0130 static constexpr auto value
0131 = const_name("numpy.ndarray[") + npy_format_descriptor<typename Type::Scalar>::name
0132 + const_name("[") + eigen_tensor_helper<remove_cv_t<Type>>::dimensions_descriptor
0133 + const_name("]") + const_name<ShowDetails>(details, const_name("")) + const_name("]");
0134 };
0135
0136
0137
0138
0139
0140
0141 PYBIND11_WARNING_PUSH
0142 PYBIND11_WARNING_DISABLE_GCC("-Wtype-limits")
0143
0144 template <typename T, int size>
0145 std::vector<T> convert_dsizes_to_vector(const Eigen::DSizes<T, size> &arr) {
0146 std::vector<T> result(size);
0147
0148 for (size_t i = 0; i < size; i++) {
0149 result[i] = arr[i];
0150 }
0151
0152 return result;
0153 }
0154
0155 template <typename T, int size>
0156 Eigen::DSizes<T, size> get_shape_for_array(const array &arr) {
0157 Eigen::DSizes<T, size> result;
0158 const T *shape = arr.shape();
0159 for (size_t i = 0; i < size; i++) {
0160 result[i] = shape[i];
0161 }
0162
0163 return result;
0164 }
0165
0166 PYBIND11_WARNING_POP
0167
0168 template <typename Type>
0169 struct type_caster<Type, typename eigen_tensor_helper<Type>::ValidType> {
0170 static_assert(!std::is_pointer<typename Type::Scalar>::value,
0171 PYBIND11_EIGEN_MESSAGE_POINTER_TYPES_ARE_NOT_SUPPORTED);
0172 using Helper = eigen_tensor_helper<Type>;
0173 static constexpr auto temp_name = get_tensor_descriptor<Type, false>::value;
0174 PYBIND11_TYPE_CASTER(Type, temp_name);
0175
0176 bool load(handle src, bool convert) {
0177 if (!convert) {
0178 if (!isinstance<array>(src)) {
0179 return false;
0180 }
0181 array temp = array::ensure(src);
0182 if (!temp) {
0183 return false;
0184 }
0185
0186 if (!temp.dtype().is(dtype::of<typename Type::Scalar>())) {
0187 return false;
0188 }
0189 }
0190
0191 array_t<typename Type::Scalar, compute_array_flag_from_tensor<Type>()> arr(
0192 reinterpret_borrow<object>(src));
0193
0194 if (arr.ndim() != Type::NumIndices) {
0195 return false;
0196 }
0197 auto shape = get_shape_for_array<typename Type::Index, Type::NumIndices>(arr);
0198
0199 if (!Helper::is_correct_shape(shape)) {
0200 return false;
0201 }
0202
0203 #if EIGEN_VERSION_AT_LEAST(3, 4, 0)
0204 auto data_pointer = arr.data();
0205 #else
0206
0207 auto data_pointer = const_cast<typename Type::Scalar *>(arr.data());
0208 #endif
0209
0210 if (is_tensor_aligned(arr.data())) {
0211 value = Eigen::TensorMap<const Type, Eigen::Aligned>(data_pointer, shape);
0212 } else {
0213 value = Eigen::TensorMap<const Type>(data_pointer, shape);
0214 }
0215
0216 return true;
0217 }
0218
0219 static handle cast(Type &&src, return_value_policy policy, handle parent) {
0220 if (policy == return_value_policy::reference
0221 || policy == return_value_policy::reference_internal) {
0222 pybind11_fail("Cannot use a reference return value policy for an rvalue");
0223 }
0224 return cast_impl(&src, return_value_policy::move, parent);
0225 }
0226
0227 static handle cast(const Type &&src, return_value_policy policy, handle parent) {
0228 if (policy == return_value_policy::reference
0229 || policy == return_value_policy::reference_internal) {
0230 pybind11_fail("Cannot use a reference return value policy for an rvalue");
0231 }
0232 return cast_impl(&src, return_value_policy::move, parent);
0233 }
0234
0235 static handle cast(Type &src, return_value_policy policy, handle parent) {
0236 if (policy == return_value_policy::automatic
0237 || policy == return_value_policy::automatic_reference) {
0238 policy = return_value_policy::copy;
0239 }
0240 return cast_impl(&src, policy, parent);
0241 }
0242
0243 static handle cast(const Type &src, return_value_policy policy, handle parent) {
0244 if (policy == return_value_policy::automatic
0245 || policy == return_value_policy::automatic_reference) {
0246 policy = return_value_policy::copy;
0247 }
0248 return cast(&src, policy, parent);
0249 }
0250
0251 static handle cast(Type *src, return_value_policy policy, handle parent) {
0252 if (policy == return_value_policy::automatic) {
0253 policy = return_value_policy::take_ownership;
0254 } else if (policy == return_value_policy::automatic_reference) {
0255 policy = return_value_policy::reference;
0256 }
0257 return cast_impl(src, policy, parent);
0258 }
0259
0260 static handle cast(const Type *src, return_value_policy policy, handle parent) {
0261 if (policy == return_value_policy::automatic) {
0262 policy = return_value_policy::take_ownership;
0263 } else if (policy == return_value_policy::automatic_reference) {
0264 policy = return_value_policy::reference;
0265 }
0266 return cast_impl(src, policy, parent);
0267 }
0268
0269 template <typename C>
0270 static handle cast_impl(C *src, return_value_policy policy, handle parent) {
0271 object parent_object;
0272 bool writeable = false;
0273 switch (policy) {
0274 case return_value_policy::move:
0275 if (std::is_const<C>::value) {
0276 pybind11_fail("Cannot move from a constant reference");
0277 }
0278
0279 src = Helper::alloc(std::move(*src));
0280
0281 parent_object
0282 = capsule(src, [](void *ptr) { Helper::free(reinterpret_cast<Type *>(ptr)); });
0283 writeable = true;
0284 break;
0285
0286 case return_value_policy::take_ownership:
0287 if (std::is_const<C>::value) {
0288
0289
0290 Helper::free(const_cast<Type *>(src));
0291 pybind11_fail("Cannot take ownership of a const reference");
0292 }
0293
0294 parent_object
0295 = capsule(src, [](void *ptr) { Helper::free(reinterpret_cast<Type *>(ptr)); });
0296 writeable = true;
0297 break;
0298
0299 case return_value_policy::copy:
0300 writeable = true;
0301 break;
0302
0303 case return_value_policy::reference:
0304 parent_object = none();
0305 writeable = !std::is_const<C>::value;
0306 break;
0307
0308 case return_value_policy::reference_internal:
0309
0310 if (!parent) {
0311 pybind11_fail("Cannot use reference internal when there is no parent");
0312 }
0313 parent_object = reinterpret_borrow<object>(parent);
0314 writeable = !std::is_const<C>::value;
0315 break;
0316
0317 default:
0318 pybind11_fail("pybind11 bug in eigen.h, please file a bug report");
0319 }
0320
0321 auto result = array_t<typename Type::Scalar, compute_array_flag_from_tensor<Type>()>(
0322 convert_dsizes_to_vector(Helper::get_shape(*src)), src->data(), parent_object);
0323
0324 if (!writeable) {
0325 array_proxy(result.ptr())->flags &= ~detail::npy_api::NPY_ARRAY_WRITEABLE_;
0326 }
0327
0328 return result.release();
0329 }
0330 };
0331
0332 template <typename StoragePointerType,
0333 bool needs_writeable,
0334 enable_if_t<!needs_writeable, bool> = true>
0335 StoragePointerType get_array_data_for_type(array &arr) {
0336 #if EIGEN_VERSION_AT_LEAST(3, 4, 0)
0337 return reinterpret_cast<StoragePointerType>(arr.data());
0338 #else
0339
0340 return reinterpret_cast<StoragePointerType>(const_cast<void *>(arr.data()));
0341 #endif
0342 }
0343
0344 template <typename StoragePointerType,
0345 bool needs_writeable,
0346 enable_if_t<needs_writeable, bool> = true>
0347 StoragePointerType get_array_data_for_type(array &arr) {
0348 return reinterpret_cast<StoragePointerType>(arr.mutable_data());
0349 }
0350
0351 template <typename T, typename = void>
0352 struct get_storage_pointer_type;
0353
0354 template <typename MapType>
0355 struct get_storage_pointer_type<MapType, void_t<typename MapType::StoragePointerType>> {
0356 using SPT = typename MapType::StoragePointerType;
0357 };
0358
0359 template <typename MapType>
0360 struct get_storage_pointer_type<MapType, void_t<typename MapType::PointerArgType>> {
0361 using SPT = typename MapType::PointerArgType;
0362 };
0363
0364 template <typename Type, int Options>
0365 struct type_caster<Eigen::TensorMap<Type, Options>,
0366 typename eigen_tensor_helper<remove_cv_t<Type>>::ValidType> {
0367 static_assert(!std::is_pointer<typename Type::Scalar>::value,
0368 PYBIND11_EIGEN_MESSAGE_POINTER_TYPES_ARE_NOT_SUPPORTED);
0369 using MapType = Eigen::TensorMap<Type, Options>;
0370 using Helper = eigen_tensor_helper<remove_cv_t<Type>>;
0371
0372 bool load(handle src, bool ) {
0373
0374 if (!isinstance<array>(src)) {
0375 return false;
0376 }
0377 auto arr = reinterpret_borrow<array>(src);
0378 if ((arr.flags() & compute_array_flag_from_tensor<Type>()) == 0) {
0379 return false;
0380 }
0381
0382 if (!arr.dtype().is(dtype::of<typename Type::Scalar>())) {
0383 return false;
0384 }
0385
0386 if (arr.ndim() != Type::NumIndices) {
0387 return false;
0388 }
0389
0390 constexpr bool is_aligned = (Options & Eigen::Aligned) != 0;
0391
0392 if (is_aligned && !is_tensor_aligned(arr.data())) {
0393 return false;
0394 }
0395
0396 auto shape = get_shape_for_array<typename Type::Index, Type::NumIndices>(arr);
0397
0398 if (!Helper::is_correct_shape(shape)) {
0399 return false;
0400 }
0401
0402 if (needs_writeable && !arr.writeable()) {
0403 return false;
0404 }
0405
0406 auto result = get_array_data_for_type<typename get_storage_pointer_type<MapType>::SPT,
0407 needs_writeable>(arr);
0408
0409 value.reset(new MapType(std::move(result), std::move(shape)));
0410
0411 return true;
0412 }
0413
0414 static handle cast(MapType &&src, return_value_policy policy, handle parent) {
0415 return cast_impl(&src, policy, parent);
0416 }
0417
0418 static handle cast(const MapType &&src, return_value_policy policy, handle parent) {
0419 return cast_impl(&src, policy, parent);
0420 }
0421
0422 static handle cast(MapType &src, return_value_policy policy, handle parent) {
0423 if (policy == return_value_policy::automatic
0424 || policy == return_value_policy::automatic_reference) {
0425 policy = return_value_policy::copy;
0426 }
0427 return cast_impl(&src, policy, parent);
0428 }
0429
0430 static handle cast(const MapType &src, return_value_policy policy, handle parent) {
0431 if (policy == return_value_policy::automatic
0432 || policy == return_value_policy::automatic_reference) {
0433 policy = return_value_policy::copy;
0434 }
0435 return cast(&src, policy, parent);
0436 }
0437
0438 static handle cast(MapType *src, return_value_policy policy, handle parent) {
0439 if (policy == return_value_policy::automatic) {
0440 policy = return_value_policy::take_ownership;
0441 } else if (policy == return_value_policy::automatic_reference) {
0442 policy = return_value_policy::reference;
0443 }
0444 return cast_impl(src, policy, parent);
0445 }
0446
0447 static handle cast(const MapType *src, return_value_policy policy, handle parent) {
0448 if (policy == return_value_policy::automatic) {
0449 policy = return_value_policy::take_ownership;
0450 } else if (policy == return_value_policy::automatic_reference) {
0451 policy = return_value_policy::reference;
0452 }
0453 return cast_impl(src, policy, parent);
0454 }
0455
0456 template <typename C>
0457 static handle cast_impl(C *src, return_value_policy policy, handle parent) {
0458 object parent_object;
0459 constexpr bool writeable = !std::is_const<C>::value;
0460 switch (policy) {
0461 case return_value_policy::reference:
0462 parent_object = none();
0463 break;
0464
0465 case return_value_policy::reference_internal:
0466
0467 if (!parent) {
0468 pybind11_fail("Cannot use reference internal when there is no parent");
0469 }
0470 parent_object = reinterpret_borrow<object>(parent);
0471 break;
0472
0473 default:
0474
0475 pybind11_fail("Invalid return_value_policy for Eigen Map type, must be either "
0476 "reference or reference_internal");
0477 }
0478
0479 auto result = array_t<typename Type::Scalar, compute_array_flag_from_tensor<Type>()>(
0480 convert_dsizes_to_vector(Helper::get_shape(*src)),
0481 src->data(),
0482 std::move(parent_object));
0483
0484 if (!writeable) {
0485 array_proxy(result.ptr())->flags &= ~detail::npy_api::NPY_ARRAY_WRITEABLE_;
0486 }
0487
0488 return result.release();
0489 }
0490
0491 #if EIGEN_VERSION_AT_LEAST(3, 4, 0)
0492
0493 static constexpr bool needs_writeable = !std::is_const<typename std::remove_pointer<
0494 typename get_storage_pointer_type<MapType>::SPT>::type>::value;
0495 #else
0496
0497 static constexpr bool needs_writeable = !std::is_const<Type>::value;
0498 #endif
0499
0500 protected:
0501
0502 std::unique_ptr<MapType> value;
0503
0504 public:
0505 static constexpr auto name = get_tensor_descriptor<Type, true, needs_writeable>::value;
0506 explicit operator MapType *() { return value.get(); }
0507 explicit operator MapType &() { return *value; }
0508 explicit operator MapType &&() && { return std::move(*value); }
0509
0510 template <typename T_>
0511 using cast_op_type = ::pybind11::detail::movable_cast_op_type<T_>;
0512 };
0513
0514 PYBIND11_NAMESPACE_END(detail)
0515 PYBIND11_NAMESPACE_END(PYBIND11_NAMESPACE)