Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2026-01-10 10:20:09

0001 /*
0002     pybind11/eigen/tensor.h: Transparent conversion for Eigen tensors
0003 
0004     All rights reserved. Use of this source code is governed by a
0005     BSD-style license that can be found in the LICENSE file.
0006 */
0007 
0008 #pragma once
0009 
0010 #include <pybind11/numpy.h>
0011 
0012 #include "common.h"
0013 
0014 #if defined(__GNUC__) && !defined(__clang__) && !defined(__INTEL_COMPILER)
0015 static_assert(__GNUC__ > 5, "Eigen Tensor support in pybind11 requires GCC > 5.0");
0016 #endif
0017 
0018 // Disable warnings for Eigen
0019 PYBIND11_WARNING_PUSH
0020 PYBIND11_WARNING_DISABLE_MSVC(4554)
0021 PYBIND11_WARNING_DISABLE_MSVC(4127)
0022 #if defined(__MINGW32__)
0023 PYBIND11_WARNING_DISABLE_GCC("-Wmaybe-uninitialized")
0024 #endif
0025 
0026 #include <unsupported/Eigen/CXX11/Tensor>
0027 
0028 PYBIND11_WARNING_POP
0029 
0030 static_assert(EIGEN_VERSION_AT_LEAST(3, 3, 0),
0031               "Eigen Tensor support in pybind11 requires Eigen >= 3.3.0");
0032 
0033 PYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
0034 
0035 PYBIND11_WARNING_DISABLE_MSVC(4127)
0036 
0037 PYBIND11_NAMESPACE_BEGIN(detail)
0038 
0039 inline bool is_tensor_aligned(const void *data) {
0040     return (reinterpret_cast<std::size_t>(data) % EIGEN_DEFAULT_ALIGN_BYTES) == 0;
0041 }
0042 
0043 template <typename T>
0044 constexpr int compute_array_flag_from_tensor() {
0045     static_assert((static_cast<int>(T::Layout) == static_cast<int>(Eigen::RowMajor))
0046                       || (static_cast<int>(T::Layout) == static_cast<int>(Eigen::ColMajor)),
0047                   "Layout must be row or column major");
0048     return (static_cast<int>(T::Layout) == static_cast<int>(Eigen::RowMajor)) ? array::c_style
0049                                                                               : array::f_style;
0050 }
0051 
0052 template <typename T>
0053 struct eigen_tensor_helper {};
0054 
0055 template <typename Scalar_, int NumIndices_, int Options_, typename IndexType>
0056 struct eigen_tensor_helper<Eigen::Tensor<Scalar_, NumIndices_, Options_, IndexType>> {
0057     using Type = Eigen::Tensor<Scalar_, NumIndices_, Options_, IndexType>;
0058     using ValidType = void;
0059 
0060     static Eigen::DSizes<typename Type::Index, Type::NumIndices> get_shape(const Type &f) {
0061         return f.dimensions();
0062     }
0063 
0064     static constexpr bool
0065     is_correct_shape(const Eigen::DSizes<typename Type::Index, Type::NumIndices> & /*shape*/) {
0066         return true;
0067     }
0068 
0069     template <typename T>
0070     struct helper {};
0071 
0072     template <size_t... Is>
0073     struct helper<index_sequence<Is...>> {
0074         static constexpr auto value = ::pybind11::detail::concat(const_name(((void) Is, "?"))...);
0075     };
0076 
0077     static constexpr auto dimensions_descriptor
0078         = helper<decltype(make_index_sequence<Type::NumIndices>())>::value;
0079 
0080     template <typename... Args>
0081     static Type *alloc(Args &&...args) {
0082         return new Type(std::forward<Args>(args)...);
0083     }
0084 
0085     static void free(Type *tensor) { delete tensor; }
0086 };
0087 
0088 template <typename Scalar_, typename std::ptrdiff_t... Indices, int Options_, typename IndexType>
0089 struct eigen_tensor_helper<
0090     Eigen::TensorFixedSize<Scalar_, Eigen::Sizes<Indices...>, Options_, IndexType>> {
0091     using Type = Eigen::TensorFixedSize<Scalar_, Eigen::Sizes<Indices...>, Options_, IndexType>;
0092     using ValidType = void;
0093 
0094     static constexpr Eigen::DSizes<typename Type::Index, Type::NumIndices>
0095     get_shape(const Type & /*f*/) {
0096         return get_shape();
0097     }
0098 
0099     static constexpr Eigen::DSizes<typename Type::Index, Type::NumIndices> get_shape() {
0100         return Eigen::DSizes<typename Type::Index, Type::NumIndices>(Indices...);
0101     }
0102 
0103     static bool
0104     is_correct_shape(const Eigen::DSizes<typename Type::Index, Type::NumIndices> &shape) {
0105         return get_shape() == shape;
0106     }
0107 
0108     static constexpr auto dimensions_descriptor
0109         = ::pybind11::detail::concat(const_name<Indices>()...);
0110 
0111     template <typename... Args>
0112     static Type *alloc(Args &&...args) {
0113         Eigen::aligned_allocator<Type> allocator;
0114         return ::new (allocator.allocate(1)) Type(std::forward<Args>(args)...);
0115     }
0116 
0117     static void free(Type *tensor) {
0118         Eigen::aligned_allocator<Type> allocator;
0119         tensor->~Type();
0120         allocator.deallocate(tensor, 1);
0121     }
0122 };
0123 
0124 template <typename Type, bool ShowDetails, bool NeedsWriteable = false>
0125 struct get_tensor_descriptor {
0126     static constexpr auto details
0127         = const_name<NeedsWriteable>(", \"flags.writeable\"", "") + const_name
0128               < static_cast<int>(Type::Layout)
0129           == static_cast<int>(Eigen::RowMajor)
0130                  > (", \"flags.c_contiguous\"", ", \"flags.f_contiguous\"");
0131     static constexpr auto value
0132         = const_name("typing.Annotated[")
0133           + io_name("numpy.typing.ArrayLike, ", "numpy.typing.NDArray[")
0134           + npy_format_descriptor<typename Type::Scalar>::name + io_name("", "]")
0135           + const_name(", \"[") + eigen_tensor_helper<remove_cv_t<Type>>::dimensions_descriptor
0136           + const_name("]\"") + const_name<ShowDetails>(details, const_name("")) + const_name("]");
0137 };
0138 
0139 // When EIGEN_AVOID_STL_ARRAY is defined, Eigen::DSizes<T, 0> does not have the begin() member
0140 // function. Falling back to a simple loop works around this issue.
0141 //
0142 // We need to disable the type-limits warning for the inner loop when size = 0.
0143 
0144 PYBIND11_WARNING_PUSH
0145 PYBIND11_WARNING_DISABLE_GCC("-Wtype-limits")
0146 
0147 template <typename T, int size>
0148 std::vector<T> convert_dsizes_to_vector(const Eigen::DSizes<T, size> &arr) {
0149     std::vector<T> result(size);
0150 
0151     for (size_t i = 0; i < size; i++) {
0152         result[i] = arr[i];
0153     }
0154 
0155     return result;
0156 }
0157 
0158 template <typename T, int size>
0159 Eigen::DSizes<T, size> get_shape_for_array(const array &arr) {
0160     Eigen::DSizes<T, size> result;
0161     const T *shape = arr.shape();
0162     for (size_t i = 0; i < size; i++) {
0163         result[i] = shape[i];
0164     }
0165 
0166     return result;
0167 }
0168 
0169 PYBIND11_WARNING_POP
0170 
0171 template <typename Type>
0172 struct type_caster<Type, typename eigen_tensor_helper<Type>::ValidType> {
0173     static_assert(!std::is_pointer<typename Type::Scalar>::value,
0174                   PYBIND11_EIGEN_MESSAGE_POINTER_TYPES_ARE_NOT_SUPPORTED);
0175     using Helper = eigen_tensor_helper<Type>;
0176     static constexpr auto temp_name = get_tensor_descriptor<Type, false>::value;
0177     PYBIND11_TYPE_CASTER(Type, temp_name);
0178 
0179     bool load(handle src, bool convert) {
0180         if (!convert) {
0181             if (!isinstance<array>(src)) {
0182                 return false;
0183             }
0184             array temp = array::ensure(src);
0185             if (!temp) {
0186                 return false;
0187             }
0188 
0189             if (!temp.dtype().is(dtype::of<typename Type::Scalar>())) {
0190                 return false;
0191             }
0192         }
0193 
0194         array_t<typename Type::Scalar, compute_array_flag_from_tensor<Type>()> arr(
0195             reinterpret_borrow<object>(src));
0196 
0197         if (arr.ndim() != Type::NumIndices) {
0198             return false;
0199         }
0200         auto shape = get_shape_for_array<typename Type::Index, Type::NumIndices>(arr);
0201 
0202         if (!Helper::is_correct_shape(shape)) {
0203             return false;
0204         }
0205 
0206 #if EIGEN_VERSION_AT_LEAST(3, 4, 0)
0207         auto data_pointer = arr.data();
0208 #else
0209         // Handle Eigen bug
0210         auto data_pointer = const_cast<typename Type::Scalar *>(arr.data());
0211 #endif
0212 
0213         if (is_tensor_aligned(arr.data())) {
0214             value = Eigen::TensorMap<const Type, Eigen::Aligned>(data_pointer, shape);
0215         } else {
0216             value = Eigen::TensorMap<const Type>(data_pointer, shape);
0217         }
0218 
0219         return true;
0220     }
0221 
0222     static handle cast(Type &&src, return_value_policy policy, handle parent) {
0223         if (policy == return_value_policy::reference
0224             || policy == return_value_policy::reference_internal) {
0225             pybind11_fail("Cannot use a reference return value policy for an rvalue");
0226         }
0227         return cast_impl(&src, return_value_policy::move, parent);
0228     }
0229 
0230     static handle cast(const Type &&src, return_value_policy policy, handle parent) {
0231         if (policy == return_value_policy::reference
0232             || policy == return_value_policy::reference_internal) {
0233             pybind11_fail("Cannot use a reference return value policy for an rvalue");
0234         }
0235         return cast_impl(&src, return_value_policy::move, parent);
0236     }
0237 
0238     static handle cast(Type &src, return_value_policy policy, handle parent) {
0239         if (policy == return_value_policy::automatic
0240             || policy == return_value_policy::automatic_reference) {
0241             policy = return_value_policy::copy;
0242         }
0243         return cast_impl(&src, policy, parent);
0244     }
0245 
0246     static handle cast(const Type &src, return_value_policy policy, handle parent) {
0247         if (policy == return_value_policy::automatic
0248             || policy == return_value_policy::automatic_reference) {
0249             policy = return_value_policy::copy;
0250         }
0251         return cast(&src, policy, parent);
0252     }
0253 
0254     static handle cast(Type *src, return_value_policy policy, handle parent) {
0255         if (policy == return_value_policy::automatic) {
0256             policy = return_value_policy::take_ownership;
0257         } else if (policy == return_value_policy::automatic_reference) {
0258             policy = return_value_policy::reference;
0259         }
0260         return cast_impl(src, policy, parent);
0261     }
0262 
0263     static handle cast(const Type *src, return_value_policy policy, handle parent) {
0264         if (policy == return_value_policy::automatic) {
0265             policy = return_value_policy::take_ownership;
0266         } else if (policy == return_value_policy::automatic_reference) {
0267             policy = return_value_policy::reference;
0268         }
0269         return cast_impl(src, policy, parent);
0270     }
0271 
0272     template <typename C>
0273     static handle cast_impl(C *src, return_value_policy policy, handle parent) {
0274         object parent_object;
0275         bool writeable = false;
0276         switch (policy) {
0277             case return_value_policy::move:
0278                 if (std::is_const<C>::value) {
0279                     pybind11_fail("Cannot move from a constant reference");
0280                 }
0281 
0282                 src = Helper::alloc(std::move(*src));
0283 
0284                 parent_object
0285                     = capsule(src, [](void *ptr) { Helper::free(reinterpret_cast<Type *>(ptr)); });
0286                 writeable = true;
0287                 break;
0288 
0289             case return_value_policy::take_ownership:
0290                 if (std::is_const<C>::value) {
0291                     // This cast is ugly, and might be UB in some cases, but we don't have an
0292                     // alternative here as we must free that memory
0293                     Helper::free(const_cast<Type *>(src));
0294                     pybind11_fail("Cannot take ownership of a const reference");
0295                 }
0296 
0297                 parent_object
0298                     = capsule(src, [](void *ptr) { Helper::free(reinterpret_cast<Type *>(ptr)); });
0299                 writeable = true;
0300                 break;
0301 
0302             case return_value_policy::copy:
0303                 writeable = true;
0304                 break;
0305 
0306             case return_value_policy::reference:
0307                 parent_object = none();
0308                 writeable = !std::is_const<C>::value;
0309                 break;
0310 
0311             case return_value_policy::reference_internal:
0312                 // Default should do the right thing
0313                 if (!parent) {
0314                     pybind11_fail("Cannot use reference internal when there is no parent");
0315                 }
0316                 parent_object = reinterpret_borrow<object>(parent);
0317                 writeable = !std::is_const<C>::value;
0318                 break;
0319 
0320             default:
0321                 pybind11_fail("pybind11 bug in eigen.h, please file a bug report");
0322         }
0323 
0324         auto result = array_t<typename Type::Scalar, compute_array_flag_from_tensor<Type>()>(
0325             convert_dsizes_to_vector(Helper::get_shape(*src)), src->data(), parent_object);
0326 
0327         if (!writeable) {
0328             array_proxy(result.ptr())->flags &= ~detail::npy_api::NPY_ARRAY_WRITEABLE_;
0329         }
0330 
0331         return result.release();
0332     }
0333 };
0334 
0335 template <typename StoragePointerType,
0336           bool needs_writeable,
0337           enable_if_t<!needs_writeable, bool> = true>
0338 StoragePointerType get_array_data_for_type(array &arr) {
0339 #if EIGEN_VERSION_AT_LEAST(3, 4, 0)
0340     return reinterpret_cast<StoragePointerType>(arr.data());
0341 #else
0342     // Handle Eigen bug
0343     return reinterpret_cast<StoragePointerType>(const_cast<void *>(arr.data()));
0344 #endif
0345 }
0346 
0347 template <typename StoragePointerType,
0348           bool needs_writeable,
0349           enable_if_t<needs_writeable, bool> = true>
0350 StoragePointerType get_array_data_for_type(array &arr) {
0351     return reinterpret_cast<StoragePointerType>(arr.mutable_data());
0352 }
0353 
0354 template <typename T, typename = void>
0355 struct get_storage_pointer_type;
0356 
0357 template <typename MapType>
0358 struct get_storage_pointer_type<MapType, void_t<typename MapType::StoragePointerType>> {
0359     using SPT = typename MapType::StoragePointerType;
0360 };
0361 
0362 template <typename MapType>
0363 struct get_storage_pointer_type<MapType, void_t<typename MapType::PointerArgType>> {
0364     using SPT = typename MapType::PointerArgType;
0365 };
0366 
0367 template <typename Type, int Options>
0368 struct type_caster<Eigen::TensorMap<Type, Options>,
0369                    typename eigen_tensor_helper<remove_cv_t<Type>>::ValidType> {
0370     static_assert(!std::is_pointer<typename Type::Scalar>::value,
0371                   PYBIND11_EIGEN_MESSAGE_POINTER_TYPES_ARE_NOT_SUPPORTED);
0372     using MapType = Eigen::TensorMap<Type, Options>;
0373     using Helper = eigen_tensor_helper<remove_cv_t<Type>>;
0374 
0375     bool load(handle src, bool /*convert*/) {
0376         // Note that we have a lot more checks here as we want to make sure to avoid copies
0377         if (!isinstance<array>(src)) {
0378             return false;
0379         }
0380         auto arr = reinterpret_borrow<array>(src);
0381         if ((arr.flags() & compute_array_flag_from_tensor<Type>()) == 0) {
0382             return false;
0383         }
0384 
0385         if (!arr.dtype().is(dtype::of<typename Type::Scalar>())) {
0386             return false;
0387         }
0388 
0389         if (arr.ndim() != Type::NumIndices) {
0390             return false;
0391         }
0392 
0393         constexpr bool is_aligned = (Options & Eigen::Aligned) != 0;
0394 
0395         if (is_aligned && !is_tensor_aligned(arr.data())) {
0396             return false;
0397         }
0398 
0399         auto shape = get_shape_for_array<typename Type::Index, Type::NumIndices>(arr);
0400 
0401         if (!Helper::is_correct_shape(shape)) {
0402             return false;
0403         }
0404 
0405         if (needs_writeable && !arr.writeable()) {
0406             return false;
0407         }
0408 
0409         auto result = get_array_data_for_type<typename get_storage_pointer_type<MapType>::SPT,
0410                                               needs_writeable>(arr);
0411 
0412         value.reset(new MapType(std::move(result), std::move(shape)));
0413 
0414         return true;
0415     }
0416 
0417     static handle cast(MapType &&src, return_value_policy policy, handle parent) {
0418         return cast_impl(&src, policy, parent);
0419     }
0420 
0421     static handle cast(const MapType &&src, return_value_policy policy, handle parent) {
0422         return cast_impl(&src, policy, parent);
0423     }
0424 
0425     static handle cast(MapType &src, return_value_policy policy, handle parent) {
0426         if (policy == return_value_policy::automatic
0427             || policy == return_value_policy::automatic_reference) {
0428             policy = return_value_policy::copy;
0429         }
0430         return cast_impl(&src, policy, parent);
0431     }
0432 
0433     static handle cast(const MapType &src, return_value_policy policy, handle parent) {
0434         if (policy == return_value_policy::automatic
0435             || policy == return_value_policy::automatic_reference) {
0436             policy = return_value_policy::copy;
0437         }
0438         return cast(&src, policy, parent);
0439     }
0440 
0441     static handle cast(MapType *src, return_value_policy policy, handle parent) {
0442         if (policy == return_value_policy::automatic) {
0443             policy = return_value_policy::take_ownership;
0444         } else if (policy == return_value_policy::automatic_reference) {
0445             policy = return_value_policy::reference;
0446         }
0447         return cast_impl(src, policy, parent);
0448     }
0449 
0450     static handle cast(const MapType *src, return_value_policy policy, handle parent) {
0451         if (policy == return_value_policy::automatic) {
0452             policy = return_value_policy::take_ownership;
0453         } else if (policy == return_value_policy::automatic_reference) {
0454             policy = return_value_policy::reference;
0455         }
0456         return cast_impl(src, policy, parent);
0457     }
0458 
0459     template <typename C>
0460     static handle cast_impl(C *src, return_value_policy policy, handle parent) {
0461         object parent_object;
0462         constexpr bool writeable = !std::is_const<C>::value;
0463         switch (policy) {
0464             case return_value_policy::reference:
0465                 parent_object = none();
0466                 break;
0467 
0468             case return_value_policy::reference_internal:
0469                 // Default should do the right thing
0470                 if (!parent) {
0471                     pybind11_fail("Cannot use reference internal when there is no parent");
0472                 }
0473                 parent_object = reinterpret_borrow<object>(parent);
0474                 break;
0475 
0476             default:
0477                 // move, take_ownership don't make any sense for a ref/map:
0478                 pybind11_fail("Invalid return_value_policy for Eigen Map type, must be either "
0479                               "reference or reference_internal");
0480         }
0481 
0482         auto result = array_t<typename Type::Scalar, compute_array_flag_from_tensor<Type>()>(
0483             convert_dsizes_to_vector(Helper::get_shape(*src)),
0484             src->data(),
0485             std::move(parent_object));
0486 
0487         if (!writeable) {
0488             array_proxy(result.ptr())->flags &= ~detail::npy_api::NPY_ARRAY_WRITEABLE_;
0489         }
0490 
0491         return result.release();
0492     }
0493 
0494 #if EIGEN_VERSION_AT_LEAST(3, 4, 0)
0495 
0496     static constexpr bool needs_writeable = !std::is_const<typename std::remove_pointer<
0497         typename get_storage_pointer_type<MapType>::SPT>::type>::value;
0498 #else
0499     // Handle Eigen bug
0500     static constexpr bool needs_writeable = !std::is_const<Type>::value;
0501 #endif
0502 
0503 protected:
0504     // TODO: Move to std::optional once std::optional has more support
0505     std::unique_ptr<MapType> value;
0506 
0507 public:
0508     // return_descr forces the use of NDArray instead of ArrayLike since refs can only reference
0509     // arrays
0510     static constexpr auto name
0511         = return_descr(get_tensor_descriptor<Type, true, needs_writeable>::value);
0512     explicit operator MapType *() { return value.get(); }
0513     explicit operator MapType &() { return *value; }
0514     explicit operator MapType &&() && { return std::move(*value); }
0515 
0516     template <typename T_>
0517     using cast_op_type = ::pybind11::detail::movable_cast_op_type<T_>;
0518 };
0519 
0520 PYBIND11_NAMESPACE_END(detail)
0521 PYBIND11_NAMESPACE_END(PYBIND11_NAMESPACE)