File indexing completed on 2026-01-10 10:20:09
0001
0002
0003
0004
0005
0006
0007
0008 #pragma once
0009
0010 #include <pybind11/numpy.h>
0011
0012 #include "common.h"
0013
0014 #if defined(__GNUC__) && !defined(__clang__) && !defined(__INTEL_COMPILER)
0015 static_assert(__GNUC__ > 5, "Eigen Tensor support in pybind11 requires GCC > 5.0");
0016 #endif
0017
0018
0019 PYBIND11_WARNING_PUSH
0020 PYBIND11_WARNING_DISABLE_MSVC(4554)
0021 PYBIND11_WARNING_DISABLE_MSVC(4127)
0022 #if defined(__MINGW32__)
0023 PYBIND11_WARNING_DISABLE_GCC("-Wmaybe-uninitialized")
0024 #endif
0025
0026 #include <unsupported/Eigen/CXX11/Tensor>
0027
0028 PYBIND11_WARNING_POP
0029
0030 static_assert(EIGEN_VERSION_AT_LEAST(3, 3, 0),
0031 "Eigen Tensor support in pybind11 requires Eigen >= 3.3.0");
0032
0033 PYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
0034
0035 PYBIND11_WARNING_DISABLE_MSVC(4127)
0036
0037 PYBIND11_NAMESPACE_BEGIN(detail)
0038
0039 inline bool is_tensor_aligned(const void *data) {
0040 return (reinterpret_cast<std::size_t>(data) % EIGEN_DEFAULT_ALIGN_BYTES) == 0;
0041 }
0042
0043 template <typename T>
0044 constexpr int compute_array_flag_from_tensor() {
0045 static_assert((static_cast<int>(T::Layout) == static_cast<int>(Eigen::RowMajor))
0046 || (static_cast<int>(T::Layout) == static_cast<int>(Eigen::ColMajor)),
0047 "Layout must be row or column major");
0048 return (static_cast<int>(T::Layout) == static_cast<int>(Eigen::RowMajor)) ? array::c_style
0049 : array::f_style;
0050 }
0051
0052 template <typename T>
0053 struct eigen_tensor_helper {};
0054
0055 template <typename Scalar_, int NumIndices_, int Options_, typename IndexType>
0056 struct eigen_tensor_helper<Eigen::Tensor<Scalar_, NumIndices_, Options_, IndexType>> {
0057 using Type = Eigen::Tensor<Scalar_, NumIndices_, Options_, IndexType>;
0058 using ValidType = void;
0059
0060 static Eigen::DSizes<typename Type::Index, Type::NumIndices> get_shape(const Type &f) {
0061 return f.dimensions();
0062 }
0063
0064 static constexpr bool
0065 is_correct_shape(const Eigen::DSizes<typename Type::Index, Type::NumIndices> & ) {
0066 return true;
0067 }
0068
0069 template <typename T>
0070 struct helper {};
0071
0072 template <size_t... Is>
0073 struct helper<index_sequence<Is...>> {
0074 static constexpr auto value = ::pybind11::detail::concat(const_name(((void) Is, "?"))...);
0075 };
0076
0077 static constexpr auto dimensions_descriptor
0078 = helper<decltype(make_index_sequence<Type::NumIndices>())>::value;
0079
0080 template <typename... Args>
0081 static Type *alloc(Args &&...args) {
0082 return new Type(std::forward<Args>(args)...);
0083 }
0084
0085 static void free(Type *tensor) { delete tensor; }
0086 };
0087
0088 template <typename Scalar_, typename std::ptrdiff_t... Indices, int Options_, typename IndexType>
0089 struct eigen_tensor_helper<
0090 Eigen::TensorFixedSize<Scalar_, Eigen::Sizes<Indices...>, Options_, IndexType>> {
0091 using Type = Eigen::TensorFixedSize<Scalar_, Eigen::Sizes<Indices...>, Options_, IndexType>;
0092 using ValidType = void;
0093
0094 static constexpr Eigen::DSizes<typename Type::Index, Type::NumIndices>
0095 get_shape(const Type & ) {
0096 return get_shape();
0097 }
0098
0099 static constexpr Eigen::DSizes<typename Type::Index, Type::NumIndices> get_shape() {
0100 return Eigen::DSizes<typename Type::Index, Type::NumIndices>(Indices...);
0101 }
0102
0103 static bool
0104 is_correct_shape(const Eigen::DSizes<typename Type::Index, Type::NumIndices> &shape) {
0105 return get_shape() == shape;
0106 }
0107
0108 static constexpr auto dimensions_descriptor
0109 = ::pybind11::detail::concat(const_name<Indices>()...);
0110
0111 template <typename... Args>
0112 static Type *alloc(Args &&...args) {
0113 Eigen::aligned_allocator<Type> allocator;
0114 return ::new (allocator.allocate(1)) Type(std::forward<Args>(args)...);
0115 }
0116
0117 static void free(Type *tensor) {
0118 Eigen::aligned_allocator<Type> allocator;
0119 tensor->~Type();
0120 allocator.deallocate(tensor, 1);
0121 }
0122 };
0123
0124 template <typename Type, bool ShowDetails, bool NeedsWriteable = false>
0125 struct get_tensor_descriptor {
0126 static constexpr auto details
0127 = const_name<NeedsWriteable>(", \"flags.writeable\"", "") + const_name
0128 < static_cast<int>(Type::Layout)
0129 == static_cast<int>(Eigen::RowMajor)
0130 > (", \"flags.c_contiguous\"", ", \"flags.f_contiguous\"");
0131 static constexpr auto value
0132 = const_name("typing.Annotated[")
0133 + io_name("numpy.typing.ArrayLike, ", "numpy.typing.NDArray[")
0134 + npy_format_descriptor<typename Type::Scalar>::name + io_name("", "]")
0135 + const_name(", \"[") + eigen_tensor_helper<remove_cv_t<Type>>::dimensions_descriptor
0136 + const_name("]\"") + const_name<ShowDetails>(details, const_name("")) + const_name("]");
0137 };
0138
0139
0140
0141
0142
0143
0144 PYBIND11_WARNING_PUSH
0145 PYBIND11_WARNING_DISABLE_GCC("-Wtype-limits")
0146
0147 template <typename T, int size>
0148 std::vector<T> convert_dsizes_to_vector(const Eigen::DSizes<T, size> &arr) {
0149 std::vector<T> result(size);
0150
0151 for (size_t i = 0; i < size; i++) {
0152 result[i] = arr[i];
0153 }
0154
0155 return result;
0156 }
0157
0158 template <typename T, int size>
0159 Eigen::DSizes<T, size> get_shape_for_array(const array &arr) {
0160 Eigen::DSizes<T, size> result;
0161 const T *shape = arr.shape();
0162 for (size_t i = 0; i < size; i++) {
0163 result[i] = shape[i];
0164 }
0165
0166 return result;
0167 }
0168
0169 PYBIND11_WARNING_POP
0170
0171 template <typename Type>
0172 struct type_caster<Type, typename eigen_tensor_helper<Type>::ValidType> {
0173 static_assert(!std::is_pointer<typename Type::Scalar>::value,
0174 PYBIND11_EIGEN_MESSAGE_POINTER_TYPES_ARE_NOT_SUPPORTED);
0175 using Helper = eigen_tensor_helper<Type>;
0176 static constexpr auto temp_name = get_tensor_descriptor<Type, false>::value;
0177 PYBIND11_TYPE_CASTER(Type, temp_name);
0178
0179 bool load(handle src, bool convert) {
0180 if (!convert) {
0181 if (!isinstance<array>(src)) {
0182 return false;
0183 }
0184 array temp = array::ensure(src);
0185 if (!temp) {
0186 return false;
0187 }
0188
0189 if (!temp.dtype().is(dtype::of<typename Type::Scalar>())) {
0190 return false;
0191 }
0192 }
0193
0194 array_t<typename Type::Scalar, compute_array_flag_from_tensor<Type>()> arr(
0195 reinterpret_borrow<object>(src));
0196
0197 if (arr.ndim() != Type::NumIndices) {
0198 return false;
0199 }
0200 auto shape = get_shape_for_array<typename Type::Index, Type::NumIndices>(arr);
0201
0202 if (!Helper::is_correct_shape(shape)) {
0203 return false;
0204 }
0205
0206 #if EIGEN_VERSION_AT_LEAST(3, 4, 0)
0207 auto data_pointer = arr.data();
0208 #else
0209
0210 auto data_pointer = const_cast<typename Type::Scalar *>(arr.data());
0211 #endif
0212
0213 if (is_tensor_aligned(arr.data())) {
0214 value = Eigen::TensorMap<const Type, Eigen::Aligned>(data_pointer, shape);
0215 } else {
0216 value = Eigen::TensorMap<const Type>(data_pointer, shape);
0217 }
0218
0219 return true;
0220 }
0221
0222 static handle cast(Type &&src, return_value_policy policy, handle parent) {
0223 if (policy == return_value_policy::reference
0224 || policy == return_value_policy::reference_internal) {
0225 pybind11_fail("Cannot use a reference return value policy for an rvalue");
0226 }
0227 return cast_impl(&src, return_value_policy::move, parent);
0228 }
0229
0230 static handle cast(const Type &&src, return_value_policy policy, handle parent) {
0231 if (policy == return_value_policy::reference
0232 || policy == return_value_policy::reference_internal) {
0233 pybind11_fail("Cannot use a reference return value policy for an rvalue");
0234 }
0235 return cast_impl(&src, return_value_policy::move, parent);
0236 }
0237
0238 static handle cast(Type &src, return_value_policy policy, handle parent) {
0239 if (policy == return_value_policy::automatic
0240 || policy == return_value_policy::automatic_reference) {
0241 policy = return_value_policy::copy;
0242 }
0243 return cast_impl(&src, policy, parent);
0244 }
0245
0246 static handle cast(const Type &src, return_value_policy policy, handle parent) {
0247 if (policy == return_value_policy::automatic
0248 || policy == return_value_policy::automatic_reference) {
0249 policy = return_value_policy::copy;
0250 }
0251 return cast(&src, policy, parent);
0252 }
0253
0254 static handle cast(Type *src, return_value_policy policy, handle parent) {
0255 if (policy == return_value_policy::automatic) {
0256 policy = return_value_policy::take_ownership;
0257 } else if (policy == return_value_policy::automatic_reference) {
0258 policy = return_value_policy::reference;
0259 }
0260 return cast_impl(src, policy, parent);
0261 }
0262
0263 static handle cast(const Type *src, return_value_policy policy, handle parent) {
0264 if (policy == return_value_policy::automatic) {
0265 policy = return_value_policy::take_ownership;
0266 } else if (policy == return_value_policy::automatic_reference) {
0267 policy = return_value_policy::reference;
0268 }
0269 return cast_impl(src, policy, parent);
0270 }
0271
0272 template <typename C>
0273 static handle cast_impl(C *src, return_value_policy policy, handle parent) {
0274 object parent_object;
0275 bool writeable = false;
0276 switch (policy) {
0277 case return_value_policy::move:
0278 if (std::is_const<C>::value) {
0279 pybind11_fail("Cannot move from a constant reference");
0280 }
0281
0282 src = Helper::alloc(std::move(*src));
0283
0284 parent_object
0285 = capsule(src, [](void *ptr) { Helper::free(reinterpret_cast<Type *>(ptr)); });
0286 writeable = true;
0287 break;
0288
0289 case return_value_policy::take_ownership:
0290 if (std::is_const<C>::value) {
0291
0292
0293 Helper::free(const_cast<Type *>(src));
0294 pybind11_fail("Cannot take ownership of a const reference");
0295 }
0296
0297 parent_object
0298 = capsule(src, [](void *ptr) { Helper::free(reinterpret_cast<Type *>(ptr)); });
0299 writeable = true;
0300 break;
0301
0302 case return_value_policy::copy:
0303 writeable = true;
0304 break;
0305
0306 case return_value_policy::reference:
0307 parent_object = none();
0308 writeable = !std::is_const<C>::value;
0309 break;
0310
0311 case return_value_policy::reference_internal:
0312
0313 if (!parent) {
0314 pybind11_fail("Cannot use reference internal when there is no parent");
0315 }
0316 parent_object = reinterpret_borrow<object>(parent);
0317 writeable = !std::is_const<C>::value;
0318 break;
0319
0320 default:
0321 pybind11_fail("pybind11 bug in eigen.h, please file a bug report");
0322 }
0323
0324 auto result = array_t<typename Type::Scalar, compute_array_flag_from_tensor<Type>()>(
0325 convert_dsizes_to_vector(Helper::get_shape(*src)), src->data(), parent_object);
0326
0327 if (!writeable) {
0328 array_proxy(result.ptr())->flags &= ~detail::npy_api::NPY_ARRAY_WRITEABLE_;
0329 }
0330
0331 return result.release();
0332 }
0333 };
0334
0335 template <typename StoragePointerType,
0336 bool needs_writeable,
0337 enable_if_t<!needs_writeable, bool> = true>
0338 StoragePointerType get_array_data_for_type(array &arr) {
0339 #if EIGEN_VERSION_AT_LEAST(3, 4, 0)
0340 return reinterpret_cast<StoragePointerType>(arr.data());
0341 #else
0342
0343 return reinterpret_cast<StoragePointerType>(const_cast<void *>(arr.data()));
0344 #endif
0345 }
0346
0347 template <typename StoragePointerType,
0348 bool needs_writeable,
0349 enable_if_t<needs_writeable, bool> = true>
0350 StoragePointerType get_array_data_for_type(array &arr) {
0351 return reinterpret_cast<StoragePointerType>(arr.mutable_data());
0352 }
0353
0354 template <typename T, typename = void>
0355 struct get_storage_pointer_type;
0356
0357 template <typename MapType>
0358 struct get_storage_pointer_type<MapType, void_t<typename MapType::StoragePointerType>> {
0359 using SPT = typename MapType::StoragePointerType;
0360 };
0361
0362 template <typename MapType>
0363 struct get_storage_pointer_type<MapType, void_t<typename MapType::PointerArgType>> {
0364 using SPT = typename MapType::PointerArgType;
0365 };
0366
0367 template <typename Type, int Options>
0368 struct type_caster<Eigen::TensorMap<Type, Options>,
0369 typename eigen_tensor_helper<remove_cv_t<Type>>::ValidType> {
0370 static_assert(!std::is_pointer<typename Type::Scalar>::value,
0371 PYBIND11_EIGEN_MESSAGE_POINTER_TYPES_ARE_NOT_SUPPORTED);
0372 using MapType = Eigen::TensorMap<Type, Options>;
0373 using Helper = eigen_tensor_helper<remove_cv_t<Type>>;
0374
0375 bool load(handle src, bool ) {
0376
0377 if (!isinstance<array>(src)) {
0378 return false;
0379 }
0380 auto arr = reinterpret_borrow<array>(src);
0381 if ((arr.flags() & compute_array_flag_from_tensor<Type>()) == 0) {
0382 return false;
0383 }
0384
0385 if (!arr.dtype().is(dtype::of<typename Type::Scalar>())) {
0386 return false;
0387 }
0388
0389 if (arr.ndim() != Type::NumIndices) {
0390 return false;
0391 }
0392
0393 constexpr bool is_aligned = (Options & Eigen::Aligned) != 0;
0394
0395 if (is_aligned && !is_tensor_aligned(arr.data())) {
0396 return false;
0397 }
0398
0399 auto shape = get_shape_for_array<typename Type::Index, Type::NumIndices>(arr);
0400
0401 if (!Helper::is_correct_shape(shape)) {
0402 return false;
0403 }
0404
0405 if (needs_writeable && !arr.writeable()) {
0406 return false;
0407 }
0408
0409 auto result = get_array_data_for_type<typename get_storage_pointer_type<MapType>::SPT,
0410 needs_writeable>(arr);
0411
0412 value.reset(new MapType(std::move(result), std::move(shape)));
0413
0414 return true;
0415 }
0416
0417 static handle cast(MapType &&src, return_value_policy policy, handle parent) {
0418 return cast_impl(&src, policy, parent);
0419 }
0420
0421 static handle cast(const MapType &&src, return_value_policy policy, handle parent) {
0422 return cast_impl(&src, policy, parent);
0423 }
0424
0425 static handle cast(MapType &src, return_value_policy policy, handle parent) {
0426 if (policy == return_value_policy::automatic
0427 || policy == return_value_policy::automatic_reference) {
0428 policy = return_value_policy::copy;
0429 }
0430 return cast_impl(&src, policy, parent);
0431 }
0432
0433 static handle cast(const MapType &src, return_value_policy policy, handle parent) {
0434 if (policy == return_value_policy::automatic
0435 || policy == return_value_policy::automatic_reference) {
0436 policy = return_value_policy::copy;
0437 }
0438 return cast(&src, policy, parent);
0439 }
0440
0441 static handle cast(MapType *src, return_value_policy policy, handle parent) {
0442 if (policy == return_value_policy::automatic) {
0443 policy = return_value_policy::take_ownership;
0444 } else if (policy == return_value_policy::automatic_reference) {
0445 policy = return_value_policy::reference;
0446 }
0447 return cast_impl(src, policy, parent);
0448 }
0449
0450 static handle cast(const MapType *src, return_value_policy policy, handle parent) {
0451 if (policy == return_value_policy::automatic) {
0452 policy = return_value_policy::take_ownership;
0453 } else if (policy == return_value_policy::automatic_reference) {
0454 policy = return_value_policy::reference;
0455 }
0456 return cast_impl(src, policy, parent);
0457 }
0458
0459 template <typename C>
0460 static handle cast_impl(C *src, return_value_policy policy, handle parent) {
0461 object parent_object;
0462 constexpr bool writeable = !std::is_const<C>::value;
0463 switch (policy) {
0464 case return_value_policy::reference:
0465 parent_object = none();
0466 break;
0467
0468 case return_value_policy::reference_internal:
0469
0470 if (!parent) {
0471 pybind11_fail("Cannot use reference internal when there is no parent");
0472 }
0473 parent_object = reinterpret_borrow<object>(parent);
0474 break;
0475
0476 default:
0477
0478 pybind11_fail("Invalid return_value_policy for Eigen Map type, must be either "
0479 "reference or reference_internal");
0480 }
0481
0482 auto result = array_t<typename Type::Scalar, compute_array_flag_from_tensor<Type>()>(
0483 convert_dsizes_to_vector(Helper::get_shape(*src)),
0484 src->data(),
0485 std::move(parent_object));
0486
0487 if (!writeable) {
0488 array_proxy(result.ptr())->flags &= ~detail::npy_api::NPY_ARRAY_WRITEABLE_;
0489 }
0490
0491 return result.release();
0492 }
0493
0494 #if EIGEN_VERSION_AT_LEAST(3, 4, 0)
0495
0496 static constexpr bool needs_writeable = !std::is_const<typename std::remove_pointer<
0497 typename get_storage_pointer_type<MapType>::SPT>::type>::value;
0498 #else
0499
0500 static constexpr bool needs_writeable = !std::is_const<Type>::value;
0501 #endif
0502
0503 protected:
0504
0505 std::unique_ptr<MapType> value;
0506
0507 public:
0508
0509
0510 static constexpr auto name
0511 = return_descr(get_tensor_descriptor<Type, true, needs_writeable>::value);
0512 explicit operator MapType *() { return value.get(); }
0513 explicit operator MapType &() { return *value; }
0514 explicit operator MapType &&() && { return std::move(*value); }
0515
0516 template <typename T_>
0517 using cast_op_type = ::pybind11::detail::movable_cast_op_type<T_>;
0518 };
0519
0520 PYBIND11_NAMESPACE_END(detail)
0521 PYBIND11_NAMESPACE_END(PYBIND11_NAMESPACE)